Source code for q2_sdk.core.http_handlers.hq_handler

import asyncio
import base64
import os
from inspect import signature, Parameter
from collections import Counter
from functools import cached_property
import json
import traceback
from typing import List, Optional
import uuid
from xml.sax import saxutils

from lxml import etree, objectify

from q2_sdk.ardent.file_upload import File
from q2_sdk.core import exceptions
from q2_sdk.core.cache import Q2CacheClient, StorageLevel, CacheConfigError
from q2_sdk.core.http_handlers.base_handler import Q2BaseRequestHandler
from q2_sdk.hq.db.user import User
from q2_sdk.hq.models.sso_response import Q2SSOResponse
from q2_sdk.core.request_handler.templating import LocalAssetLoader
from q2_sdk.core.configuration import settings
from q2_sdk.hq import http
from q2_sdk.hq.db.wedge_address import WedgeAddress
from q2_sdk.hq.models.db_config.db_config_list import DbConfigList
from q2_sdk.hq.models.form_info import FormInfo
from q2_sdk.hq.models.hq_credentials import HqCredentials
from q2_sdk.hq.models.hq_commands import HqCommands
from q2_sdk.hq.models.online_session import OnlineSession
from q2_sdk.ui import forms
from q2_sdk.ui.script_appenders import StyleFixAppender
from q2_sdk.core.configuration import get_settings


