"""
HealthCheck Extension just returns 200
Helps ensure the service is up and responding
"""
from datetime import datetime, timedelta
import importlib
from typing import List
from lxml.etree import XMLSyntaxError
from tornado.web import HTTPError
import requests.exceptions
from q2_sdk.core import cache
from q2_sdk.core.non_http_handlers.udp import writer as udp_writer
from q2_sdk.core.http_handlers.base_handler import Q2BaseRequestHandler
from q2_sdk.core.exceptions import VaultError
from q2_sdk.hq.exceptions import MissingStoredProcError
from q2_sdk.core.configuration import settings
from q2_sdk.hq.hq_api.q2_api import ExecuteStoredProcedure
from q2_sdk.hq import http
from q2_sdk.tools import utils
from q2_sdk.hq.exceptions import HQConnectionError
from q2_sdk.core.non_http_handlers.custom_health_check_handler import (
    CustomHealthCheckResponse,
)
from . import filters
[docs]
class HealthCheckHandler(Q2BaseRequestHandler):
    RECHECK_INTERVAL = 60  # In seconds
    HQ_HEALTHY = "N/A"
    LAST_HQ_SUCCESS = None
    VAULT_HEALTHY = "N/A"
    LAST_VAULT_SUCCESS = None
    MEMCACHED_HEALTHY = "N/A"
    LAST_MEMCACHED_SUCCESS = None
    CUSTOM_CHECKS_HEALTHY = "N/A"
    LAST_CUSTOM_CHECKS_SUCCESS = None
    CUSTOM_CHECK_DETAILS = []
    CUSTOM_HC_FAILED_EXTENSIONS = []
    CUSTOM_HC_HEALTHY_EXTENSIONS = []
    @property
    def LOGGING_FILTERS(self):
        return [filters.FilterAll()] if not settings.DEBUG else []
    def __init__(self, application, request, **kwargs):
        super().__init__(application, request, **kwargs)
        self.vault_failure_reason = None
        self.hq_failure_reason = None
        self.memcached_failure_reason = None
        self.now = datetime.now()
        self.custom_status_endpoints = None
        self.custom_status_response = None
        self.enable_summary_log = False
[docs]
    def reset_health_stats(self):
        """Reset static variables on the class in case of inheritance"""
        self.__class__.HQ_HEALTHY = "N/A"
        self.__class__.LAST_HQ_SUCCESS = None
        self.__class__.VAULT_HEALTHY = "N/A"
        self.__class__.LAST_VAULT_SUCCESS = None
        self.__class__.MEMCACHED_HEALTHY = "N/A"
        self.__class__.LAST_MEMCACHED_SUCCESS = None
        self.__class__.CUSTOM_CHECKS_HEALTHY = "N/A"
        self.__class__.LAST_CUSTOM_CHECKS_SUCCESS = None
        self.__class__.CUSTOM_CHECK_DETAILS = []
        self.__class__.CUSTOM_HC_FAILED_EXTENSIONS = []
        self.__class__.CUSTOM_HC_HEALTHY_EXTENSIONS = [] 
    def post(self, *args, **kwargs):
        raise HTTPError(405)
