Source code for q2_sdk.models.push_notification

import asyncio
from dataclasses import asdict, dataclass, field
from datetime import datetime, timezone
from enum import Enum
from typing import Union

import requests
from google.protobuf.json_format import ParseDict, ParseError
from google.protobuf.timestamp_pb2 import Timestamp
from q2msg.notifications.push_pb2 import PushRequestV2 as PushRequest

from q2_sdk.core import cache, message_bus
from q2_sdk.core.cache import Q2CacheClient
from q2_sdk.core.configuration import settings
from q2_sdk.core.default_settings import EnvLevel
from q2_sdk.core.q2_logging.logger import Q2LoggerType
from q2_sdk.hq.db.notification import CreateNotificationParams
from q2_sdk.hq.db.notification import Notification as NotificationDBObject
from q2_sdk.hq.db.notification_status import NotificationStatus
from q2_sdk.hq.db.notification_type import NotificationType
from q2_sdk.hq.db.push_notification_targets import PushNotificationTargets
from q2_sdk.hq.hq_api.q2_api import GetVersions
from q2_sdk.models.holocron import (
    get_envstack_by_cust_key,
)
from q2_sdk.tools.utils import BaseDataMapper

REQUEST_PROTO_MSG_TYPE = "q2msg.notifications.push.PushRequestV2"
KRAYT_URL = settings.FIREBASE_KRAYT_BY_ENV[settings.DEPLOY_ENV]
TOPIC = settings.FIREBASE_TOPIC_BY_ENV[settings.DEPLOY_ENV]


