# adapted from Hugging Face accelerate/utils/bnb.py accelerate/utils/modeling.py

import logging

import torch
import torch.nn as nn

from .bnb_config import BnbQuantizationConfig

try:
    import bitsandbytes as bnb

    IS_4BIT_BNB_AVAILABLE = bnb.__version__ >= "0.39.0"
    IS_8BIT_BNB_AVAILABLE = bnb.__version__ >= "0.37.2"
except ImportError:
    pass


logger = logging.getLogger(__name__)


def quantize_model(
    model: torch.nn.Module,
    bnb_quantization_config: BnbQuantizationConfig,
):
    """
    This function will quantize the input loaded model with the associated config passed in `bnb_quantization_config`.
    We will quantize the model and put the model on the GPU.

    Args:
        model (`torch.nn.Module`):
            Input model. The model already loaded
        bnb_quantization_config (`BnbQuantizationConfig`):
            The bitsandbytes quantization parameters

    Returns:
        `torch.nn.Module`: The quantized model
    """

    load_in_4bit = bnb_quantization_config.load_in_4bit
    load_in_8bit = bnb_quantization_config.load_in_8bit

    if load_in_8bit and not IS_8BIT_BNB_AVAILABLE:
        raise ImportError(
            "You have a version of `bitsandbytes` that is not compatible with 8bit quantization,"
            " make sure you have the latest version of `bitsandbytes` installed."
        )
    if load_in_4bit and not IS_4BIT_BNB_AVAILABLE:
        raise ValueError(
            "You have a version of `bitsandbytes` that is not compatible with 4bit quantization,"
            "make sure you have the latest version of `bitsandbytes` installed."
        )

    # We keep some modules such as the lm_head in their original dtype for numerical stability reasons
    if bnb_quantization_config.skip_modules is None:
        bnb_quantization_config.skip_modules = get_keys_to_not_convert(model)

    modules_to_not_convert = bnb_quantization_config.skip_modules

    # We add the modules we want to keep in full precision
    if bnb_quantization_config.keep_in_fp32_modules is None:
        bnb_quantization_config.keep_in_fp32_modules = []
    keep_in_fp32_modules = bnb_quantization_config.keep_in_fp32_modules

    # compatibility with peft
    model.is_loaded_in_4bit = load_in_4bit
    model.is_loaded_in_8bit = load_in_8bit

    # assert model_device is cuda
    model_device = next(model.parameters()).device

    model = replace_with_bnb_layers(model, bnb_quantization_config, modules_to_not_convert=modules_to_not_convert)

    # convert param to the right dtype
    dtype = bnb_quantization_config.torch_dtype
    for name, param in model.state_dict().items():
        if any(module_to_keep_in_fp32 in name for module_to_keep_in_fp32 in keep_in_fp32_modules):
            param.to(torch.float32)
            if param.dtype != torch.float32:
                name = name.replace(".weight", "").replace(".bias", "")
                param = getattr(model, name, None)
                if param is not None:
                    param.to(torch.float32)
        elif torch.is_floating_point(param):
            param.to(dtype)
    if model_device.type == "cuda":
        # move everything to cpu in the first place because we can't do quantization if the weights are already on cuda
        model.cuda(torch.cuda.current_device())
        torch.cuda.empty_cache()
    elif torch.cuda.is_available():
        model.to(torch.cuda.current_device())
        logger.info(
            f"The model device type is {model_device.type}. However, cuda is needed for quantization."
            "We move the model to cuda."
        )
    else:
        raise RuntimeError("No GPU found. A GPU is needed for quantization.")
    return model


def replace_with_bnb_layers(model, bnb_quantization_config, modules_to_not_convert=None, current_key_name=None):
    """
    A helper function to replace all `torch.nn.Linear` modules by `bnb.nn.Linear8bit` modules or by `bnb.nn.Linear4bit`
    modules from the `bitsandbytes`library. The function will be run recursively and replace `torch.nn.Linear` modules.

    Parameters:
        model (`torch.nn.Module`):
            Input model or `torch.nn.Module` as the function is run recursively.
        modules_to_not_convert (`List[str]`):
            Names of the modules to not quantize convert. In practice we keep the `lm_head` in full precision for
            numerical stability reasons.
        current_key_name (`List[str]`, *optional*):
            An array to track the current key of the recursion. This is used to check whether the current key (part of
            it) is not in the list of modules to not convert.
    """

    if modules_to_not_convert is None:
        modules_to_not_convert = []

    model, has_been_replaced = _replace_with_bnb_layers(
        model, bnb_quantization_config, modules_to_not_convert, current_key_name
    )
    if not has_been_replaced:
        logger.warning(
            "You are loading your model in 8bit or 4bit but no linear modules were found in your model."
            " this can happen for some architectures such as gpt2 that uses Conv1D instead of Linear layers."
            " Please double check your model architecture, or submit an issue on github if you think this is"
            " a bug."
        )
    return model


