Source code for q2_sdk.models.pinion

from __future__ import annotations

import asyncio
import json
import logging
from abc import ABC, abstractmethod
from dataclasses import asdict, dataclass, field
from enum import Enum
from typing import Optional

import requests
from google.protobuf.json_format import ParseDict, ParseError
from q2msg.olb.notification_pb2 import NotificationEnvelope as NotificationEnvelopeProto

from q2_sdk.core import cache, contexts, message_bus
from q2_sdk.core.cache import Q2CacheClient
from q2_sdk.core.configuration import settings
from q2_sdk.core.default_settings import EnvLevel
from q2_sdk.core.http_handlers.base_handler import get_trace_id
from q2_sdk.core.q2_logging.logger import Q2LoggerType
from q2_sdk.hq.db.generated_transactions import GeneratedTransactions
from q2_sdk.hq.models.hq_credentials import HqCredentials
from q2_sdk.models.holocron import get_envstack_by_cust_key
from q2_sdk.tools.utils import divide_list, is_json

REQUEST_PROTO_MSG_TYPE = "q2msg.olb.notification.NotificationEnvelope"
KRAYT_URL = settings.PINION_KRAYT_BY_ENV[settings.DEPLOY_ENV]
TOPIC = settings.PINION_TOPIC
LOGGER = logging.getLogger(__name__)


