"""
This module provides an easy interface for asynchronous web requests
The Requests library (https://requests.readthedocs.io/en/master/)
is a beautiful and well known interface for making web requests
in python. With that said, it doesn't play nice with the asynchronous
features of Antilles. This is a thin wrapper module to achieve
the best of both worlds.
To utilize the Requests functionality with an existing ``requests.Session``
object (https://requests.readthedocs.io/en/master/user/advanced/?highlight=session#session-objects)
simply pass the session keyword argument to each HTTP method in this module.
If the "session" argument is omitted from the function call, then it will
use the default, non-session-aware requests method.
"""
import asyncio
import functools
import inspect
import logging
import re
import secrets
from http import HTTPStatus
from timeit import default_timer
from urllib.parse import urlparse
import requests
from q2_sdk.models.oauth_client import OAuthClient
from q2_sdk.core.configuration import settings
from q2_sdk.core.exceptions import BlacklistedUrlError
from q2_sdk.core.prometheus import MetricType, get_metric
from q2_sdk.core.opentelemetry.q2_requests import SpanContext
from q2_sdk.version import __version__
TEST_MODE = settings.TEST_MODE  # True when Q2SDK_TEST_MODE env var is True
OUTBOUND_FAILURE_MSG = (
    "%s not in OUTBOUND_WHITELIST. "
    "Outbound calls are allowed at development time, "
    "but will be blocked in the Q2 datacenter. "
    "Please create a Q2 Support ticket at "
    "https://q2developer.com/support/create?ticketType=Whitelist "
    "to whitelist this url if you are planning "
    "to use it in staging/production, "
    "then add it to the OUTBOUND_WHITELIST variable in your "
    "settings file to suppress this message."
)
TEXT_CONTENT_TYPES = {
    "text/css",
    "text/csv",
    "text/html",
    "text/javascript",
    "text/plain",
    "text/xml",
    "application/json",
    "application/xml",
    "application/javascript",
    "application/soap+xml",
    "application/vnd.kafka.v2+json",
    "application/vnd.kafka.json.v2+json",
}
USER_AGENT = f"python-q2-requests/{__version__}"
[docs]
class CallIdLogAdapter(logging.LoggerAdapter):
[docs]
    def process(self, msg, kwargs):
        return f"{self.extra['call_id']} {msg}", kwargs 
 