[docs] class AndroidMessagePriority(Enum): NORMAL = 0 HIGH = 1
[docs] class Visibility(Enum): VISIBILITY_UNSPECIFIED = 0 PRIVATE = 1 PUBLIC = 2 SECRET = 3
[docs] class NotificationPriority(Enum): PRIORITY_UNSPECIFIED = 0 PRIORITY_MIN = 1 PRIORITY_LOW = 2 PRIORITY_DEFAULT = 3 PRIORITY_HIGH = 4 PRIORITY_MAX = 5
[docs] @dataclass class Notification(BaseDataMapper): title: str = None body: str = None image: str = None
[docs] @dataclass class Color(BaseDataMapper): red: float = None green: float = None blue: float = None alpha: float = None
[docs] @dataclass class LightSettings(BaseDataMapper): color: Color = None lightOnDurationMillis: int = None lightOffDurationMillis: int = None
[docs] @dataclass class AndroidNotification(BaseDataMapper): title: str = None body: str = None icon: str = None color: str = None sound: str = None tag: str = None clickAction: str = None bodyLocKey: str = None bodyLocArgs: list[str] = field(default_factory=list) titleLocKey: str = None titleLocArgs: list[str] = field(default_factory=list) channelId: str = None ticker: str = None sticky: bool = None eventTime: str = None localOnly: bool = None notificationPriority: NotificationPriority = None defaultSound: bool = None defaultVibrateTimings: bool = None defaultLightSettings: bool = None vibrateTimings: list[str] = field(default_factory=list) visibility: Visibility = None notificationCount: int = None lightSettings: LightSettings = None image: str = None
[docs] @dataclass class AndroidFCMOptions(BaseDataMapper): analyticsLabel: str = None
[docs] @dataclass class AndroidConfig(BaseDataMapper): collapseKey: str = None priority: AndroidMessagePriority = None ttl: str = None restrictedPackageName: str = None data: dict[str, str] = field(default_factory=dict) notification: AndroidNotification = None fcmOptions: AndroidFCMOptions = None directBootOkay: bool = None
[docs] @dataclass class WebpushNotificationAction(BaseDataMapper): action: str = None title: str = None icon: str = None
[docs] @dataclass class WebpushNotification(BaseDataMapper): actions: list[WebpushNotificationAction] = field(default_factory=list) title: str = None body: str = None icon: str = None badge: str = None direction: str = None data: str = None image: str = None language: str = None renotify: bool = None requireInteraction: bool = None silent: bool = None tag: str = None timestampMillis: int = None vibrate: list[int] = field(default_factory=list) customData: dict[str, str] = field(default_factory=dict)
[docs] @dataclass class WebpushFCMOptions(BaseDataMapper): link: str = None
[docs] @dataclass class WebpushConfig(BaseDataMapper): headers: dict[str, str] = field(default_factory=dict) data: dict[str, str] = field(default_factory=dict) notification: WebpushNotification = None fcmOptions: WebpushFCMOptions = None
[docs] @dataclass class ApsAlert(BaseDataMapper): title: str = None subTitle: str = None body: str = None locKey: str = None locArgs: list[str] = field(default_factory=list) titleLoc: str = None titleLocArgs: list[str] = field(default_factory=list) subTitleLoc: str = None subTitleLocArgs: list[str] = field(default_factory=list) actionLocKey: str = None launchImage: str = None
[docs] @dataclass class CriticalSound(BaseDataMapper): critical: bool = None name: str = None volume: float = None
[docs] @dataclass class Aps(BaseDataMapper): alertString: str = None alert: ApsAlert = None badge: int = None sound: str = None criticalSound: CriticalSound = None contentAvailable: bool = None mutableContent: bool = None category: str = None threadId: str = None customData: dict[str, str] = field(default_factory=dict)
[docs] @dataclass class APNSPayload(BaseDataMapper): aps: Aps = None customData: dict[str, str] = field(default_factory=dict) authenticationCode: str = None
[docs] @dataclass class APNSFCMOptions(BaseDataMapper): analyticsLabel: str = None imageUrl: str = None
[docs] @dataclass class ApnsConfig(BaseDataMapper): headers: dict[str, str] = field(default_factory=dict) payload: APNSPayload = None fcmOptions: APNSFCMOptions = None
[docs] @dataclass class FcmOptions(BaseDataMapper): analyticsLabel: str = None
[docs] @dataclass class Message(BaseDataMapper): name: str = None data: dict[str, str] = field(default_factory=dict) notification: Notification = None android: AndroidConfig = None webpush: WebpushConfig = None apns: ApnsConfig = None fcmOptions: FcmOptions = None
[docs] @dataclass class SendRequest(BaseDataMapper): validateOnly: bool = None message: Message = None
[docs] @dataclass class NotificationRequest(BaseDataMapper): id: int create_date: str country_id: str = "USA" retry_count: int = 0 id_only: bool = False
[docs] @dataclass class PushRequestV2(BaseDataMapper): notification: NotificationRequest request: SendRequest
[docs] @dataclass class SendResponseV2: success: bool = False error_message: str = "" data: dict = field(default_factory=dict)
[docs] @dataclass class PushNotificationParams: subject: str """Note: the target address value is no longer in use. Passing in an empty string should suffice.""" target_address: str target_user_id: int device_nickname: str = "" country_code: str = "USA" envstack: str = "" environment: Union[EnvLevel, str] = ""
[docs] class PushNotification: """ Provides ability to send rich (enhanced) push notifications to end users through the Firebase Cloud Messaging (FCM) Service. Push notifications are supported on HQ version 4.5.0.6065 and above. For additional information on the payload, please visit the `Firebase API Documentation <https://firebase.google.com/docs/reference/fcm/rest/v1/projects.messages>`_. Note, sending rich data such as images, sounds, deep linking, would require MSDK module integration. Please visit the `MSDK Documentation <https://docs.q2developer.com/native/index.html>`_ for more details. """ MIN_HQ_BUILD = "6065" def __init__( self, logger: Q2LoggerType, hq_credentials, params: PushNotificationParams, cache: Q2CacheClient = None, ): self.logger = logger self.hq_credentials = hq_credentials self.country_code = params.country_code self.subject = params.subject self.target_address = params.target_address self.target_user_id = params.target_user_id self.device_nickname = params.device_nickname self.cache = cache or self._get_cache() self.envstack = params.envstack or self.hq_credentials.env_stack self.environment = params.environment or settings.ENV_LEVEL_DEPLOY_ENV if isinstance(self.environment, str): self.environment = EnvLevel.from_string(self.environment)
[docs] async def send_by_device_nickname( self, send_request: Union[SendRequest, dict] ) -> SendResponseV2: """ Send notification by device nickname. Converts provided request into the q2msg PushRequestV2 protobuf message. Encrypts the message and sends the request. """ if not self.device_nickname: return SendResponseV2(error_message="Device nickname not provided") response = await self._handle_send(send_request) return response[0]
[docs] async def send_by_id( self, send_request: Union[SendRequest, dict] ) -> list[SendResponseV2]: """ Send notification by user ID. Will send notification to all devices tied to the user. Converts provided request into the q2msg PushRequestV2 protobuf message. Encrypts the message and sends the request. """ if not self.target_user_id: return SendResponseV2(error_message="Target user ID not provided") return await self._handle_send(send_request)
[docs] async def send(self, send_request: Union[SendRequest, dict]) -> SendResponseV2: """ This function has been deprecated. Use `send_by_device_nickname` (or `send_by_id` depending on your needs) instead. """ self.logger.warning( "The push notification object's send method has been deprecated. Use `send_by_device_nickname` instead" ) return await self.send_by_device_nickname(send_request)
async def _handle_send( self, send_request: Union[SendRequest, dict] ) -> SendResponseV2: # get all the notification details such as the gcm tokens and notification type/status information notification_details = await self._get_notification_details() if not notification_details.success: return [notification_details] # populate PushRequestV2 dataclass with send_request input for protobuf message conversion if isinstance(send_request, dict): send_request = SendRequest.from_dict(send_request) # Send the request return await self._handle_notification_call( notification_details, send_request, ) async def _get_notification_details(self) -> SendResponseV2: # verify stack meets minimum required HQ version & fetch envstack from holocron if necessary required_details = await self._fetch_stack_details() if not required_details.success: return required_details # fetch device/gcm token either by device nickname, or if not provided # fetch all valid tokens for the user get_token_response = await self._get_gcm_tokens() if not get_token_response.success: self.logger.error(get_token_response.error_message) return get_token_response # get the notification type and status get_notification_details = await self._get_notification_type_and_status() if not get_notification_details.success: return get_notification_details get_notification_details.data |= get_token_response.data return get_notification_details async def _handle_notification_call( self, notification_details, send_request ) -> SendResponseV2: parallel_calls = [] for token in notification_details.data["token_info"]: task = asyncio.create_task( self._send_notification( notification_details.data["NotificationTypeID"], notification_details.data["NotificationStatusID"], token, send_request, ) ) parallel_calls.append(task) await asyncio.gather(*parallel_calls) responses = [] for task in parallel_calls: responses.append(task.result()) return responses async def _send_notification( self, notification_type_id, notification_status_id, token, send_request ) -> SendResponseV2: # generate a notification row for the NDS to update. One will be generated per token. timestamp_obj = Timestamp() current_date = self._get_current_date(timestamp_obj) notification_id = await self._generate_notification_id( current_date, notification_type_id, notification_status_id, token ) push_request_v2 = PushRequestV2( notification=NotificationRequest( id=notification_id, country_id=self.country_code, create_date=timestamp_obj.ToJsonString(), ), request=send_request, ) # Serialize dataclass to protobuf message message_result = self._serialize_to_protobuf(push_request_v2, token) if not message_result.success: return message_result # Encrypt and send request response = await self._send_message( message_result.data["message"], notification_id ) return response async def _fetch_stack_details(self) -> SendResponseV2: hq_version_response = await self._get_hq_info() if not hq_version_response.success: return hq_version_response if not self.envstack: envstack_result = await self._get_envstack() if not envstack_result.success: self.logger.error(envstack_result.error_message) return envstack_result else: self.logger.info("Using provided envstack") return SendResponseV2(success=True) def _serialize_to_protobuf( self, request: PushRequestV2, gcm_token: str ) -> SendResponseV2: """ Serializes push request object into a protobuf message. """ # Convert push request object dictonary to a protobuf message try: push_request_dict = asdict(request) push_request_dict["request"]["message"]["token"] = gcm_token message = ParseDict(push_request_dict, PushRequest()) self.logger.info( f"Successfully converted push request to a protobuf message: {message}" ) return SendResponseV2(success=True, data={"message": message}) except (TypeError, ParseError) as err: msg = f"Failed to convert push request to q2msg protobuf message {err}" self.logger.exception(msg) return SendResponseV2(error_message=msg) async def _send_message( self, message: PushRequest, notification_id ) -> SendResponseV2: """ Sends the protobuf message onto the message bus. """ try: self.logger.info("Pushing rich push message onto bus...") response = await message_bus.push( self.logger, message=message, envstack=self.envstack, message_type=REQUEST_PROTO_MSG_TYPE, topic=TOPIC, krayt_url=KRAYT_URL, ) msg_id = response.json()["msgId"] self.logger.info( f"Push notification sent. MsgId: {msg_id} NotificationId: {notification_id}" ) return SendResponseV2( success=True, data={"NotificationID": notification_id} ) except (requests.exceptions.RequestException, AttributeError): err_msg = "Failed to send notification" self.logger.exception(err_msg) return SendResponseV2(error_message=err_msg) async def _get_hq_info(self): """ Attempts to retrieve cached HQ version/build and envstack if available, Otherwise, an call is made to retrieve this information and cached by HQ URL. """ response = SendResponseV2() version_key = f"{self.hq_credentials.customer_key}_HQVersion" envstack_key = f"{self.hq_credentials.customer_key}_EnvStack" cached_hq = await self.cache.get_many_async([version_key, envstack_key]) or {} version = cached_hq.get(version_key) envstack = cached_hq.get(envstack_key) if not version: self.logger.info("HQ version not found in cache. Calling HQ...") params_obj = GetVersions.ParamsObj( self.logger, hq_credentials=self.hq_credentials ) result = await GetVersions.execute(params_obj) try: for node in result.result_node.Data.FiList.Detail: if node.Name.text == "HqVersion": version = node.Value.text hq_build = version.split(".")[-1] if hq_build < self.MIN_HQ_BUILD: msg = f"HQ Version {version} does not meet minimum required version 4.5.0.6065" self.logger.error(msg) response.error_message = msg return response if node.Name.text == "EnvStack": self.logger.info("Envstack found in HQ.") envstack = node.Value.text if not version: msg = "HQ version not found" self.logger.error(msg) response.error_message = msg return response except (AttributeError, ValueError) as err: msg = f"HQ GetVersions Failure: {err}" self.logger.exception(msg) response.error_message = msg return response await self.cache.set_many_async( { version_key: version, envstack_key: envstack, }, expire=60 * 60 * 12, ) else: self.logger.info(f"HQ version found in the cache: {version}") response.success = True self.envstack = envstack if not self.envstack else self.envstack return response async def _get_gcm_tokens(self) -> SendResponseV2: pnt_obj = PushNotificationTargets(self.logger, self.hq_credentials) user_devices = await pnt_obj.get(self.target_user_id) if self.device_nickname: self.logger.info("Search by device nickname detected...") targets = [ x for x in user_devices if x.Nickname.text == self.device_nickname ] if not targets: return SendResponseV2( error_message="Could not find device token for provided user and device nickname" ) sorted_targets = sorted( targets, key=lambda target: self._parse_datetime(target.Created.text), reverse=True, ) tokens = [sorted_targets[0].GcmToken.text] else: self.logger.info("Search by target user id detected...") tokens = [x.GcmToken.text for x in user_devices if x.UseForAlerts] if not tokens: return SendResponseV2( error_message="Could not find device token for provided target user" ) self.logger.info(f"Gcm tokens found: {tokens}") return SendResponseV2(success=True, data={"token_info": tokens}) def _parse_datetime(self, datetime_str): return datetime.fromisoformat(datetime_str) async def _generate_notification_id( self, current_date: str, notification_type_id, notification_status_id, token ) -> int: """ Inserts a row into the Q2_Notifications table based on provided parameters. """ params = CreateNotificationParams( NotificationSubject=self.subject, NotificationTypeID=notification_type_id, NotificationStatusID=notification_status_id, UserID=str(self.target_user_id), CreateDate=current_date, TargetAddress=token, ) notification_obj = NotificationDBObject( self.logger, hq_credentials=self.hq_credentials ) notification_result = await notification_obj.add(params) self.logger.info(f"Generated NotificationID {notification_result.id}") return notification_result.id async def _get_notification_type_and_status(self) -> SendResponseV2: """ Fetches the notification type and status information """ parallel_calls = [ asyncio.create_task( NotificationType(self.logger, self.hq_credentials).get_by_name("Push"), name="type", ), asyncio.create_task( NotificationStatus(self.logger, self.hq_credentials).get_by_name("New"), name="status", ), ] self.logger.info( f"Executing {len(parallel_calls)} calls to fetch notification type and status information" ) await asyncio.gather(*parallel_calls) for task in parallel_calls: call_type = task.get_name() self.logger.debug(f"fetching result for notification {call_type}") result = task.result() match call_type: case "type": try: notification_type_id = result[0].NotificationTypeID.text except AttributeError: msg = "Failed to retrieve push notification type id" self.logger.exception(msg) return SendResponseV2(error_message=msg) case "status": try: notification_status_id = result[0].NotificationStatusID.text except AttributeError: msg = "Failed to retrieve push notification status id" self.logger.exception(msg) return SendResponseV2(error_message=msg) return SendResponseV2( success=True, data={ "NotificationTypeID": notification_type_id, "NotificationStatusID": notification_status_id, }, ) def _get_current_date(self, timestamp_obj: Timestamp) -> str: timestamp_obj.GetCurrentTime() timestamp = timestamp_obj.seconds + timestamp_obj.nanos / 1e9 current_datetime = datetime.fromtimestamp(timestamp, tz=timezone.utc) current_date = datetime.strftime(current_datetime, "%Y-%m-%d %H:%M:%S") return current_date def _get_cache(self) -> Q2CacheClient: return cache.get_cache(logger=self.logger) async def _get_envstack(self): """ Fetches the envstack for message bus communication. """ self.logger.warning("Envstack not found in hq credentials...") result = await get_envstack_by_cust_key( self.logger, self.hq_credentials, self.environment, self.cache ) if not result.success: self.logger.exception(result.error_message) else: self.envstack = result.data["envstack"] return result