Source code for q2_sdk.core.http_handlers.base_handler

import asyncio
import hashlib
import ipaddress
import logging
import os
import traceback
import uuid
from dataclasses import dataclass
from functools import cached_property, partial
from importlib import import_module
from socket import gaierror
from typing import TYPE_CHECKING, Any, Awaitable, Callable, Dict, List, Optional
from uuid import uuid4
from urllib.parse import urlparse

from opentelemetry import trace
from tornado.web import Finish, RequestHandler

from q2_sdk.core import cache, configuration, contexts, exceptions, vault
from q2_sdk.core.cache import Q2CacheClient, StorageLevel
from q2_sdk.core.configuration import settings
from q2_sdk.core.install_steps.db_plan import DbPlan
from q2_sdk.core.opentelemetry.filters import BaseEventFilter
from q2_sdk.core.opentelemetry.handler import SpanContext
from q2_sdk.core.opentelemetry.span import Q2SpanAttributes
from q2_sdk.core.prometheus import MetricType, get_metric
from q2_sdk.core.q2_logging.logger import Q2LoggerAdapter
from q2_sdk.core.rate_limiter import RateLimiter
from q2_sdk.core.request_handler.templating import Q2RequestHandlerTemplater
from q2_sdk.hq.models.hq_credentials import HqCredentials
from q2_sdk.models.cores.base_core import BaseCore
from q2_sdk.models.cores.models.core_user import CoreUser
from q2_sdk.hq.models.account_list import AccountList
from q2_sdk.hq.models.online_user import OnlineUser
from q2_sdk.tools import utils

if TYPE_CHECKING:
    from opentelemetry.sdk.trace import Span

    from q2_sdk.models.pinion import Pinion
    from q2_sdk.hq.db.db_object_factory import DbObjectFactory

RATE_LIMIT_ERROR_MSG = "Too many requests"


def is_q2_traffic(remote_ip: str, whitelist: Optional[List[str]] = None) -> bool:
    """
    Returns True if IP belongs to a known Q2 source

    :param remote_ip: IpAddr as str
    :param whitelist: If exists, will merge with addresses in global INBOUND_IP_WHITELIST
    """

    if not whitelist:
        whitelist = []
    inbound_ip_whitelist = configuration.settings.INBOUND_IP_WHITELIST
    _q2_networks = [ipaddress.IPv4Network(x) for x in inbound_ip_whitelist]
    _q2_networks.extend([ipaddress.IPv4Network(x) for x in whitelist])
    remote_ip = remote_ip.strip()
    ip_addr = ipaddress.ip_address(remote_ip)
    if ip_addr.is_private:
        return True
    for network in _q2_networks:
        if ip_addr in network:
            return True
    return False


# Prometheus Metrics
def increment_current_requests(endpoint, method, reverse=False):
    """
    if reverse, will decrement
    """
    operation = "inc"
    if reverse:
        operation = "dec"
    get_metric(
        MetricType.Gauge,
        "caliper_current_requests",
        "Requests in progress",
        {"endpoint": endpoint, "method": method},
        chain={"op": operation},
    )


@dataclass
class ConfigurationCheck:
    is_valid: bool
    error: Optional[str] = ""
    fix: Optional[Awaitable] = None


class classproperty:
    """
    This is used just like @property,
    but can also be used by an uninstantiated version of the class
    """

    def __init__(self, func):
        self.func = func

    def __get__(self, _, owner):
        return self.func(owner)


