import asyncio
from dataclasses import asdict, dataclass, field
from datetime import datetime, timezone
from enum import Enum
from typing import Union
import requests
from google.protobuf.json_format import ParseDict, ParseError
from google.protobuf.timestamp_pb2 import Timestamp
from q2msg.notifications.push_pb2 import PushRequestV2 as PushRequest
from q2_sdk.core import cache, message_bus
from q2_sdk.core.cache import Q2CacheClient
from q2_sdk.core.configuration import settings
from q2_sdk.core.default_settings import EnvLevel
from q2_sdk.core.q2_logging.logger import Q2LoggerType
from q2_sdk.hq.db.notification import CreateNotificationParams
from q2_sdk.hq.db.notification import Notification as NotificationDBObject
from q2_sdk.hq.db.notification_status import NotificationStatus
from q2_sdk.hq.db.notification_type import NotificationType
from q2_sdk.hq.db.push_notification_targets import PushNotificationTargets
from q2_sdk.hq.hq_api.q2_api import GetVersions
from q2_sdk.models.holocron import (
get_envstack_by_cust_key,
)
from q2_sdk.tools.utils import BaseDataMapper
REQUEST_PROTO_MSG_TYPE = "q2msg.notifications.push.PushRequestV2"
KRAYT_URL = settings.FIREBASE_KRAYT_BY_ENV[settings.DEPLOY_ENV]
TOPIC = settings.FIREBASE_TOPIC_BY_ENV[settings.DEPLOY_ENV]
[docs]
class AndroidMessagePriority(Enum):
NORMAL = 0
HIGH = 1
[docs]
class Visibility(Enum):
VISIBILITY_UNSPECIFIED = 0
PRIVATE = 1
PUBLIC = 2
SECRET = 3
[docs]
class NotificationPriority(Enum):
PRIORITY_UNSPECIFIED = 0
PRIORITY_MIN = 1
PRIORITY_LOW = 2
PRIORITY_DEFAULT = 3
PRIORITY_HIGH = 4
PRIORITY_MAX = 5
[docs]
@dataclass
class Notification(BaseDataMapper):
title: str = None
body: str = None
image: str = None
[docs]
@dataclass
class Color(BaseDataMapper):
red: float = None
green: float = None
blue: float = None
alpha: float = None
[docs]
@dataclass
class LightSettings(BaseDataMapper):
color: Color = None
lightOnDurationMillis: int = None
lightOffDurationMillis: int = None
[docs]
@dataclass
class AndroidNotification(BaseDataMapper):
title: str = None
body: str = None
icon: str = None
color: str = None
sound: str = None
tag: str = None
clickAction: str = None
bodyLocKey: str = None
bodyLocArgs: list[str] = field(default_factory=list)
titleLocKey: str = None
titleLocArgs: list[str] = field(default_factory=list)
channelId: str = None
ticker: str = None
sticky: bool = None
eventTime: str = None
localOnly: bool = None
notificationPriority: NotificationPriority = None
defaultSound: bool = None
defaultVibrateTimings: bool = None
defaultLightSettings: bool = None
vibrateTimings: list[str] = field(default_factory=list)
visibility: Visibility = None
notificationCount: int = None
lightSettings: LightSettings = None
image: str = None
[docs]
@dataclass
class AndroidFCMOptions(BaseDataMapper):
analyticsLabel: str = None
[docs]
@dataclass
class AndroidConfig(BaseDataMapper):
collapseKey: str = None
priority: AndroidMessagePriority = None
ttl: str = None
restrictedPackageName: str = None
data: dict[str, str] = field(default_factory=dict)
notification: AndroidNotification = None
fcmOptions: AndroidFCMOptions = None
directBootOkay: bool = None
[docs]
@dataclass
class WebpushNotificationAction(BaseDataMapper):
action: str = None
title: str = None
icon: str = None
[docs]
@dataclass
class WebpushNotification(BaseDataMapper):
actions: list[WebpushNotificationAction] = field(default_factory=list)
title: str = None
body: str = None
icon: str = None
badge: str = None
direction: str = None
data: str = None
image: str = None
language: str = None
renotify: bool = None
requireInteraction: bool = None
silent: bool = None
tag: str = None
timestampMillis: int = None
vibrate: list[int] = field(default_factory=list)
customData: dict[str, str] = field(default_factory=dict)
[docs]
@dataclass
class WebpushFCMOptions(BaseDataMapper):
link: str = None
[docs]
@dataclass
class WebpushConfig(BaseDataMapper):
headers: dict[str, str] = field(default_factory=dict)
data: dict[str, str] = field(default_factory=dict)
notification: WebpushNotification = None
fcmOptions: WebpushFCMOptions = None
[docs]
@dataclass
class ApsAlert(BaseDataMapper):
title: str = None
subTitle: str = None
body: str = None
locKey: str = None
locArgs: list[str] = field(default_factory=list)
titleLoc: str = None
titleLocArgs: list[str] = field(default_factory=list)
subTitleLoc: str = None
subTitleLocArgs: list[str] = field(default_factory=list)
actionLocKey: str = None
launchImage: str = None
[docs]
@dataclass
class CriticalSound(BaseDataMapper):
critical: bool = None
name: str = None
volume: float = None
[docs]
@dataclass
class Aps(BaseDataMapper):
alertString: str = None
alert: ApsAlert = None
badge: int = None
sound: str = None
criticalSound: CriticalSound = None
contentAvailable: bool = None
mutableContent: bool = None
category: str = None
threadId: str = None
customData: dict[str, str] = field(default_factory=dict)
[docs]
@dataclass
class APNSPayload(BaseDataMapper):
aps: Aps = None
customData: dict[str, str] = field(default_factory=dict)
authenticationCode: str = None
[docs]
@dataclass
class APNSFCMOptions(BaseDataMapper):
analyticsLabel: str = None
imageUrl: str = None
[docs]
@dataclass
class ApnsConfig(BaseDataMapper):
headers: dict[str, str] = field(default_factory=dict)
payload: APNSPayload = None
fcmOptions: APNSFCMOptions = None
[docs]
@dataclass
class FcmOptions(BaseDataMapper):
analyticsLabel: str = None
[docs]
@dataclass
class Message(BaseDataMapper):
name: str = None
data: dict[str, str] = field(default_factory=dict)
notification: Notification = None
android: AndroidConfig = None
webpush: WebpushConfig = None
apns: ApnsConfig = None
fcmOptions: FcmOptions = None
[docs]
@dataclass
class SendRequest(BaseDataMapper):
validateOnly: bool = None
message: Message = None
[docs]
@dataclass
class NotificationRequest(BaseDataMapper):
id: int
create_date: str
country_id: str = "USA"
retry_count: int = 0
id_only: bool = False
[docs]
@dataclass
class PushRequestV2(BaseDataMapper):
notification: NotificationRequest
request: SendRequest
[docs]
@dataclass
class SendResponseV2:
success: bool = False
error_message: str = ""
data: dict = field(default_factory=dict)
[docs]
@dataclass
class PushNotificationParams:
subject: str
"""Note: the target address value is no longer in use. Passing in an empty string should suffice."""
target_address: str
target_user_id: int
device_nickname: str = ""
country_code: str = "USA"
envstack: str = ""
environment: Union[EnvLevel, str] = ""
[docs]
class PushNotification:
"""
Provides ability to send rich (enhanced) push notifications to end users through the Firebase Cloud Messaging (FCM) Service.
Push notifications are supported on HQ version 4.5.0.6065 and above.
For additional information on the payload, please visit the `Firebase API Documentation <https://firebase.google.com/docs/reference/fcm/rest/v1/projects.messages>`_.
Note, sending rich data such as images, sounds, deep linking, would require MSDK module integration. Please visit the
`MSDK Documentation <https://docs.q2developer.com/native/index.html>`_ for more details.
"""
MIN_HQ_BUILD = "6065"
def __init__(
self,
logger: Q2LoggerType,
hq_credentials,
params: PushNotificationParams,
cache: Q2CacheClient = None,
):
self.logger = logger
self.hq_credentials = hq_credentials
self.country_code = params.country_code
self.subject = params.subject
self.target_address = params.target_address
self.target_user_id = params.target_user_id
self.device_nickname = params.device_nickname
self.cache = cache or self._get_cache()
self.envstack = params.envstack or self.hq_credentials.env_stack
self.environment = params.environment or settings.ENV_LEVEL_DEPLOY_ENV
if isinstance(self.environment, str):
self.environment = EnvLevel.from_string(self.environment)
[docs]
async def send_by_device_nickname(
self, send_request: Union[SendRequest, dict]
) -> SendResponseV2:
"""
Send notification by device nickname.
Converts provided request into the q2msg PushRequestV2 protobuf message.
Encrypts the message and sends the request.
"""
if not self.device_nickname:
return SendResponseV2(error_message="Device nickname not provided")
response = await self._handle_send(send_request)
return response[0]
[docs]
async def send_by_id(
self, send_request: Union[SendRequest, dict]
) -> list[SendResponseV2]:
"""
Send notification by user ID.
Will send notification to all devices tied to the user.
Converts provided request into the q2msg PushRequestV2 protobuf message.
Encrypts the message and sends the request.
"""
if not self.target_user_id:
return SendResponseV2(error_message="Target user ID not provided")
return await self._handle_send(send_request)
[docs]
async def send(self, send_request: Union[SendRequest, dict]) -> SendResponseV2:
"""
This function has been deprecated. Use `send_by_device_nickname`
(or `send_by_id` depending on your needs) instead.
"""
self.logger.warning(
"The push notification object's send method has been deprecated. Use `send_by_device_nickname` instead"
)
return await self.send_by_device_nickname(send_request)
async def _handle_send(
self, send_request: Union[SendRequest, dict]
) -> SendResponseV2:
# get all the notification details such as the gcm tokens and notification type/status information
notification_details = await self._get_notification_details()
if not notification_details.success:
return [notification_details]
# populate PushRequestV2 dataclass with send_request input for protobuf message conversion
if isinstance(send_request, dict):
send_request = SendRequest.from_dict(send_request)
# Send the request
return await self._handle_notification_call(
notification_details,
send_request,
)
async def _get_notification_details(self) -> SendResponseV2:
# verify stack meets minimum required HQ version & fetch envstack from holocron if necessary
required_details = await self._fetch_stack_details()
if not required_details.success:
return required_details
# fetch device/gcm token either by device nickname, or if not provided
# fetch all valid tokens for the user
get_token_response = await self._get_gcm_tokens()
if not get_token_response.success:
self.logger.error(get_token_response.error_message)
return get_token_response
# get the notification type and status
get_notification_details = await self._get_notification_type_and_status()
if not get_notification_details.success:
return get_notification_details
get_notification_details.data |= get_token_response.data
return get_notification_details
async def _handle_notification_call(
self, notification_details, send_request
) -> SendResponseV2:
parallel_calls = []
for token in notification_details.data["token_info"]:
task = asyncio.create_task(
self._send_notification(
notification_details.data["NotificationTypeID"],
notification_details.data["NotificationStatusID"],
token,
send_request,
)
)
parallel_calls.append(task)
await asyncio.gather(*parallel_calls)
responses = []
for task in parallel_calls:
responses.append(task.result())
return responses
async def _send_notification(
self, notification_type_id, notification_status_id, token, send_request
) -> SendResponseV2:
# generate a notification row for the NDS to update. One will be generated per token.
timestamp_obj = Timestamp()
current_date = self._get_current_date(timestamp_obj)
notification_id = await self._generate_notification_id(
current_date, notification_type_id, notification_status_id, token
)
push_request_v2 = PushRequestV2(
notification=NotificationRequest(
id=notification_id,
country_id=self.country_code,
create_date=timestamp_obj.ToJsonString(),
),
request=send_request,
)
# Serialize dataclass to protobuf message
message_result = self._serialize_to_protobuf(push_request_v2, token)
if not message_result.success:
return message_result
# Encrypt and send request
response = await self._send_message(
message_result.data["message"], notification_id
)
return response
async def _fetch_stack_details(self) -> SendResponseV2:
hq_version_response = await self._get_hq_info()
if not hq_version_response.success:
return hq_version_response
if not self.envstack:
envstack_result = await self._get_envstack()
if not envstack_result.success:
self.logger.error(envstack_result.error_message)
return envstack_result
else:
self.logger.info("Using provided envstack")
return SendResponseV2(success=True)
def _serialize_to_protobuf(
self, request: PushRequestV2, gcm_token: str
) -> SendResponseV2:
"""
Serializes push request object into a protobuf message.
"""
# Convert push request object dictonary to a protobuf message
try:
push_request_dict = asdict(request)
push_request_dict["request"]["message"]["token"] = gcm_token
message = ParseDict(push_request_dict, PushRequest())
self.logger.info(
f"Successfully converted push request to a protobuf message: {message}"
)
return SendResponseV2(success=True, data={"message": message})
except (TypeError, ParseError) as err:
msg = f"Failed to convert push request to q2msg protobuf message {err}"
self.logger.exception(msg)
return SendResponseV2(error_message=msg)
async def _send_message(
self, message: PushRequest, notification_id
) -> SendResponseV2:
"""
Sends the protobuf message onto the message bus.
"""
try:
self.logger.info("Pushing rich push message onto bus...")
response = await message_bus.push(
self.logger,
message=message,
envstack=self.envstack,
message_type=REQUEST_PROTO_MSG_TYPE,
topic=TOPIC,
krayt_url=KRAYT_URL,
)
msg_id = response.json()["msgId"]
self.logger.info(
f"Push notification sent. MsgId: {msg_id} NotificationId: {notification_id}"
)
return SendResponseV2(
success=True, data={"NotificationID": notification_id}
)
except (requests.exceptions.RequestException, AttributeError):
err_msg = "Failed to send notification"
self.logger.exception(err_msg)
return SendResponseV2(error_message=err_msg)
async def _get_hq_info(self):
"""
Attempts to retrieve cached HQ version/build and envstack if available,
Otherwise, an call is made to retrieve this information and cached by HQ URL.
"""
response = SendResponseV2()
version_key = f"{self.hq_credentials.customer_key}_HQVersion"
envstack_key = f"{self.hq_credentials.customer_key}_EnvStack"
cached_hq = await self.cache.get_many_async([version_key, envstack_key]) or {}
version = cached_hq.get(version_key)
envstack = cached_hq.get(envstack_key)
if not version:
self.logger.info("HQ version not found in cache. Calling HQ...")
params_obj = GetVersions.ParamsObj(
self.logger, hq_credentials=self.hq_credentials
)
result = await GetVersions.execute(params_obj)
try:
for node in result.result_node.Data.FiList.Detail:
if node.Name.text == "HqVersion":
version = node.Value.text
hq_build = version.split(".")[-1]
if hq_build < self.MIN_HQ_BUILD:
msg = f"HQ Version {version} does not meet minimum required version 4.5.0.6065"
self.logger.error(msg)
response.error_message = msg
return response
if node.Name.text == "EnvStack":
self.logger.info("Envstack found in HQ.")
envstack = node.Value.text
if not version:
msg = "HQ version not found"
self.logger.error(msg)
response.error_message = msg
return response
except (AttributeError, ValueError) as err:
msg = f"HQ GetVersions Failure: {err}"
self.logger.exception(msg)
response.error_message = msg
return response
await self.cache.set_many_async(
{
version_key: version,
envstack_key: envstack,
},
expire=60 * 60 * 12,
)
else:
self.logger.info(f"HQ version found in the cache: {version}")
response.success = True
self.envstack = envstack if not self.envstack else self.envstack
return response
async def _get_gcm_tokens(self) -> SendResponseV2:
pnt_obj = PushNotificationTargets(self.logger, self.hq_credentials)
user_devices = await pnt_obj.get(self.target_user_id)
if self.device_nickname:
self.logger.info("Search by device nickname detected...")
targets = [
x for x in user_devices if x.Nickname.text == self.device_nickname
]
if not targets:
return SendResponseV2(
error_message="Could not find device token for provided user and device nickname"
)
sorted_targets = sorted(
targets,
key=lambda target: self._parse_datetime(target.Created.text),
reverse=True,
)
tokens = [sorted_targets[0].GcmToken.text]
else:
self.logger.info("Search by target user id detected...")
tokens = [x.GcmToken.text for x in user_devices if x.UseForAlerts]
if not tokens:
return SendResponseV2(
error_message="Could not find device token for provided target user"
)
self.logger.info(f"Gcm tokens found: {tokens}")
return SendResponseV2(success=True, data={"token_info": tokens})
def _parse_datetime(self, datetime_str):
return datetime.fromisoformat(datetime_str)
async def _generate_notification_id(
self, current_date: str, notification_type_id, notification_status_id, token
) -> int:
"""
Inserts a row into the Q2_Notifications table based on provided parameters.
"""
params = CreateNotificationParams(
NotificationSubject=self.subject,
NotificationTypeID=notification_type_id,
NotificationStatusID=notification_status_id,
UserID=str(self.target_user_id),
CreateDate=current_date,
TargetAddress=token,
)
notification_obj = NotificationDBObject(
self.logger, hq_credentials=self.hq_credentials
)
notification_result = await notification_obj.add(params)
self.logger.info(f"Generated NotificationID {notification_result.id}")
return notification_result.id
async def _get_notification_type_and_status(self) -> SendResponseV2:
"""
Fetches the notification type and status information
"""
parallel_calls = [
asyncio.create_task(
NotificationType(self.logger, self.hq_credentials).get_by_name("Push"),
name="type",
),
asyncio.create_task(
NotificationStatus(self.logger, self.hq_credentials).get_by_name("New"),
name="status",
),
]
self.logger.info(
f"Executing {len(parallel_calls)} calls to fetch notification type and status information"
)
await asyncio.gather(*parallel_calls)
for task in parallel_calls:
call_type = task.get_name()
self.logger.debug(f"fetching result for notification {call_type}")
result = task.result()
match call_type:
case "type":
try:
notification_type_id = result[0].NotificationTypeID.text
except AttributeError:
msg = "Failed to retrieve push notification type id"
self.logger.exception(msg)
return SendResponseV2(error_message=msg)
case "status":
try:
notification_status_id = result[0].NotificationStatusID.text
except AttributeError:
msg = "Failed to retrieve push notification status id"
self.logger.exception(msg)
return SendResponseV2(error_message=msg)
return SendResponseV2(
success=True,
data={
"NotificationTypeID": notification_type_id,
"NotificationStatusID": notification_status_id,
},
)
def _get_current_date(self, timestamp_obj: Timestamp) -> str:
timestamp_obj.GetCurrentTime()
timestamp = timestamp_obj.seconds + timestamp_obj.nanos / 1e9
current_datetime = datetime.fromtimestamp(timestamp, tz=timezone.utc)
current_date = datetime.strftime(current_datetime, "%Y-%m-%d %H:%M:%S")
return current_date
def _get_cache(self) -> Q2CacheClient:
return cache.get_cache(logger=self.logger)
async def _get_envstack(self):
"""
Fetches the envstack for message bus communication.
"""
self.logger.warning("Envstack not found in hq credentials...")
result = await get_envstack_by_cust_key(
self.logger, self.hq_credentials, self.environment, self.cache
)
if not result.success:
self.logger.exception(result.error_message)
else:
self.envstack = result.data["envstack"]
return result