#!/usr/bin/env python
# -*- encoding: utf-8 -*-

import hashlib
import io
import os
import re
import socket
from enum import Enum
from typing import Any, Dict, List, Union

import boto3
import botocore
import torch

from internlm.utils.common import SingletonMeta
from internlm.utils.logger import get_logger

logger = get_logger(__file__)

boto3_url_re = re.compile(r"([^\.]+)\.([\d\.]+)")

MB = 1024**2

storage_manager = None


def check_folder(fp: str):
    storage_manager.assert_fp_exists(fp)


def get_fns(fp: str):
    return storage_manager.get_fns(fp)


def llm_load(fp: str, *args, **kwargs):
    return storage_manager.load(fp, *args, **kwargs)


def llm_save(save_path: str, saved_obj: Any, *args, **kwargs):
    storage_manager.save(save_path, *args, saved_obj=saved_obj, **kwargs)


class CheckpointType(Enum):
    NORMAL_CHECKPOINT = 1


class StorageClient:
    """
    StorageClient as a client for s3 storage access.
    """

    def __init__(self, handler) -> None:
        self.handler = handler

    @staticmethod
    def load(client, load_path: str, map_location):
        raise NotImplementedError

    @staticmethod
    def sync_upload_fileobj(*args, saved_obj=None, **kwargs):
        raise NotImplementedError

    @staticmethod
    def assert_fp_exists(client):
        raise NotImplementedError

    @staticmethod
    def get_fns(client):
        raise NotImplementedError


class Boto3MetaInfo:
    def __init__(self, client: StorageClient, bucket_name: str, endpoint: str, file_path: str) -> None:
        self.client = client
        self.bucket_name = bucket_name
        self.endpoint = endpoint
        self.file_path = file_path


class LocalMetaInfo:
    def __init__(self, client: StorageClient, dest_path: str) -> None:
        self.client = client
        self.dest_path = dest_path


def unpack_meta(meta):
    args = []
    for k, v in meta.__dict__.items():
        if k == "endpoint":
            continue
        args.append(v)
    return args


def compute_file_md5_by_chunk(file_name: str):
    hash_md5 = hashlib.md5()
    with open(file_name, "rb") as f:
        for chunk in iter(lambda: f.read(4096), b""):
            hash_md5.update(chunk)
    return hash_md5.hexdigest()


def get_boto3_meta(fp: str) -> Boto3MetaInfo:
    assert fp.startswith("s3://"), f"Path '{fp}' is not a boto3 url"
    parts = fp.lstrip("s3://").split(os.path.sep)
    match = boto3_url_re.match(parts[0])
    assert match is not None, f"url '{fp}' is not a valid boto3 url"
    bucket_name, endpoint = match.group(1), match.group(2)
    endpoint = "http://" + endpoint + ":80"
    return Boto3MetaInfo(None, bucket_name, endpoint, os.path.sep.join(parts[1:]))


def get_local_meta(fp: str) -> LocalMetaInfo:
    assert not fp.startswith("s3://"), f"Path '{fp}' is not a local path"
    return LocalMetaInfo(None, fp)


class Boto3Client(StorageClient):
    """
    Boto3Client
    """

    def __init__(
        self,
        s3_endpoint_url: str,
        use_threads: int = True,
        multipart_chunksize=8 * MB,
        max_concurrency: int = 10,
        multipart_threshold=100 * MB,
    ) -> None:
        """S3 object/file storage management class

        Args:
            s3_access_keys_id (str): S3 access key ID.
            s3_secret_access_key (str): S3 secret access key.
            use_threads (bool, optional): Whether to enable multipart. Defaults to True.
            multipart_chunksize (_type_, optional): Defaults to 8*MB.
            max_concurrency (int, optional): Defaults to 10.

        Raises:
            RuntimeError: Connection failures caused by misconfiguration or network problems.
        """
        super().__init__(boto3)
        self.botocore = botocore
        try:
            s3_access_key_id = os.environ["S3_ACCESS_KEY_ID"]
            s3_secret_access_key = os.environ["S3_SECRET_ACCESS_KEY_ID"]
        except KeyError as exc:
            raise RuntimeError(
                "Please set boto3 bucket 'S3_ACCESS_KEY_ID' and 'S3_SECRET_ACCESS_KEY_ID' using environment variable!"
            ) from exc

        self.client = self.handler.client(
            "s3",
            "",
            use_ssl=False,
            verify=False,
            endpoint_url=s3_endpoint_url,
            aws_access_key_id=s3_access_key_id,
            aws_secret_access_key=s3_secret_access_key,
        )

        self.config = self.handler.s3.transfer.TransferConfig(
            multipart_threshold=multipart_threshold,
            max_concurrency=max_concurrency,
            multipart_chunksize=multipart_chunksize,
            use_threads=use_threads,
        )

    @staticmethod
    def sync_upload_fileobj(handler, bucket_name: str, fp: str, *args, saved_obj=None, **kwargs):
        assert saved_obj is not None, "saved_obj is None!"
        try:
            with io.BytesIO() as f:
                torch.save(saved_obj, f, *args, **kwargs)
                f.seek(0)
                handler.client.upload_fileobj(f, bucket_name, fp, Config=handler.config)
        except handler.botocore.exceptions.EndpointConnectionError as exc:
            raise RuntimeError(
                f"Boto3 Network Error: Please Check your Internet Connection in {socket.gethostname()}"
            ) from exc

    @staticmethod
    def load(handler, bucket_name: str, fp: str, *args, map_location="cpu", **kwargs) -> Dict:
        """
        Args:
            fp (str): Path to save, eg. s3://opennlplab/model_weights/xxx/ddd.pt
        """
        try:
            with io.BytesIO() as f:
                handler.client.download_fileobj(bucket_name, fp, f, Config=handler.config)
                f.seek(0)
                states = torch.load(f, *args, map_location=map_location, **kwargs)
        except handler.botocore.exceptions.EndpointConnectionError as exc:
            raise RuntimeError(
                f"Boto3 Network Error: Please Check your Internet Connection in {socket.gethostname()}"
            ) from exc
        return states

    @staticmethod
    def assert_fp_exists(
        handler,
        bucket_name: str,
        fp: str,
    ):
        assert len(list(handler.client.list_objects(Bucket=bucket_name, Prefix=fp)["Contents"])) > 0, fp

    @staticmethod
    def get_fns(handler, bucket_name: str, fp: str):
        """
        Ref: https://stackoverflow.com/questions/54314563/
        how-to-get-more-than-1000-objects-from-s3-by-using-list-objects-v2
        """
        paginator = handler.client.get_paginator("list_objects_v2")
        pages = paginator.paginate(Bucket=bucket_name, Prefix=fp)

        folder_name_list = []
        for page in pages:
            for obj in page["Contents"]:
                fp: str = obj["Key"]
                folder_name_list.append(fp.rsplit("/", maxsplit=1)[1])
        return folder_name_list