[docs] class Q2BaseRequestHandler(RequestHandler): """ Inherits from Tornado's RequestHandler, but adds a few Q2 specific niceties. - Handles both gets and posts - If REQUIRED_CONFIGURATIONS is set, ensures the entries are set in the extension's settings file or the webserver will not start. REQUIRED_CONFIGURATIONS is a dictionary of key value pairs where the key is the required name and the value is the default value when :code:`q2 generate_config` is called - If DYNAMIC_CORE_SELECTION is set to True, will prompt for a Core name on installation - OPTIONAL_CONFIGURATIONS work the same was as REQUIRED_CONFIGURATIONS, but will not stop the server from running if omitted - FRONTEND_OVERRIDE_EXTENSION allows using another extension as the frontend for this one """ CONFIG_FILE_NAME = "" DB_PLAN = DbPlan() REQUIRED_CONFIGURATIONS = {} OPTIONAL_CONFIGURATIONS = {} LOGGING_FILTERS: List[logging.Filter] = [] EVENT_FILTERS: List[BaseEventFilter] = [] DYNAMIC_CORE_SELECTION = False INBOUND_IP_WHITELIST = [] FRONTEND_OVERRIDE_EXTENSION = None CUSTOM_AUDIT_ACTIONS = [] def __init__(self, application, request, **kwargs): self._logging_level = kwargs.pop("logging_level") super().__init__(application, request, **kwargs) self.config = configuration.get_configuration( self.CONFIG_FILE_NAME, self.REQUIRED_CONFIGURATIONS ) self.hq_credentials = HqCredentials( configuration.settings.HQ_CREDENTIALS.hq_url, configuration.settings.HQ_CREDENTIALS.csr_user, configuration.settings.HQ_CREDENTIALS.csr_pwd, configuration.settings.HQ_CREDENTIALS.aba, "", customer_key=configuration.settings.HQ_CREDENTIALS.customer_key, ) self._logger = None self.sdk_session_identifier: str = uuid.uuid4().hex self._cache: Optional[Q2CacheClient] = None self.bt_handle = None self._templater: Optional[Q2RequestHandlerTemplater] = None self.rate_limiters: List[RateLimiter] = [] self._ui_text = None self._core: Optional[BaseCore] = None if self.config and self.config.extra_configurations: self.config.log_extra_configurations(self.logger) self.allow_non_q2_traffic = False if configuration.settings.FORK_REQUESTS: self.this_pid = os.getpid() if self.this_pid != configuration.settings.SERVER_PID: self.logger.debug(f"PID: {self.this_pid}") self.pid_path = ( f"{configuration.settings.FORKED_CHILD_PID_DIR}/{self.this_pid}" ) if configuration.settings.ENABLE_OPEN_TELEMETRY: self.otel_handler_context_token = None self.otel_span: Optional["Span"] = None self.otel_inbound_context_token = None self._db = None self.enable_summary_log = True self.default_cache_level = StorageLevel.Service self.online_user: Optional[OnlineUser] = None self.account_list: Optional[AccountList] = None
[docs] def data_received(self, _: bytes): self.logger.warning("Method data_received is not implemented")
@classmethod async def validate_configuration(cls) -> List[ConfigurationCheck]: return [ConfigurationCheck(True)]
[docs] async def run_async(self, funcname: Callable, *args, **kwargs): """Easy way to run a non_async function in a background thread Usage: def non_async_func(foo, spam="eggs"): time_consuming_thing() await self.run_async(non_async_func, 1, spam="eggs") :param funcname: Function reference (do not instantiate) :param args: Positional parameters :param kwargs: Keyword parameters """ func = partial(funcname, *args, **kwargs) await asyncio.get_running_loop().run_in_executor(None, func)
@property def default_summary_log(self) -> Dict[str, Any]: summary_log: Dict[str, Any] = { "status_code": self.get_status(), } if configuration.settings.ENABLE_OPEN_TELEMETRY: summary_log["trace_id"] = get_trace_id() if configuration.settings.MULTITENANT: if self.hq_credentials: if self.hq_credentials.customer_key: summary_log["customer_key"] = self.hq_credentials.customer_key if self.hq_credentials.env_stack: env_stack_list = self.hq_credentials.env_stack.split("-") if len(env_stack_list) == 3: summary_log["fi_num"] = ( f"{env_stack_list[0]}-{env_stack_list[1]}" ) else: summary_log["fi_num"] = f"{env_stack_list[0]}-00" return summary_log @property def override_summary_log(self): return {} @property def core(self) -> BaseCore: """ Instantiates a communication object between this extension and the core listed in configuration.settings.CORE Must pip install q2_cores or populate settings.CUSTOM_CORES for this to be accessible :return: Subclass of q2_sdk.models.cores.BaseCore """ core_options = utils.get_core_config_options() if self.DYNAMIC_CORE_SELECTION is True: core_name = self.db_config.get("q2_core", {}).get("name") if not core_name or core_name == "None": raise exceptions.CoreNotConfiguredError( "Core must be configured in the database if DYNAMIC_CORE_SELECTION is set to True" ) core_mapper_str = core_name if core_name not in configuration.settings.CUSTOM_CORES: core_options.setdefault(core_name, {}) core_options[core_name]["path"] = f"q2_cores.{core_name}.core" elif configuration.settings.CORE is None: self.logger.error( "CORE must be set in the settings file if you are planning to use it" ) raise exceptions.CoreNotConfiguredError( "configuration.settings.CORE not set to a subclass of q2_sdk.models.cores.BaseCore" ) else: core_mapper_str = configuration.settings.CORE if core_mapper_str not in core_options: core_options.setdefault(core_mapper_str, {}) core_options[core_mapper_str]["path"] = core_mapper_str if self._core is None: core_mapper_str = core_mapper_str.split("q2_cores.")[ -1 ] # For backwards compatibility self.logger.debug("Setting core to %s", core_mapper_str) if configuration.settings.MOCK_BRIDGE_CALLS: self.logger.info("Core: Mocking bridge calls") core_config_info = core_options.get(core_mapper_str) if not core_config_info: self.logger.error("Unable to import %s" % core_mapper_str) raise ImportError core_mapper_path = core_options[core_mapper_str]["path"] try: core_module = import_module(core_mapper_path) except ImportError: self.logger.error("Unable to import %s" % core_mapper_path) raise if self.online_user is None: self.logger.warning( "Core may not function without an online_user. Did you accidentally call this from __init__?" ) context_param = CoreUser( self.online_user or OnlineUser(), self.account_list or AccountList() ) if self.DYNAMIC_CORE_SELECTION: self._core = core_module.Core( self.logger, context_param, hq_credentials=self.hq_credentials, db_config_dict=self.db_config["q2_core"]["configs"], ) else: self._core = core_module.Core( self.logger, context_param, hq_credentials=self.hq_credentials ) return self._core @property def db(self) -> "DbObjectFactory": if self._db is None: from q2_sdk.hq.db.db_object_factory import DbObjectFactory self._db = DbObjectFactory(self.logger, self.hq_credentials) return self._db @property def templater(self) -> Q2RequestHandlerTemplater: if not self._templater: self._templater = Q2RequestHandlerTemplater(self.logger, self.__class__) return self._templater @property def cache(self) -> Q2CacheClient: """ Instantiates a communication object between this extension and the pymemcache library. Cache can be configured via configuration.settings.CACHE. This is a factory that instantiates the cache in self.default_cache_level. There is also self.stack_cache, self.service_cache, or self.session_cache for direct access based on your context. """ if not self._cache: match self.default_cache_level: case StorageLevel.Service: self._cache = self.service_cache case StorageLevel.Stack: self._cache = self.stack_cache case StorageLevel.Session: self._cache = self.session_cache return self._cache
[docs] @cached_property def service_cache(self) -> Q2CacheClient: """Q2CacheClient scoped to the loadbalanced running service.""" return self.get_cache()
[docs] @cached_property def stack_cache(self) -> Q2CacheClient: """Q2CacheClient scoped to the current financial institution stack.""" raise NotImplementedError
[docs] @cached_property def session_cache(self) -> Q2CacheClient: """Same as self.cache but limited to online user session""" raise NotImplementedError
[docs] def get_cache(self, prefix=None, **kwargs) -> Q2CacheClient: """ :param prefix: If defined will be prepended to all keys """ return cache.get_cache(logger=self.logger, prefix=prefix, **kwargs)
@property def vault(self) -> "vault.Q2Vault": return vault.get_client(logger=self.logger) @property def pinion(self) -> "Pinion": from q2_sdk.models.pinion import Pinion return Pinion( logger=self.logger, hq_credentials=self.hq_credentials, cache_obj=self.cache ) @property def base_assets_url(self): """Url to use for creating a fully qualified URL to your extension assets""" return utils.get_base_assets_url(self.extension_name) @classproperty def extension_name(cls): split = cls.__module__.split(".") return ".".join(split[:-1]) @property def logger(self): if self._logger is None: author = "Q2" if configuration.settings.IS_CUSTOMER_CREATED: author = "Cust" extra = {"guid": self.sdk_session_identifier, "author": author} if configuration.settings.INCLUDE_QUERY_PARAMS_IN_LOGS: name = self.request.uri else: name = str(urlparse(self.request.uri).path) name = name.lstrip("/") logger = logging.getLogger("extension.%s" % name) logger.setLevel(self._logging_level) logger = Q2LoggerAdapter(logger, extra) # Add Global filters for log_filter in configuration.settings.GLOBAL_LOGGING_FILTERS: logger.logger.addFilter(log_filter) # Add extension specific filters for log_filter in self.LOGGING_FILTERS: logger.logger.addFilter(log_filter) self._logger = logger return self._logger def _get_ui_text_prefix(self): prefix = f"{self.DB_PLAN.ui_text_prefix.rstrip('/')}/" return prefix async def _get_ui_text_from_db(self, prefix, language, ui_selection): values_dict = {} ui_text_module = import_module("q2_sdk.hq.db.ui_text") db_values = await ui_text_module.UiText( self.logger, hq_credentials=self.hq_credentials ).get(prefix=prefix, ui_selection=ui_selection) for item in db_values: if not item.Language.text == language: continue if not item.ShortName.text.startswith(prefix): continue key = item.ShortName.text key_minus_prefix = key[len(prefix) :] if ui_selection is None and item.findtext("UiSelectionID") is not None: # Ignore rows with populated UiSelectionIDs if we asked for None continue if values_dict.get(key_minus_prefix) and ui_selection: # Ensure an empty UiSelection does not override a more specific one if item.findtext("UiSelectionID") is None: continue values_dict[key_minus_prefix] = item.findtext("TextValue", "") return values_dict def _get_ui_text_cache_key( self, prefix, language=None, ui_selection=None, cache_time=60 ): language = self._get_ui_text_language(language) hash_txt = f"{self.hq_credentials.hq_url}.{prefix}.{language}.{ui_selection}.{cache_time}.uitext".encode() return hashlib.sha1(hash_txt).hexdigest() def _get_ui_text_language(self, language=None): if language is None: online_user = vars(self).get("online_user") if online_user is not None: language = online_user.language else: language = "USEnglish" return language def _get_ui_selection(self, ui_selection=None): if ui_selection is None: form_info = vars(self).get("form_info") if form_info is not None: ui_selection = form_info.ui_selection_short_name return ui_selection def _get_appd_bt_name(self): return f"{self.request.uri.lstrip('/')} - {self.request.method}"
[docs] async def get_ui_text( self, language: Optional[str] = None, ui_selection: Optional[str] = None, cache_time: int = 60, ) -> dict: """ Returns any UI Text elements from the database that were inserted as part of self.DB_PLAN """ if not self._ui_text: if language is None: language = self._get_ui_text_language() if ui_selection is None: ui_selection = self._get_ui_selection() prefix = self._get_ui_text_prefix() cache_key = self._get_ui_text_cache_key( prefix, language, ui_selection=ui_selection, cache_time=cache_time ) cached_ui_text = self.cache.get(cache_key, default={}) if cached_ui_text: self.logger.debug("Getting UI Text from Cache") self._ui_text = cached_ui_text else: self.logger.debug("Getting UI Text from Database") self._ui_text = await self._get_ui_text_from_db( prefix, language, ui_selection ) try: self.cache.set(cache_key, self._ui_text, expire=cache_time) except gaierror: self.logger.error( "Error connecting to the memcache server: %s", traceback.format_exc(), ) return self._ui_text
[docs] async def prepare(self): """ Fires before any request handling code By default, will check any :ref:`rate_limiter` instances in self.rate_limiters and disallow further processing if any are over their specified limit """ # This MUST occur during `prepare`, not `__init__` (it is part of a separate task that will cause the context not to get GC) if not configuration.settings.TEST_MODE: contexts.set_current_request(self) increment_current_requests(self.extension_name, self.request.method) configured_allowed_list = ( configuration.settings.ALLOW_TRAFFIC_TO_DEFAULT_EXTENSIONS ) allowed_default_extensions = [ x for x in configured_allowed_list if x.startswith("q2_sdk.extensions") ] if ( self.allow_non_q2_traffic is False and self.extension_name not in allowed_default_extensions ): ip_list = self.request.headers.get( "X-Forwarded-For", str(self.request.remote_ip) ) ip_list = ip_list.split(",") for remote_ip in ip_list: is_allowed = is_q2_traffic( remote_ip, whitelist=self.INBOUND_IP_WHITELIST ) if not is_allowed: self.logger.warning("Blocked non-Q2 traffic %s" % remote_ip) self.set_status(403) self.finish() raise Finish() if not self.extension_name.startswith("q2_sdk.extensions"): if ( configuration.settings.FORK_REQUESTS and self.this_pid == configuration.settings.SERVER_PID ): self.write(RATE_LIMIT_ERROR_MSG) self.set_status(429, RATE_LIMIT_ERROR_MSG) self.finish() return for rate_limiter in self.rate_limiters: is_allowed = rate_limiter.is_allowed() if not is_allowed: self.write(RATE_LIMIT_ERROR_MSG) self.set_status(429, RATE_LIMIT_ERROR_MSG) self.finish() break
def _request_summary(self): return f"{self.logger.extra['guid']} {super()._request_summary()}" async def _execute(self, transforms, *args: bytes, **kwargs: bytes) -> None: """ This is the last function before the request ends. Also a convenient place to kill the process if running forked. """ if ( configuration.settings.FORK_REQUESTS and self.this_pid != configuration.settings.SERVER_PID ): retry_count = 0 while not os.path.exists(self.pid_path): await asyncio.sleep(0.1) retry_count += 1 if retry_count > 10: self.logger.error( f"Initial PID file never written. Cancelling request for {self.extension_name}" ) exit(-1) with open(self.pid_path, "w") as handle: handle.write("endpoint: " + self.extension_name + "\n") handle.write("method: " + self.request.method + "\n") handle.write("path: " + self.request.path + "\n") handle.write("query: " + self.request.query + "\n") with SpanContext(self) as span: if span: span.set_attribute( Q2SpanAttributes.POINT_OF_INTEREST, self.sdk_session_identifier ) await super()._execute(transforms, *args, **kwargs) if configuration.settings.FORK_REQUESTS: while not self.request.server_connection.stream.closed(): await asyncio.sleep(0.1) if self.this_pid != configuration.settings.SERVER_PID: self.logger.debug(f"Exiting process {self.this_pid}") os.remove(self.pid_path) exit(0)
[docs] def on_finish(self): """Fires as the request is ending""" try: if self.enable_summary_log: if not isinstance(self.override_summary_log, dict) or not isinstance( self.default_summary_log, dict ): raise ValueError( "self.default_summary_log and self.override_summary_log must be dictionaries" ) default_log = self._cleanse_summary_dict(self.default_summary_log) override_log = self._cleanse_summary_dict(self.override_summary_log) summary_line = default_log | override_log summary_log_str = "" for key, value in summary_line.items(): summary_log_str += " {}='{}'".format(str(key), str(value)) if summary_log_str: self.logger.summary( f"summary line:{summary_log_str}", add_to_buffer=False, _danger_unencrypted_search_fields={ field: summary_line[key] for key, field in { "HQ_ID": "sessionId", "fi_num": "fi_num", "customer_key": "customer_key", "api_key": "api_key", # Caliper API key, added for compat. In the future, this should be a config at the project level, not hardcoded into the SDK "trace_id": "trace_id", }.items() if key in summary_line }, ) except Exception: summary_line = {} if self.get_status() == 500: self.logger.replay_buffer() self._set_runtime_metric() increment_current_requests( self.extension_name, self.request.method, reverse=True )
def _cleanse_summary_dict(self, summary_dict: dict) -> dict: updated_summary_dict = {} # check for cases when key value passed in with double quotes for key, value in summary_dict.items(): updated_key = key # spaces should be updated to underscores if " " in updated_key: updated_key = "_".join(updated_key.split(" ")) # ' cannot exist in key if "'" in updated_key: updated_key = "".join(updated_key.split("'")) self.logger.warning( f"' not accepted in summary dictionary key. Updating '{key}' to '{updated_key}'" ) updated_summary_dict[updated_key] = value return updated_summary_dict
[docs] def log_exception(self, typ, value: Optional[BaseException], tb) -> None: if configuration.settings.ENABLE_OPEN_TELEMETRY: SpanContext.log_traceback(tb) self.logger.error( "Uncaught exception:\n\n%s\n\n", self.request, exc_info=(typ, value, tb), add_to_buffer=False, )
[docs] def get_template(self, template_name: str, replacements_dict: dict) -> str: """ Used for loading static blocks of text with substitutions. Great for storing large HTML blocks in separate files. Uses jinja behind the scenes (http://jinja.pocoo.org/) Will search in all templates folders starting with the calling extension, then moving on to the parent class, and so on, preferring the lowest. :param template_name: file_name including suffix. ex. initial.html :param replacements_dict: Key value pairs to replace within template_name :return: Body of template_name file replaced with values in replacements_dict """ replacements_dict = self._replacement_dict_update(replacements_dict) return self.templater.get_template(template_name, replacements_dict)
def _replacement_dict_update(self, replacements_dict: dict): if replacements_dict is None: replacements_dict = {} this = replacements_dict.get("this") if this is not None and not isinstance(this, dict): # If 'this' is passed as a non-dict, just pass it along without # populating it with vars(self) pass else: if this is not None: this = {**vars(self), **this} else: this = vars(self) this["base_assets_url"] = self.base_assets_url replacements_dict.update({"this": this}) return replacements_dict def _set_runtime_metric(self): get_metric( MetricType.Histogram, "caliper_endpoints", "Endpoint runtime", { "endpoint": self.extension_name, "status_code": self._status_code, "method": self.request.method, }, chain={"op": "observe", "params": [self.request.request_time()]}, )
def get_trace_id(generate=False): trace_id = "" if settings.ENABLE_OPEN_TELEMETRY: trace_id = hex( trace.get_current_span().get_span_context().trace_id ).removeprefix("0x") elif generate: trace_id = str(uuid4()) return trace_id