import asyncio
import hashlib
import ipaddress
import logging
import os
import traceback
import uuid
from dataclasses import dataclass
from functools import cached_property, partial
from importlib import import_module
from socket import gaierror
from typing import TYPE_CHECKING, Any, Awaitable, Callable, Dict, List, Optional
from uuid import uuid4
from urllib.parse import urlparse
from opentelemetry import trace
from tornado.web import Finish, RequestHandler
from q2_sdk.core import cache, configuration, contexts, exceptions, vault
from q2_sdk.core.cache import Q2CacheClient, StorageLevel
from q2_sdk.core.configuration import settings
from q2_sdk.core.install_steps.db_plan import DbPlan
from q2_sdk.core.opentelemetry.filters import BaseEventFilter
from q2_sdk.core.opentelemetry.handler import SpanContext
from q2_sdk.core.opentelemetry.span import Q2SpanAttributes
from q2_sdk.core.prometheus import MetricType, get_metric
from q2_sdk.core.q2_logging.logger import Q2LoggerAdapter
from q2_sdk.core.rate_limiter import RateLimiter
from q2_sdk.core.request_handler.templating import Q2RequestHandlerTemplater
from q2_sdk.hq.models.hq_credentials import HqCredentials
from q2_sdk.models.cores.base_core import BaseCore
from q2_sdk.models.cores.models.core_user import CoreUser
from q2_sdk.hq.models.account_list import AccountList
from q2_sdk.hq.models.online_user import OnlineUser
from q2_sdk.tools import utils
if TYPE_CHECKING:
from opentelemetry.sdk.trace import Span
from q2_sdk.models.pinion import Pinion
from q2_sdk.hq.db.db_object_factory import DbObjectFactory
RATE_LIMIT_ERROR_MSG = "Too many requests"
def is_q2_traffic(remote_ip: str, whitelist: Optional[List[str]] = None) -> bool:
"""
Returns True if IP belongs to a known Q2 source
:param remote_ip: IpAddr as str
:param whitelist: If exists, will merge with addresses in global INBOUND_IP_WHITELIST
"""
if not whitelist:
whitelist = []
inbound_ip_whitelist = configuration.settings.INBOUND_IP_WHITELIST
_q2_networks = [ipaddress.IPv4Network(x) for x in inbound_ip_whitelist]
_q2_networks.extend([ipaddress.IPv4Network(x) for x in whitelist])
remote_ip = remote_ip.strip()
ip_addr = ipaddress.ip_address(remote_ip)
if ip_addr.is_private:
return True
for network in _q2_networks:
if ip_addr in network:
return True
return False
# Prometheus Metrics
def increment_current_requests(endpoint, method, reverse=False):
"""
if reverse, will decrement
"""
operation = "inc"
if reverse:
operation = "dec"
get_metric(
MetricType.Gauge,
"caliper_current_requests",
"Requests in progress",
{"endpoint": endpoint, "method": method},
chain={"op": operation},
)
@dataclass
class ConfigurationCheck:
is_valid: bool
error: Optional[str] = ""
fix: Optional[Awaitable] = None
class classproperty:
"""
This is used just like @property,
but can also be used by an uninstantiated version of the class
"""
def __init__(self, func):
self.func = func
def __get__(self, _, owner):
return self.func(owner)
[docs]
class Q2BaseRequestHandler(RequestHandler):
"""
Inherits from Tornado's RequestHandler, but adds a few Q2 specific niceties.
- Handles both gets and posts
- If REQUIRED_CONFIGURATIONS is set, ensures the entries are set in the
extension's settings file or the webserver will not start.
REQUIRED_CONFIGURATIONS is a dictionary of key value pairs where the
key is the required name and the value is the default value when
:code:`q2 generate_config` is called
- If DYNAMIC_CORE_SELECTION is set to True, will prompt for a Core name on installation
- OPTIONAL_CONFIGURATIONS work the same was as REQUIRED_CONFIGURATIONS,
but will not stop the server from running if omitted
- FRONTEND_OVERRIDE_EXTENSION allows using another extension as the frontend
for this one
"""
CONFIG_FILE_NAME = ""
DB_PLAN = DbPlan()
REQUIRED_CONFIGURATIONS = {}
OPTIONAL_CONFIGURATIONS = {}
LOGGING_FILTERS: List[logging.Filter] = []
EVENT_FILTERS: List[BaseEventFilter] = []
DYNAMIC_CORE_SELECTION = False
INBOUND_IP_WHITELIST = []
FRONTEND_OVERRIDE_EXTENSION = None
CUSTOM_AUDIT_ACTIONS = []
def __init__(self, application, request, **kwargs):
self._logging_level = kwargs.pop("logging_level")
super().__init__(application, request, **kwargs)
self.config = configuration.get_configuration(
self.CONFIG_FILE_NAME, self.REQUIRED_CONFIGURATIONS
)
self.hq_credentials = HqCredentials(
configuration.settings.HQ_CREDENTIALS.hq_url,
configuration.settings.HQ_CREDENTIALS.csr_user,
configuration.settings.HQ_CREDENTIALS.csr_pwd,
configuration.settings.HQ_CREDENTIALS.aba,
"",
customer_key=configuration.settings.HQ_CREDENTIALS.customer_key,
)
self._logger = None
self.sdk_session_identifier: str = uuid.uuid4().hex
self._cache: Optional[Q2CacheClient] = None
self.bt_handle = None
self._templater: Optional[Q2RequestHandlerTemplater] = None
self.rate_limiters: List[RateLimiter] = []
self._ui_text = None
self._core: Optional[BaseCore] = None
if self.config and self.config.extra_configurations:
self.config.log_extra_configurations(self.logger)
self.allow_non_q2_traffic = False
if configuration.settings.FORK_REQUESTS:
self.this_pid = os.getpid()
if self.this_pid != configuration.settings.SERVER_PID:
self.logger.debug(f"PID: {self.this_pid}")
self.pid_path = (
f"{configuration.settings.FORKED_CHILD_PID_DIR}/{self.this_pid}"
)
if configuration.settings.ENABLE_OPEN_TELEMETRY:
self.otel_handler_context_token = None
self.otel_span: Optional["Span"] = None
self.otel_inbound_context_token = None
self._db = None
self.enable_summary_log = True
self.default_cache_level = StorageLevel.Service
self.online_user: Optional[OnlineUser] = None
self.account_list: Optional[AccountList] = None
[docs]
def data_received(self, _: bytes):
self.logger.warning("Method data_received is not implemented")
@classmethod
async def validate_configuration(cls) -> List[ConfigurationCheck]:
return [ConfigurationCheck(True)]
[docs]
async def run_async(self, funcname: Callable, *args, **kwargs):
"""Easy way to run a non_async function in a background thread
Usage:
def non_async_func(foo, spam="eggs"):
time_consuming_thing()
await self.run_async(non_async_func, 1, spam="eggs")
:param funcname: Function reference (do not instantiate)
:param args: Positional parameters
:param kwargs: Keyword parameters
"""
func = partial(funcname, *args, **kwargs)
await asyncio.get_running_loop().run_in_executor(None, func)
@property
def default_summary_log(self) -> Dict[str, Any]:
summary_log: Dict[str, Any] = {
"status_code": self.get_status(),
}
if configuration.settings.ENABLE_OPEN_TELEMETRY:
summary_log["trace_id"] = get_trace_id()
if configuration.settings.MULTITENANT:
if self.hq_credentials:
if self.hq_credentials.customer_key:
summary_log["customer_key"] = self.hq_credentials.customer_key
if self.hq_credentials.env_stack:
env_stack_list = self.hq_credentials.env_stack.split("-")
if len(env_stack_list) == 3:
summary_log["fi_num"] = (
f"{env_stack_list[0]}-{env_stack_list[1]}"
)
else:
summary_log["fi_num"] = f"{env_stack_list[0]}-00"
return summary_log
@property
def override_summary_log(self):
return {}
@property
def core(self) -> BaseCore:
"""
Instantiates a communication object between this extension and the core listed in configuration.settings.CORE
Must pip install q2_cores or populate settings.CUSTOM_CORES for this to be accessible
:return: Subclass of q2_sdk.models.cores.BaseCore
"""
core_options = utils.get_core_config_options()
if self.DYNAMIC_CORE_SELECTION is True:
core_name = self.db_config.get("q2_core", {}).get("name")
if not core_name or core_name == "None":
raise exceptions.CoreNotConfiguredError(
"Core must be configured in the database if DYNAMIC_CORE_SELECTION is set to True"
)
core_mapper_str = core_name
if core_name not in configuration.settings.CUSTOM_CORES:
core_options.setdefault(core_name, {})
core_options[core_name]["path"] = f"q2_cores.{core_name}.core"
elif configuration.settings.CORE is None:
self.logger.error(
"CORE must be set in the settings file if you are planning to use it"
)
raise exceptions.CoreNotConfiguredError(
"configuration.settings.CORE not set to a subclass of q2_sdk.models.cores.BaseCore"
)
else:
core_mapper_str = configuration.settings.CORE
if core_mapper_str not in core_options:
core_options.setdefault(core_mapper_str, {})
core_options[core_mapper_str]["path"] = core_mapper_str
if self._core is None:
core_mapper_str = core_mapper_str.split("q2_cores.")[
-1
] # For backwards compatibility
self.logger.debug("Setting core to %s", core_mapper_str)
if configuration.settings.MOCK_BRIDGE_CALLS:
self.logger.info("Core: Mocking bridge calls")
core_config_info = core_options.get(core_mapper_str)
if not core_config_info:
self.logger.error("Unable to import %s" % core_mapper_str)
raise ImportError
core_mapper_path = core_options[core_mapper_str]["path"]
try:
core_module = import_module(core_mapper_path)
except ImportError:
self.logger.error("Unable to import %s" % core_mapper_path)
raise
if self.online_user is None:
self.logger.warning(
"Core may not function without an online_user. Did you accidentally call this from __init__?"
)
context_param = CoreUser(
self.online_user or OnlineUser(), self.account_list or AccountList()
)
if self.DYNAMIC_CORE_SELECTION:
self._core = core_module.Core(
self.logger,
context_param,
hq_credentials=self.hq_credentials,
db_config_dict=self.db_config["q2_core"]["configs"],
)
else:
self._core = core_module.Core(
self.logger, context_param, hq_credentials=self.hq_credentials
)
return self._core
@property
def db(self) -> "DbObjectFactory":
if self._db is None:
from q2_sdk.hq.db.db_object_factory import DbObjectFactory
self._db = DbObjectFactory(self.logger, self.hq_credentials)
return self._db
@property
def templater(self) -> Q2RequestHandlerTemplater:
if not self._templater:
self._templater = Q2RequestHandlerTemplater(self.logger, self.__class__)
return self._templater
@property
def cache(self) -> Q2CacheClient:
"""
Instantiates a communication object between this extension and the pymemcache library.
Cache can be configured via configuration.settings.CACHE.
This is a factory that instantiates the cache in self.default_cache_level.
There is also self.stack_cache, self.service_cache, or self.session_cache for direct access based on your context.
"""
if not self._cache:
match self.default_cache_level:
case StorageLevel.Service:
self._cache = self.service_cache
case StorageLevel.Stack:
self._cache = self.stack_cache
case StorageLevel.Session:
self._cache = self.session_cache
return self._cache
[docs]
@cached_property
def service_cache(self) -> Q2CacheClient:
"""Q2CacheClient scoped to the loadbalanced running service."""
return self.get_cache()
[docs]
@cached_property
def stack_cache(self) -> Q2CacheClient:
"""Q2CacheClient scoped to the current financial institution stack."""
raise NotImplementedError
[docs]
@cached_property
def session_cache(self) -> Q2CacheClient:
"""Same as self.cache but limited to online user session"""
raise NotImplementedError
[docs]
def get_cache(self, prefix=None, **kwargs) -> Q2CacheClient:
"""
:param prefix: If defined will be prepended to all keys
"""
return cache.get_cache(logger=self.logger, prefix=prefix, **kwargs)
@property
def vault(self) -> "vault.Q2Vault":
return vault.get_client(logger=self.logger)
@property
def pinion(self) -> "Pinion":
from q2_sdk.models.pinion import Pinion
return Pinion(
logger=self.logger, hq_credentials=self.hq_credentials, cache_obj=self.cache
)
@property
def base_assets_url(self):
"""Url to use for creating a fully qualified URL to your extension assets"""
return utils.get_base_assets_url(self.extension_name)
@classproperty
def extension_name(cls):
split = cls.__module__.split(".")
return ".".join(split[:-1])
@property
def logger(self):
if self._logger is None:
author = "Q2"
if configuration.settings.IS_CUSTOMER_CREATED:
author = "Cust"
extra = {"guid": self.sdk_session_identifier, "author": author}
if configuration.settings.INCLUDE_QUERY_PARAMS_IN_LOGS:
name = self.request.uri
else:
name = str(urlparse(self.request.uri).path)
name = name.lstrip("/")
logger = logging.getLogger("extension.%s" % name)
logger.setLevel(self._logging_level)
logger = Q2LoggerAdapter(logger, extra)
# Add Global filters
for log_filter in configuration.settings.GLOBAL_LOGGING_FILTERS:
logger.logger.addFilter(log_filter)
# Add extension specific filters
for log_filter in self.LOGGING_FILTERS:
logger.logger.addFilter(log_filter)
self._logger = logger
return self._logger
def _get_ui_text_prefix(self):
prefix = f"{self.DB_PLAN.ui_text_prefix.rstrip('/')}/"
return prefix
async def _get_ui_text_from_db(self, prefix, language, ui_selection):
values_dict = {}
ui_text_module = import_module("q2_sdk.hq.db.ui_text")
db_values = await ui_text_module.UiText(
self.logger, hq_credentials=self.hq_credentials
).get(prefix=prefix, ui_selection=ui_selection)
for item in db_values:
if not item.Language.text == language:
continue
if not item.ShortName.text.startswith(prefix):
continue
key = item.ShortName.text
key_minus_prefix = key[len(prefix) :]
if ui_selection is None and item.findtext("UiSelectionID") is not None:
# Ignore rows with populated UiSelectionIDs if we asked for None
continue
if values_dict.get(key_minus_prefix) and ui_selection:
# Ensure an empty UiSelection does not override a more specific one
if item.findtext("UiSelectionID") is None:
continue
values_dict[key_minus_prefix] = item.findtext("TextValue", "")
return values_dict
def _get_ui_text_cache_key(
self, prefix, language=None, ui_selection=None, cache_time=60
):
language = self._get_ui_text_language(language)
hash_txt = f"{self.hq_credentials.hq_url}.{prefix}.{language}.{ui_selection}.{cache_time}.uitext".encode()
return hashlib.sha1(hash_txt).hexdigest()
def _get_ui_text_language(self, language=None):
if language is None:
online_user = vars(self).get("online_user")
if online_user is not None:
language = online_user.language
else:
language = "USEnglish"
return language
def _get_ui_selection(self, ui_selection=None):
if ui_selection is None:
form_info = vars(self).get("form_info")
if form_info is not None:
ui_selection = form_info.ui_selection_short_name
return ui_selection
def _get_appd_bt_name(self):
return f"{self.request.uri.lstrip('/')} - {self.request.method}"
[docs]
async def get_ui_text(
self,
language: Optional[str] = None,
ui_selection: Optional[str] = None,
cache_time: int = 60,
) -> dict:
"""
Returns any UI Text elements from the database that were inserted as part of self.DB_PLAN
"""
if not self._ui_text:
if language is None:
language = self._get_ui_text_language()
if ui_selection is None:
ui_selection = self._get_ui_selection()
prefix = self._get_ui_text_prefix()
cache_key = self._get_ui_text_cache_key(
prefix, language, ui_selection=ui_selection, cache_time=cache_time
)
cached_ui_text = self.cache.get(cache_key, default={})
if cached_ui_text:
self.logger.debug("Getting UI Text from Cache")
self._ui_text = cached_ui_text
else:
self.logger.debug("Getting UI Text from Database")
self._ui_text = await self._get_ui_text_from_db(
prefix, language, ui_selection
)
try:
self.cache.set(cache_key, self._ui_text, expire=cache_time)
except gaierror:
self.logger.error(
"Error connecting to the memcache server: %s",
traceback.format_exc(),
)
return self._ui_text
[docs]
async def prepare(self):
"""
Fires before any request handling code
By default, will check any :ref:`rate_limiter` instances in
self.rate_limiters and disallow further processing if any
are over their specified limit
"""
# This MUST occur during `prepare`, not `__init__` (it is part of a separate task that will cause the context not to get GC)
if not configuration.settings.TEST_MODE:
contexts.set_current_request(self)
increment_current_requests(self.extension_name, self.request.method)
configured_allowed_list = (
configuration.settings.ALLOW_TRAFFIC_TO_DEFAULT_EXTENSIONS
)
allowed_default_extensions = [
x for x in configured_allowed_list if x.startswith("q2_sdk.extensions")
]
if (
self.allow_non_q2_traffic is False
and self.extension_name not in allowed_default_extensions
):
ip_list = self.request.headers.get(
"X-Forwarded-For", str(self.request.remote_ip)
)
ip_list = ip_list.split(",")
for remote_ip in ip_list:
is_allowed = is_q2_traffic(
remote_ip, whitelist=self.INBOUND_IP_WHITELIST
)
if not is_allowed:
self.logger.warning("Blocked non-Q2 traffic %s" % remote_ip)
self.set_status(403)
self.finish()
raise Finish()
if not self.extension_name.startswith("q2_sdk.extensions"):
if (
configuration.settings.FORK_REQUESTS
and self.this_pid == configuration.settings.SERVER_PID
):
self.write(RATE_LIMIT_ERROR_MSG)
self.set_status(429, RATE_LIMIT_ERROR_MSG)
self.finish()
return
for rate_limiter in self.rate_limiters:
is_allowed = rate_limiter.is_allowed()
if not is_allowed:
self.write(RATE_LIMIT_ERROR_MSG)
self.set_status(429, RATE_LIMIT_ERROR_MSG)
self.finish()
break
def _request_summary(self):
return f"{self.logger.extra['guid']} {super()._request_summary()}"
async def _execute(self, transforms, *args: bytes, **kwargs: bytes) -> None:
"""
This is the last function before the request ends.
Also a convenient place to kill the process if running forked.
"""
if (
configuration.settings.FORK_REQUESTS
and self.this_pid != configuration.settings.SERVER_PID
):
retry_count = 0
while not os.path.exists(self.pid_path):
await asyncio.sleep(0.1)
retry_count += 1
if retry_count > 10:
self.logger.error(
f"Initial PID file never written. Cancelling request for {self.extension_name}"
)
exit(-1)
with open(self.pid_path, "w") as handle:
handle.write("endpoint: " + self.extension_name + "\n")
handle.write("method: " + self.request.method + "\n")
handle.write("path: " + self.request.path + "\n")
handle.write("query: " + self.request.query + "\n")
with SpanContext(self) as span:
if span:
span.set_attribute(
Q2SpanAttributes.POINT_OF_INTEREST, self.sdk_session_identifier
)
await super()._execute(transforms, *args, **kwargs)
if configuration.settings.FORK_REQUESTS:
while not self.request.server_connection.stream.closed():
await asyncio.sleep(0.1)
if self.this_pid != configuration.settings.SERVER_PID:
self.logger.debug(f"Exiting process {self.this_pid}")
os.remove(self.pid_path)
exit(0)
[docs]
def on_finish(self):
"""Fires as the request is ending"""
try:
if self.enable_summary_log:
if not isinstance(self.override_summary_log, dict) or not isinstance(
self.default_summary_log, dict
):
raise ValueError(
"self.default_summary_log and self.override_summary_log must be dictionaries"
)
default_log = self._cleanse_summary_dict(self.default_summary_log)
override_log = self._cleanse_summary_dict(self.override_summary_log)
summary_line = default_log | override_log
summary_log_str = ""
for key, value in summary_line.items():
summary_log_str += " {}='{}'".format(str(key), str(value))
if summary_log_str:
self.logger.summary(
f"summary line:{summary_log_str}",
add_to_buffer=False,
_danger_unencrypted_search_fields={
field: summary_line[key]
for key, field in {
"HQ_ID": "sessionId",
"fi_num": "fi_num",
"customer_key": "customer_key",
"api_key": "api_key", # Caliper API key, added for compat. In the future, this should be a config at the project level, not hardcoded into the SDK
"trace_id": "trace_id",
}.items()
if key in summary_line
},
)
except Exception:
summary_line = {}
if self.get_status() == 500:
self.logger.replay_buffer()
self._set_runtime_metric()
increment_current_requests(
self.extension_name, self.request.method, reverse=True
)
def _cleanse_summary_dict(self, summary_dict: dict) -> dict:
updated_summary_dict = {} # check for cases when key value passed in with double quotes
for key, value in summary_dict.items():
updated_key = key
# spaces should be updated to underscores
if " " in updated_key:
updated_key = "_".join(updated_key.split(" "))
# ' cannot exist in key
if "'" in updated_key:
updated_key = "".join(updated_key.split("'"))
self.logger.warning(
f"' not accepted in summary dictionary key. Updating '{key}' to '{updated_key}'"
)
updated_summary_dict[updated_key] = value
return updated_summary_dict
[docs]
def log_exception(self, typ, value: Optional[BaseException], tb) -> None:
if configuration.settings.ENABLE_OPEN_TELEMETRY:
SpanContext.log_traceback(tb)
self.logger.error(
"Uncaught exception:\n\n%s\n\n",
self.request,
exc_info=(typ, value, tb),
add_to_buffer=False,
)
[docs]
def get_template(self, template_name: str, replacements_dict: dict) -> str:
"""
Used for loading static blocks of text with substitutions. Great for storing
large HTML blocks in separate files. Uses jinja behind the scenes
(http://jinja.pocoo.org/)
Will search in all templates folders starting with the calling extension,
then moving on to the parent class, and so on, preferring the lowest.
:param template_name: file_name including suffix. ex. initial.html
:param replacements_dict: Key value pairs to replace within template_name
:return: Body of template_name file replaced with values in replacements_dict
"""
replacements_dict = self._replacement_dict_update(replacements_dict)
return self.templater.get_template(template_name, replacements_dict)
def _replacement_dict_update(self, replacements_dict: dict):
if replacements_dict is None:
replacements_dict = {}
this = replacements_dict.get("this")
if this is not None and not isinstance(this, dict):
# If 'this' is passed as a non-dict, just pass it along without
# populating it with vars(self)
pass
else:
if this is not None:
this = {**vars(self), **this}
else:
this = vars(self)
this["base_assets_url"] = self.base_assets_url
replacements_dict.update({"this": this})
return replacements_dict
def _set_runtime_metric(self):
get_metric(
MetricType.Histogram,
"caliper_endpoints",
"Endpoint runtime",
{
"endpoint": self.extension_name,
"status_code": self._status_code,
"method": self.request.method,
},
chain={"op": "observe", "params": [self.request.request_time()]},
)
def get_trace_id(generate=False):
trace_id = ""
if settings.ENABLE_OPEN_TELEMETRY:
trace_id = hex(
trace.get_current_span().get_span_context().trace_id
).removeprefix("0x")
elif generate:
trace_id = str(uuid4())
return trace_id