from dataclasses import asdict, dataclass
from enum import Enum
from q2_sdk.core.exceptions import DatabaseDataError
from q2_sdk.hq.api_helpers import build_policy_data_from_hq_response
from q2_sdk.hq.db.sso_user_logon import SsoUserLogon
from q2_sdk.hq.db.user_logon import UserLogon
from q2_sdk.hq.hq_api.q2_api import (
GetPolicyDataByPolicyIdentifier,
SetPolicyDataByPolicyIdentifier,
)
from q2_sdk.hq.models.policy_data import (
Account,
Data,
Features,
GeneratedTransactionRights,
Subsidiaries,
)
from q2_sdk.hq.models.policy_data import (
PolicyData as PolicyDataModel,
)
from .db_object import DbObject
[docs]
class Entity(Enum):
User = "U"
Customer = "C"
Group = "G"
UserRole = "R"
Company = "O"
SSOIdentifier = "S"
[docs]
@dataclass
class PolicyDataValues:
entity: Entity
entity_id: int
policy_data: list[
GeneratedTransactionRights | Subsidiaries | Account | Features
] = None
def __post_init__(self):
if not self.policy_data:
self.policy_data = []
[docs]
def construct_set_policy_shape(self):
sorted_policies = {}
for policy in self.policy_data:
if not policy.PolicyIdentifier:
policy.PolicyIdentifier = f"{self.entity.value}-{self.entity_id}"
key = None
match policy:
case GeneratedTransactionRights():
key = "Q2_PolicyGeneratedTransactionRights"
case Subsidiaries():
key = "Q2_PolicySubsidiaries"
case Account():
key = "Q2_PolicyAccounts"
case Features():
key = "Q2_PolicyFeatures"
case Data():
key = "Q2_PolicyData"
case _:
raise ValueError(
"Unsupported policy type passed in. Policy must be a GeneratedTransactionRights,"
" Q2_PolicySubsidiaries, Q2_PolicyAccounts, Q2_PolicyFeatures"
)
policy_details = asdict(
policy, dict_factory=lambda policy: {k: v for (k, v) in policy}
)
if key not in sorted_policies.keys():
column_fields = list(policy_details.keys())
sorted_policies[key] = {"columns": column_fields, "rows": []}
policy_row_values = {"U": list(policy_details.values())}
sorted_policies[key]["rows"].append(policy_row_values)
return sorted_policies
[docs]
class PolicyData(DbObject):
[docs]
async def get(
self, entity: Entity, entity_id: int, include_parent_policy_data: bool = False
) -> PolicyDataModel:
"""
A helper function for getting policy data for an entity. Example function calls::
await self.db.policy_data.get(Entity.User, self.online_user.user_id)
await self.db.policy_data.get(Entity.Customer, self.online_user.customer_id)
await self.db.policy_data.get(Entity.Group, self.online_user.group_id)
await self.db.policy_data.get(Entity.UserRole, user[0].UserRoleID, include_parent_policy_data=True)
await self.db.policy_data.get(Entity.SSOIdentifier, sso_user_logon[0].SSOUserLogonID, include_parent_policy_data=True)
:param entity: Either User, Customer, Group, UserRole, Company of the Entity enum
:param entity_id: The id of the entity. If user, provide the user id. If group, provide the group id
:param include_parent_policy_data: If true, then all policies in the user hierarchy will also be provided
:return: PolicyData object
"""
if entity == Entity.SSOIdentifier:
entity, entity_id = await self.get_user_details_by_sso_id(entity, entity_id)
entity_id = int(entity_id)
policy_id = f"{entity.value}-{entity_id}"
parameters = GetPolicyDataByPolicyIdentifier.ParamsObj(
self.logger,
policy_id,
include_parent_policy_data,
hq_credentials=self.hq_credentials,
)
hq_response = await GetPolicyDataByPolicyIdentifier.execute(parameters)
if not hq_response.success:
self.logger.error("HQ call failed: %s", hq_response.error_message)
return PolicyDataModel()
return build_policy_data_from_hq_response(hq_response.result_node)
[docs]
async def set(self, policy_data: PolicyDataValues) -> str:
"""
A helper function for setting policy data for an entity. Example function calls::
await self.db.policy_data.set(Entity.User, self.online_user.user_id, policy_data)
await self.db.policy_data.set(Entity.Customer, self.online_user.customer_id, policy_data)
await self.db.policy_data.set(Entity.Group, self.online_user.group_id, policy_data)
await self.db.policy_data.set(Entity.UserRole, user[0].UserRoleID, policy_data)
await self.db.policy_data.set(Entity.SSOIdentifier, sso_user_logon[0].SSOUserLogonID, user policy_data)
:param entity: Either User, Customer, Group, UserRole, Company, or SSO Identifier of the Entity enum.
If SSO Identifier provided, ensure to specify the policy details for the user
:param entity_id: The id of the entity. If user, provide the user id. If group, provide the group id
:param policy_data: The list of policy information you are setting. These must be of GeneratedTransactionRights,
Subsidiaries, Account, Features type
:return: Success or failure message based on hq response
"""
entity_id = None
if policy_data.entity == Entity.SSOIdentifier:
entity, entity_id = await self.get_user_details_by_sso_id(
policy_data.entity, policy_data.entity_id
)
policy_data.entity = entity
policy_data.entity_id = int(entity_id) if entity_id else policy_data.entity_id
policy_data_shape = policy_data.construct_set_policy_shape()
params_obj = SetPolicyDataByPolicyIdentifier.ParamsObj(
logger=self.logger,
policy_identifier=f"{policy_data.entity.value}-{policy_data.entity_id}",
policy_data=policy_data_shape,
hq_credentials=self.hq_credentials,
)
hq_response = await SetPolicyDataByPolicyIdentifier.execute(
params_obj, use_json=True
)
if not hq_response.success:
self.logger.error("HQ call failed: %s", hq_response.error_message)
return f"Failure: {hq_response.error_message}"
return "Success"
[docs]
async def get_user_details_by_sso_id(
self, entity: Entity, entity_id: int
) -> tuple[Entity, int]:
sso_identifier_obj = SsoUserLogon(
self.logger, hq_credentials=self.hq_credentials
)
sso_identifier = await sso_identifier_obj.get(entity_id)
if not sso_identifier:
raise DatabaseDataError("SSO identifier not found")
user_logon_id = int(sso_identifier[0].UserLogonID.text)
self.logger.debug("SSO Identifier user logon id found")
user_logon_obj = UserLogon(self.logger, hq_credentials=self.hq_credentials)
user_logon = await user_logon_obj.get_login_by_logon_id(
user_logonid=user_logon_id
)
if not user_logon:
raise DatabaseDataError("No user found under provided SSO identifier")
user_id = int(user_logon[0].UserID.text)
self.logger.debug(f"User ID found from sso identifier: {sso_identifier}")
entity = entity.User
entity_id = int(user_id)
return entity, entity_id