Source code for q2_sdk.models.oauth_client

from dataclasses import dataclass
from typing import Optional
from q2_sdk.core.cache import Q2CacheClient


[docs] @dataclass class OAuthToken: access_token: str expires_in: int refresh_token: Optional[str] = None token_type: str = "Bearer"
[docs] @classmethod def from_dict(cls, cached_dict: dict): return cls(**cached_dict)
[docs] def serialize_as_header(self): return {"Authorization": f"{self.token_type} {self.access_token}"}
[docs] class OAuthClient: """ This class takes care of takes care of some of the grunt work of authenticating with OAuth, such as - Retrieving an access token - Refreshing an access token using a refresh token - Passing in an Authorization header - Caching the token until it exires To use, create a child class that inherits from `OAuthClient`, define the `get_token` and (optionally) `refresh_token` functions, then pass an instance of your child class to a q2_requests function using the ``oauth_client`` keyword argument. """ def __init__(self, cache: Q2CacheClient): self.cache = cache @property def name(self) -> str: """ A unique string to be used for storing tokens in cache. By default, this will be the class name. If you plan to use this class to store tokens for multiple domains or endpoints, you should override this. """ return self.__class__.__name__.upper() @property def access_key(self) -> str: """ Key for storing OAuth access token in cache. You can override this for advanced use cases, but consider overriding `name` instead. """ return f"{self.name}_ACCESS_TOKEN" @property def refresh_key(self) -> str: """ Key for storing OAuth refresh token in cache. You can override this for advanced use cases, but consider overriding `name` instead. """ return f"{self.name}_REFRESH_TOKEN" @property def unauthorized_status_codes(self) -> list[int]: """ A list of HTTP status codes that indicate that the cached token is invalid. If a request returns one of these statuses, a new token will be aquired and the request will retry exactly once. """ return [400, 401, 403]
[docs] async def get_token(self) -> OAuthToken: """ Define behavior for getting an access token from your IDP """ raise NotImplementedError()
[docs] async def refresh_token(self, refresh_token: Optional[str]) -> OAuthToken: """ Define behavior for refreshing an access token from your IDP """ return await self.get_token()
[docs] async def get_token_obj(self) -> OAuthToken: access_token = self.cache.get(self.access_key) if not access_token: refresh_token = self.cache.get(self.refresh_key) if refresh_token: token_obj = await self.refresh_token(refresh_token) else: token_obj = await self.get_token() if token_obj.refresh_token: self.cache.set(self.refresh_key, token_obj.refresh_token) self.cache.set( self.access_key, vars(token_obj), expire=token_obj.expires_in ) else: token_obj = OAuthToken.from_dict(access_token) return token_obj
[docs] def clear_token(self): self.cache.delete(self.access_key)