from argparse import _SubParsersAction
from dataclasses import dataclass
from enum import Enum, auto
from functools import partial
from io import BytesIO
from typing import Optional, Union
from uuid import uuid4
from q2_sdk.tools.utils import pascalize
from q2_sdk.core.cli.textui import colored, puts
from q2_sdk.core import q2_requests
from q2_sdk.core.configuration import settings
from q2_sdk.core.exceptions import ConfigurationError
DEFAULT_EXPIRY = 60  # 1 Minute
if settings.DEBUG:
    DEFAULT_EXPIRY *= (
        120  # 120 minutes in Dev due to timezone differences in various environments
    )
[docs]
@dataclass
class PublicDownloadUrl:
    raw: dict
    headers: dict
    aws_url: str
    proxy_url: str
    def __init__(self, raw: dict):
        """
        .. code-block:: javascript
            // Sample raw input
            {
                "data": {
                    "headers": {
                        "x-amz-server-side-encryption-customer-algorithm":"AES256",
                        "x-amz-server-side-encryption-customer-key":"abc",
                        "x-amz-server-side-encryption-customer-key-MD5":"hash"
                    },
                    "url": "https://awsurl",
                    "proxyUrl": "http://ardentfsproxyurl"
                }
            }
        """
        self.raw = raw
        self.headers = raw["data"]["headers"]
        self.aws_url = raw["data"]["url"]
        self.proxy_url = raw["data"]["proxyUrl"] 
[docs]
@dataclass
class AwsFile:
    raw: dict
    id: str
    url: str
    bucket: str
    authorization_fields: dict
    def __init__(self, raw: dict):
        """
        Transforms an ArdentFS JSON blob into a usable Python object
        .. code-block:: javascript
            // Sample raw input
            {
              "data": {
                "url": "https://s3.us-east-2.amazonaws.com/ardent-fs",
                "fields": {
                  "bucket": "ardent-fs",
                  "X-Amz-Algorithm": "AWS4-HMAC-SHA256",
                  "X-Amz-Credential": "AAAAAB4WV3347YTY2G7Z/20200820/us-east-2/s3/aws4_request",
                  "X-Amz-Date": "20200820T195929Z",
                  "X-Amz-Security-Token": "FwoGZX...",
                  "Policy": "eyJleHBpcmF...",
                  "X-Amz-Signature": "f5e87a...",
                  "key": "dev/c35a08ed-6e17-4bd3-a021-41cd8ef5298f"
                },
                "bucket": "dev",
                "id": "c35a08ed-6e17-4bd3-a021-41cd8ef5298f"
              }
            }
        """
        self.raw = raw
        data = raw["data"]
        self.id = data["id"]
        self.url = data["url"]
        self.bucket = data["bucket"]
        self.authorization_fields = data["fields"] 
[docs]
class TTLType(Enum):
    Days = auto()
    Years = auto()
    def as_query_param(self):
        match self:
            case TTLType.Days:
                return "ttlDays"
            case TTLType.Years:
                return "ttlYears"
    @staticmethod
    def from_str(s):
        try:
            return TTLType[pascalize(s)]
        except KeyError:
            raise ValueError("Options are Days or Years") 
[docs]
class File:
    """
    Uploads and downloads files from Amazon S3 by way of Q2's
    ArdentFS endpoints. This allows for much larger file sizes over the network
    than would be possible by POSTing those files to the SDK directly.
    Uploaded files can be downloaded for a limited time before their
    TTL expires (24 hours typically).
    """
    def __init__(self, logger):
        self.logger = logger
    def add_arguments(self, parser: _SubParsersAction):
        subparser = parser.add_parser("file_upload")
        subparser.set_defaults(parser="file_upload")
        subparser.set_defaults(
            func=partial(self.upload_from_file, serialize_for_cli=True)
        )
        subparser.add_argument("data", help="File path to upload")
        subparser.add_argument(
            "--timeout",
            default=DEFAULT_EXPIRY,
            type=int,
            help="Override the default timeout",
        )
        subparser.add_argument(
            "--ttl",
            help="File's time to live in AWS",
            choices=[1, 2, 3],
            type=int,
            default=1,
        )
        subparser.add_argument(
            "--ttl-type",
            choices=list(TTLType),
            type=TTLType.from_str,
            default=TTLType.Days,
        )
        subparser.add_argument(
            "--content-type",
            help="If set, browsers can take advantage of this metadata to render appropriately",
        )
        subparser = parser.add_parser("file_download")
        subparser.set_defaults(parser="file_download")
        subparser.set_defaults(func=partial(self.download, serialize_for_cli=True))
        subparser.add_argument("file_key", help="Corresponds to AwsFile.id")
        subparser.add_argument(
            "-o",
            "--output",
            dest="output_path",
            help="Write to <file> instead of stdout",
        )
        subparser.add_argument(
            "--timeout",
            default=DEFAULT_EXPIRY,
            type=int,
            help="Override the default timeout",
        )
        subparser = parser.add_parser("get_public_download_url")
        subparser.set_defaults(parser="get_public_download_url")
        subparser.set_defaults(
            func=partial(self.get_public_download_url, serialize_for_cli=True)
        )
        subparser.add_argument("file_key", help="Corresponds to AwsFile.id")
        subparser.add_argument(
            "-e",
            "--expiry",
            default=DEFAULT_EXPIRY,
            type=int,
            help="How long the link should be available",
        )
        subparser.add_argument(
            "--render-in-browser",
            default=False,
            action="store_true",
            help="If True, browser will render inline rather than auto download",
        )
