import json
import logging
import os
import re
import time
from dataclasses import dataclass
from enum import Enum
from pathlib import Path
from typing import Literal, Optional, Union, overload
import hvac
from hvac.exceptions import Forbidden
from requests.exceptions import ConnectionError as ReqConnectionError
from q2_sdk.core import contexts, initial_settings, cache
from q2_sdk.core.configuration import settings
from q2_sdk.core.exceptions import BadParameterError, MockDataError, VaultError
from q2_sdk.core.non_http_handlers.udp import writer as udp_writer
from q2_sdk.core.opentelemetry.span import Q2Span, Q2SpanAttributes
from q2_sdk.hq.models.hq_credentials import HqCredentials
from q2_sdk.models.recursive_encoder import JsonSerializable
from q2_sdk.models.unique_stack import ForkedUniqueStack
from q2_sdk.tools.sentinel import DEFAULT
NOMAD_SECRETS_DIR = initial_settings.NOMAD_SECRETS_DIR
DEFAULT_PREFIX = initial_settings.VAULT_DEFAULT_PREFIX
LOCAL_VAULT_DIR = initial_settings.VAULT_LOCAL_DIR
CERT_CACHE_SECONDS = 300 # Number of seconds a certificate will be cached on local disk
LOGGER = logging.getLogger()
LOGGER.setLevel(initial_settings.LOGGING_LEVEL)
VAULT_ENCRYPTION_KEY = settings.VAULT_ENCRYPTION_KEY
[docs]
class StorageLevel(Enum):
Institution = "institution"
Stack = "stack"
[docs]
@dataclass
class SearchReplacements:
institution_id: Optional[str] = None
stack_id: Optional[str] = None
def __post_init__(self):
from q2_sdk.core.configuration import settings
self.institution_id: str = (
self.institution_id or settings.COMPANY or settings.HQ_CREDENTIALS.aba
)
if not self.stack_id:
current_request = contexts.get_current_request(raise_if_none=False)
if current_request:
customer_key = (
current_request.request_handler.hq_credentials.customer_key
)
self.stack_id = (
settings.VAULT_KEY if customer_key == "CHANGEME" else customer_key
)
else:
self.stack_id = settings.VAULT_KEY
[docs]
@dataclass
class Certificate:
CERT_STR = r"(-----BEGIN CERTIFICATE-----.+?-----END CERTIFICATE-----)"
KEY_STR = r"(-----BEGIN \w*\s?PRIVATE KEY-----.+?-----END \w*\s?PRIVATE KEY-----)"
path: str
def __post_init__(self):
self._body = None
self._cert = None
self._key = None
def _get_body(self):
if self._body is None:
with open(self.path) as handle:
self._body = handle.read()
return self._body
@property
def cert(self):
body = self._get_body()
if self._cert is None:
matches = re.search(self.CERT_STR, body, flags=re.DOTALL)
cert = ""
for match in matches.groups():
cert += match
self._cert = cert
return self._cert
@property
def key(self):
body = self._get_body()
if self._key is None:
matches = re.search(self.KEY_STR, body, flags=re.DOTALL)
key = ""
for match in matches.groups():
key += match
self._key = key
return self._key
[docs]
@dataclass
class RecentVaultKey(JsonSerializable):
prefix: str
key: str
success: bool
def __eq__(self, __value: object) -> bool:
return self.prefix + self.key == __value.prefix + __value.key
[docs]
class RecentKeysStack(ForkedUniqueStack):
"""
Configure RecentKeys to work with forked mode
"""
UDP_MESSAGE_TYPE = udp_writer.MsgType.Vault
[docs]
@dataclass
class LocalPathConfig:
cert_name: str
institution: str
hq_credentials: HqCredentials
env: str
def file_name(self):
return f"{self.institution}_{self.hq_credentials.aba}_{self.env}_{self.cert_name}.pem"
[docs]
@dataclass
class VaultPath:
key: str
prefix: str
@property
def cache_key(self):
return f"{self.prefix}/{self.key}"
[docs]
class Q2Vault:
"""Class for interacting with Hashicorp's Vault (https://www.vaultproject.io)"""
RECENT_KEYS = RecentKeysStack(100)
def __init__(
self, addr, token, allow_local: bool = True, logger=LOGGER, timeout: int = 3
):
self.addr = addr
self.client = hvac.Client(url=addr, token=token, timeout=timeout)
self.deploy_env = os.environ.get("DEPLOY_ENV", "DEV")
self.logger = logger
self._settings_module = None
self.is_local_client = LOCAL_VAULT_DIR and allow_local is True
[docs]
@Q2Span.instrument(name="vault.read", skip=["self"])
def read(
self,
key,
*,
level: StorageLevel = StorageLevel.Stack,
replacements: Optional[SearchReplacements] = None,
default=None,
**kwargs,
) -> dict:
"""
Vault stores data at paths like a filesystem. We utilize the base path
secret/ABA for historical reasons, then your data can be stored below
that at various locations depending on your needs.
In order of most to least specific data storage location, we have:
- Stack
- Each Database in Q2 gets a unique id assigned to it. This is
available in settings.VAULT_KEY
- Institution
- This will tie to the value in settings.COMPANY if it exists,
or settings.HQ_CREDENTIALS.aba if it does not
:param key: Keyname in Vault
:param level: To specify Stack or Institution storage location
:param replacements: Provided as a way to query a different FI than
the one configured for this service.
Useful in multitenant scenarios
:type replacements: SearchReplacements
:param default: Value to return if data does not exist at the specified location
"""
if self._settings_module is None:
from q2_sdk.core.configuration import settings
self._settings_module = settings
if self._settings_module.VAULT_SCOPED_READ is False:
self.logger.warning(
"Opting in to old Vault read style. This is deprecated. Please set VAULT_SCOPED_READ=True in your settings file."
)
return self.read_raw(key, prefix=kwargs.get("prefix", DEFAULT_PREFIX))
if replacements is None:
replacements = SearchReplacements()
if not isinstance(replacements, SearchReplacements):
raise BadParameterError(
"replacements parameter must be an instance of SearchReplacements"
)
vault_path = self._get_path(
key,
level,
replacements,
default_prefix=kwargs.get("prefix", DEFAULT_PREFIX),
)
data = None
cache_obj = None
if self._settings_module.VAULT_CACHE_READS:
cache_obj = cache.get_cache(encryption_key=VAULT_ENCRYPTION_KEY)
data = cache_obj.get(vault_path.cache_key)
if isinstance(data, str):
try:
data = json.loads(data)
except json.JSONDecodeError:
self.logger.error("failed to json serialize string from vault")
if data is None:
vault_data = self._get_vault_data(vault_path, default)
vault_expiry = self._settings_module.VAULT_EXPIRY
if cache_obj:
cache_obj.set(vault_path.cache_key, vault_data, expire=vault_expiry)
data = vault_data
return data
def _get_path(
self,
key,
level: StorageLevel = StorageLevel.Stack,
replacements: Optional[SearchReplacements] = None,
default_prefix=DEFAULT_PREFIX,
) -> VaultPath:
match level:
case StorageLevel.Stack:
if not replacements.stack_id:
Q2Span.set_attribute(
Q2SpanAttributes.SDK_FRAMEWORK_ERROR, "Vault Config"
)
raise BadParameterError(
"replacements.stack_id is empty, but vault request was made at the Stack level"
)
path = f"{self.deploy_env}_{replacements.stack_id}"
case StorageLevel.Institution:
path = replacements.institution_id
case _:
Q2Span.set_attribute(
Q2SpanAttributes.SDK_FRAMEWORK_ERROR, "Vault Config"
)
raise BadParameterError(
"level parameter must be an instance of StorageLevel"
)
return VaultPath(key, f"{default_prefix}/{path}")
def _get_vault_data(self, vault_path: VaultPath, default=None):
data = self.read_raw(vault_path.key, prefix=vault_path.prefix)
if not data:
if not default:
Q2Span.set_attribute(Q2SpanAttributes.SDK_FRAMEWORK_ERROR, "Vault Data")
return default
return data.get("data", data)
[docs]
def read_raw(self, key, *, prefix=DEFAULT_PREFIX):
"""Catchall way of interacting with Vault directly if one of the more specific helper functions doesn't do the trick"""
if prefix is None:
prefix = ""
if self.is_local_client:
data = self._get_local_key(key, prefix)
else:
try:
self.logger.debug("Searching %s/%s", prefix, key)
data = self.client.read("{}/{}".format(prefix, key))
except ReqConnectionError as err:
raise ReqConnectionError(
"Unable to connect to vault_addr: {}".format(self.addr)
) from err
except Forbidden as err:
new_token = get_token()
if new_token == self.client.token:
raise Forbidden("Bad Vault Token") from err
self.client.token = new_token
data = self.read_raw(key, prefix=prefix)
if data:
self.logger.debug("Vault data read successfully")
self.RECENT_KEYS.append(RecentVaultKey(prefix, key, True).to_json())
else:
self.logger.debug("No vault data")
Q2Span.add_event(
"Vault Data", {"requested_path": "{}/{}".format(prefix, key)}
)
self.RECENT_KEYS.append(RecentVaultKey(prefix, key, False).to_json())
return data
[docs]
def write(self, path: str, body=dict, prefix=DEFAULT_PREFIX):
"""
This is only possible if appropriate authorization (Vault Policies) are
tied to the active vault token.
If configured with a local vault directory, writes will always be allowed.
"""
if not prefix:
prefix = ""
path = Path(prefix) / path
if self.is_local_client:
local_vault_path = Path(LOCAL_VAULT_DIR)
local_vault_path.mkdir(exist_ok=True)
full_path = local_vault_path / path
full_path.parent.mkdir(parents=True, exist_ok=True)
full_path.write_text(json.dumps(body))
self.logger.info("Writing vault data to local path: %s", full_path)
else:
self.logger.info("Writing vault data to vault at %s", path)
return self.client.write(path, **body)
def _get_local_key(self, key, prefix):
error_msg = 'Mock vault data must be a JSON object with a "data" key'
path = Path(LOCAL_VAULT_DIR) / prefix / key
path = path.expanduser().resolve()
if not path.exists():
self.logger.error("Path does not exist: %s", path)
return None
with open(path) as handle:
self.logger.debug("Searching %s", path)
try:
response = json.loads(handle.read())
except json.decoder.JSONDecodeError as err:
raise MockDataError(error_msg) from err
if not response.get("data"):
raise MockDataError(error_msg)
return response
@overload
def get_certificate(
self, cert_name: str, hq_credentials: HqCredentials, as_obj: Literal[True]
) -> Certificate: ...
@overload
def get_certificate(
self,
cert_name: str,
hq_credentials: HqCredentials,
env: str,
as_obj: Literal[True],
) -> Certificate: ...
@overload
def get_certificate(
self,
cert_name: str,
hq_credentials: HqCredentials,
env: str = "default",
as_obj: bool = False,
) -> str: ...
[docs]
def get_certificate(
self,
cert_name: str,
hq_credentials: HqCredentials,
env: str = "default",
as_obj: bool = False,
) -> Union[str, Certificate]:
"""
Gets a certificate from vault or local disk cache if available. Local certificate cache is good for 300 seconds.
Looks up one of two paths:
- If VAULT_SCOPED_READ is True in settings: ``{institution}/certs/{cert_name}``
where ``{institution}`` is either settings.COMPANY or hq_credentials.ABA
- If VAULT_SCOPED_READ is False in settings: ``{ABA}_certs/{env}/{cert_name}``
If VAULT_SCOPED_READ is True, it will still fall through to old behavior if cert is not found at new location.
:param cert_name: Last part of the path in Vault
:param hq_credentials: Used to specify the first part of the path in Vault
:param env: Middle part of the path in Vault (ignored if VAULT_SCOPED_READ is True)
:param as_obj: If True, will return a Certificate object, with .key and .cert properties as well as .path
"""
if self._settings_module is None:
from q2_sdk.core.configuration import settings
self._settings_module = settings
recalculate = True
cert_path = self._get_local_secret_cert_path(
LocalPathConfig(
cert_name, self._settings_module.COMPANY, hq_credentials, env
)
)
if (
cert_path
and cert_path.is_file()
and file_age(cert_path) < CERT_CACHE_SECONDS
):
recalculate = False
self.logger.debug(
f"Certificate file less than {CERT_CACHE_SECONDS} seconds old. Skipping vault read..."
)
if recalculate:
vault_response = None
error = None
if self._settings_module.VAULT_SCOPED_READ:
key = f"certs/{cert_name}"
vault_response = self.read(key, level=StorageLevel.Institution)
error = f"KEY {key} not present in Vault"
if not vault_response:
# Fall back to older storage location using {aba}_certs pattern
key = f"{hq_credentials.aba}_certs/{env}/{cert_name}"
vault_response = self.read_raw(key)
if not error:
error = f"KEY {key} not present in Vault"
if not vault_response:
raise VaultError(error)
with open(cert_path, "w") as cert_file:
if "data" in vault_response:
vault_response = vault_response["data"]
cert = vault_response.get("cert")
if cert:
cert_file.write(cert)
cert_file.write("\n")
key = vault_response.get("key")
if key:
cert_file.write(key)
cert_file.write("\n")
response = str(cert_path)
if as_obj is True:
response = Certificate(response)
return response
@staticmethod
def _get_local_secret_cert_path(local_path_config: LocalPathConfig) -> Path:
if NOMAD_SECRETS_DIR is not None:
key = local_path_config.file_name()
return Path(NOMAD_SECRETS_DIR).joinpath(key).expanduser().resolve()
[docs]
def get_smart_token(self, key) -> str:
"""Helper for getting Q2Smart token for this SDK instance"""
key = f"{self.deploy_env}_{key}"
data = self.read_raw(key)
if not data:
raise VaultError(f"KEY {key} not present in Vault")
data_node = data["data"]
try:
smart_token = data_node["SMART_TOKEN"]
except KeyError as err:
raise VaultError(f"SMART_TOKEN is not set for KEY {key}") from err
return smart_token
[docs]
def get_hq_creds(self, key, prefix_with_deploy_env=True) -> HqCredentials:
"""Helper for getting HQCredentials for this SDK instance
:param key: vault key
:param prefix_with_deploy_env: if true, will search {deploy_env}_{key}
"""
customer_key = key
if prefix_with_deploy_env:
key = f"{self.deploy_env}_{key}"
orig_key = key
if self.is_local_client:
key = f"{key}/hq_creds"
data = self.read_raw(key)
if self.is_local_client and not data:
self.logger.error(
f"Data not generated for local vault mock. Try running `q2 vault add_hq_creds {orig_key}`"
)
if not data:
raise VaultError(f"KEY {key} not present in Vault")
data_node = data["data"]
try:
# We've gone through a few versions of the hq_url variable name in
# Vault, hence all the fall through options
hq_url = data_node.get(
"URL_LIST", data_node.get("URL", data_node.get("HQ_URL"))
)
if not hq_url:
raise KeyError
csr_user = data_node["CSR"]
csr_pwd = data_node["CSR_pwd"]
aba = data_node["ABA"]
database_name = data_node.get("DB_NAME")
server_name = data_node.get("DBServerName")
env_stack = data_node.get("env_stack")
except KeyError as err:
raise VaultError(
"Following keys must all be populated in Vault for KEY {}: (URL, CSR, CSR_pwd, ABA)".format(
key
)
) from err
return HqCredentials(
hq_url,
csr_user,
csr_pwd,
aba,
customer_key=customer_key,
database_name=database_name,
db_server_name=server_name,
env_stack=env_stack,
)
[docs]
def is_authenticated(self) -> bool:
"""Identifies if the client is authenticated with the vault service. Essentially validating the token used is
a valid token. Passthru to the underlying client method."""
if self.is_local_client is True:
return True
return self.client.is_authenticated()
[docs]
def get_token() -> str:
"""Gets vault_token from either the filesystem or an environment variable"""
vault_token = _get_token_from_path(f"{NOMAD_SECRETS_DIR}/vault_token")
if not vault_token:
vault_token = hvac.utils.get_token_from_env()
return vault_token
def _get_token_from_path(path: str):
vault_token = None
if os.path.exists(path):
with open(path) as handle:
vault_token = handle.read().strip()
return vault_token
[docs]
def get_client(allow_local=True, logger=None, timeout=DEFAULT) -> Q2Vault:
"""Returns a Q2Vault instance"""
kwargs = {}
if timeout != DEFAULT:
kwargs["timeout"] = timeout
if logger is None:
current_request = contexts.get_current_request(raise_if_none=False)
if current_request:
logger = current_request.request_handler.logger
else:
logger = LOGGER
if allow_local and LOCAL_VAULT_DIR:
client = Q2Vault("localhost", "", allow_local=True, logger=logger, **kwargs)
return client
client = None
vault_addr = os.environ.get("VAULT_ADDR")
vault_token = get_token()
if vault_addr and vault_token: # pragma: no cover
client = Q2Vault(
vault_addr, vault_token, allow_local=False, logger=logger, **kwargs
)
return client
[docs]
def file_age(filepath) -> int:
"""
Returns the file age in seconds.
"""
return int(time.time() - os.path.getmtime(filepath))