class Q2RequestInterface:
    def __init__(self, logger, minimal_logging: bool = False):
        self.dynamic_log_level = logging.getLevelName("DEBUG")
        info_level_entrypoints = settings.INSTALLED_ENTRYPOINTS + ["run"]
        if settings.RUNNING_ENTRYPOINT in info_level_entrypoints:
            self.dynamic_log_level = logging.getLevelName("INFO")
        self.logger = logger
        self._minimal_logging = minimal_logging
    async def call(self, url: str, verb: str, **kwargs):
        call_id = secrets.token_hex(4)
        logger = None
        if self.logger:
            logger = CallIdLogAdapter(self.logger, {"call_id": call_id})
        verb = verb.lower()
        available_verbs = ("get", "post", "put", "delete", "options", "head", "patch")
        assert verb in available_verbs, "Verb must be in %s" % str(available_verbs)
        if logger:
            params_msg = (
                f" Params: {kwargs.get('params')}" if kwargs.get("params") else ""
            )
            logger.log(
                self.dynamic_log_level,
                f"Sending HTTP request {verb.upper()} {url}{params_msg}",
            )
            if (
                logger.isEnabledFor(logging.getLevelName("DEBUG"))
                and not self._minimal_logging
            ):
                req_params = inspect.signature(requests.Request).parameters.keys()
                req = requests.Request(
                    method=verb,
                    url=url,
                    **{k: v for k, v in kwargs.items() if k in req_params},
                )
                req = req.prepare()
                logger.debug("Request Headers: %s", req.headers)
                logger.debug("Request Body: %s", req.body)
                logger.debug("Effective URL: %s", req.url)
                for file in kwargs.get("files", []):
                    try:
                        getattr(kwargs["files"][file], "seek")
                        kwargs["files"][file].seek(0)
                    except AttributeError:
                        # Is not a seekable file object
                        pass
        session: requests.Session | None = kwargs.pop("session", None)
        oauth_client = kwargs.pop("oauth_client", None)
        headers = session.headers if session else kwargs.get("headers") or {}
        retry_if_unauth = kwargs.pop("retry_if_unauth", oauth_client is not None)
        bt_handle = kwargs.pop("bt_handle", None)
        mock_response = kwargs.pop("mock_response", None)
        return_success = kwargs.pop("return_success", True)
        status_code = kwargs.pop("status_code", None)
        if not kwargs.get("timeout"):
            kwargs["timeout"] = (
                settings.Q2REQUESTS_DEFAULT_CONNECT_TIMEOUT,
                settings.Q2REQUESTS_DEFAULT_TIMEOUT,
            )
        if "User-Agent" not in headers:
            headers["User-Agent"] = USER_AGENT
        if oauth_client:
            if not isinstance(oauth_client, OAuthClient):
                raise TypeError("oauth_client must inherit from OAuthClient base class")
            headers = {
                **headers,
                **(await oauth_client.get_token_obj()).serialize_as_header(),
            }
        if session:
            session.headers = headers
        else:
            kwargs["headers"] = headers
        if TEST_MODE:
            from q2_sdk.tools.testing.models import Q2RequestMock
            request_mock = Q2RequestMock(
                verb,
                mock_response=mock_response,
                return_success=return_success,
                status_code=status_code,
                headers=kwargs.get("headers"),
            )
            func = functools.partial(request_mock.call, url, **kwargs)
        elif session:
            func = functools.partial(getattr(session, verb), url, **kwargs)
        else:
            func = functools.partial(getattr(requests, verb), url, **kwargs)
        start_time = default_timer()
        response = await make_request(func, bt_handle=bt_handle)
        end_time = default_timer()
        request_time = end_time - start_time
        if (
            oauth_client is not None
            and response.status_code in oauth_client.unauthorized_status_codes
            and retry_if_unauth
        ):
            oauth_client.clear_token()
            kwargs["retry_if_unauth"] = False
            return await self.call(url, verb, oauth_client=oauth_client, **kwargs)
        parsed_url = urlparse(url)
        clean_url = parsed_url.netloc.split("@")[-1]
        get_metric(
            MetricType.Histogram,
            "caliper_http_requests",
            "Outbound HTTP traffic",
            {"method": verb, "endpoint": clean_url, "scheme": parsed_url.scheme},
            chain={"op": "observe", "params": [request_time]},
        )
        if logger:
            if response.status_code < 400:
                log_method = logger.info
            elif response.status_code < 500:
                log_method = logger.warning
            else:
                log_method = logger.error
            try:
                http_phrase = HTTPStatus(response.status_code).phrase
            except ValueError:
                http_phrase = ""
            msg = (
                f"Received HTTP response after {1000 * request_time:.2f}ms - "
                f"{response.status_code} {http_phrase}"
            )
            logger.log(self.dynamic_log_level, msg)
            if not response.ok:
                log_method("Response URL: %s", url)
                log_method("Response Headers: %s", response.headers)
                log_method("Response Body: %s", response.content)
            elif not self._minimal_logging:
                content_type = response.headers.get("content-type")
                if not content_type:
                    logger.debug(
                        "Response did not have a content-type header, will not log content"
                    )
                elif any([
                    valid_type in content_type for valid_type in TEXT_CONTENT_TYPES
                ]):
                    logger.debug("Response Headers: %s", response.headers)
                    logger.debug("Response Body: %s", response.content)
                else:
                    logger.debug(
                        'Response of type "%s" not logging content', content_type
                    )
        return response
def verify_whitelist(function):
    @functools.wraps(function)
    async def wrapper(*args, **kwargs):
        verify = kwargs.pop("verify_whitelist", True)
        if settings.DEBUG and verify:
            bound_args = inspect.signature(function).bind(*args, **kwargs).arguments
            logger = bound_args["logger"]
            url = bound_args["url"]
            base_url = urlparse(url).netloc
            try:
                _check_whitelist(base_url)
            except BlacklistedUrlError as err:
                error_msg = OUTBOUND_FAILURE_MSG % urlparse(url).netloc
                if logger:
                    logger.error(error_msg)
                raise BlacklistedUrlError(error_msg) from err
        response = await function(*args, **kwargs)
        return response
    def _check_whitelist(url):
        """
        The purpose of this is not security, but rather to give an early
        warning of urls that will need to be whitelisted through Q2's networking
        layer
        """
        allowed = False
        for pattern in settings.OUTBOUND_WHITELIST:
            pattern = re.sub(r"https?://", "", pattern)
            if re.search(pattern, url):
                allowed = True
                break
        if not allowed:
            raise BlacklistedUrlError
    return wrapper
async def make_request(func: functools.partial, bt_handle=None):
    response = None
    with SpanContext(func) as ctx:
        response = await asyncio.get_event_loop().run_in_executor(None, func)
        ctx.record_response(response)
    return response
[docs]
@verify_whitelist
async def get(
    logger, url, params=None, session=None, minimal_logging: bool = False, **kwargs
) -> requests.Response:
    r"""Sends a GET request.
    :param logger: Reference to calling request's logger (self.logger in your extension)
    :param url: URL for the new :class:`Request` object.
    :param params: (optional) Dictionary or bytes to be sent in the query string for the :class:`Request`.
    :param session: (optional) :class:`requests.Session` object to base the request off of.
    :param minimal_logging: (optional) flag to turn off debug logging of request header/body/etc.
    :param \*\*kwargs: Optional arguments that ``request`` takes.
    """
    kwargs["params"] = params
    kwargs["session"] = session
    request_obj = Q2RequestInterface(logger, minimal_logging)
    return await request_obj.call(url, "get", **kwargs) 