[docs] @dataclass class Result: success: bool = False error_message: Optional[str] = None data: dict = field(default_factory=dict)
[docs] class TemplatedEvents(Enum): """ Enum for templated events. Field Details: :param: ACCOUNT_REFRESH: Performs an account reload :param: ALERT: Sends an alert notification """ ACCOUNT_REFRESH = "AccountUpdateNeeded" ALERT = "SendApplicationAlert"
[docs] class AlertType(Enum): """ Enum for alert types. Field Details: :param: info: Denotes an info level alert :param: error: Denotes an error level alert :param: warning: Denotes an warning level alert :param: success: Denotes an success level alert """ INFO = "info" ERROR = "error" WARNING = "warning" SUCCESS = "success"
[docs] class NotificationRoutingType(Enum): """ Enum for notification routing types. Field Details: :param: MSG_TYPE_GENERAL: The notification is forwarded to the user identified in the userId field :param: MSG_TYPE_FI: The notification is broadcasted to all users within the FI """ MSG_TYPE_GENERAL = 0 MSG_TYPE_FI = 2
[docs] @dataclass class NotificationRouting: """ Represents the notification routing information. Field Details: :param: notification_type: The type of notification to be sent to the frontend :param: userId: The recipient of the notification message """ notification_type: NotificationRoutingType userId: Optional[str] = None
[docs] @dataclass class NotificationData: """ Represents the notification data. Field Details: :param: message: Message details :param: event: The event tied to the request :param: payload: Additional instructions sent to the frontend. """ message: Optional[str] = None event: Optional[TemplatedEvents | str] = None payload: dict = field(default_factory=dict)
[docs] @dataclass class BaseTemplate: """ Base Pinion template for sending messages. Dictionary format: { routing : { notification_type: ... userId: ... } data : { message: ... event: ... payload: ... } } """ routing: NotificationRouting data: NotificationData def __post_init__(self): self._validate_notification_type() self._prepare_payload() def _validate_notification_type(self): """ Ensures the notification type is stored as a value. """ if isinstance(self.routing.notification_type, NotificationRoutingType): self.routing.notification_type = self.routing.notification_type.value def _prepare_payload(self): """ Prepares the payload with trace_id and converts it to JSON if necessary. """ trace_id = get_trace_id(generate=True) if isinstance(self.data.payload, dict): self.data.payload["traceId"] = trace_id if not is_json(self.data.payload): self.data.payload = json.dumps(self.data.payload)
[docs] @dataclass class BaseTemplateFactory(ABC): """ A factory class responsible for converting specific template instances into a BaseTemplate. """ notification_routing_type: NotificationRoutingType def __post_init__(self): self.notification_routing_type = self._get_enum_value() def _get_enum_value(self) -> int: """ Returns the integer representation of the routing type. """ return ( self.notification_routing_type.value if isinstance(self.notification_routing_type, NotificationRoutingType) else self.notification_routing_type )
[docs] async def to_base_template( self, logger: Q2LoggerType, hq_credentials: HqCredentials, ) -> Result: """ Converts a specific template (AccountRefreshTemplate or AlertTemplate) instance into a BaseTemplate. """ get_template = await self.convert_to_base_template(logger, hq_credentials) if not get_template.success: get_template.error_message = ( "Failed to convert to base template. " + get_template.error_message ) logger.exception(get_template.error_message) return get_template
def _requires_user_id(self) -> bool: """ Checks if a user ID is required. """ return ( self.notification_routing_type == NotificationRoutingType.MSG_TYPE_GENERAL.value )
[docs] @abstractmethod async def convert_to_base_template(self): pass # pragma: no cover
[docs] @dataclass class AccountRefreshTemplate(BaseTemplateFactory): """ Template designed for account reloads. This operation should be performed only when a user is logged in. """ user_id: Optional[int] = None transaction_ids: list[int] = field(default_factory=list) message: Optional[str] = None
[docs] async def convert_to_base_template( self, logger: Q2LoggerType, hq_credentials: HqCredentials ) -> Result: """ Converts the AccountRefreshTemplate to a BaseTemplate. """ result = Result() # Validate user ID for MSG_TYPE_GENERAL routing if self._requires_user_id() and not self.user_id: result.error_message = ( "User ID must be provided for notification type: MSG_TYPE_GENERAL" ) return result # Prepare payload and lookup accounts payload, invalid_transaction_ids = await self._construct_payload( logger, hq_credentials ) # If there are any invalid transaction IDs, include them in the result's message if self.transaction_ids and not payload: result.error_message = ( "Could not find host account ids tied to provided transaction ids" ) return result # Create base template result.success = True result.data["template"] = self._create_base_template(payload) if invalid_transaction_ids: result.error_message = ( f"Failed to process transaction IDs: {invalid_transaction_ids}" ) return result
async def _construct_payload( self, logger: Q2LoggerType, hq_credentials: HqCredentials ) -> tuple[list[int], Optional[str]]: """ Constructs the payload from the transaction IDs. """ payload = [] invalid_transaction_ids = [] if self.transaction_ids: haids, invalid_ids = await self._lookup_accounts_by_transaction_id( logger, hq_credentials, self.transaction_ids ) invalid_transaction_ids.extend(invalid_ids) payload = haids return payload, invalid_transaction_ids async def _lookup_accounts_by_transaction_id( self, logger: Q2LoggerType, hq_credentials: HqCredentials, transaction_ids: list[int], ) -> tuple[list[int], list[int]]: """ Looks up host account IDs asynchronously for each transaction ID. """ logger.info("Looking up transaction ids for host account ids...") gt_obj = GeneratedTransactions(logger, hq_credentials=hq_credentials) # CPerform parallel lookups for transaction IDs in chunks parallel_calls = [ asyncio.create_task( gt_obj.get_by_id(transaction_id), name=str(transaction_id) ) for transaction_id in transaction_ids ] chunked_calls = list(divide_list(parallel_calls, 25)) haids = set() invalid_ids = [] for chunk in chunked_calls: await asyncio.gather(*chunk) for task in chunk: result = task.result() transaction_id = int(task.get_name()) if not result: invalid_ids.append(transaction_id) continue haids.add(int(result[0].HostAccountID.text)) logger.debug( f"Host account ids: {haids} Invalid transaction ids: {invalid_ids}" ) return list(haids), invalid_ids def _create_base_template(self, payload: list[int]) -> BaseTemplate: """ Creates the BaseTemplate from the AccountRefreshTemplate data. """ return BaseTemplate( routing=NotificationRouting( notification_type=self.notification_routing_type, userId=str(self.user_id) if self.user_id else "", ), data=NotificationData( message=self.message, event=TemplatedEvents.ACCOUNT_REFRESH.value, payload=json.dumps(payload), ), )
[docs] @dataclass class AlertTemplate(BaseTemplateFactory): """ A template designed for sending alerts with control over whether the alert is closable, its duraction, and the alert type (info, error, etc.). """ user_id: Optional[int] = None message: Optional[str] = None alert_type: Optional[AlertType] = None closable: Optional[bool] = None duration_in_ms: Optional[int] = None
[docs] async def convert_to_base_template( self, logger: Q2LoggerType, hq_credentials: HqCredentials ) -> Result: """ Converts the AlertTemplate into a BaseTemplate, validating required fields and preparing the payload for sending """ # Validate user ID for MSG_TYPE_GENERAL routing if self._requires_user_id() and not self.user_id: msg = "User ID must be provided for notification type: MSG_TYPE_GENERAL" return Result(error_message=msg) # Prepare payload payload = await self._construct_payload(logger, hq_credentials) # Create and return base template data = {"template": self._create_base_template(payload)} return Result(success=True, data=data)
async def _construct_payload( self, logger: Q2LoggerType, hq_credentials: HqCredentials ) -> dict: """ Prepare the payload by populating the possible fields. """ payload = {} if self.alert_type: payload["alertType"] = self.alert_type.value if self.closable: payload["closable"] = self.closable if self.duration_in_ms: payload["durationInMs"] = self.duration_in_ms return payload def _create_base_template(self, payload: dict) -> BaseTemplate: """ Creates the BaseTemplate from the AlertTemplate data. """ return BaseTemplate( routing=NotificationRouting( notification_type=self.notification_routing_type, userId=str(self.user_id) if self.user_id else "", ), data=NotificationData( message=self.message, event=TemplatedEvents.ALERT.value, payload=json.dumps(payload), ), )
[docs] class Pinion: """ Pinion (Beta) is a system of services designed to deliver lightweight, distributed notifications to end-user applications via their browsers. This class provides the ability to send notifications using the SDK, which interacts with the underlying services to push messages to end users in real time. Note: This integration requires a minimum UUX version of 4.6.1.2. """ def __init__( self, logger: Q2LoggerType = None, hq_credentials: HqCredentials = None, environment: EnvLevel = settings.ENV_LEVEL_DEPLOY_ENV, cache_obj: Q2CacheClient = None, ): self.logger = logger or self._get_logger_from_context() self.hq_credentials = hq_credentials or self._get_hq_credentials_from_context() if not self.hq_credentials: raise ValueError( "Hq credentials are required but was not provided and couldn't be retrieved from context" ) self.envstack = self.hq_credentials.env_stack self.cache = cache_obj or cache.get_cache(logger=self.logger) self.environment = environment
[docs] async def send( self, template: BaseTemplate | AccountRefreshTemplate | AlertTemplate ) -> Result: """ Sends a notification message based on the provided template. :param: template: The template to use for sending the message. Currently, supports an account refresh, alert and base template. Base template can be used for generic functionality, however, other templates have also been defined to facilitate interactions with Pinion. Returns: Result: Result object returned with success or failure, including invalid transaction IDs for account refresh operation if applicable """ # Fetch envstack if necessary if not self.envstack: envstack_result = await self._fetch_envstack() if not envstack_result.success: return envstack_result # Convert to base template if not already template_result = await self._convert_template(template) if not template_result.success: return template_result # Serialize template to protobuf message message_result = self._serialize_to_protobuf(template_result.data["template"]) if not message_result.success: return message_result # Send the message and capture the result result = await self._send_message(message_result.data["message"]) # Capture invalid transaction IDs in the final result if isinstance(template, AccountRefreshTemplate) and result.success: result.error_message = template_result.error_message return result
def _get_logger_from_context(self): context = contexts.get_current_request(raise_if_none=False) return context.request_handler.logger if context else LOGGER def _get_hq_credentials_from_context(self): current_request = contexts.get_current_request(raise_if_none=False) return ( current_request.request_handler.hq_credentials if current_request else None ) async def _fetch_envstack(self) -> Result: """ Fetches the envstack for message bus communication. """ self.logger.warning("Envstack not found in hq credentials...") result = await get_envstack_by_cust_key( self.logger, self.hq_credentials, self.environment, self.cache ) if not result.success: self.logger.exception(result.error_message) else: self.envstack = result.data["envstack"] return result async def _convert_template( self, template: AccountRefreshTemplate | AlertTemplate ) -> Result: """ Converts a given template to a base template if necessary. """ if isinstance(template, BaseTemplate): return Result(success=True, data={"template": template}) return await template.to_base_template(self.logger, self.hq_credentials) def _serialize_to_protobuf(self, template: BaseTemplate) -> Result: """ Serializes base template into a protobuf message. """ try: notification_envelope_dict = asdict(template) message = ParseDict(notification_envelope_dict, NotificationEnvelopeProto()) return Result(success=True, data={"message": message}) except (TypeError, ParseError) as err: err_msg = f"Failed to convert notification details to q2msg protobuf message {err}" self.logger.exception(err_msg) return Result(error_message=err_msg) async def _send_message(self, message: NotificationEnvelopeProto) -> Result: """ Sends the protobuf message onto the message bus. """ try: self.logger.info("Pushing pinion message onto bus...") response = await message_bus.push( self.logger, message=message, envstack=self.envstack, message_type=REQUEST_PROTO_MSG_TYPE, topic=TOPIC, krayt_url=KRAYT_URL, ) msg_id = response.json()["msgId"] self.logger.info( f"Message successfully added to bus. MsgId:{msg_id} UserID:{message.routing.userId}" ) return Result(success=True, data={"msg_id": msg_id}) except (requests.exceptions.RequestException, AttributeError): err_msg = "Failed to send message" self.logger.exception(err_msg) return Result(error_message=err_msg)