class LocalClient(StorageClient):
    """
    Storage Client for local NFS.
    """

    def __init__(self, *args, **kwargs) -> None:  # pylint: disable=W0613
        super().__init__(None)

    @staticmethod
    def sync_upload_fileobj(handler, fp: str, *args, saved_obj=None, **kwargs):
        assert isinstance(handler, LocalClient)
        assert saved_obj is not None
        fp_dirname = os.path.dirname(fp)
        if not os.path.exists(fp_dirname):
            os.makedirs(fp_dirname, exist_ok=True)
        torch.save(saved_obj, fp, *args, **kwargs)

    @staticmethod
    def load(handler, fp: str, *args, map_location="cpu", **kwargs):
        assert isinstance(handler, LocalClient)
        assert os.path.exists(fp), f"{fp} is not found!"
        with open(fp, "rb") as f:
            states = torch.load(f, map_location=map_location, *args, **kwargs)
        return states

    @staticmethod
    def assert_fp_exists(handler, folder):
        assert isinstance(handler, LocalClient)
        assert os.path.exists(folder), folder

    @staticmethod
    def get_fns(handler, folder):
        assert isinstance(handler, LocalClient)
        assert os.path.exists(folder), f"folder '{folder}' not exists!"
        fns = os.listdir(folder)
        return fns

    @staticmethod
    def delete_obj(handler, fp: str):
        assert isinstance(handler, LocalClient)
        if not os.path.isdir(fp):
            os.remove(fp)


class StorageManager(metaclass=SingletonMeta):
    """
    Storage Manager for saving or loading checkpoint.
    """

    BACKEND_TYPE = {"boto3", "local"}
    BACKEND_INIT_METHOD = {
        "boto3": Boto3Client,
        "local": LocalClient,
    }
    CLI_DICT = {}

    def __init__(self) -> None:
        pass

    def _get_client(self, path=str) -> Union[Boto3MetaInfo, LocalMetaInfo]:
        """
        example:
        local:/path/to/checkpoint
        boto3:s3://model_weights/0331/120bi

        Args:
            path (str): _description_
        """
        try:
            backend, path = path.split(":", maxsplit=1)
        except Exception as exc:
            raise AttributeError(f"Given path '{path}' is not startwith backend prefix:'local/boto3'") from exc

        init_args = (None,)
        if backend == "local":
            meta_info = get_local_meta(path)
            backend_key = backend
        elif backend == "boto3":
            meta_info = get_boto3_meta(path)
            backend_key = backend + ":" + meta_info.endpoint
            init_args = (meta_info.endpoint,)
            if (
                "http_proxy" in os.environ
                or "https_proxy" in os.environ
                or "HTTP_PROXY" in os.environ
                or "HTTPS_PROXY" in os.environ
            ):
                raise RuntimeWarning(
                    "HTTP/HTTPS proxy is detected when using boto3, incorrectly setting \
the proxy may make boto3 unavailable or affect performance."
                )

        assert backend in StorageManager.BACKEND_TYPE, f"Unkown backend: {backend}"

        # boto3 backend need special treatment.
        if backend_key not in StorageManager.CLI_DICT:
            StorageManager.CLI_DICT.update({backend_key: StorageManager.BACKEND_INIT_METHOD[backend](*init_args)})

        meta_info.client = StorageManager.CLI_DICT[backend_key]

        return meta_info

    def assert_fp_exists(self, folder) -> None:
        meta = self._get_client(path=folder)
        meta.client.assert_fp_exists(*unpack_meta(meta))

    def get_fns(self, folder) -> List[str]:
        meta = self._get_client(path=folder)
        return meta.client.get_fns(*unpack_meta(meta))

    def save(self, save_path: str, saved_obj: Any, *args, **kwargs):
        meta = self._get_client(path=save_path)

        meta.client.sync_upload_fileobj(*unpack_meta(meta), *args, saved_obj=saved_obj, **kwargs)

    def load(self, load_path: str, *args, map_location="cpu", **kwargs) -> Any:

        meta = self._get_client(path=load_path)
        return meta.client.load(*unpack_meta(meta), map_location=map_location, *args, **kwargs)

    def delete_obj(self, fp: str):
        meta = self._get_client(path=fp)
        meta.client.delete_obj(*unpack_meta(meta))


storage_manager = StorageManager()