import json
import logging
import os
import re
import time
from dataclasses import dataclass
from enum import Enum
from pathlib import Path
from typing import Literal, Optional, Union, overload
import hvac
from hvac.exceptions import Forbidden
from requests.exceptions import ConnectionError as ReqConnectionError
from q2_sdk.core import contexts, initial_settings, cache
from q2_sdk.core.configuration import settings
from q2_sdk.core.exceptions import BadParameterError, MockDataError, VaultError
from q2_sdk.core.non_http_handlers.udp import writer as udp_writer
from q2_sdk.core.opentelemetry.span import Q2Span, Q2SpanAttributes
from q2_sdk.hq.models.hq_credentials import HqCredentials
from q2_sdk.models.recursive_encoder import JsonSerializable
from q2_sdk.models.unique_stack import ForkedUniqueStack
from q2_sdk.tools.sentinel import DEFAULT
NOMAD_SECRETS_DIR = initial_settings.NOMAD_SECRETS_DIR
DEFAULT_PREFIX = initial_settings.VAULT_DEFAULT_PREFIX
LOCAL_VAULT_DIR = initial_settings.VAULT_LOCAL_DIR
CERT_CACHE_SECONDS = 300  # Number of seconds a certificate will be cached on local disk
LOGGER = logging.getLogger()
LOGGER.setLevel(initial_settings.LOGGING_LEVEL)
VAULT_ENCRYPTION_KEY = settings.VAULT_ENCRYPTION_KEY
[docs]
class StorageLevel(Enum):
    Institution = "institution"
    Stack = "stack" 
[docs]
@dataclass
class SearchReplacements:
    institution_id: Optional[str] = None
    stack_id: Optional[str] = None
    def __post_init__(self):
        from q2_sdk.core.configuration import settings
        self.institution_id: str = (
            self.institution_id or settings.COMPANY or settings.HQ_CREDENTIALS.aba
        )
        if not self.stack_id:
            current_request = contexts.get_current_request(raise_if_none=False)
            if current_request:
                customer_key = (
                    current_request.request_handler.hq_credentials.customer_key
                )
                self.stack_id = (
                    settings.VAULT_KEY if customer_key == "CHANGEME" else customer_key
                )
            else:
                self.stack_id = settings.VAULT_KEY 
[docs]
@dataclass
class Certificate:
    CERT_STR = r"(-----BEGIN CERTIFICATE-----.+?-----END CERTIFICATE-----)"
    KEY_STR = r"(-----BEGIN \w*\s?PRIVATE KEY-----.+?-----END \w*\s?PRIVATE KEY-----)"
    path: str
    def __post_init__(self):
        self._body = None
        self._cert = None
        self._key = None
    def _get_body(self):
        if self._body is None:
            with open(self.path) as handle:
                self._body = handle.read()
        return self._body
    @property
    def cert(self):
        body = self._get_body()
        if self._cert is None:
            matches = re.search(self.CERT_STR, body, flags=re.DOTALL)
            cert = ""
            for match in matches.groups():
                cert += match
            self._cert = cert
        return self._cert
    @property
    def key(self):
        body = self._get_body()
        if self._key is None:
            matches = re.search(self.KEY_STR, body, flags=re.DOTALL)
            key = ""
            for match in matches.groups():
                key += match
            self._key = key
        return self._key 
[docs]
@dataclass
class RecentVaultKey(JsonSerializable):
    prefix: str
    key: str
    success: bool
    def __eq__(self, __value: object) -> bool:
        return self.prefix + self.key == __value.prefix + __value.key 
[docs]
class RecentKeysStack(ForkedUniqueStack):
    """
    Configure RecentKeys to work with forked mode
    """
    UDP_MESSAGE_TYPE = udp_writer.MsgType.Vault 
[docs]
@dataclass
class LocalPathConfig:
    cert_name: str
    institution: str
    hq_credentials: HqCredentials
    env: str
    def file_name(self):
        return f"{self.institution}_{self.hq_credentials.aba}_{self.env}_{self.cert_name}.pem" 
[docs]
@dataclass
class VaultPath:
    key: str
    prefix: str
    @property
    def cache_key(self):
        return f"{self.prefix}/{self.key}" 
[docs]
class Q2Vault:
    """Class for interacting with Hashicorp's Vault (https://www.vaultproject.io)"""
    RECENT_KEYS = RecentKeysStack(100)
    def __init__(
        self,
        addr,
        token,
        allow_local: bool = True,
        logger=LOGGER,
        timeout: int = 3,
        namespace=os.getenv("VAULT_NAMESPACE"),
    ):
        self.addr = addr
        self.client = hvac.Client(
            url=addr,
            token=token,
            timeout=timeout,
            namespace=namespace,
        )
        self.deploy_env = os.environ.get("DEPLOY_ENV", "DEV")
        self.logger = logger
        self._settings_module = None
        self.is_local_client = LOCAL_VAULT_DIR and allow_local is True
