"""
Utils for model inference
"""

import math
import os
import re
from enum import Enum
from pathlib import Path
from typing import Optional, Tuple, Union

import torch
from diffusers import DiffusionPipeline
from torch import nn

from colossalai.logging import get_dist_logger
from colossalai.testing import free_port

logger = get_dist_logger(__name__)


def init_to_get_rotary(self, base=10000, use_elem=False):
    """
    This function initializes the rotary positional embedding, it is compatible for all models and is called in ShardFormer
    Args:
        self : Model that holds the rotary positional embedding
        base : calculation arg
        use_elem : activated when using chatglm-based models
    """
    self.config.head_dim_ = self.config.hidden_size // self.config.num_attention_heads
    if not hasattr(self.config, "rope_scaling"):
        rope_scaling_factor = 1.0
    else:
        rope_scaling_factor = self.config.rope_scaling.factor if self.config.rope_scaling is not None else 1.0

    if hasattr(self.config, "max_sequence_length"):
        max_seq_len = self.config.max_sequence_length
    elif hasattr(self.config, "max_position_embeddings"):
        max_seq_len = self.config.max_position_embeddings * rope_scaling_factor
    else:
        max_seq_len = 2048 * rope_scaling_factor
    base = float(base)

    # NTK  ref: https://www.reddit.com/r/LocalLLaMA/comments/14lz7j5/ntkaware_scaled_rope_allows_llama_models_to_have/
    ntk_alpha = os.environ.get("INFER_NTK_ALPHA", None)

    if ntk_alpha is not None:
        ntk_alpha = float(ntk_alpha)
        assert ntk_alpha >= 1, "NTK alpha must be greater than or equal to 1"
        if ntk_alpha > 1:
            print(f"Note: NTK enabled, alpha set to {ntk_alpha}")
        max_seq_len *= ntk_alpha
        base = base * (ntk_alpha ** (self.head_dim_ / (self.head_dim_ - 2)))  # Base change formula

    n_elem = self.config.head_dim_
    if use_elem:
        n_elem //= 2

    inv_freq = 1.0 / (base ** (torch.arange(0, n_elem, 2, device="cpu", dtype=torch.float32) / n_elem))
    t = torch.arange(max_seq_len + 1024 * 64, device="cpu", dtype=torch.float32) / rope_scaling_factor
    freqs = torch.outer(t, inv_freq)

    self._cos_cached = torch.cos(freqs).to(self.dtype).cuda()
    self._sin_cached = torch.sin(freqs).to(self.dtype).cuda()


def has_index_file(checkpoint_path: str) -> Tuple[bool, Optional[Path]]:
    """
    Check whether the checkpoint has an index file.

    Args:
        checkpoint_path (str): path to the checkpoint.

    Returns:
        Tuple[bool, Optional[Path]]: a tuple of (has_index_file, index_file_path)
    """
    checkpoint_path = Path(checkpoint_path)
    if checkpoint_path.is_file():
        # check if it is .index.json
        reg = re.compile("(.*?).index((\..*)?).json")
        if reg.fullmatch(checkpoint_path.name) is not None:
            return True, checkpoint_path
        else:
            return False, None
    elif checkpoint_path.is_dir():
        index_files = list(checkpoint_path.glob("*.index.*json"))

        for index_file in index_files:
            if "safetensors" in index_file.__str__():
                return True, index_file.__str__()  # return the safetensors file first

        if len(index_files) == 1:
            return True, index_files[0]
        else:
            assert (
                len(index_files) == 1
            ), f"Expected to find one .index.json file in {checkpoint_path}, but found {len(index_files)}"
            return False, None
    else:
        raise RuntimeError(f"Invalid checkpoint path {checkpoint_path}. Expected a file or a directory.")


def get_model_size(model: nn.Module):
    """Calculates the total size of the model weights (including biases) in bytes.
    Args:
        model: The PyTorch model to analyze.
    Returns:
        The total size of the model weights in bytes.
    """
    total_size = 0
    for key, param in model.named_parameters():
        total_size += param.element_size() * param.numel()
    return total_size / (1024**3)


def find_available_ports(num: int):
    try:
        free_ports = [free_port() for i in range(num)]
    except OSError as e:
        print(f"An OS error occurred: {e}")
        raise RuntimeError("Error finding available ports")
    return free_ports


def get_alibi_slopes(num_heads: int, device: torch.device) -> torch.Tensor:
    """
    Alibi slopes calculation adapted from https://github.com/huggingface/transformers/blob/v4.36.0/src/transformers/models/bloom/modeling_bloom.py#L57

    Args:
        num_heads (int): The number of attention heads.
        device (torch.device): The device to use.

    Returns:
        torch.Tensor: The Alibi slopes.
    """
    closest_power_of_2 = 2 ** math.floor(math.log2(num_heads))
    base = torch.tensor(2 ** (-(2 ** -(math.log2(closest_power_of_2) - 3))), dtype=torch.float32, device=device)
    powers = torch.arange(1, 1 + closest_power_of_2, dtype=torch.int32, device=device)
    slopes = torch.pow(base, powers)
    if closest_power_of_2 != num_heads:
        extra_base = torch.tensor(
            2 ** (-(2 ** -(math.log2(2 * closest_power_of_2) - 3))), dtype=torch.float32, device=device
        )
        num_remaining_heads = min(closest_power_of_2, num_heads - closest_power_of_2)
        extra_powers = torch.arange(1, 1 + 2 * num_remaining_heads, 2, dtype=torch.int32, device=device)
        slopes = torch.cat([slopes, torch.pow(extra_base, extra_powers)], dim=0)
    return slopes


def can_use_flash_attn2(dtype: torch.dtype) -> bool:
    """
    Check flash attention2 availability.
    """
    if dtype not in (torch.float16, torch.bfloat16):
        return False

    try:
        from flash_attn import flash_attn_varlen_func  # noqa

        return True
    except ImportError:
        logger.warning(f"flash_attn2 has not been installed yet, we will use triton flash attn instead.")
        return False


class ModelType(Enum):
    DIFFUSION_MODEL = "Diffusion Model"
    LLM = "Large Language Model (LLM)"
    UNKNOWN = "Unknown Model Type"


def get_model_type(model_or_path: Union[nn.Module, str, DiffusionPipeline]):
    if isinstance(model_or_path, DiffusionPipeline):
        return ModelType.DIFFUSION_MODEL
    elif isinstance(model_or_path, nn.Module):
        return ModelType.LLM
    elif isinstance(model_or_path, str):
        try:
            from transformers import AutoConfig

            hf_config = AutoConfig.from_pretrained(model_or_path, trust_remote_code=True)
            return ModelType.LLM
        except:
            """
            model type is not `ModelType.LLM`
            """

        try:
            DiffusionPipeline.load_config(model_or_path)
            return ModelType.DIFFUSION_MODEL
        except:
            """
            model type is not `ModelType.DIFFUSION_MODEL`
            """
    else:
        return ModelType.UNKNOWN