from dataclasses import dataclass
from datetime import datetime
from lxml.objectify import E, StringElement
from lxml.etree import tostring, fromstring, XMLSyntaxError
from typing import List, Optional
from .db_object import DbObject
from .representation_row_base import RepresentationRowBase
from q2_sdk.core.dynamic_imports import (
api_ExecuteStoredProcedure as ExecuteStoredProcedure,
)
from q2_sdk.core.exceptions import HqResponseError
from q2_sdk.hq.models.hq_credentials import HqCredentials
from q2_sdk.hq.hq_api.q2_api import GetHqVersion
from q2_sdk.hq.models.hq_params.stored_procedure import Param
from q2_sdk.tools.utils import float_to_string
D_TYPES = ExecuteStoredProcedure.DataType
[docs]
@dataclass
class RateRequest: # pylint: disable=invalid-name
StartDate: str
ExpireDate: str
SourceAuthority: str
SourceCurrency: str
ExchangeRate: str | float
ExchangeCurrency: str
def __post_init__(self):
if isinstance(self.ExchangeRate, float):
self.ExchangeRate = float_to_string(self.ExchangeRate)
[docs]
@dataclass
class ImportRates:
rates: List[RateRequest]
[docs]
@dataclass
class ExchangeRateResponse:
success: bool
error_messages: List[str]
[docs]
class CurrencyExchangeRateRow(RepresentationRowBase):
StartDate: StringElement = "StartDate"
ExpireDate: StringElement = "ExpireDate"
SourceAuthority: StringElement = "SourceAuthority"
SourceCurrency: StringElement = "SourceCurrency"
ExchangeRate: StringElement = "ExchangeRate"
ExchangeCurrency: StringElement = "ExchangeCurrency"
[docs]
class CurrencyExchangeRate(DbObject):
NAME = "CurrencyExchangeRate"
REPRESENTATION_ROW_CLASS = CurrencyExchangeRateRow
def __init__(
self,
logger=None,
hq_credentials: Optional[HqCredentials] = None,
ret_table_obj: Optional[bool] = None,
):
super().__init__(logger, hq_credentials, ret_table_obj)
self.response_obj = ExchangeRateResponse(success=False, error_messages=None)
[docs]
async def get(self, source: str = "", destination: str = ""):
params = []
if not any([source, destination]):
self.logger.info("source or destination parameters must be passed in")
return
if source:
Param(source, D_TYPES.VarChar, "source_currency").add_to_param_list(params)
elif destination:
Param(
destination, D_TYPES.VarChar, "destination_currency"
).add_to_param_list(params)
try:
await self.call_hq(
"sdk_GetExchangeRates", ExecuteStoredProcedure.SqlParameters(params)
)
except HqResponseError:
return self.hq_response
return self.hq_response
[docs]
async def create(self, rate_data: ImportRates) -> List[CurrencyExchangeRateRow]:
if self.response_obj.error_messages is None:
self.response_obj.error_messages = []
xml_import_string = await self._build_create_xml_request(rate_data)
if not xml_import_string:
self.response_obj.error_messages.append("Failed to build xml shape")
return self.response_obj
response = await self._call_import_exchange_rates(xml_import_string)
if not response.success:
self.response_obj.error_messages.append(
f"Stored proc failure: {response.error_message}"
)
return self.response_obj
try:
return_block = response.result_node.Data.NewDataSet.Table.Column1.text
info = fromstring(return_block)
error_messages = info.findtext("Error").split("Message:")[1]
except (AttributeError, XMLSyntaxError):
error_messages = None
self.response_obj.success = True
if error_messages:
self.response_obj.error_messages.append(f"Errors: {error_messages}")
return self.response_obj
async def _get_server_time(self, logger, hq_credentials):
get_version_obj = GetHqVersion.ParamsObj(logger, hq_credentials=hq_credentials)
get_version = await GetHqVersion.execute(get_version_obj)
current_date_time = get_version.server_date_time.isoformat()
return current_date_time
async def _build_create_xml_request(self, request: ImportRates):
exchange_rates = []
for rate in request.rates:
row = self._build_import_row(rate)
if row is not None:
exchange_rates.append(row)
if not exchange_rates:
return
now = await self._get_server_time(self.logger, self.hq_credentials)
return tostring(E.FERates(*exchange_rates, InsertDate=now), encoding="unicode")
def _build_import_row(self, rate: RateRequest):
try:
start_date = datetime.strptime(rate.StartDate, "%Y-%m-%d %H:%M:%S")
expire_date = datetime.strptime(rate.ExpireDate, "%Y-%m-%d %H:%M:%S")
except (TypeError, ValueError):
self.response_obj.error_messages.append(
f"Error parsing a date: {rate.StartDate} {rate.ExpireDate}"
)
return
return E.ImportRow(
StartDate=start_date.strftime("%Y-%m-%dT%H:%M:%S"),
ExpireDate=expire_date.strftime("%Y-%m-%dT%H:%M:%S"),
Description=rate.SourceAuthority,
USCurrency=rate.SourceCurrency,
ExchangeRate=rate.ExchangeRate,
ExchangeCurrency=rate.ExchangeCurrency,
)
async def _call_import_exchange_rates(self, payload):
sql_params = []
Param(payload, D_TYPES.Xml, "rates").add_to_param_list(sql_params)
sql_parameters = ExecuteStoredProcedure.SqlParameters(sql_params)
param_obj = ExecuteStoredProcedure.ParamsObj(
self.logger,
"bsp_ImportExchangeRates",
sql_parameters=sql_parameters,
hq_credentials=self.hq_credentials,
)
return await ExecuteStoredProcedure.execute(param_obj)