[docs]
    @Q2Span.instrument(name="vault.read", skip=["self"])
    def read(
        self,
        key,
        *,
        level: StorageLevel = StorageLevel.Stack,
        replacements: Optional[SearchReplacements] = None,
        default=None,
        **kwargs,
    ) -> dict:
        """
        Vault stores data at paths like a filesystem. We utilize the base path
        secret/ABA for historical reasons, then your data can be stored below
        that at various locations depending on your needs.
        In order of most to least specific data storage location, we have:
        - Stack
            - Each Database in Q2 gets a unique id assigned to it. This is
              available in settings.VAULT_KEY
        - Institution
            - This will tie to the value in settings.COMPANY if it exists,
              or settings.HQ_CREDENTIALS.aba if it does not
        :param key: Keyname in Vault
        :param level: To specify Stack or Institution storage location
        :param replacements: Provided as a way to query a different FI than
          the one configured for this service.
          Useful in multitenant scenarios
        :type replacements: SearchReplacements
        :param default: Value to return if data does not exist at the specified location
        """
        if self._settings_module is None:
            from q2_sdk.core.configuration import settings
            self._settings_module = settings
        if self._settings_module.VAULT_SCOPED_READ is False:
            self.logger.warning(
                "Opting in to old Vault read style. This is deprecated. Please set VAULT_SCOPED_READ=True in your settings file."
            )
            return self.read_raw(key, prefix=kwargs.get("prefix", DEFAULT_PREFIX))
        if replacements is None:
            replacements = SearchReplacements()
        if not isinstance(replacements, SearchReplacements):
            raise BadParameterError(
                "replacements parameter must be an instance of SearchReplacements"
            )
        vault_path = self._get_path(
            key,
            level,
            replacements,
            default_prefix=kwargs.get("prefix", DEFAULT_PREFIX),
        )
        data = None
        cache_obj = None
        if self._settings_module.VAULT_CACHE_READS:
            cache_obj = cache.get_cache(encryption_key=VAULT_ENCRYPTION_KEY)
            data = cache_obj.get(vault_path.cache_key)
            if isinstance(data, str):
                try:
                    data = json.loads(data)
                except json.JSONDecodeError:
                    self.logger.error("failed to json serialize string from vault")
        if data is None:
            vault_data = self._get_vault_data(vault_path, default)
            vault_expiry = self._settings_module.VAULT_EXPIRY
            if cache_obj:
                cache_obj.set(vault_path.cache_key, vault_data, expire=vault_expiry)
            data = vault_data
        return data 
    def _get_path(
        self,
        key,
        level: StorageLevel = StorageLevel.Stack,
        replacements: Optional[SearchReplacements] = None,
        default_prefix=DEFAULT_PREFIX,
    ) -> VaultPath:
        match level:
            case StorageLevel.Stack:
                if not replacements.stack_id:
                    Q2Span.set_attribute(
                        Q2SpanAttributes.SDK_FRAMEWORK_ERROR, "Vault Config"
                    )
                    raise BadParameterError(
                        "replacements.stack_id is empty, but vault request was made at the Stack level"
                    )
                path = f"{self.deploy_env}_{replacements.stack_id}"
            case StorageLevel.Institution:
                path = replacements.institution_id
            case _:
                Q2Span.set_attribute(
                    Q2SpanAttributes.SDK_FRAMEWORK_ERROR, "Vault Config"
                )
                raise BadParameterError(
                    "level parameter must be an instance of StorageLevel"
                )
        return VaultPath(key, f"{default_prefix}/{path}")
    def _get_vault_data(self, vault_path: VaultPath, default=None):
        data = self.read_raw(vault_path.key, prefix=vault_path.prefix)
        if not data:
            if not default:
                Q2Span.set_attribute(Q2SpanAttributes.SDK_FRAMEWORK_ERROR, "Vault Data")
            return default
        return data.get("data", data)
[docs]
    def read_raw(self, key, *, prefix=DEFAULT_PREFIX):
        """Catchall way of interacting with Vault directly if one of the more specific helper functions doesn't do the trick"""
        if prefix is None:
            prefix = ""
        if self.is_local_client:
            data = self._get_local_key(key, prefix)
        else:
            try:
                self.logger.debug("Searching %s/%s", prefix, key)
                data = self.client.read("{}/{}".format(prefix, key))
            except ReqConnectionError as err:
                raise ReqConnectionError(
                    "Unable to connect to vault_addr: {}".format(self.addr)
                ) from err
            except Forbidden as err:
                new_token = get_token()
                if new_token == self.client.token:
                    raise Forbidden("Bad Vault Token") from err
                self.client.token = new_token
                data = self.read_raw(key, prefix=prefix)
        if data:
            self.logger.debug("Vault data read successfully")
            self.RECENT_KEYS.append(RecentVaultKey(prefix, key, True).to_json())
        else:
            self.logger.debug("No vault data")
            Q2Span.add_event(
                "Vault Data", {"requested_path": "{}/{}".format(prefix, key)}
            )
            self.RECENT_KEYS.append(RecentVaultKey(prefix, key, False).to_json())
        return data 