[docs]
    async def get_upload_url(self, name: Optional[str] = None) -> AwsFile:
        """
        Query ArdentFS for a url to POST a file to.
        This is useful for loading a file to cloud storage
        straight from the JavaScript layer, bypassing the
        Q2 infrastructure.
        :param name: Name in AWS is always a guid, but this adds a metadata field with the name
        """
        if not settings.ARDENTFS_URL:
            raise ConfigurationError("No ArdentFsUrl set in settings file")
        url = f"{settings.ARDENTFS_URL}/{settings.ARDENTFS_BUCKET}/files/temporaryUploadUrl"
        params = {}
        if name:
            params = {"filename": name}
        response = await q2_requests.get(
            self.logger, url, params=params, verify_whitelist=False
        )
        if not response.ok:
            self.logger.error("Unable to get upload url")
            response.raise_for_status()
        aws_file = AwsFile(response.json())
        return aws_file 
[docs]
    async def upload_from_file(
        self,
        data: Union[str, BytesIO],
        name: Optional[str] = None,
        serialize_for_cli=False,
        timeout=30,
        ttl=1,
        ttl_type: TTLType = TTLType.Days,
        content_type: Optional[str] = None,
        minimal_logging=False,
        **kwargs,
    ) -> str:
        """
        Uploads file to AWS using temporary upload URL gathered from ArdentFS
        :param data: Either the path to a file or an already open BytesIO object
        :param name: Name in AWS is always a guid, but this adds a metadata field with the name
        :param timeout: Max upload time. Defaults to 30 seconds
        :param ttl: Time to Live for file in AWS S3. Allows [1, 2, 3]
        :param ttl_type: Can be either days for short term or years for long term storage
        :param content_type: Corresponds to the Content-Type header of the request
        :param minimal_logging: (optional) flag to turn off debug logging of request header/body/etc.
        :return: ID of file in aws
        """
        assert isinstance(ttl_type, TTLType), (
            "ttl_type must be an instance of the TTLType Enum"
        )
        # Backwards compatibility. Remove in 3.0
        ttl_days = kwargs.pop("ttl_days", None)
        if ttl_days is not None:
            ttl_type = TTLType.Days
            ttl = ttl_days
        assert ttl in [1, 2, 3], "ttl must be between 1 and 3"
        handle = data
        if isinstance(data, str):
            handle = open(data, "rb")
        if not name:
            name = uuid4().hex
        url = f"{settings.ARDENTFS_URL}/{settings.ARDENTFS_BUCKET}/files/touch"
        response = await q2_requests.get(
            self.logger,
            url,
            params={"filename": name, ttl_type.as_query_param(): ttl},
            verify_whitelist=False,
            minimal_logging=minimal_logging,
        )
        response.raise_for_status()
        response = response.json()["data"]
        file_id = response["id"]
        encryption_key = response["key"]
        url = f"{settings.ARDENTFS_URL}/{settings.ARDENTFS_BUCKET}/files"
        response = await q2_requests.post(
            self.logger,
            url,
            params={
                "id": file_id,
                "key": encryption_key,
            },
            data={"content-type": content_type},
            files={"file": (name, handle.read())},
            verify_whitelist=False,
            timeout=timeout,
            minimal_logging=minimal_logging,
        )
        handle.close()
        response.raise_for_status()
        if serialize_for_cli:
            response = [
                "File uploaded successfully!",
                f"To download call ``q2 ardent file_download {file_id}:{encryption_key}``",
            ]
            return "\n".join(response)
        return f"{file_id}:{encryption_key}" 
    @staticmethod
    def build_upload_body(aws_file):
        body = {
            "key": aws_file.authorization_fields["key"],
            **{x: y for (x, y) in aws_file.authorization_fields.items() if x != "key"},
        }
        return body