[docs] class Q2HqRequestHandler(Q2BaseRequestHandler): """ RequestHandler meant to be used for requests incoming from HQ. - Parses apart XML posts into inspectable objects using lxml - Puts form fields into a dictionary - If WEDGE_ADDRESS_CONFIGS is set, ensures the entries are set in the database on installation of the extension. WEDGE_ADDRESS_CONFIGS is an instance of DbConfigList filled with DbConfig objects. - hq_response_attrs is a dict that will be appended as attributes to the Data node HQ is expecting """ WEDGE_ADDRESS_CONFIGS = DbConfigList([]) OPTIONAL_WEDGE_ADDRESS_CONFIGS = {} DEFAULT_ROUTE = "default" FRIENDLY_NAME = None PROPERTY_LONG_NAME = None USER_PROPERTY_DATA_ELEMENT = None def __init__(self, application, request, **kwargs): super().__init__(application, request, **kwargs) self._unparsed_routes = ["default/_get_upload_url"] self._form_fields = None self._request_as_obj = None self._is_xml = None self._is_json = None self._db_config = None self.request_name = "" self.request_name_elem: Optional[etree.Element] = None self.hq_response_attrs = {} self.hq_commands: Optional[HqCommands] = None self.active_route = None self.return_as_html = True self.online_session: Optional[OnlineSession] = None self.form_info: Optional[FormInfo] = None self.config_cache_key = self._calculate_cache_config_key() self.is_authenticated = None # this will append UUX style overrides to all returned HTML if true self.use_style_fixes = False # this will convert script and link tags to jinja imports if true self.expand_local_sources = True self._session_cache = None # Backwards compatibility for dict wedge_address_configs. Remove in 3.0 if isinstance(self.WEDGE_ADDRESS_CONFIGS, dict): self.WEDGE_ADDRESS_CONFIGS = DbConfigList.from_dict( self.WEDGE_ADDRESS_CONFIGS, self.OPTIONAL_WEDGE_ADDRESS_CONFIGS ) self.default_cache_level = StorageLevel.Stack def _get_db_config_from_request(self): _db_config = {} config_node = None zoned_config_node = None # If a wedge config was sent in this request, parse it and update what we have. if self.request_name_elem is not None: config_node = self.request_name_elem.findtext(".//Config") zoned_config_node = self.request_name_elem.findtext(".//ZoneConfig") if config_node: config_node = saxutils.unescape(config_node) _db_config.update(json.loads(config_node)) else: ardent_header_data = json.loads( self.request.headers.get("Wedgeconfig", "{}") ) if ardent_header_data: _db_config.update(ardent_header_data) if zoned_config_node: zoned_config_node = saxutils.unescape(zoned_config_node) _db_config.update(json.loads(zoned_config_node)) return _db_config def _calculate_cache_config_key(self): config_cache_key = f"{self.extension_name}_wedge_config" if customer_key := self.request.headers.get("customerKey", None): config_cache_key = f"{self.extension_name}_wedge_config_{customer_key}" elif wedge_json_data := self.request.headers.get("Wedgejsondata", None): if isinstance(wedge_json_data, str): wedge_json_data = json.loads(wedge_json_data) if customer_key := wedge_json_data.get("customerKey", None): config_cache_key = f"{self.extension_name}_wedge_config_{customer_key}" return config_cache_key
[docs] async def get_user_dob(self): """ Date of Birth is sometimes populated in the Q2 database, but rarely passed in the initial payload from HQ. Calling self.get_user_dob() will efficiently look it up if it exists. After that, self.online_user.dob will have it stored. """ if not self.online_user.dob: user = ( await User(self.logger, hq_credentials=self.hq_credentials).get( user_id=self.online_user.user_id ) )[0] dob = user.findtext("DOB", "") self.online_user.dob = dob if dob: self.online_user.as_demographic_info().dob = dob else: self.logger.debug("DOB is already populated, skipping DB call") dob = self.online_user.dob return dob
[docs] async def get_wedge_address_configs(self, force=False) -> dict: """ Gather wedge_address_configs either from incoming data or from database Has the side effect of seeding data for subsequent self.db_config calls :param force: If True, will query the database even if self.WEDGE_ADDRESS_CONFIGS is blank """ if not self._db_config: config_node = None _db_config = self._get_db_config_from_request() if force or ( not _db_config and len(self.WEDGE_ADDRESS_CONFIGS.db_configs) != 0 ): cache_config = self.service_cache.get(self.config_cache_key) if not cache_config: wedge_object = WedgeAddress(self.logger, self.hq_credentials) config_node = await wedge_object.get(short_name=self.extension_name) if config_node: _db_config.update(json.loads(config_node[0].Config.text)) self.service_cache.set(self.config_cache_key, _db_config, expire=60) else: self.logger.debug( "wedge address config found in cache with key: %s", self.config_cache_key, ) _db_config = cache_config override_dict = _db_config.get("_overrides", {}) customer_key = override_dict.get("customer_key") if customer_key: override_hq = self._get_hq_from_key(customer_key) self.hq_credentials = override_hq self._validate_db_configs(_db_config) self._db_config = _db_config self.logger.debug("WedgeAddressConfigs: %s", self._db_config) return self._db_config
def _validate_db_configs(self, db_config_list: DbConfigList): """ Raises exception if required WEDGE_ADDRESS_CONFIGS are not set """ bad_config = False configs = getattr(self, "DB_CONFIGS", self.WEDGE_ADDRESS_CONFIGS) # Sanity-check the config to make sure all expected keys are now set. for key in [x.name for x in configs.required]: if key not in db_config_list: bad_config = True self.logger.error("Q2_WedgeAddress.Config missing key: %s." % key) # If anything went wrong, fail. if bad_config: self.logger.error( "DbConfigs Misconfigured. Try running `q2 update_installed -e %s`" % (self.extension_name) ) raise exceptions.ConfigurationError("Q2_WedgeAddress.Config Misconfigured")
[docs] @cached_property def session_cache(self) -> Q2CacheClient: """Same as self.cache but limited to online user session""" return self.get_cache(prefix=self.online_session.session_id)
[docs] @cached_property def stack_cache(self) -> Q2CacheClient: """Q2CacheClient scoped to the current financial institution stack.""" if not self.hq_credentials.customer_key: self.logger.error( "Customer Key is unset, but cache requested at stack level. Refusing to cache." ) raise CacheConfigError return self.get_cache(prefix=self.hq_credentials.customer_key)
@property def wedge_address_configs(self) -> dict: """Alias to self.db_config""" return self.db_config @property def db_config(self) -> dict: """ Dictionary representation of data passed from the Q2_WedgeAddress table's Config column """ if self._db_config is None: _db_config = self._get_db_config_from_request() self._validate_db_configs(_db_config) self._db_config = _db_config return self._db_config def clean_html_tags(self, custom_response): if self._is_central_report(): return custom_response return custom_response.replace(b"<html>", b"").replace(b"</html>", b"") def _is_central_report(self): if self.request_name == "ReportParameterObject": return True return False @property def is_xml(self): if self._is_xml is None: content_type = self.request.headers.get("Content-Type", "Unknown") is_xml = True if "xml" not in content_type: try: self.request_as_obj except etree.XMLSyntaxError: is_xml = False self._is_xml = is_xml return self._is_xml @property def is_json(self): if self._is_json is None: self._is_json = "json" in self.request.headers.get( "Content-Type", "Unknown" ) return self._is_json
[docs] async def post(self, *args, **kwargs): """ Parses apart POST data based on content-type and format, creates a few handy objects, then farms out the work to the extension specific function defined in `q2_post` """ message = self.request.body self.logger.debug( "POST RequestBody: %s", message.decode("utf8"), add_to_buffer=False ) if self.is_xml: response = await self._handle_soap_request(*args, **kwargs) self.set_header("Content-Type", "text/xml") # XML In, XML Out self.set_cookie("Q2API-Compatibility", "universal") else: response = await self._handle_form_request(*args, **kwargs) if self.get_status() != 500: if settings.TEMPORARY_LOG_RESPONSE_ENABLE: self.logger.info("Response: %s", response) elif settings.LOG_RESPONSE_IN_DEBUG: self.logger.debug("Response: %s", response) self.write(response) self.finish() if self._is_xml and self._tag_name == "GetUnsignedToken": return await self.q2_on_finish()
async def _handle_form_request(self, *args, **kwargs): """ If the incoming request is not wrapped in a soap envelope (less common) don't try to massage the data first """ await self.get_wedge_address_configs() try: response = await self.q2_post(*args, **kwargs) if isinstance(response, forms.Q2Form) or ( self.is_json and (isinstance(response, Q2SSOResponse)) ): response = response.serialize() except Exception: # If there is an unknown exception, log it tagged with this # request's GUID self.logger.error(traceback.format_exc(), add_to_buffer=False) self.set_status(500) self.finish() return return response
[docs] async def q2_post(self, *args, **kwargs): """ Overridable method. Should be defined in the extension inheriting this class. Called from the housekeeping Q2 specific POST method after the incoming message is parsed and broken into objects """ route_response = await self.route_request() return route_response if route_response is not None else ""
[docs] async def q2_on_finish(self) -> None: """Called after the end of `q2_post`. This will not send data back to the end user, but is meant for background processing after the http connection has been closed. This will NOT be called if the response is a 500 """
@property def request_as_obj(self) -> objectify.ObjectifiedElement: if self._request_as_obj is None: element = objectify.fromstring(self.request.body) if element.tag == "{http://schemas.xmlsoap.org/soap/envelope/}Envelope": element = element.Body.getchildren()[0] self._request_as_obj = element return self._request_as_obj @request_as_obj.setter def request_as_obj(self, value) -> None: self._request_as_obj = value @property def _tag_name(self) -> str: return etree.QName(self.request_as_obj).localname @property def form_fields(self) -> dict: if self._form_fields is None: if self.is_xml: form_fields = self.request_as_obj.findall(".//FormFields") self._form_fields = self.parse_form_fields(form_fields) else: objectified_form_fields = [] for key, val in self.request.arguments.items(): if isinstance(val, (str, bytes)): val = [val] for item in val: if isinstance(item, bytes): item = item.decode() objectified_form_fields.append( objectify.E.root( objectify.E.Name(key), objectify.E.Value(item) ) ) self._form_fields = self.parse_form_fields(objectified_form_fields) return self._form_fields @form_fields.setter def form_fields(self, value): self._form_fields = value async def _handle_soap_request(self, *args, **kwargs): """ Handles the more common HQ requests wrapped in SOAP envelopes """ if self._tag_name == "GetUnsignedToken": response = http.get_unhandled_token_response(uuid.uuid4()) return response elif self._tag_name in ( "ExecuteRequestAsXml", "ExecuteRequestAsString", "LoginToAdapter", ): if self._tag_name == "ExecuteRequestAsXml": self.request_name = self.request_as_obj.xmlDoc.getchildren()[0].tag elif self._tag_name == "ExecuteRequestAsString": self.request_name = objectify.fromstring( self.request_as_obj.xml.text ).tag routing_key = self.form_fields.get("routing_key") self.request_name_elem = self.request_as_obj.find( ".//{}".format(self.request_name) ) await self.get_wedge_address_configs() if self.request_name_elem is None: self.logger.debug( "This does not appear to be an HQ request. Skipping the HQ-specific logic." ) elif routing_key not in self._unparsed_routes: await self._build_models_from_hq(self.request_as_obj) try: custom_response = await self.q2_post(*args, **kwargs) except Exception: # If there is an unknown exception, log it tagged with this # request's GUID self.logger.error(traceback.format_exc(), add_to_buffer=False) self.set_status(500) self.finish() return if isinstance(custom_response, forms.Q2Form): custom_response = custom_response.serialize() if self.return_as_html: if custom_response is not None and self.use_style_fixes: style_fix_loader = StyleFixAppender(custom_response) custom_response = style_fix_loader.append_style_fixes() asset_loader = LocalAssetLoader(self.templater, custom_response) if self.expand_local_sources: custom_response = asset_loader.expand_local_sources() else: custom_response = asset_loader.raw_html custom_response = self.clean_html_tags(custom_response) if custom_response is None: self.logger.warning( 'Empty response detected. Are you sure you called "return" in your method?' ) response = self.wrap_soap_response(custom_response) else: self.request_name = self.request_as_obj.tag response = await self.q2_post(*args, **kwargs) return response async def _build_models_from_hq(self, element): self.is_authenticated = self.request_name_elem.get( "IsAuthenticated" ) == "True" and self.request_name_elem.get("HqAuthToken") hq_auth_token = ( self.request_name_elem.get("HqAuthToken") if self.is_authenticated else None ) aba = self.request_name_elem.get("ABA", self.hq_credentials.aba) reported_hq_url = self.request_name_elem.get( "HqBaseUrl", self.hq_credentials.hq_url ) env_stack = self.request_name_elem.get( "EnvStack", self.hq_credentials.env_stack ) db_schema_name = self.request_name_elem.get( "DbSchemaName", self.hq_credentials.db_schema_name ) db_name = self.request_name_elem.get( "DbName", self.hq_credentials.database_name ) cust_key = self.hq_credentials.customer_key or self.request_name_elem.get( "InsightCustomerKey" ) if not cust_key and settings.LOCAL_DEV: cust_key = settings.VAULT_KEY if not settings.USE_INCOMING_HQ_URL: reported_hq_url = self.hq_credentials.hq_url configured_hq = HqCredentials( self.hq_credentials.hq_url, self.hq_credentials.csr_user, self.hq_credentials.csr_pwd, aba, ) if cust_key and not configured_hq.hq_url: # Get it from vault if you don't have it locally configured_hq = self._get_hq_from_key(cust_key) configured_hq.customer_key = cust_key configured_hq.database_name = db_name configured_hq.db_schema_name = db_schema_name configured_hq.env_stack = env_stack try: self.form_info = FormInfo(element) except AttributeError: self.logger.debug("No form info could be found in request") self.hq_credentials = configured_hq await self.get_wedge_address_configs() self.hq_credentials.auth_token = hq_auth_token self.hq_credentials.reported_hq_url = ( reported_hq_url if reported_hq_url else self.hq_credentials.hq_url ) debug_mode = get_settings().DEBUG if not self.hq_credentials.customer_key and debug_mode: self.hq_credentials.customer_key = "CHANGEME" try: self.online_session = OnlineSession(element) except AttributeError: self.logger.debug("No online session could be found in request") def _get_hq_from_key(self, customer_key) -> HqCredentials: hq_creds = None cache_key = f"hq_for_ck_{customer_key}" cached_creds = self.service_cache.get(cache_key) if cached_creds: hq_creds = HqCredentials( cached_creds["HqUrl"], cached_creds["CsrUser"], cached_creds["CsrPwd"], cached_creds["ABA"], customer_key=customer_key, database_name=cached_creds.get("DB_NAME"), db_server_name=cached_creds.get("DBServerName"), ) self.logger.debug("Getting HQ from cache for key: %s", customer_key) else: hq_creds = self.vault.get_hq_creds(customer_key) self.service_cache.set(cache_key, hq_creds.serialize_as_dict(), expire=60) return hq_creds @staticmethod def parse_form_fields(form_fields: List[objectify.ObjectifiedElement]) -> dict: parsed_form_fields = {} all_keys = Counter([x.Name.text for x in form_fields]) for field in form_fields: name = field.Name.text.replace("%5B%5D", "[]") value = field.Value.text if field.Value.text else "" if all_keys[name] > 1 or "[]" in name: name = name.rstrip("[]") parsed_form_fields.setdefault(name, []) parsed_form_fields[name].extend(value.split(",")) else: parsed_form_fields[name] = value return parsed_form_fields
[docs] def wrap_soap_response(self, custom_response): """ If the incoming request is wrapped in a soap envelope, this prepares it for return back to HQ. There are two formats we know about, but since it is possible for HQ instances to be customized and expect a different format, this allows for overriding the response at an extension level. :param custom_response: Response to wrap in an appropriate soap envelope """ if self._is_central_report(): return http.get_backoffice_report_response( self.request_name, custom_response ) wrap_mapping = { "ExecuteRequestAsXml": http.get_execute_request_as_xml_response, "ExecuteRequestAsString": http.get_execute_request_as_string_response, } response = wrap_mapping[self._tag_name]( self.request_name, custom_response, self.hq_response_attrs, self.hq_commands ) return response
async def route_request(self, routing_key: Optional[str] = None): if not routing_key: routing_key = self.form_fields.get("routing_key") if not routing_key: routing_key = self.DEFAULT_ROUTE if isinstance(routing_key, list): if len(routing_key) > 1: self.logger.error("Multiple routing keys present: '%s'.", routing_key) routing_key = self.DEFAULT_ROUTE else: routing_key = routing_key[0].decode("utf-8") if routing_key not in self.router: self.logger.error( "Routing key '%s' not present in router dictionary.", routing_key ) routing_key = self.DEFAULT_ROUTE route = self.router[routing_key] self.logger.info("Transition: routing to '%s'", routing_key) self.active_route = routing_key return await self.call_route(route) async def call_route(self, route): is_async_callable_class = callable(route) and asyncio.iscoroutinefunction( route.__call__ ) is_async_callable_function = asyncio.iscoroutinefunction(route) arg_types = {"positional": [], "keyword": {}} form_fields = self.form_fields.copy() params = signature(route).parameters var_pos = False var_keyword = False positional_only_args = [ x for x in params.values() if x.kind == Parameter.POSITIONAL_ONLY ] for name, param in params.items(): if name in form_fields: if param.kind in (Parameter.POSITIONAL_ONLY, Parameter.VAR_POSITIONAL): arg_types["positional"].append(form_fields.pop(name)) else: arg_types["keyword"][name] = form_fields.pop(name) if param.kind == Parameter.VAR_POSITIONAL: var_pos = True if param.kind == Parameter.VAR_KEYWORD: var_keyword = True if len(arg_types["positional"]) != len(positional_only_args): raise TypeError("Not enough positional arguments provided") if var_pos: for value in form_fields.values(): arg_types["positional"].append(value) elif var_keyword: for key, value in form_fields.items(): arg_types["keyword"][key] = value if is_async_callable_class or is_async_callable_function: return await route(*arg_types["positional"], **arg_types["keyword"]) return route(*arg_types["positional"], **arg_types["keyword"]) @property def router(self): return {"default": self.default, "default/get_upload_url": self._get_upload_url} async def default(self): pass async def _get_upload_url(self): data_json = base64.b64decode(self.form_fields.get("data")) data = json.loads(data_json) self.logger.debug(data) aws_files = [] for file_name in data["files"]: aws_file = await File(self.logger).get_upload_url(name=file_name) aws_files.append(aws_file.raw) return json.dumps(aws_files)
[docs] async def download_file(self, input_name: str) -> bytes: """Downloads file_key from AWS using the ardent.file_upload.File class""" file_keys = self.form_fields.get("q2_uploadedFileKeys") if isinstance(file_keys, str): file_keys = file_keys.split(",") for key in file_keys: if key.startswith(f"{input_name}|"): file_key = key.split("|", maxsplit=1)[1] return await File(self.logger).download(file_key) raise KeyError(f"No file input POSTed with name: {input_name}")
[docs] def set_hq_commands(self, hq_commands: HqCommands): """ HQ has the ability to reload accounts from the DB/Core or skip a disclaimer for the remainder of a session. To do so, call this method with a valid HqCommands instance. Note: This does NOT immediately call HQ to do the operation, but rather will send the command to HQ along with the response shape from your route. For instance: .. code:: async def default(self): #do stuff account_id = 12345 hq_commands_obj = HqCommands( HqCommands.AccountReloadType.FROM_HOST, [account_id], ) self.set_hq_commands(hq_commands_obj) # Nothing happens here html = self.get_template('template_name.html', {}) return html # This is where the magic happens """ self.hq_commands = hq_commands