from argparse import _SubParsersAction, ArgumentParser
from enum import Enum
from datetime import datetime
from dateutil import parser as datetime_parser
from functools import partial
from typing import List, Optional
from lxml.objectify import IntElement, StringElement
from q2_sdk.core.dynamic_imports import (
api_ExecuteStoredProcedure as ExecuteStoredProcedure,
)
from q2_sdk.hq.models.online_session import OnlineSession
from q2_sdk.hq.models.online_user import OnlineUser
from .db_object import DbObject
from .audit_record import AuditRecord
from .representation_row_base import RepresentationRowBase
[docs]
class TokenType(Enum):
Access = "Access"
Authorization = "Authorization"
Refresh = "Refresh"
[docs]
class OAuthLookupDataRow(RepresentationRowBase):
OAuthLookupKeyID: IntElement = "OAuthLookupKeyID"
OAuthLookupKeyName: StringElement = "OAuthLookupKeyName"
UserID_: IntElement = "UserID_"
CustomerID_: IntElement = "CustomerID_"
accessToken: StringElement = "accessToken"
Expiry: StringElement = "Expiry"
[docs]
class OAuthLookup(DbObject):
REPRESENTATION_ROW_CLASS = OAuthLookupDataRow
def _add_standard_subparser_args(
self, subparser: ArgumentParser, add_user_and_customer=True
):
token_type_choices = [x.name for x in TokenType]
subparser.add_argument("clientid")
subparser.add_argument(
"-t",
"--type",
dest="token_type",
default="Access",
choices=token_type_choices,
help="SDK_OauthTokenType.ShortName",
)
if add_user_and_customer:
group = subparser.add_mutually_exclusive_group(required=True)
group.add_argument("-u", "--user_id", help="Q2_User.UserID")
group.add_argument("-c", "--customer_id", help="Q2_User.CustomerID")
[docs]
def add_arguments(self, parser: _SubParsersAction):
subparser = parser.add_parser("get_oauth_token")
subparser.set_defaults(parser="get_oauth_token")
token_type_choices = [x.name for x in TokenType]
self._add_standard_subparser_args(subparser)
subparser.add_argument("-e", "--enabled_only", default="True")
subparser.set_defaults(func=partial(self.get, serialize_for_cli=True))
subparser = parser.add_parser("get_oauth_token_by_access_token")
subparser.set_defaults(parser="get_oauth_token_by_access_token")
self._add_standard_subparser_args(subparser, add_user_and_customer=False)
subparser.add_argument("-e", "--enabled_only", default="True")
subparser.set_defaults(func=partial(self.get_by_value, serialize_for_cli=True))
subparser = parser.add_parser("add_oauth_token")
subparser.set_defaults(parser="add_oauth_token")
self._add_standard_subparser_args(subparser)
subparser.add_argument("scope", type=str)
subparser.add_argument("access_token", type=str)
subparser.add_argument(
"--expiry",
type=datetime_parser.parse,
help="Datetime at which the token will no longer be valid",
)
subparser.set_defaults(func=partial(self.add_update))
subparser = parser.add_parser("update_oauth_token")
subparser.set_defaults(parser="update_oauth_token")
self._add_standard_subparser_args(subparser)
subparser.add_argument("scope", type=str)
subparser.add_argument("access_token", type=str)
subparser.add_argument(
"--expiry",
type=datetime_parser.parse,
help="Datetime at which the token will no longer be valid",
)
subparser.set_defaults(func=partial(self.add_update))
subparser = parser.add_parser("expire_oauth_token")
subparser.set_defaults(parser="expire_oauth_token")
subparser.add_argument("customer_id", type=int)
subparser.add_argument("user_id", type=int)
subparser.add_argument("clientid", type=str)
subparser.add_argument("access_token", type=str)
subparser.add_argument("oauthtokenid", type=str)
subparser.add_argument(
"-t",
"--type",
dest="token_type",
default="Access",
choices=token_type_choices,
help="SDK_OauthTokenType.ShortName",
)
subparser.set_defaults(func=partial(self.disable))
subparser = parser.add_parser("enable_oauth_token")
subparser.set_defaults(parser="enable_oauth_token")
subparser.add_argument("customer_id", type=int)
subparser.add_argument("user_id", type=int)
subparser.add_argument("clientid", type=str)
subparser.add_argument("access_token", type=str)
subparser.add_argument("oauthtokenid", type=str)
subparser.add_argument(
"-t",
"--type",
dest="token_type",
default="Access",
choices=token_type_choices,
help="SDK_OauthTokenType.ShortName",
)
subparser.set_defaults(func=partial(self.enable))
[docs]
async def get(
self,
clientid: str,
user_id: Optional[int] = None,
customer_id: Optional[int] = None,
token_type: TokenType = "Access",
enabled_only="True",
no_trunc=False,
serialize_for_cli=False,
) -> List[OAuthLookupDataRow]:
assert (user_id is None and customer_id is None) is False, (
"Must specify either customerId or UserId"
)
truncate = not no_trunc
show_disabled = True
if str(enabled_only.upper()) == "TRUE" or enabled_only == "1":
show_disabled = False
response = await self.call_hq(
"sdk_OAuthTokenGet",
ExecuteStoredProcedure.SqlParameters([
ExecuteStoredProcedure.SqlParam(
ExecuteStoredProcedure.DataType.VarChar,
"clientID",
str(clientid),
),
ExecuteStoredProcedure.SqlParam(
ExecuteStoredProcedure.DataType.Int, "userID", user_id
),
ExecuteStoredProcedure.SqlParam(
ExecuteStoredProcedure.DataType.Int, "customerID", customer_id
),
ExecuteStoredProcedure.SqlParam(
ExecuteStoredProcedure.DataType.Bit,
"showDisabled",
show_disabled,
),
ExecuteStoredProcedure.SqlParam(
ExecuteStoredProcedure.DataType.VarChar,
"TokenType",
str(token_type),
),
]),
)
fields_to_truncate = []
if serialize_for_cli:
if truncate:
fields_to_truncate = ["AccessToken", "Scope"]
response = self.serialize_for_cli(
response,
[
"UserID_",
"CustomerID_",
"Token",
"Scope",
"ClientID",
"Enabled",
"TokenTypeID",
"TokenTypeShortname",
"Expiry",
],
fields_to_truncate=fields_to_truncate,
)
return response
[docs]
async def get_by_user(
self,
user_id: Optional[int] = None,
customer_id: Optional[int] = None,
no_trunc=False,
serialize_for_cli=False,
):
assert user_id or customer_id, "user_id or customer_id needs to be specified"
truncate = not no_trunc
response = await self.call_hq(
"sdk_OAuthTokenGetByUserOrCustomer",
ExecuteStoredProcedure.SqlParameters([
ExecuteStoredProcedure.SqlParam(
ExecuteStoredProcedure.DataType.Int, "userID", user_id
),
ExecuteStoredProcedure.SqlParam(
ExecuteStoredProcedure.DataType.Int, "customerID", customer_id
),
]),
)
fields_to_truncate = []
if serialize_for_cli:
if truncate:
fields_to_truncate = ["AccessToken", "Scope"]
response = self.serialize_for_cli(
response,
[
"OAuthTokenID",
"UserID_",
"CustomerID_",
"Token",
"Scope",
"Enabled",
"ClientID",
"CreateDate",
"LastUpdated",
"Expiry",
"TokenTypeID",
"TokenTypeShortname",
],
fields_to_truncate=fields_to_truncate,
)
return response
[docs]
async def get_by_id(self, oauth_lookup_id, no_trunc=False, serialize_for_cli=False):
truncate = not no_trunc
response = await self.call_hq(
"sdk_OAuthTokenGetByID",
ExecuteStoredProcedure.SqlParameters([
ExecuteStoredProcedure.SqlParam(
ExecuteStoredProcedure.DataType.Int,
"OAuthTokenID",
oauth_lookup_id,
)
]),
)
fields_to_truncate = []
if serialize_for_cli:
if truncate:
fields_to_truncate = ["AccessToken", "Scope"]
response = self.serialize_for_cli(
response,
[
"OAuthTokenID",
"UserID_",
"CustomerID_",
"Token",
"Scope",
"ClientID",
"CreateDate",
"LastUpdate",
"Expiry",
"TokenTypeID",
"TokenTypeShortname",
"Enabled",
],
fields_to_truncate=fields_to_truncate,
)
return response[0]
[docs]
async def get_by_value(
self,
clientid: str,
access_token: str,
enabled_only="True",
token_type: TokenType = "Access",
no_trunc=False,
serialize_for_cli=False,
) -> List[OAuthLookupDataRow]:
truncate = not no_trunc
show_disabled = True
if str(enabled_only.upper()) == "TRUE" or enabled_only == "1":
show_disabled = False
response = await self.call_hq(
"sdk_OAuthTokenGet",
ExecuteStoredProcedure.SqlParameters([
ExecuteStoredProcedure.SqlParam(
ExecuteStoredProcedure.DataType.VarChar, "clientID", clientid
),
ExecuteStoredProcedure.SqlParam(
ExecuteStoredProcedure.DataType.VarChar,
"accessToken",
access_token,
),
ExecuteStoredProcedure.SqlParam(
ExecuteStoredProcedure.DataType.Bit,
"showDisabled",
show_disabled,
),
ExecuteStoredProcedure.SqlParam(
ExecuteStoredProcedure.DataType.VarChar,
"TokenType",
str(token_type),
),
]),
)
fields_to_truncate = []
if serialize_for_cli:
if truncate:
fields_to_truncate = ["AccessToken", "Scope"]
response = self.serialize_for_cli(
response,
[
"UserID_",
"CustomerID_",
"Token",
"Scope",
"ClientID",
"TokenTypeID",
"TokenTypeShortname",
"Expiry",
"Enabled",
],
fields_to_truncate=fields_to_truncate,
)
return response
[docs]
async def add_update(
self,
customer_id: int,
user_id: str,
clientid: str,
access_token: str,
token_type: TokenType = "Access",
scope: Optional[str] = None,
expiry: Optional[datetime] = None,
):
if expiry:
expiry = expiry.isoformat()
sql_params = [
ExecuteStoredProcedure.SqlParam(
ExecuteStoredProcedure.DataType.VarChar, "clientID", clientid
),
ExecuteStoredProcedure.SqlParam(
ExecuteStoredProcedure.DataType.Int, "customerID", customer_id
),
ExecuteStoredProcedure.SqlParam(
ExecuteStoredProcedure.DataType.Int, "userID", user_id
),
ExecuteStoredProcedure.SqlParam(
ExecuteStoredProcedure.DataType.VarChar, "accessToken", access_token
),
ExecuteStoredProcedure.SqlParam(
ExecuteStoredProcedure.DataType.VarChar, "scope", scope
),
ExecuteStoredProcedure.SqlParam(
ExecuteStoredProcedure.DataType.VarChar, "TokenType", str(token_type)
),
ExecuteStoredProcedure.SqlParam(
ExecuteStoredProcedure.DataType.VarChar, "expiry", expiry
),
]
result = await self.call_hq(
"sdk_OAuthTokenAddUpdate", ExecuteStoredProcedure.SqlParameters(sql_params)
)
return result[0]
[docs]
async def enable(
self,
customer_id: int = -1,
user_id: int = -1,
clientid: str = None,
access_token: str = None,
token_type: TokenType = "Access",
oauthtokenid: int = -1,
):
if oauthtokenid != -1:
return await self._enable_expire(
-1, -1, None, None, oauthtokenid, enabled=True, token_type=token_type
)
else:
return await self._enable_expire(
customer_id,
user_id,
clientid,
access_token,
None,
enabled=True,
token_type=token_type,
)
[docs]
async def disable(
self,
customer_id: int = -1,
user_id: int = -1,
clientid: str = None,
access_token: str = None,
token_type: TokenType = "Access",
oauthtokenid: int = -1,
):
if oauthtokenid != -1:
return await self._enable_expire(
customer_id=-1,
user_id=-1,
clientid=None,
access_token=None,
oauthtokenid=oauthtokenid,
enabled=False,
token_type=token_type,
)
else:
return await self._enable_expire(
customer_id=customer_id,
user_id=user_id,
clientid=clientid,
access_token=access_token,
oauthtokenid=None,
enabled=False,
token_type=token_type,
)
async def _enable_expire(
self,
customer_id: int,
user_id: str,
clientid: str,
access_token: str,
oauthtokenid: int,
enabled: bool,
token_type: TokenType = "Access",
):
sql_params = [
ExecuteStoredProcedure.SqlParam(
ExecuteStoredProcedure.DataType.BigInt, "oauthTokenID", oauthtokenid
),
ExecuteStoredProcedure.SqlParam(
ExecuteStoredProcedure.DataType.VarChar, "clientID", clientid
),
ExecuteStoredProcedure.SqlParam(
ExecuteStoredProcedure.DataType.Int, "customerID", customer_id
),
ExecuteStoredProcedure.SqlParam(
ExecuteStoredProcedure.DataType.Int, "userID", user_id
),
ExecuteStoredProcedure.SqlParam(
ExecuteStoredProcedure.DataType.VarChar, "accessToken", access_token
),
ExecuteStoredProcedure.SqlParam(
ExecuteStoredProcedure.DataType.Bit, "enabled", enabled
),
ExecuteStoredProcedure.SqlParam(
ExecuteStoredProcedure.DataType.VarChar, "TokenType", str(token_type)
),
]
return await self.call_hq(
"sdk_OAuthTokenEnableExpire",
ExecuteStoredProcedure.SqlParameters(sql_params),
)
[docs]
async def delete(
self,
clientid: str,
customer_id: int,
user_id: int,
access_token: str,
online_session: OnlineSession,
online_user: OnlineUser,
token_type: TokenType = "Access",
):
result = await self.call_hq(
"sdk_OAuthTokenDelete",
ExecuteStoredProcedure.SqlParameters([
ExecuteStoredProcedure.SqlParam(
ExecuteStoredProcedure.DataType.VarChar, "clientID", clientid
),
ExecuteStoredProcedure.SqlParam(
ExecuteStoredProcedure.DataType.Int, "customerID", customer_id
),
ExecuteStoredProcedure.SqlParam(
ExecuteStoredProcedure.DataType.Int, "userID", user_id
),
ExecuteStoredProcedure.SqlParam(
ExecuteStoredProcedure.DataType.VarChar,
"accessToken",
access_token,
),
ExecuteStoredProcedure.SqlParam(
ExecuteStoredProcedure.DataType.VarChar,
"TokenType",
str(token_type),
),
]),
)
if self.hq_response.success:
audit = AuditRecord(self.logger, self.hq_credentials)
await audit.create(
f"OAuth_client_lookup deleted. customer_id: {customer_id} clientid: {clientid}",
online_session.session_id,
workstation_id=online_session.workstation,
customer_id=online_user.customer_id,
user_id=online_user.user_id,
user_logon_id=online_user.user_logon_id,
)
return result