[docs]
    def write(self, path: str, body=dict, prefix=DEFAULT_PREFIX):
        """
        This is only possible if appropriate authorization (Vault Policies) are
        tied to the active vault token.
        If configured with a local vault directory, writes will always be allowed.
        """
        if not prefix:
            prefix = ""
        path = Path(prefix) / path
        if self.is_local_client:
            local_vault_path = Path(LOCAL_VAULT_DIR)
            local_vault_path.mkdir(exist_ok=True)
            full_path = local_vault_path / path
            full_path.parent.mkdir(parents=True, exist_ok=True)
            full_path.write_text(json.dumps(body))
            self.logger.info("Writing vault data to local path: %s", full_path)
        else:
            self.logger.info("Writing vault data to vault at %s", path)
            return self.client.write(path, **body) 
    def _get_local_key(self, key, prefix):
        error_msg = 'Mock vault data must be a JSON object with a "data" key'
        path = Path(LOCAL_VAULT_DIR) / prefix / key
        path = path.expanduser().resolve()
        if not path.exists():
            self.logger.error("Path does not exist: %s", path)
            return None
        with open(path) as handle:
            self.logger.debug("Searching %s", path)
            try:
                response = json.loads(handle.read())
            except json.decoder.JSONDecodeError as err:
                raise MockDataError(error_msg) from err
            if not response.get("data"):
                raise MockDataError(error_msg)
            return response
    @overload
    def get_certificate(
        self, cert_name: str, hq_credentials: HqCredentials, as_obj: Literal[True]
    ) -> Certificate: ...
    @overload
    def get_certificate(
        self,
        cert_name: str,
        hq_credentials: HqCredentials,
        env: str,
        as_obj: Literal[True],
    ) -> Certificate: ...
    @overload
    def get_certificate(
        self,
        cert_name: str,
        hq_credentials: HqCredentials,
        env: str = "default",
        as_obj: bool = False,
    ) -> str: ...
[docs]
    def get_certificate(
        self,
        cert_name: str,
        hq_credentials: HqCredentials,
        env: str = "default",
        as_obj: bool = False,
    ) -> Union[str, Certificate]:
        """
        Gets a certificate from vault or local disk cache if available. Local certificate cache is good for 300 seconds.
        Looks up one of two paths:
        - If VAULT_SCOPED_READ is True in settings: ``{institution}/certs/{cert_name}``
          where ``{institution}`` is either settings.COMPANY or hq_credentials.ABA
        - If VAULT_SCOPED_READ is False in settings: ``{ABA}_certs/{env}/{cert_name}``
        If VAULT_SCOPED_READ is True, it will still fall through to old behavior if cert is not found at new location.
        :param cert_name: Last part of the path in Vault
        :param hq_credentials: Used to specify the first part of the path in Vault
        :param env: Middle part of the path in Vault (ignored if VAULT_SCOPED_READ is True)
        :param as_obj: If True, will return a Certificate object, with .key and .cert properties as well as .path
        """
        if self._settings_module is None:
            from q2_sdk.core.configuration import settings
            self._settings_module = settings
        recalculate = True
        cert_path = self._get_local_secret_cert_path(
            LocalPathConfig(
                cert_name, self._settings_module.COMPANY, hq_credentials, env
            )
        )
        if (
            cert_path
            and cert_path.is_file()
            and file_age(cert_path) < CERT_CACHE_SECONDS
        ):
            recalculate = False
            self.logger.debug(
                f"Certificate file less than {CERT_CACHE_SECONDS} seconds old. Skipping vault read..."
            )
        if recalculate:
            vault_response = None
            error = None
            if self._settings_module.VAULT_SCOPED_READ:
                key = f"certs/{cert_name}"
                vault_response = self.read(key, level=StorageLevel.Institution)
                error = f"KEY {key} not present in Vault"
            if not vault_response:
                # Fall back to older storage location using {aba}_certs pattern
                key = f"{hq_credentials.aba}_certs/{env}/{cert_name}"
                vault_response = self.read_raw(key)
                if not error:
                    error = f"KEY {key} not present in Vault"
            if not vault_response:
                raise VaultError(error)
            with open(cert_path, "w") as cert_file:
                if "data" in vault_response:
                    vault_response = vault_response["data"]
                cert = vault_response.get("cert")
                if cert:
                    cert_file.write(cert)
                    cert_file.write("\n")
                key = vault_response.get("key")
                if key:
                    cert_file.write(key)
                    cert_file.write("\n")
        response = str(cert_path)
        if as_obj is True:
            response = Certificate(response)
        return response 
    @staticmethod
    def _get_local_secret_cert_path(local_path_config: LocalPathConfig) -> Path:
        if NOMAD_SECRETS_DIR is not None:
            key = local_path_config.file_name()
            return Path(NOMAD_SECRETS_DIR).joinpath(key).expanduser().resolve()
