from __future__ import annotations
import base64
from dataclasses import dataclass
from enum import Enum
import os
import inspect
from typing import ClassVar, Optional, Self
from lxml import etree, objectify # type: ignore
from lxml.builder import E
from q2_sdk.core.exceptions import ImproperUseError
from q2_sdk.tools.utils import pascalize
try:
from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes
except ModuleNotFoundError: # pragma: no cover
pass
[docs]
class AuthorizationStatus(Enum):
PasswordIsBad = 0
PasswordIsGood = 1
Expired = 2
InitialPwdExpired = 4
LockLogin = 8
DisableLogin = 16
NextInvalidChangesStatus = 32
TacRequired = 64
UserTokenCreated = (
128 # success: here's the token, see other flags for options allowed
)
BrowserRegistrationCountExceeded = (
256 # error: no token created, cannot register or roam
)
RoamingIsAllowed = 512
BrowserRegistrationIsAllowed = 1024
IsMultiChannelAuthenticated = 2048
ExceptionDuringProcessing = 4096
LockUser = 8192
[docs]
@dataclass
class StandardAuthResponseFields:
success: bool
status_description: Optional[str]
hq_error_return_code: Optional[int]
end_user_message: Optional[str]
@staticmethod
def get_success(*args, **kwargs) -> StandardAuthResponseFields:
return StandardAuthResponseFields(True, None, None, None)
@staticmethod
def get_failure(*args, **kwargs) -> StandardAuthResponseFields:
return StandardAuthResponseFields(
False, "Auth Failure", -1, "There has been a failure with authentication"
)
def serialize_as_xml(self) -> etree.Element:
match self.success:
case True:
xml = E.Root(
E.Status("Success"),
E.HQErrorReturnCode("0"),
)
case False | _:
xml = E.Root(
E.Status("Error"),
E.StatusDescription(self.status_description),
E.HQErrorReturnCode(str(self.hq_error_return_code)),
E.EndUserMessage(self.end_user_message),
)
return xml
[docs]
@dataclass
class BaseAuthResponse:
"""
All other auth responses have at least this info
.. code-block:: xml
<Q2Bridge request="RequestType" messageID="messageID">
<Status>"Success"/"Error"</Status>
<HQErrorReturnCode>{0}</HQErrorReturnCode>
<StatusDescription>{0}</StatusDescription>
<EndUserMessage>{0}</EndUserMessage>
</Q2Bridge>
"""
standard_auth_response_fields: StandardAuthResponseFields
valid_calling_funcs: ClassVar[list[str]] = ["get_success", "get_failure"]
def __post_init__(self):
call_stack = inspect.stack()
calling_funcs = [call_stack[x][3] for x in range(len(call_stack))]
is_valid = False
for name in calling_funcs:
if name in self.valid_calling_funcs:
is_valid = True
if not is_valid:
raise ImproperUseError(
f"This Response should not be instantiated directly. Please use one of the following methods: {self.valid_calling_funcs}"
)
self.additional_response_fields = {}
def serialize_as_xml(self, message_id: str) -> str:
root = etree.Element("Q2Bridge", {"messageID": message_id})
auth_xml = self.standard_auth_response_fields.serialize_as_xml()
for node in auth_xml.getchildren():
etree.SubElement(root, node.tag).text = node.text
for name, value in self.additional_response_fields.items():
elem = etree.SubElement(root, name)
elem.text = str(value)
return root
def add_response_field(self, key, value):
assert value is not None
if isinstance(value, Enum):
value = value.value
self.additional_response_fields[pascalize(key)] = value
@classmethod
def _get_standard_auth_success_fields(cls):
return StandardAuthResponseFields.get_success()
@classmethod
def _get_standard_auth_failure_fields(cls):
return StandardAuthResponseFields.get_failure()
[docs]
@classmethod
def get_success(cls, *args, **kwargs) -> Self:
"""Returns OK to HQ"""
return cls(cls._get_standard_auth_success_fields())
[docs]
@classmethod
def get_failure(cls, *args, **kwargs) -> Self:
"""Returns Failure to HQ"""
return cls(cls._get_standard_auth_failure_fields())
[docs]
@dataclass
class BaseAuthRequest:
raw: objectify.Element
@staticmethod
def from_xml(xml: objectify.Element) -> BaseAuthRequest:
return BaseAuthRequest(xml)
[docs]
@dataclass
class Password:
value: str
encrypted: bool
def decrypt(self):
if self.encrypted is False:
return self.value
# TODO: Make this not environment variables
key = bytes([int(x) for x in os.environ.get("PASSWORD_ENC_KEY").split(",")])
iv = bytes([int(x) for x in os.environ.get("PASSWORD_ENC_IV").split(",")])
cipher = Cipher(algorithms.AES(key), modes.CBC(iv))
enc_pass = base64.b64decode(self.value)
decryptor = cipher.decryptor()
res = decryptor.update(enc_pass) + decryptor.finalize()
return res.strip().decode()