[docs]
@verify_whitelist
async def options(logger, url, session=None, minimal_logging: bool = False, **kwargs):
    r"""Sends an OPTIONS request.
    :param logger: Reference to calling request's logger (self.logger in your extension)
    :param url: URL for the new :class:`Request` object.
    :param session: (optional) :class:`requests.Session` object to base the request off of.
    :param minimal_logging: (optional) flag to turn off debug logging of request header/body/etc.
    :param \*\*kwargs: Optional arguments that ``request`` takes.
    """
    kwargs["session"] = session
    request_obj = Q2RequestInterface(logger, minimal_logging)
    return await request_obj.call(url, "options", **kwargs) 
[docs]
@verify_whitelist
async def post(
    logger,
    url,
    data=None,
    json=None,
    session=None,
    minimal_logging: bool = False,
    **kwargs,
) -> requests.Response:
    r"""Sends a POST request.
    :param logger: Reference to calling request's logger (self.logger in your extension)
    :param url: URL for the new :class:`Request` object.
    :param data: (optional) Dictionary (will be form-encoded), bytes, or file-like object to send in the body of the :class:`Request`.
    :param json: (optional) json data to send in the body of the :class:`Request`.
    :param session: (optional) :class:`requests.Session` object to base the request off of.
    :param minimal_logging: (optional) flag to turn off debug logging of request header/body/etc.
    :param \*\*kwargs: Optional arguments that ``request`` takes.
    """
    kwargs["session"] = session
    kwargs["json"] = json
    kwargs["data"] = data
    request_obj = Q2RequestInterface(logger, minimal_logging)
    return await request_obj.call(url, "post", **kwargs) 
[docs]
@verify_whitelist
async def head(
    logger, url, session=None, minimal_logging: bool = False, **kwargs
) -> requests.Response:
    r"""Sends a HEAD request.
    :param logger: Reference to calling request's logger (self.logger in your extension)
    :param url: URL for the new :class:`Request` object.
    :param session: (optional) :class:`requests.Session` object to base the request off of.
    :param minimal_logging: (optional) flag to turn off debug logging of request header/body/etc.
    :param \*\*kwargs: Optional arguments that ``request`` takes.
    """
    kwargs["session"] = session
    request_obj = Q2RequestInterface(logger, minimal_logging)
    return await request_obj.call(url, "head", **kwargs) 
[docs]
@verify_whitelist
async def put(
    logger, url, data=None, session=None, minimal_logging: bool = False, **kwargs
) -> requests.Response:
    r"""Sends a PUT request.
    :param logger: Reference to calling request's logger (self.logger in your extension)
    :param url: URL for the new :class:`Request` object.
    :param data: (optional) Dictionary (will be form-encoded), bytes, or file-like object to send in the body of the :class:`Request`.
    :param json: (optional) json data to send in the body of the :class:`Request`.
    :param session: (optional) :class:`requests.Session` object to base the request off of.
    :param minimal_logging: (optional) flag to turn off debug logging of request header/body/etc.
    :param \*\*kwargs: Optional arguments that ``request`` takes.
    """
    kwargs["session"] = session
    kwargs["data"] = data
    request_obj = Q2RequestInterface(logger, minimal_logging)
    return await request_obj.call(url, "put", **kwargs) 
[docs]
@verify_whitelist
async def patch(
    logger, url, data=None, session=None, minimal_logging: bool = False, **kwargs
) -> requests.Response:
    r"""Sends a PATCH request.
    :param logger: Reference to calling request's logger (self.logger in your extension)
    :param url: URL for the new :class:`Request` object.
    :param data: (optional) Dictionary (will be form-encoded), bytes, or file-like object to send in the body of the :class:`Request`.
    :param json: (optional) json data to send in the body of the :class:`Request`.
    :param session: (optional) :class:`requests.Session` object to base the request off of.
    :param minimal_logging: (optional) flag to turn off debug logging of request header/body/etc.
    :param \*\*kwargs: Optional arguments that ``request`` takes.
    """
    kwargs["session"] = session
    kwargs["data"] = data
    request_obj = Q2RequestInterface(logger, minimal_logging)
    return await request_obj.call(url, "patch", **kwargs) 
[docs]
@verify_whitelist
async def delete(
    logger, url, session=None, minimal_logging: bool = False, **kwargs
) -> requests.Response:
    r"""Sends a DELETE request.
    :param logger: Reference to calling request's logger (self.logger in your extension)
    :param url: URL for the new :class:`Request` object.
    :param session: (optional) :class:`requests.Session` object to base the request off of.
    :param minimal_logging: (optional) flag to turn off debug logging of request header/body/etc.
    :param \*\*kwargs: Optional arguments that ``request`` takes.
    """
    kwargs["session"] = session
    request_obj = Q2RequestInterface(logger, minimal_logging)
    return await request_obj.call(url, "delete", **kwargs)