from __future__ import annotations
from dataclasses import dataclass
from typing import Optional
import jwt
from lxml import objectify # type: ignore
from q2_sdk.tools.utils import to_bool
from q2_sdk.core import contexts, q2_requests
from .base import BaseAuthRequest, BaseAuthResponse
[docs]
@dataclass
class Oauth2Config:
client_id: str
client_secret: str
token_url: str
authorize_url: str
redirect_uri: str
well_known_url: Optional[str] = None
jwks_url: Optional[str] = None
audience: Optional[str] = None
async def get_access_token(self, code: str) -> dict:
context = contexts.get_current_request()
logger = context.request_handler.logger
if self.well_known_url:
self.jwks_url = self.well_known_url
logger.warning(
"The `well_known_url` parameter has been deprecated."
"Please use `jwks_url instead in Oauth2Config"
)
jwt_resp = await q2_requests.post(
logger,
self.token_url,
data={
"grant_type": "authorization_code",
"client_id": self.client_id,
"client_secret": self.client_secret,
"code": code,
"redirect_uri": self.redirect_uri,
},
)
access_token = jwt_resp.json()["access_token"]
match self.jwks_url:
case str():
logger.debug(
"Oauth config provides a well_known_url. Getting signing key and using RS256 encryption algorithm."
)
jwks_client = jwt.PyJWKClient(self.jwks_url)
signing_key = jwks_client.get_signing_key_from_jwt(access_token)
decoded = jwt.decode(
access_token,
key=signing_key.key,
algorithms=["RS256"],
audience=self.audience,
)
case None:
logger.debug(
"Oauth config provides has no well_known_url. Using HS256 encryption algorithm."
)
decoded = jwt.decode(
access_token, algorithms=["HS256"], audience=self.audience
)
return decoded
[docs]
@dataclass
class Request(BaseAuthRequest):
"""
.. code-block:: xml
<HQ request="LogonUserExternal" messageID="{0}">
<Token>m5HTKm6iYTNBtTMvhRpA9C3uWFknXRbBkFOOPJxM3Vo</Token>
<Token2>state</Token2>
<IsPrelogonSession>False</IsPrelogonSession>
<SessionId>shb3cnukdl32c54nodvyxcsj</SessionId>
</HQ>
"""
raw: objectify.Element
code: str
state: str
is_prelogon_session: bool
session_id: str
@staticmethod
def from_xml(xml: objectify.Element) -> Request:
code = xml.Token.text
state = xml.Token2.text
is_prelogon_session = to_bool(xml.IsPrelogonSession.text)
session_id = xml.SessionId.text
return Request(
xml,
code,
state,
is_prelogon_session,
session_id,
)
[docs]
@dataclass
class Response(BaseAuthResponse):
"""
Requires the user_identifier, which Q2_SSOUserLogon.SSOIdentifier field in the
database and creates a user session.
This is typically the "sub" field if you are dealing with a jwt access_token.
.. code-block:: xml
<Q2Bridge request="LogonUserExternal" messageId="messageID">");
<Status>Success</Status>
<UserIdentifier>Q2_SSOUserLogon.SsoIdentifier</UserIdentifier>
</Q2Bridge>
"""
[docs]
@classmethod
def get_success(
cls,
user_identifier: str,
):
resp = cls(cls._get_standard_auth_success_fields())
resp.add_response_field("user_identifier", user_identifier)
return resp