[docs]
    async def upload(
        self,
        data: Union[str, bytes],
        name: Optional[str] = None,
        timeout=30,
        ttl_days=1,
        content_type: Optional[str] = None,
        minimal_logging=False,
    ) -> str:
        """
        Uploads to AWS using a temporary upload URL gathered from ArdentFS
        :param data: String/Bytes to upload to AWS
        :param name: Name in AWS is always a guid, but this adds a metadata field with the name
        :param timeout: Max upload time. Defaults to 30 seconds
        :param ttl_days: Time to Live for file in AWS S3. Allows [1, 2, 3]
        :param content_type: Corresponds to the Content-Type header of the request
        :return: ID of file in aws
        """
        if isinstance(data, str):
            data = data.encode()
        upload_file = BytesIO(data)
        return await self.upload_from_file(
            upload_file,
            name=name,
            timeout=timeout,
            ttl_days=ttl_days,
            content_type=content_type,
            minimal_logging=minimal_logging,
        ) 
[docs]
    async def get_public_download_url(
        self,
        file_key: str,
        expiry=DEFAULT_EXPIRY,
        render_in_browser=False,
        serialize_for_cli=False,
    ) -> PublicDownloadUrl:
        """
        Generates several download urls accessible from outside the Q2 network.
        The response.aws_url combined with response.headers will download directly from Amazon S3.
        The response.proxy_url can be used as a standalone url with a hop in between Amazon S3.
        :param file_key: Name of file in AWS. Corresponds to AwsFile.id
        :param expiry: TTL (in seconds)
        :param render_in_browser: If True, browser will attempt to render inline rather than auto download, if
                                  content-type was provided at the time of upload
        :param serialize_for_cli: Used when running from the command line
        :return: PublicDownloadUrl instance
        """
        headers = None
        params = {"expiry": expiry, "includeContentDisposition": not render_in_browser}
        if ":" in file_key:
            file_key, encryption_key = file_key.split(":")
            headers = {"key": encryption_key}
        url = f"{settings.ARDENTFS_URL}/{settings.ARDENTFS_BUCKET}/files/{file_key}/singleUseUrl"
        response = await q2_requests.get(
            self.logger, url, params=params, headers=headers, verify_whitelist=False
        )
        response.raise_for_status()
        download_url = PublicDownloadUrl(response.json())
        if serialize_for_cli:
            puts(colored.yellow("--CURL CMD--\n\n"))
            curl_cmd = "curl " + download_url.aws_url
            for key, value in download_url.headers.items():
                curl_cmd += f" -H {key}:{value}"
            puts(curl_cmd)
            puts("\n")
            puts(colored.yellow("--Proxy Link--\n\n"))
            puts(download_url.proxy_url)
            return
        return download_url 
[docs]
    @staticmethod
    def get_download_url(file_key: str):
        """
        Returns the url to download a file stored in AWS
        :param file_key: Name of file in AWS. Corresponds to AwsFile.id
        """
        return f"{settings.ARDENTFS_URL}/{settings.ARDENTFS_BUCKET}/files/{file_key}" 
[docs]
    async def download(
        self,
        file_key: str,
        output_path: Optional[str] = None,
        serialize_for_cli=False,
        timeout=30,
    ) -> bytes:
        """
        Pull file_key contents from AWS, optionally writing it to the filesystem.
        :param file_key: Name of file in AWS. Corresponds to AwsFile.id
        :param output_path: If specified, will write the data to a file
        :param timeout: Max download time. Defaults to 30 seconds
        """
        if not settings.ARDENTFS_URL:
            raise ConfigurationError("No ArdentFsUrl set in settings file")
        headers = None
        if ":" in file_key:
            file_key, encryption_key = file_key.split(":")
            headers = {"key": encryption_key}
        url = self.get_download_url(file_key)
        response = await q2_requests.get(
            self.logger, url, headers=headers, verify_whitelist=False, timeout=timeout
        )
        if output_path:
            with open(output_path, "wb") as handle:
                handle.write(response.content)
            return True
        if serialize_for_cli:
            return response.text
        return response.content