import asyncio
import base64
import os
from inspect import signature, Parameter
from collections import Counter
from functools import cached_property
import json
import traceback
from typing import List, Optional
import uuid
from xml.sax import saxutils
from lxml import etree, objectify
from q2_sdk.ardent.file_upload import File
from q2_sdk.core import exceptions
from q2_sdk.core.cache import Q2CacheClient, StorageLevel, CacheConfigError
from q2_sdk.core.http_handlers.base_handler import Q2BaseRequestHandler
from q2_sdk.hq.db.user import User
from q2_sdk.hq.models.sso_response import Q2SSOResponse
from q2_sdk.core.request_handler.templating import LocalAssetLoader
from q2_sdk.core.configuration import settings
from q2_sdk.hq import http
from q2_sdk.hq.db.wedge_address import WedgeAddress
from q2_sdk.hq.models.db_config.db_config_list import DbConfigList
from q2_sdk.hq.models.form_info import FormInfo
from q2_sdk.hq.models.hq_credentials import HqCredentials
from q2_sdk.hq.models.hq_commands import HqCommands
from q2_sdk.hq.models.online_session import OnlineSession
from q2_sdk.ui import forms
from q2_sdk.ui.script_appenders import StyleFixAppender
from q2_sdk.core.configuration import get_settings
[docs]
class Q2HqRequestHandler(Q2BaseRequestHandler):
"""
RequestHandler meant to be used for requests incoming from HQ.
- Parses apart XML posts into inspectable objects using lxml
- Puts form fields into a dictionary
- If WEDGE_ADDRESS_CONFIGS is set, ensures the entries are set in the
database on installation of the extension.
WEDGE_ADDRESS_CONFIGS is an instance of DbConfigList filled with
DbConfig objects.
- hq_response_attrs is a dict that will be appended as attributes to the Data node HQ is expecting
"""
WEDGE_ADDRESS_CONFIGS = DbConfigList([])
OPTIONAL_WEDGE_ADDRESS_CONFIGS = {}
DEFAULT_ROUTE = "default"
FRIENDLY_NAME = None
PROPERTY_LONG_NAME = None
USER_PROPERTY_DATA_ELEMENT = None
def __init__(self, application, request, **kwargs):
super().__init__(application, request, **kwargs)
self._unparsed_routes = ["default/_get_upload_url"]
self._form_fields = None
self._request_as_obj = None
self._is_xml = None
self._is_json = None
self._db_config = None
self.request_name = ""
self.request_name_elem: Optional[etree.Element] = None
self.hq_response_attrs = {}
self.hq_commands: Optional[HqCommands] = None
self.active_route = None
self.return_as_html = True
self.online_session: Optional[OnlineSession] = None
self.form_info: Optional[FormInfo] = None
self.config_cache_key = self._calculate_cache_config_key()
self.is_authenticated = None
# this will append UUX style overrides to all returned HTML if true
self.use_style_fixes = False
# this will convert script and link tags to jinja imports if true
self.expand_local_sources = True
self._session_cache = None
# Backwards compatibility for dict wedge_address_configs. Remove in 3.0
if isinstance(self.WEDGE_ADDRESS_CONFIGS, dict):
self.WEDGE_ADDRESS_CONFIGS = DbConfigList.from_dict(
self.WEDGE_ADDRESS_CONFIGS, self.OPTIONAL_WEDGE_ADDRESS_CONFIGS
)
self.default_cache_level = StorageLevel.Stack
def _get_db_config_from_request(self):
_db_config = {}
config_node = None
zoned_config_node = None
# If a wedge config was sent in this request, parse it and update what we have.
if self.request_name_elem is not None:
config_node = self.request_name_elem.findtext(".//Config")
zoned_config_node = self.request_name_elem.findtext(".//ZoneConfig")
if config_node:
config_node = saxutils.unescape(config_node)
_db_config.update(json.loads(config_node))
else:
ardent_header_data = json.loads(
self.request.headers.get("Wedgeconfig", "{}")
)
if ardent_header_data:
_db_config.update(ardent_header_data)
if zoned_config_node:
zoned_config_node = saxutils.unescape(zoned_config_node)
_db_config.update(json.loads(zoned_config_node))
return _db_config
def _calculate_cache_config_key(self):
config_cache_key = f"{self.extension_name}_wedge_config"
if customer_key := self.request.headers.get("customerKey", None):
config_cache_key = f"{self.extension_name}_wedge_config_{customer_key}"
elif wedge_json_data := self.request.headers.get("Wedgejsondata", None):
if isinstance(wedge_json_data, str):
wedge_json_data = json.loads(wedge_json_data)
if customer_key := wedge_json_data.get("customerKey", None):
config_cache_key = f"{self.extension_name}_wedge_config_{customer_key}"
return config_cache_key
[docs]
async def get_user_dob(self):
"""
Date of Birth is sometimes populated in the Q2 database, but rarely passed in the initial payload from HQ.
Calling self.get_user_dob() will efficiently look it up if it exists.
After that, self.online_user.dob will have it stored.
"""
if not self.online_user.dob:
user = (
await User(self.logger, hq_credentials=self.hq_credentials).get(
user_id=self.online_user.user_id
)
)[0]
dob = user.findtext("DOB", "")
self.online_user.dob = dob
if dob:
self.online_user.as_demographic_info().dob = dob
else:
self.logger.debug("DOB is already populated, skipping DB call")
dob = self.online_user.dob
return dob
[docs]
async def get_wedge_address_configs(self, force=False) -> dict:
"""
Gather wedge_address_configs either from incoming data or from database
Has the side effect of seeding data for subsequent self.db_config calls
:param force: If True, will query the database even if self.WEDGE_ADDRESS_CONFIGS is blank
"""
if not self._db_config:
config_node = None
_db_config = self._get_db_config_from_request()
if force or (
not _db_config and len(self.WEDGE_ADDRESS_CONFIGS.db_configs) != 0
):
cache_config = self.service_cache.get(self.config_cache_key)
if not cache_config:
wedge_object = WedgeAddress(self.logger, self.hq_credentials)
config_node = await wedge_object.get(short_name=self.extension_name)
if config_node:
_db_config.update(json.loads(config_node[0].Config.text))
self.service_cache.set(self.config_cache_key, _db_config, expire=60)
else:
self.logger.debug(
"wedge address config found in cache with key: %s",
self.config_cache_key,
)
_db_config = cache_config
override_dict = _db_config.get("_overrides", {})
customer_key = override_dict.get("customer_key")
if customer_key:
override_hq = self._get_hq_from_key(customer_key)
self.hq_credentials = override_hq
self._validate_db_configs(_db_config)
self._db_config = _db_config
self.logger.debug("WedgeAddressConfigs: %s", self._db_config)
return self._db_config
def _validate_db_configs(self, db_config_list: DbConfigList):
"""
Raises exception if required WEDGE_ADDRESS_CONFIGS are not set
"""
bad_config = False
configs = getattr(self, "DB_CONFIGS", self.WEDGE_ADDRESS_CONFIGS)
# Sanity-check the config to make sure all expected keys are now set.
for key in [x.name for x in configs.required]:
if key not in db_config_list:
bad_config = True
self.logger.error("Q2_WedgeAddress.Config missing key: %s." % key)
# If anything went wrong, fail.
if bad_config:
self.logger.error(
"DbConfigs Misconfigured. Try running `q2 update_installed -e %s`"
% (self.extension_name)
)
raise exceptions.ConfigurationError("Q2_WedgeAddress.Config Misconfigured")
[docs]
@staticmethod
def get_dc_cookie():
"""
Q2's production environments run in two data centers. New sessions (in production) are distributed evenly
across those two data centers. In order to make sure the Ardent API calls return to the same data center that
this token was generated in, we need to pass a cookie on the Ardent API requests. To do that we will inspect
an environment variable and generate the cookie string accordingly.
"""
environment_name = os.environ.get("NOMAD_DC")
cookie = "AA_DC=00"
if environment_name is not None:
if "austin" in environment_name:
cookie = "AA_DC=01"
elif "dallas" in environment_name:
cookie = "AA_DC=02"
elif "aws-use1-" in environment_name:
cookie = "AA_DC=03"
elif "aws-usw2-" in environment_name:
cookie = "AA_DC=04"
elif "aws-hou-" in environment_name:
cookie = "AA_DC=05"
elif "aws-den-" in environment_name:
cookie = "AA_DC=06"
return {"cookie": cookie}
[docs]
@cached_property
def session_cache(self) -> Q2CacheClient:
"""Same as self.cache but limited to online user session"""
return self.get_cache(prefix=self.online_session.session_id)
[docs]
@cached_property
def stack_cache(self) -> Q2CacheClient:
"""Q2CacheClient scoped to the current financial institution stack."""
if not self.hq_credentials.customer_key:
self.logger.error(
"Customer Key is unset, but cache requested at stack level. Refusing to cache."
)
raise CacheConfigError
return self.get_cache(prefix=self.hq_credentials.customer_key)
@property
def wedge_address_configs(self) -> dict:
"""Alias to self.db_config"""
return self.db_config
@property
def db_config(self) -> dict:
"""
Dictionary representation of data passed from the Q2_WedgeAddress table's Config column
"""
if self._db_config is None:
_db_config = self._get_db_config_from_request()
self._validate_db_configs(_db_config)
self._db_config = _db_config
return self._db_config
def clean_html_tags(self, custom_response):
if self._is_central_report():
return custom_response
return custom_response.replace(b"<html>", b"").replace(b"</html>", b"")
def _is_central_report(self):
if self.request_name == "ReportParameterObject":
return True
return False
@property
def is_xml(self):
if self._is_xml is None:
content_type = self.request.headers.get("Content-Type", "Unknown")
is_xml = True
if "xml" not in content_type:
try:
self.request_as_obj
except etree.XMLSyntaxError:
is_xml = False
self._is_xml = is_xml
return self._is_xml
@property
def is_json(self):
if self._is_json is None:
self._is_json = "json" in self.request.headers.get(
"Content-Type", "Unknown"
)
return self._is_json
[docs]
async def post(self, *args, **kwargs):
"""
Parses apart POST data based on content-type and format,
creates a few handy objects, then farms out the work to the extension
specific function defined in `q2_post`
"""
message = self.request.body
self.logger.debug(
"POST RequestBody: %s", message.decode("utf8"), add_to_buffer=False
)
if self.is_xml:
response = await self._handle_soap_request(*args, **kwargs)
self.set_header("Content-Type", "text/xml") # XML In, XML Out
self.set_cookie("Q2API-Compatibility", "universal")
else:
response = await self._handle_form_request(*args, **kwargs)
if self.get_status() != 500:
if settings.TEMPORARY_LOG_RESPONSE_ENABLE:
self.logger.info("Response: %s", response)
elif settings.LOG_RESPONSE_IN_DEBUG:
self.logger.debug("Response: %s", response)
self.write(response)
self.finish()
if self._is_xml and self._tag_name == "GetUnsignedToken":
return
await self.q2_on_finish()
async def _handle_form_request(self, *args, **kwargs):
"""
If the incoming request is not wrapped in a soap envelope (less common)
don't try to massage the data first
"""
await self.get_wedge_address_configs()
try:
response = await self.q2_post(*args, **kwargs)
if isinstance(response, forms.Q2Form) or (
self.is_json and (isinstance(response, Q2SSOResponse))
):
response = response.serialize()
except Exception:
# If there is an unknown exception, log it tagged with this
# request's GUID
self.logger.error(traceback.format_exc(), add_to_buffer=False)
self.set_status(500)
self.finish()
return
return response
[docs]
async def q2_post(self, *args, **kwargs):
"""
Overridable method. Should be defined in the extension inheriting this class.
Called from the housekeeping Q2 specific POST method after the incoming
message is parsed and broken into objects
"""
route_response = await self.route_request()
return route_response if route_response is not None else ""
[docs]
async def q2_on_finish(self) -> None:
"""Called after the end of `q2_post`.
This will not send data back to the end user,
but is meant for background processing after
the http connection has been closed.
This will NOT be called if the response is a 500
"""
@property
def request_as_obj(self) -> objectify.ObjectifiedElement:
if self._request_as_obj is None:
element = objectify.fromstring(self.request.body)
if element.tag == "{http://schemas.xmlsoap.org/soap/envelope/}Envelope":
element = element.Body.getchildren()[0]
self._request_as_obj = element
return self._request_as_obj
@request_as_obj.setter
def request_as_obj(self, value) -> None:
self._request_as_obj = value
@property
def _tag_name(self) -> str:
return etree.QName(self.request_as_obj).localname
@property
def form_fields(self) -> dict:
if self._form_fields is None:
if self.is_xml:
form_fields = self.request_as_obj.findall(".//FormFields")
self._form_fields = self.parse_form_fields(form_fields)
else:
objectified_form_fields = []
for key, val in self.request.arguments.items():
if isinstance(val, (str, bytes)):
val = [val]
for item in val:
if isinstance(item, bytes):
item = item.decode()
objectified_form_fields.append(
objectify.E.root(
objectify.E.Name(key), objectify.E.Value(item)
)
)
self._form_fields = self.parse_form_fields(objectified_form_fields)
return self._form_fields
@form_fields.setter
def form_fields(self, value):
self._form_fields = value
async def _handle_soap_request(self, *args, **kwargs):
"""
Handles the more common HQ requests wrapped in SOAP envelopes
"""
if self._tag_name == "GetUnsignedToken":
response = http.get_unhandled_token_response(uuid.uuid4())
return response
elif self._tag_name in (
"ExecuteRequestAsXml",
"ExecuteRequestAsString",
"LoginToAdapter",
):
if self._tag_name == "ExecuteRequestAsXml":
self.request_name = self.request_as_obj.xmlDoc.getchildren()[0].tag
elif self._tag_name == "ExecuteRequestAsString":
self.request_name = objectify.fromstring(
self.request_as_obj.xml.text
).tag
routing_key = self.form_fields.get("routing_key")
self.request_name_elem = self.request_as_obj.find(
".//{}".format(self.request_name)
)
await self.get_wedge_address_configs()
if self.request_name_elem is None:
self.logger.debug(
"This does not appear to be an HQ request. Skipping the HQ-specific logic."
)
elif routing_key not in self._unparsed_routes:
await self._build_models_from_hq(self.request_as_obj)
try:
custom_response = await self.q2_post(*args, **kwargs)
except Exception:
# If there is an unknown exception, log it tagged with this
# request's GUID
self.logger.error(traceback.format_exc(), add_to_buffer=False)
self.set_status(500)
self.finish()
return
if isinstance(custom_response, forms.Q2Form):
custom_response = custom_response.serialize()
if self.return_as_html:
if custom_response is not None and self.use_style_fixes:
style_fix_loader = StyleFixAppender(custom_response)
custom_response = style_fix_loader.append_style_fixes()
asset_loader = LocalAssetLoader(self.templater, custom_response)
if self.expand_local_sources:
custom_response = asset_loader.expand_local_sources()
else:
custom_response = asset_loader.raw_html
custom_response = self.clean_html_tags(custom_response)
if custom_response is None:
self.logger.warning(
'Empty response detected. Are you sure you called "return" in your method?'
)
response = self.wrap_soap_response(custom_response)
else:
self.request_name = self.request_as_obj.tag
response = await self.q2_post(*args, **kwargs)
return response
async def _build_models_from_hq(self, element):
self.is_authenticated = self.request_name_elem.get(
"IsAuthenticated"
) == "True" and self.request_name_elem.get("HqAuthToken")
hq_auth_token = (
self.request_name_elem.get("HqAuthToken") if self.is_authenticated else None
)
aba = self.request_name_elem.get("ABA", self.hq_credentials.aba)
reported_hq_url = self.request_name_elem.get(
"HqBaseUrl", self.hq_credentials.hq_url
)
env_stack = self.request_name_elem.get(
"EnvStack", self.hq_credentials.env_stack
)
db_schema_name = self.request_name_elem.get(
"DbSchemaName", self.hq_credentials.db_schema_name
)
db_name = self.request_name_elem.get(
"DbName", self.hq_credentials.database_name
)
cust_key = self.hq_credentials.customer_key or self.request_name_elem.get(
"InsightCustomerKey"
)
if not cust_key and settings.LOCAL_DEV:
cust_key = settings.VAULT_KEY
if not settings.USE_INCOMING_HQ_URL:
reported_hq_url = self.hq_credentials.hq_url
configured_hq = HqCredentials(
self.hq_credentials.hq_url,
self.hq_credentials.csr_user,
self.hq_credentials.csr_pwd,
aba,
)
if cust_key and not configured_hq.hq_url:
# Get it from vault if you don't have it locally
configured_hq = self._get_hq_from_key(cust_key)
configured_hq.customer_key = cust_key
configured_hq.database_name = db_name
configured_hq.db_schema_name = db_schema_name
configured_hq.env_stack = env_stack
try:
self.form_info = FormInfo(element)
except AttributeError:
self.logger.debug("No form info could be found in request")
self.hq_credentials = configured_hq
await self.get_wedge_address_configs()
self.hq_credentials.auth_token = hq_auth_token
self.hq_credentials.reported_hq_url = (
reported_hq_url if reported_hq_url else self.hq_credentials.hq_url
)
debug_mode = get_settings().DEBUG
if not self.hq_credentials.customer_key and debug_mode:
self.hq_credentials.customer_key = "CHANGEME"
try:
self.online_session = OnlineSession(element)
except AttributeError:
self.logger.debug("No online session could be found in request")
def _get_hq_from_key(self, customer_key) -> HqCredentials:
hq_creds = None
cache_key = f"hq_for_ck_{customer_key}"
cached_creds = self.service_cache.get(cache_key)
if cached_creds:
hq_creds = HqCredentials(
cached_creds["HqUrl"],
cached_creds["CsrUser"],
cached_creds["CsrPwd"],
cached_creds["ABA"],
customer_key=customer_key,
database_name=cached_creds.get("DB_NAME"),
db_server_name=cached_creds.get("DBServerName"),
)
self.logger.debug("Getting HQ from cache for key: %s", customer_key)
else:
hq_creds = self.vault.get_hq_creds(customer_key)
self.service_cache.set(cache_key, hq_creds.serialize_as_dict(), expire=60)
return hq_creds
@staticmethod
def parse_form_fields(form_fields: List[objectify.ObjectifiedElement]) -> dict:
parsed_form_fields = {}
all_keys = Counter([x.Name.text for x in form_fields])
for field in form_fields:
name = field.Name.text.replace("%5B%5D", "[]")
value = field.Value.text if field.Value.text else ""
if all_keys[name] > 1 or "[]" in name:
name = name.rstrip("[]")
parsed_form_fields.setdefault(name, [])
parsed_form_fields[name].extend(value.split(","))
else:
parsed_form_fields[name] = value
return parsed_form_fields
[docs]
def wrap_soap_response(self, custom_response):
"""
If the incoming request is wrapped in a soap envelope, this prepares
it for return back to HQ. There are two formats we know about, but
since it is possible for HQ instances to be customized and expect a
different format, this allows for overriding the response at an
extension level.
:param custom_response: Response to wrap in an appropriate soap envelope
"""
if self._is_central_report():
return http.get_backoffice_report_response(
self.request_name, custom_response
)
wrap_mapping = {
"ExecuteRequestAsXml": http.get_execute_request_as_xml_response,
"ExecuteRequestAsString": http.get_execute_request_as_string_response,
}
response = wrap_mapping[self._tag_name](
self.request_name, custom_response, self.hq_response_attrs, self.hq_commands
)
return response
async def route_request(self, routing_key: Optional[str] = None):
if not routing_key:
routing_key = self.form_fields.get("routing_key")
if not routing_key:
routing_key = self.DEFAULT_ROUTE
if isinstance(routing_key, list):
if len(routing_key) > 1:
self.logger.error("Multiple routing keys present: '%s'.", routing_key)
routing_key = self.DEFAULT_ROUTE
else:
routing_key = routing_key[0].decode("utf-8")
if routing_key not in self.router:
self.logger.error(
"Routing key '%s' not present in router dictionary.", routing_key
)
routing_key = self.DEFAULT_ROUTE
route = self.router[routing_key]
self.logger.info("Transition: routing to '%s'", routing_key)
self.active_route = routing_key
return await self.call_route(route)
async def call_route(self, route):
is_async_callable_class = callable(route) and asyncio.iscoroutinefunction(
route.__call__
)
is_async_callable_function = asyncio.iscoroutinefunction(route)
arg_types = {"positional": [], "keyword": {}}
form_fields = self.form_fields.copy()
params = signature(route).parameters
var_pos = False
var_keyword = False
positional_only_args = [
x for x in params.values() if x.kind == Parameter.POSITIONAL_ONLY
]
for name, param in params.items():
if name in form_fields:
if param.kind in (Parameter.POSITIONAL_ONLY, Parameter.VAR_POSITIONAL):
arg_types["positional"].append(form_fields.pop(name))
else:
arg_types["keyword"][name] = form_fields.pop(name)
if param.kind == Parameter.VAR_POSITIONAL:
var_pos = True
if param.kind == Parameter.VAR_KEYWORD:
var_keyword = True
if len(arg_types["positional"]) != len(positional_only_args):
raise TypeError("Not enough positional arguments provided")
if var_pos:
for value in form_fields.values():
arg_types["positional"].append(value)
elif var_keyword:
for key, value in form_fields.items():
arg_types["keyword"][key] = value
if is_async_callable_class or is_async_callable_function:
return await route(*arg_types["positional"], **arg_types["keyword"])
return route(*arg_types["positional"], **arg_types["keyword"])
@property
def router(self):
return {"default": self.default, "default/get_upload_url": self._get_upload_url}
async def default(self):
pass
async def _get_upload_url(self):
data_json = base64.b64decode(self.form_fields.get("data"))
data = json.loads(data_json)
self.logger.debug(data)
aws_files = []
for file_name in data["files"]:
aws_file = await File(self.logger).get_upload_url(name=file_name)
aws_files.append(aws_file.raw)
return json.dumps(aws_files)
[docs]
async def download_file(self, input_name: str) -> bytes:
"""Downloads file_key from AWS using the ardent.file_upload.File class"""
file_keys = self.form_fields.get("q2_uploadedFileKeys")
if isinstance(file_keys, str):
file_keys = file_keys.split(",")
for key in file_keys:
if key.startswith(f"{input_name}|"):
file_key = key.split("|", maxsplit=1)[1]
return await File(self.logger).download(file_key)
raise KeyError(f"No file input POSTed with name: {input_name}")
[docs]
def set_hq_commands(self, hq_commands: HqCommands):
"""
HQ has the ability to reload accounts from the DB/Core or skip
a disclaimer for the remainder of a session. To do so, call this method
with a valid HqCommands instance.
Note: This does NOT immediately call HQ to do the operation,
but rather will send the command to HQ along with the
response shape from your route.
For instance:
.. code::
async def default(self):
#do stuff
account_id = 12345
hq_commands_obj = HqCommands(
HqCommands.AccountReloadType.FROM_HOST,
[account_id],
)
self.set_hq_commands(hq_commands_obj) # Nothing happens here
html = self.get_template('template_name.html', {})
return html # This is where the magic happens
"""
self.hq_commands = hq_commands