from argparse import _SubParsersAction
import logging
from typing import Optional, Union, Type
from q2_sdk.core.default_settings import RUNNING_ENTRYPOINT
from q2_sdk.core.exceptions import DatabaseDataError, HqResponseError
from q2_sdk.core.configuration import settings
from q2_sdk.core.opentelemetry.span import Q2Span, Q2SpanAttributes
from q2_sdk.hq.db.representation_row_base import RepresentationRowBase
from q2_sdk.hq.models.hq_credentials import HqCredentials
from q2_sdk.hq.models.hq_response import HqResponse
from q2_sdk.core.dynamic_imports import (
api_ExecuteStoredProcedure,
ob_ExecuteStoredProcedure,
)
from q2_sdk.tools.sentinel import Sentinel
from q2_sdk.tools.utils import serialize_for_cli
DEFAULT_LOGGER = logging.getLogger()
DEFAULT = Sentinel("DEFAULT_DB_PARAM")
[docs]
class DbObject:
GET_BY_NAME_KEY = ""
NAME = ""
REPRESENTATION_ROW_CLASS: Optional[Type[RepresentationRowBase]] = None
def __init__(
self,
logger,
hq_credentials: Optional[HqCredentials] = None,
ret_table_obj: Optional[bool] = None,
):
"""
Programmatic access to the Q2 database. Not as flexible as a true ORM, but takes the
guesswork out of database schemas and ensures safety in the transactions.
:param logger: Reference to calling request's logger (self.logger in your extension)
:param hq_credentials: HQ Connectivity Information (Defaults to settings file)
:param ret_table_obj: Flag to return list of LXML elements if ``False`` or
TableRow objects from DB calls if ``True`` (Defaults to settings file)
"""
if ret_table_obj is None:
ret_table_obj = settings.RETURN_TABLE_OBJECTS_FROM_DB
if not logger:
logger = DEFAULT_LOGGER
self.logger = logger
self.hq_response: Optional[HqResponse] = None
self.ret_table_obj = ret_table_obj
self._hq_credentials = hq_credentials if hq_credentials else None
[docs]
def add_arguments(self, parser: _SubParsersAction):
"""
Hook for subclassed DbObjects to add custom arguments.
"""
@property
def hq_credentials(self) -> HqCredentials:
if not self._hq_credentials:
self._hq_credentials = settings.HQ_CREDENTIALS
return self._hq_credentials
[docs]
@Q2Span.instrument(skip=["stored_proc_short_name"])
async def call_hq(
self,
stored_proc_short_name: str,
sql_parameters: Optional[
Union[
api_ExecuteStoredProcedure.SqlParameters,
ob_ExecuteStoredProcedure.SqlParameters,
]
] = None,
specific_table: str = "Table",
representation_class_override=None,
use_json=True,
**kwargs,
):
Q2Span.set_attribute(
Q2SpanAttributes.STORED_PROCEDURE_NAME, stored_proc_short_name
)
representation_row_class = (
representation_class_override or self.REPRESENTATION_ROW_CLASS
)
self.hq_response = await self._call_execute_stored_procedure(
stored_proc_short_name, sql_parameters, use_json=use_json, **kwargs
)
if self.hq_response.success is False:
raise HqResponseError(
f'HQ Request returned with error message: "{self.hq_response.error_message}"'
)
if self.ret_table_obj:
return self.hq_response.parse_sproc_return(
representation_row_class=representation_row_class,
specific_table=specific_table,
)
else:
return self.hq_response.parse_sproc_return(specific_table=specific_table)
async def _call_execute_stored_procedure(
self,
stored_proc_short_name: str,
sql_parameters: Optional[
Union[
api_ExecuteStoredProcedure.SqlParameters,
ob_ExecuteStoredProcedure.SqlParameters,
]
] = None,
use_json=True,
**kwargs,
) -> HqResponse:
if self.hq_credentials.auth_token:
execute_module = ob_ExecuteStoredProcedure
else:
execute_module = api_ExecuteStoredProcedure
if RUNNING_ENTRYPOINT == "run":
self.logger.info(f"Executing Stored Procedure: {stored_proc_short_name}")
result = await execute_module.execute(
execute_module.ParamsObj(
logger=self.logger,
stored_proc_short_name=stored_proc_short_name,
sql_parameters=sql_parameters,
hq_credentials=self.hq_credentials,
),
use_json=use_json,
**kwargs,
)
return result
[docs]
@staticmethod
def serialize_for_cli(
rows: list,
fields_to_display: Optional[list[str]] = None,
fields_to_truncate: Optional[list[str]] = None,
):
"""
Tab delimits response for printing to a terminal
:param rows: XML elements from HQ response
:param fields_to_display: Optional. Displays all fields otherwise
:param fields_to_truncate: Optional. Limits display of these fields to 15 characters
:return: Tab delimited database rows
"""
return serialize_for_cli(rows, fields_to_display, fields_to_truncate)
[docs]
async def get_by_name(self, name, get_by_name_key=None, get_func=None, **kwargs):
if not get_by_name_key:
get_by_name_key = self.GET_BY_NAME_KEY
if not get_func:
get_func = self.get
if not (get_by_name_key and self.NAME):
raise NotImplementedError
full_response = await get_func(**kwargs)
filtered_response = [
x for x in full_response if x.findtext(get_by_name_key) == name
]
if not filtered_response:
raise DatabaseDataError(f'No {self.NAME} with name "{name}"')
else:
if len(filtered_response) > 1:
raise DatabaseDataError(f'More than one {self.NAME} with name "{name}"')
product = filtered_response[0]
return product