def _replace_with_bnb_layers(
    model,
    bnb_quantization_config,
    modules_to_not_convert=None,
    current_key_name=None,
):
    """
    Private method that wraps the recursion for module replacement.

    Returns the converted model and a boolean that indicates if the conversion has been successfull or not.
    """
    # bitsandbytes will initialize CUDA on import, so it needs to be imported lazily

    has_been_replaced = False
    for name, module in model.named_children():
        if current_key_name is None:
            current_key_name = []
        current_key_name.append(name)
        if isinstance(module, nn.Linear) and name not in modules_to_not_convert:
            # Check if the current key is not in the `modules_to_not_convert`
            current_key_name_str = ".".join(current_key_name)
            proceed = True
            for key in modules_to_not_convert:
                if (
                    (key in current_key_name_str) and (key + "." in current_key_name_str)
                ) or key == current_key_name_str:
                    proceed = False
                    break
            if proceed:
                # Load bnb module with empty weight and replace ``nn.Linear` module
                if bnb_quantization_config.load_in_8bit:
                    bnb_module = bnb.nn.Linear8bitLt(
                        module.in_features,
                        module.out_features,
                        module.bias is not None,
                        has_fp16_weights=False,
                        threshold=bnb_quantization_config.llm_int8_threshold,
                    )
                elif bnb_quantization_config.load_in_4bit:
                    bnb_module = bnb.nn.Linear4bit(
                        module.in_features,
                        module.out_features,
                        module.bias is not None,
                        bnb_quantization_config.bnb_4bit_compute_dtype,
                        compress_statistics=bnb_quantization_config.bnb_4bit_use_double_quant,
                        quant_type=bnb_quantization_config.bnb_4bit_quant_type,
                    )
                else:
                    raise ValueError("load_in_8bit and load_in_4bit can't be both False")
                bnb_module.weight.data = module.weight.data
                bnb_module.weight.skip_zero_check = True
                if module.bias is not None:
                    bnb_module.bias.data = module.bias.data
                    bnb_module.bias.skip_zero_check = True
                bnb_module.requires_grad_(False)
                setattr(model, name, bnb_module)
                has_been_replaced = True
        if len(list(module.children())) > 0:
            _, _has_been_replaced = _replace_with_bnb_layers(
                module, bnb_quantization_config, modules_to_not_convert, current_key_name
            )
            has_been_replaced = has_been_replaced | _has_been_replaced
        # Remove the last key for recursion
        current_key_name.pop(-1)
    return model, has_been_replaced


def get_keys_to_not_convert(model):
    r"""
    An utility function to get the key of the module to keep in full precision if any For example for CausalLM modules
    we may want to keep the lm_head in full precision for numerical stability reasons. For other architectures, we want
    to keep the tied weights of the model. The function will return a list of the keys of the modules to not convert in
    int8.

    Parameters:
    model (`torch.nn.Module`):
        Input model
    """
    # Create a copy of the model
    # with init_empty_weights():
    #    tied_model = deepcopy(model)  # this has 0 cost since it is done inside `init_empty_weights` context manager`
    tied_model = model

    tied_params = find_tied_parameters(tied_model)
    # For compatibility with Accelerate < 0.18
    if isinstance(tied_params, dict):
        tied_keys = sum(list(tied_params.values()), []) + list(tied_params.keys())
    else:
        tied_keys = sum(tied_params, [])
    has_tied_params = len(tied_keys) > 0

    # Check if it is a base model
    is_base_model = False
    if hasattr(model, "base_model_prefix"):
        is_base_model = not hasattr(model, model.base_model_prefix)

    # Ignore this for base models (BertModel, GPT2Model, etc.)
    if (not has_tied_params) and is_base_model:
        return []

    # otherwise they have an attached head
    list_modules = list(model.named_children())
    list_last_module = [list_modules[-1][0]]

    # add last module together with tied weights
    intersection = set(list_last_module) - set(tied_keys)
    list_untouched = list(set(tied_keys)) + list(intersection)

    # remove ".weight" from the keys
    names_to_remove = [".weight", ".bias"]
    filtered_module_names = []
    for name in list_untouched:
        for name_to_remove in names_to_remove:
            if name_to_remove in name:
                name = name.replace(name_to_remove, "")
        filtered_module_names.append(name)

    return filtered_module_names


def find_tied_parameters(model: nn.Module, **kwargs):
    """
    Find the tied parameters in a given model.

    <Tip warning={true}>

    The signature accepts keyword arguments, but they are for the recursive part of this function and you should ignore
    them.

    </Tip>

    Args:
        model (`torch.nn.Module`): The model to inspect.

    Returns:
        List[List[str]]: A list of lists of parameter names being all tied together.

    Example:

    ```py
    >>> from collections import OrderedDict
    >>> import torch.nn as nn

    >>> model = nn.Sequential(OrderedDict([("linear1", nn.Linear(4, 4)), ("linear2", nn.Linear(4, 4))]))
    >>> model.linear2.weight = model.linear1.weight
    >>> find_tied_parameters(model)
    [['linear1.weight', 'linear2.weight']]
    ```
    """
    # Initialize result and named_parameters before recursing.
    named_parameters = kwargs.get("named_parameters", None)
    prefix = kwargs.get("prefix", "")
    result = kwargs.get("result", {})

    if named_parameters is None:
        named_parameters = {n: p for n, p in model.named_parameters()}
    else:
        # A tied parameter will not be in the full `named_parameters` seen above but will be in the `named_parameters`
        # of the submodule it belongs to. So while recursing we track the names that are not in the initial
        # `named_parameters`.
        for name, parameter in model.named_parameters():
            full_name = name if prefix == "" else f"{prefix}.{name}"
            if full_name not in named_parameters:
                # When we find one, it has to be one of the existing parameters.
                for new_name, new_param in named_parameters.items():
                    if new_param is parameter:
                        if new_name not in result:
                            result[new_name] = []
                        result[new_name].append(full_name)

    # Once we have treated direct parameters, we move to the child modules.
    for name, child in model.named_children():
        child_name = name if prefix == "" else f"{prefix}.{name}"
        find_tied_parameters(child, named_parameters=named_parameters, prefix=child_name, result=result)

    return FindTiedParametersResult([sorted([weight] + list(set(tied))) for weight, tied in result.items()])


class FindTiedParametersResult(list):
    """
    This is a subclass of a list to handle backward compatibility for Transformers. Do not rely on the fact this is not
    a list or on the `values` method as in the future this will be removed.
    """

    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)

    def values(self):
        return sum([x[1:] for x in self], [])