import math
import importlib
from dataclasses import dataclass

import torch
from torch import nn
from torch.autograd import Function


def check_config(config):
    if config.hidden_size % config.nhead != 0:
        raise Exception("hidden_size % nhead != 0")

    factor = 8 if config.fp16 else 4
    upbound = factor * 1024 * 4
    if config.hidden_size > upbound:
        # as required by ln backward kernel currently
        raise Exception(f"hidden_size > {upbound}")

    head_dim = config.hidden_size // config.nhead
    if head_dim % factor != 0:
        # as required by reshape kernel
        raise Exception(f"head_dim({head_dim}) % {factor} != 0")


def calc_offset(sizes):
    offsets = [0]
    tmp = 0
    for x in sizes:
        tmp += x
        offsets.append(tmp)
    return offsets


colossal_multihead_attention = None


@dataclass
class Config:
    max_batch_tokens: int  # max batch token numbers
    max_seq_len: int  # max sequence length
    hidden_size: int  # size of transformer hidden layers
    nhead: int  # number of heads in attention
    attn_prob_dropout_ratio: float  # attention score dropout ratio
    hidden_dropout_ratio: float  # dropout ration before residual
    norm_first: bool  # norm_first
    fp16: bool  # fp16 presion


class MultiHeadAttention1DFunc(Function):

    @staticmethod
    def forward(ctx, input, input_mask, in_proj_weight, in_proj_bias, out_proj_weight,
                out_proj_bias, norm_weight, norm_bias, config):
        cuda_module = colossal_multihead_attention
        forward_func = (cuda_module.multihead_attention_fw_fp16
                        if config.fp16 else cuda_module.multihead_attention_fw_fp32)
        if config.fp16:
            input = input.to(torch.half)
            input_mask = input_mask.to(torch.half)

        (output,) = forward_func(config.layer_id, input, input_mask, in_proj_weight, in_proj_bias,
                                 out_proj_weight, out_proj_bias, norm_weight, norm_bias,
                                 config.training, config.norm_first)

        if config.is_grad_enabled and config.training:
            ctx.save_for_backward(output, input, input_mask, in_proj_weight, in_proj_bias,
                                  out_proj_weight, out_proj_bias, norm_weight, norm_bias)
            ctx.config = config
        return output

    @staticmethod
    def backward(ctx, grad_output):
        assert ctx.config.training

        cuda_module = colossal_multihead_attention
        backward_func = (cuda_module.multihead_attention_bw_fp16
                         if ctx.config.fp16 else cuda_module.multihead_attention_bw_fp32)

        output, input, input_mask, in_proj_weight, in_proj_bias, out_proj_weight, \
            out_proj_bias, norm_weight, norm_bias = ctx.saved_tensors

        grad_input = None
        grad_in_proj_weight = None
        grad_in_proj_bias = None
        grad_out_proj_weight = None
        grad_out_proj_bias = None
        grad_norm_weight = None
        grad_norm_bias = None

        if ctx.config.fp16:
            grad_output = grad_output.to(torch.half)
            output = output.to(torch.half)
            input = input.to(torch.half)
            input_mask = input_mask.to(torch.half)
        grad_input, grad_in_proj_weight, grad_in_proj_bias, grad_out_proj_weight, \
            grad_out_proj_bias, grad_norm_weight, grad_norm_bias = backward_func(
                ctx.config.layer_id, grad_output, output, input, input_mask, in_proj_weight,
                in_proj_bias, out_proj_weight, out_proj_bias, norm_weight, norm_bias)

        return (grad_input, None, grad_in_proj_weight, grad_in_proj_bias, grad_out_proj_weight,
                grad_out_proj_bias, grad_norm_weight, grad_norm_bias, None)