[docs]
    def get_smart_token(self, key) -> str:
        """Helper for getting Q2Smart token for this SDK instance"""
        key = f"{self.deploy_env}_{key}"
        data = self.read_raw(key)
        if not data:
            raise VaultError(f"KEY {key} not present in Vault")
        data_node = data["data"]
        try:
            smart_token = data_node["SMART_TOKEN"]
        except KeyError as err:
            raise VaultError(f"SMART_TOKEN is not set for KEY {key}") from err
        return smart_token 
[docs]
    def get_hq_creds(self, key, prefix_with_deploy_env=True) -> HqCredentials:
        """Helper for getting HQCredentials for this SDK instance
        :param key: vault key
        :param prefix_with_deploy_env: if true, will search {deploy_env}_{key}
        """
        customer_key = key
        if prefix_with_deploy_env:
            key = f"{self.deploy_env}_{key}"
        orig_key = key
        if self.is_local_client:
            key = f"{key}/hq_creds"
        data = self.read_raw(key)
        if self.is_local_client and not data:
            self.logger.error(
                f"Data not generated for local vault mock. Try running `q2 vault add_hq_creds {orig_key}`"
            )
        if not data:
            raise VaultError(f"KEY {key} not present in Vault")
        data_node = data["data"]
        try:
            # We've gone through a few versions of the hq_url variable name in
            # Vault, hence all the fall through options
            hq_url = data_node.get(
                "URL_LIST", data_node.get("URL", data_node.get("HQ_URL"))
            )
            if not hq_url:
                raise KeyError
            csr_user = data_node["CSR"]
            csr_pwd = data_node["CSR_pwd"]
            aba = data_node["ABA"]
            database_name = data_node.get("DB_NAME")
            server_name = data_node.get("DBServerName")
            env_stack = data_node.get("env_stack")
        except KeyError as err:
            raise VaultError(
                "Following keys must all be populated in Vault for KEY {}: (URL, CSR, CSR_pwd, ABA)".format(
                    key
                )
            ) from err
        return HqCredentials(
            hq_url,
            csr_user,
            csr_pwd,
            aba,
            customer_key=customer_key,
            database_name=database_name,
            db_server_name=server_name,
            env_stack=env_stack,
        ) 
[docs]
    def is_authenticated(self) -> bool:
        """Identifies if the client is authenticated with the vault service.   Essentially validating the token used is
        a valid token.  Passthru to the underlying client method."""
        if self.is_local_client is True:
            return True
        return self.client.is_authenticated() 
 
[docs]
def get_token() -> str:
    """Gets vault_token from either the filesystem or an environment variable"""
    vault_token = _get_token_from_path(f"{NOMAD_SECRETS_DIR}/vault_token")
    if not vault_token:
        vault_token = hvac.utils.get_token_from_env()
    return vault_token 
def _get_token_from_path(path: str):
    vault_token = None
    if os.path.exists(path):
        with open(path) as handle:
            vault_token = handle.read().strip()
    return vault_token
[docs]
def get_client(allow_local=True, logger=None, timeout=DEFAULT) -> Q2Vault:
    """Returns a Q2Vault instance"""
    kwargs = {}
    if timeout != DEFAULT:
        kwargs["timeout"] = timeout
    if logger is None:
        current_request = contexts.get_current_request(raise_if_none=False)
        if current_request:
            logger = current_request.request_handler.logger
        else:
            logger = LOGGER
    if allow_local and LOCAL_VAULT_DIR:
        client = Q2Vault("localhost", "", allow_local=True, logger=logger, **kwargs)
        return client
    client = None
    vault_addr = os.environ.get("VAULT_ADDR")
    vault_token = get_token()
    if vault_addr and vault_token:  # pragma: no cover
        client = Q2Vault(
            vault_addr, vault_token, allow_local=False, logger=logger, **kwargs
        )
    return client 
[docs]
def file_age(filepath) -> int:
    """
    Returns the file age in seconds.
    """
    return int(time.time() - os.path.getmtime(filepath))