[docs]
    async def get(self, *args, **kwargs):
        """Returns 200"""
        if self.get_argument("no_cache", "").lower() == "true":
            self.logger.info("Resetting health stats")
            self.reset_health_stats()
        template_type = "txt"
        if self.request.headers.get("Accept", "").startswith("application/json"):
            template_type = "json"
        if query_format := self.get_argument("format", None):
            template_type = query_format
        if (
            settings.USE_VAULT_FOR_CREDS
            and not settings.HEALTH_CHECK_OVERRIDES.skip_vault
        ):
            service_health = await self.get_service_health(
                HealthCheckHandler.LAST_VAULT_SUCCESS, self.now, self.get_vault_health
            )
            if service_health is not None:
                HealthCheckHandler.VAULT_HEALTHY = service_health
            if service_health is True:
                HealthCheckHandler.LAST_VAULT_SUCCESS = self.now
        if (
            settings.HQ_CREDENTIALS.hq_url
            and not settings.HEALTH_CHECK_OVERRIDES.skip_hq
        ):
            service_health = await self.get_service_health(
                HealthCheckHandler.LAST_HQ_SUCCESS, self.now, self.get_hq_health
            )
            if service_health is not None:
                HealthCheckHandler.HQ_HEALTHY = service_health
            if service_health is True:
                HealthCheckHandler.LAST_HQ_SUCCESS = self.now
        if not settings.HEALTH_CHECK_OVERRIDES.skip_memcache:
            service_health = await self.get_service_health(
                HealthCheckHandler.LAST_MEMCACHED_SUCCESS,
                self.now,
                self.get_memcached_health,
            )
            if service_health is not None:
                HealthCheckHandler.MEMCACHED_HEALTHY = service_health
            if service_health is True:
                HealthCheckHandler.LAST_MEMCACHED_SUCCESS = self.now
        self.custom_status_endpoints = utils.get_installed_custom_status_endpoints()
        if self.custom_status_endpoints:
            self.custom_status_response = await self.get_service_health(
                HealthCheckHandler.LAST_CUSTOM_CHECKS_SUCCESS,
                self.now,
                self.get_custom_health,
            )
            for info in HealthCheckHandler.CUSTOM_HC_HEALTHY_EXTENSIONS:
                HealthCheckHandler.CUSTOM_CHECK_DETAILS.append(info)
            if self.custom_status_response is not None:
                over_all_health = (
                    True
                    if not HealthCheckHandler.CUSTOM_HC_FAILED_EXTENSIONS
                    else False
                )
                if over_all_health is not None:
                    HealthCheckHandler.CUSTOM_CHECKS_HEALTHY = over_all_health
                if over_all_health is True:
                    HealthCheckHandler.LAST_CUSTOM_CHECKS_SUCCESS = self.now
                    HealthCheckHandler.CUSTOM_HC_HEALTHY_EXTENSIONS = []
        if template_type == "json":
            response = self.serialize_as_json()
        else:
            response = self.serialize_as_html()
        self.write(response)
        if settings.FORK_REQUESTS:
            msg = {
                "HQ_HEALTHY": HealthCheckHandler.HQ_HEALTHY,
                "MEMCACHED_HEALTHY": HealthCheckHandler.MEMCACHED_HEALTHY,
                "VAULT_HEALTHY": HealthCheckHandler.VAULT_HEALTHY,
                "CUSTOM_HEALTH": HealthCheckHandler.CUSTOM_CHECKS_HEALTHY,
                "CUSTOM_DETAILS": HealthCheckHandler.CUSTOM_CHECK_DETAILS,
                "CUSTOM_FAILED": HealthCheckHandler.CUSTOM_HC_FAILED_EXTENSIONS,
                "CUSTOM_HEALTHY": HealthCheckHandler.CUSTOM_HC_HEALTHY_EXTENSIONS,
            }
            if HealthCheckHandler.LAST_HQ_SUCCESS:
                msg["LAST_HQ_SUCCESS"] = HealthCheckHandler.LAST_HQ_SUCCESS.isoformat()
            if HealthCheckHandler.LAST_VAULT_SUCCESS:
                msg["LAST_VAULT_SUCCESS"] = (
                    HealthCheckHandler.LAST_VAULT_SUCCESS.isoformat()
                )
            if HealthCheckHandler.LAST_MEMCACHED_SUCCESS:
                msg["LAST_MEMCACHED_SUCCESS"] = (
                    HealthCheckHandler.LAST_MEMCACHED_SUCCESS.isoformat()
                )
            if HealthCheckHandler.LAST_CUSTOM_CHECKS_SUCCESS:
                msg["LAST_CUSTOM_CHECKS_SUCCESS"] = (
                    HealthCheckHandler.LAST_CUSTOM_CHECKS_SUCCESS.isoformat()
                )
            udp_writer.send_msg(udp_writer.MsgType.HealthCheck, msg) 
    def serialize_as_json(self):
        return self._get_health_as_dict()
    def serialize_as_html(self):
        return self._serialize_from_template("healthcheck.txt.jinja2")
    def _get_health_as_dict(self):
        return {
            "healthy": self.is_healthy(),
            "hq_healthy": HealthCheckHandler.HQ_HEALTHY,
            "hq_failure_reason": self.hq_failure_reason,
            "vault_healthy": HealthCheckHandler.VAULT_HEALTHY,
            "vault_failure_reason": self.vault_failure_reason,
            "memcached_healthy": HealthCheckHandler.MEMCACHED_HEALTHY,
            "memcached_failure_reason": self.memcached_failure_reason,
            "custom_checks_healthy": HealthCheckHandler.CUSTOM_CHECKS_HEALTHY,
            "custom_checks_details": HealthCheckHandler.CUSTOM_CHECK_DETAILS,
        }
    def _serialize_from_template(self, template_name):
        template = self.get_template(template_name, self._get_health_as_dict())
        return template
    async def get_service_health(
        self, last_success: datetime, now: datetime, service_func
    ):
        if last_success is None or (now - last_success) > timedelta(
            seconds=HealthCheckHandler.RECHECK_INTERVAL
        ):
            try:
                return await service_func()
            except Exception as error:
                self.logger.error("Exception occurred while trying to call service")
                self.logger.error(str(error))
                return False
    @staticmethod
    def is_healthy():
        health_checks = [
            HealthCheckHandler.HQ_HEALTHY,
            HealthCheckHandler.VAULT_HEALTHY,
            HealthCheckHandler.MEMCACHED_HEALTHY,
            HealthCheckHandler.CUSTOM_CHECKS_HEALTHY,
        ]
        for check in health_checks:
            if check is False:
                return False
        return True
    async def get_vault_health(self):
        try:
            http.refresh_hq_creds(logger=self.logger, timeout=1)
        except (VaultError, requests.exceptions.ConnectionError) as err:
            self.vault_failure_reason = err.args[0]
            self.logger.error(self.vault_failure_reason)
            return False
        return True
    async def get_hq_health(self):
        timeout = 3
        try:
            hq_response = await ExecuteStoredProcedure.execute(
                ExecuteStoredProcedure.ParamsObj(
                    self.logger, "sdk_HealthCheck", hq_credentials=self.hq_credentials
                ),
                timeout=timeout,
            )
        except (IOError, XMLSyntaxError):
            self.hq_failure_reason = "No response from HQ: {}".format(
                settings.HQ_CREDENTIALS.hq_url
            )
            self.logger.error(self.hq_failure_reason)
            return False
        except MissingStoredProcError:
            self.hq_failure_reason = (
                "Missing sdk_HealthCheck stored procedure in database"
            )
            self.logger.error(self.hq_failure_reason)
            return False
        except HQConnectionError:
            self.hq_failure_reason = "%s did not respond in %s seconds" % (
                settings.HQ_CREDENTIALS.hq_url,
                timeout,
            )
            self.logger.error(self.hq_failure_reason)
            return False
        if not hq_response.success:
            if hq_response.error_message.lower() == "logon failure":
                self.hq_failure_reason = "Bad Credentials for HQ: {}".format(
                    settings.HQ_CREDENTIALS.hq_url
                )
            else:
                self.hq_failure_reason = "Malformed response from HQ: {}".format(
                    settings.HQ_CREDENTIALS.hq_url
                )
            self.logger.error(self.hq_failure_reason)
            return False
        return True
    async def get_memcached_health(self):
        cache_id = "{}:{}".format(settings.CACHE["HOST"], settings.CACHE["PORT"])
        if HealthCheckHandler.MEMCACHED_HEALTHY is False:
            cache.get_cache()
        try:
            cache.get_cache().version()
            return True
        except ConnectionRefusedError:
            self.memcached_failure_reason = "Unable to connect to Memcached: {}".format(
                cache_id
            )
            return False
        except Exception:
            self.memcached_failure_reason = "Unknown Memcached Failure: {}".format(
                cache_id
            )
            return False
    async def get_custom_health(self) -> List[CustomHealthCheckResponse]:
        self.logger.info("Found custom health checks")
        HealthCheckHandler.CUSTOM_CHECK_DETAILS = []
        custom_response_objects = []
        if HealthCheckHandler.CUSTOM_HC_FAILED_EXTENSIONS:
            self.logger.info(
                f"Calling the last failed endpoints: {HealthCheckHandler.CUSTOM_HC_FAILED_EXTENSIONS}"
            )
            self.custom_status_endpoints = (
                HealthCheckHandler.CUSTOM_HC_FAILED_EXTENSIONS
            )
            HealthCheckHandler.CUSTOM_HC_FAILED_EXTENSIONS = []
        for status in self.custom_status_endpoints:
            extension_class = f"{utils.pascalize(status.split('.')[-1])}Handler"
            extension = f"{status}.status"
            endpoint_class = getattr(
                importlib.import_module(extension), extension_class
            )
            status_obj = endpoint_class(logger=self.logger)
            status_response = await status_obj.run()
            custom_response_objects.append(status_response)
            if status_response.is_healthy:
                HealthCheckHandler.CUSTOM_HC_HEALTHY_EXTENSIONS.append(
                    vars(status_response)
                )
            else:
                HealthCheckHandler.CUSTOM_HC_FAILED_EXTENSIONS.append(status)
                HealthCheckHandler.CUSTOM_CHECK_DETAILS.append(vars(status_response))
        return custom_response_objects