class MultiHeadAttention(nn.Module):
    """Initialize the MultiHeadAttention.

    Static variable:

        layer_id: The layer-index counter starting from 0 and incrementing by 1 every time a layer object is instantiated,
        e.g. if a model has 24 transformer layers, layer_id goes from 0 to 23.

    Arguments:
        hidden_size: Total dimension of hidden_size.
        nhead: Number of parallel attention heads.
        batch_size: Batch Size for one foward
        max_seq_len: Max length of input sequence
        dropout: Dropout probability
        norm_first: perform LayerNorms before attention
    """

    layer_id = 0

    def __init__(self,
                 hidden_size,
                 nhead,
                 batch_size,
                 max_seq_len,
                 dropout=0.0,
                 norm_first=False,
                 fp16=True,
                 pg=None):
        super(MultiHeadAttention, self).__init__()

        self.config = Config(batch_size * max_seq_len, max_seq_len, hidden_size, nhead, dropout,
                             dropout, norm_first, fp16)
        check_config(self.config)
        self.pg = pg
        self.pg_size = 1
        if self.pg:
            self.pg_size = pg.size()
        self.config.layer_id = MultiHeadAttention.layer_id
        MultiHeadAttention.layer_id = MultiHeadAttention.layer_id + 1

        # Load cuda modules if needed
        global colossal_multihead_attention
        if colossal_multihead_attention is None:
            try:
                colossal_multihead_attention = importlib.import_module("colossal_multihead_attention")
            except ImportError:
                raise RuntimeError('MultiHeadAttention requires cuda extensions')

        # create the layer in cuda kernels.
        cuda_module = colossal_multihead_attention
        create_layer_func = (cuda_module.create_multihead_attention_fp16
                             if self.config.fp16 else cuda_module.create_multihead_attention_fp32)

        create_layer_func(
            self.config.layer_id,
            self.config.max_batch_tokens,
            self.config.max_seq_len,
            self.config.hidden_size,
            self.config.nhead,
            self.config.attn_prob_dropout_ratio,
            self.config.hidden_dropout_ratio,
            self.config.norm_first,
            self.pg,
        )

        hs = self.config.hidden_size

        self.precision = torch.float32
        if self.config.fp16:
            self.precision = torch.half

        self.hs_per_rank = int(hs / self.pg_size)

        self.in_proj_weight = nn.Parameter(torch.Tensor(3, self.hs_per_rank, hs))
        self.in_proj_bias = nn.Parameter(torch.Tensor(3, self.hs_per_rank))
        self.out_proj_weight = nn.Parameter(torch.Tensor(hs, self.hs_per_rank))
        self.out_proj_bias = nn.Parameter(torch.Tensor(hs))
        self.norm_weight = nn.Parameter(torch.Tensor(hs))
        self.norm_bias = nn.Parameter(torch.Tensor(hs))

        self.reset_parameters()
        torch.cuda.empty_cache()

    def calc_bound(self, w):
        fan_in, _ = nn.init._calculate_fan_in_and_fan_out(w)
        bound = 1.0 / math.sqrt(fan_in)
        return bound

    def reset_parameters(self):
        hs = self.config.hidden_size

        nn.init.zeros_(self.out_proj_bias)

        nn.init.ones_(self.norm_weight)
        nn.init.zeros_(self.norm_bias)

        if self.pg_size > 1:
            rank_in_pg = torch.distributed.get_rank(self.pg)
            attn_qkvw_global = torch.empty(hs * 3, hs)
            attn_qkvb_global = torch.empty(hs * 3)
            nn.init.xavier_uniform_(attn_qkvw_global, 1.0 / math.sqrt(2.0))
            bound = self.calc_bound(attn_qkvw_global)
            nn.init.uniform_(attn_qkvb_global, -bound, bound)

            attn_qkvw_global = attn_qkvw_global.cuda()
            attn_qkvb_global = attn_qkvb_global.cuda()
            torch.distributed.broadcast(attn_qkvw_global, src=0, group=self.pg)
            torch.distributed.broadcast(attn_qkvb_global, src=0, group=self.pg)
            attn_qkvw_global = attn_qkvw_global.cpu()
            attn_qkvb_global = attn_qkvb_global.cpu()

            with torch.no_grad():
                self.in_proj_weight.copy_(
                    attn_qkvw_global.view(3, hs, hs)[
                        :, int(hs * rank_in_pg / self.pg_size):
                        int(hs * (rank_in_pg + 1) / self.pg_size),
                        :])
                self.in_proj_bias.copy_(
                    attn_qkvb_global.view(3, hs)[
                        :, int(hs * rank_in_pg / self.pg_size):
                        int(hs * (rank_in_pg + 1) / self.pg_size)])

            attn_ow_global = torch.empty(hs, hs)
            nn.init.xavier_uniform_(attn_ow_global, 1.0)
            attn_ow_global = attn_ow_global.cuda()
            torch.distributed.broadcast(attn_ow_global, src=0, group=self.pg)
            attn_ow_global = attn_ow_global.cpu()
            with torch.no_grad():
                self.out_proj_weight.copy_(attn_ow_global[
                    :, int(hs * rank_in_pg / self.pg_size):
                    int(hs * (rank_in_pg + 1) / self.pg_size)])

        else:
            attn_qkvw = self.in_proj_weight.view(-1, hs)
            nn.init.xavier_uniform_(attn_qkvw, 1.0 / math.sqrt(2.0))
            bound = self.calc_bound(attn_qkvw)
            nn.init.uniform_(self.in_proj_bias, -bound, bound)

            nn.init.xavier_uniform_(self.out_proj_weight, 1.0)

    def state_dict(self, destination=None, prefix="", keep_vars=False):
        destination = torch.nn.Module.state_dict(self,
                                                 destination=destination,
                                                 prefix=prefix,
                                                 keep_vars=keep_vars)
        return destination

    def forward(self, hidden_states, encoder_padding_mask):
        self.config.training = self.training
        self.config.is_grad_enabled = torch.is_grad_enabled()
        hidden_states = hidden_states.contiguous()
        encoder_padding_mask = ((encoder_padding_mask * -1e8).type_as(hidden_states).contiguous())

        bs, sl, dim = hidden_states.size()
        if bs * sl > self.config.max_batch_tokens:
            raise ValueError(
                f"Batch token numbers {bs * sl} exceeds the limit {self.config.max_batch_tokens}.")
        if sl > self.config.max_seq_len:
            raise ValueError(f"Sequence length {sl} exceeds the limit {self.config.max_seq_len}.")
        if len(encoder_padding_mask.size()) == 1:
            assert bs == 1 and sl == encoder_padding_mask.size(0)
        else:
            assert bs == encoder_padding_mask.size(0) and sl == encoder_padding_mask.size(1)

        output = MultiHeadAttention1DFunc.apply(hidden_states, encoder_padding_mask,
                                                self.in_proj_weight, self.in_proj_bias,
                                                self.out_proj_weight, self.out_proj_bias,
                                                self.norm_weight, self.norm_bias, self.config)

        return output.to(self.precision)