import math
import torch
import torch.nn as nn
from colossalai.context import ParallelMode
from colossalai.nn.layer import VanillaPatchEmbedding, VanillaClassifier, \
    WrappedDropout as Dropout, WrappedDropPath as DropPath
from colossalai.nn.layer.moe import build_ffn_experts, MoeLayer, Top2Router, NormalNoiseGenerator, MoeModule
from .util import moe_sa_args, moe_mlp_args
from ..helper import TransformerLayer
from colossalai.context.moe_context import MOE_CONTEXT
from colossalai.utils import get_current_device
from typing import List


class VanillaSelfAttention(nn.Module):
    """Standard ViT self attention.
    """

    def __init__(self,
                 d_model: int,
                 n_heads: int,
                 d_kv: int,
                 attention_drop: float = 0,
                 drop_rate: float = 0,
                 bias: bool = True,
                 dropout1=None,
                 dropout2=None):
        super().__init__()
        self.n_heads = n_heads
        self.d_kv = d_kv
        self.scale = 1.0 / math.sqrt(self.d_kv)

        self.dense1 = nn.Linear(d_model, 3 * n_heads * d_kv, bias, device=get_current_device())
        self.softmax = nn.Softmax(dim=-1)
        self.atten_drop = nn.Dropout(attention_drop) if dropout1 is None else dropout1
        self.dense2 = nn.Linear(n_heads * d_kv, d_model, device=get_current_device())
        self.dropout = nn.Dropout(drop_rate) if dropout2 is None else dropout2

    def forward(self, x):
        qkv = self.dense1(x)
        new_shape = qkv.shape[:2] + (3, self.n_heads, self.d_kv)
        qkv = qkv.view(*new_shape)
        qkv = qkv.permute(2, 0, 3, 1, 4)
        q, k, v = qkv[:]

        x = torch.matmul(q, k.transpose(-2, -1)) * self.scale
        x = self.atten_drop(self.softmax(x))

        x = torch.matmul(x, v)
        x = x.transpose(1, 2)
        new_shape = x.shape[:2] + (self.n_heads * self.d_kv,)
        x = x.reshape(*new_shape)
        x = self.dense2(x)
        x = self.dropout(x)

        return x


class VanillaFFN(nn.Module):
    """FFN composed with two linear layers, also called MLP.
    """

    def __init__(self,
                 d_model: int,
                 d_ff: int,
                 activation=None,
                 drop_rate: float = 0,
                 bias: bool = True,
                 dropout1=None,
                 dropout2=None):
        super().__init__()
        dense1 = nn.Linear(d_model, d_ff, bias, device=get_current_device())
        act = nn.GELU() if activation is None else activation
        dense2 = nn.Linear(d_ff, d_model, bias, device=get_current_device())
        drop1 = nn.Dropout(drop_rate) if dropout1 is None else dropout1
        drop2 = nn.Dropout(drop_rate) if dropout2 is None else dropout2

        self.ffn = nn.Sequential(dense1, act, drop1, dense2, drop2)

    def forward(self, x):
        return self.ffn(x)


class Widenet(nn.Module):

    def __init__(self,
                 num_experts: int,
                 capacity_factor_train: float = 1.25,
                 capacity_factor_eval: float = 2.0,
                 drop_tks: bool = True,
                 img_size: int = 224,
                 patch_size: int = 16,
                 in_chans: int = 3,
                 num_classes: int = 1000,
                 depth: int = 12,
                 d_model: int = 768,
                 num_heads: int = 12,
                 d_kv: int = 64,
                 d_ff: int = 4096,
                 attention_drop: float = 0.,
                 drop_rate: float = 0.1,
                 drop_path: float = 0.):
        super().__init__()

        embedding = VanillaPatchEmbedding(img_size=img_size,
                                          patch_size=patch_size,
                                          in_chans=in_chans,
                                          embed_size=d_model)
        embed_dropout = Dropout(p=drop_rate, mode=ParallelMode.TENSOR)

        shared_sa = VanillaSelfAttention(**moe_sa_args(
            d_model=d_model, n_heads=num_heads, d_kv=d_kv, attention_drop=attention_drop, drop_rate=drop_rate))

        noisy_func = NormalNoiseGenerator(num_experts)
        shared_router = Top2Router(capacity_factor_train=capacity_factor_train,
                                   capacity_factor_eval=capacity_factor_eval,
                                   noisy_func=noisy_func,
                                   drop_tks=drop_tks)
        shared_experts = build_ffn_experts(num_experts, d_model, d_ff, drop_rate=drop_rate)

        # stochastic depth decay rule
        dpr = [x.item() for x in torch.linspace(0, drop_path, depth)]
        blocks = [
            TransformerLayer(att=shared_sa,
                             ffn=MoeLayer(dim_model=d_model,
                                          num_experts=num_experts,
                                          router=shared_router,
                                          experts=shared_experts),
                             norm1=nn.LayerNorm(d_model, eps=1e-6),
                             norm2=nn.LayerNorm(d_model, eps=1e-6),
                             droppath=DropPath(p=dpr[i], mode=ParallelMode.TENSOR)) for i in range(depth)
        ]
        norm = nn.LayerNorm(d_model, eps=1e-6)
        self.linear = VanillaClassifier(in_features=d_model, num_classes=num_classes)
        nn.init.zeros_(self.linear.weight)
        nn.init.zeros_(self.linear.bias)
        self.widenet = nn.Sequential(embedding, embed_dropout, *blocks, norm)

    def forward(self, x):
        MOE_CONTEXT.reset_loss()
        x = self.widenet(x)
        x = torch.mean(x, dim=1)
        x = self.linear(x)
        return x


class ViTMoE(nn.Module):

    def __init__(self,
                 num_experts: int or List[int],
                 use_residual: bool = False,
                 capacity_factor_train: float = 1.25,
                 capacity_factor_eval: float = 2.0,
                 drop_tks: bool = True,
                 img_size: int = 224,
                 patch_size: int = 16,
                 in_chans: int = 3,
                 num_classes: int = 1000,
                 depth: int = 12,
                 d_model: int = 768,
                 num_heads: int = 12,
                 d_kv: int = 64,
                 d_ff: int = 3072,
                 attention_drop: float = 0.,
                 drop_rate: float = 0.1,
                 drop_path: float = 0.):
        super().__init__()

        assert depth % 2 == 0, "The number of layers should be even right now"

        if isinstance(num_experts, list):
            assert len(num_experts) == depth // 2, \
                "The length of num_experts should equal to the number of MOE layers"
            num_experts_list = num_experts
        else:
            num_experts_list = [num_experts] * (depth // 2)

        embedding = VanillaPatchEmbedding(img_size=img_size,
                                          patch_size=patch_size,
                                          in_chans=in_chans,
                                          embed_size=d_model)
        embed_dropout = Dropout(p=drop_rate, mode=ParallelMode.TENSOR)

        # stochastic depth decay rule
        dpr = [x.item() for x in torch.linspace(0, drop_path, depth)]
        blocks = []
        for i in range(depth):
            sa = VanillaSelfAttention(**moe_sa_args(
                d_model=d_model, n_heads=num_heads, d_kv=d_kv, attention_drop=attention_drop, drop_rate=drop_rate))

            if i % 2 == 0:
                ffn = VanillaFFN(**moe_mlp_args(d_model=d_model, d_ff=d_ff, drop_rate=drop_rate))
            else:
                num_experts = num_experts_list[i // 2]
                experts = build_ffn_experts(num_experts, d_model, d_ff, drop_rate=drop_rate)
                ffn = MoeModule(dim_model=d_model,
                                num_experts=num_experts,
                                top_k=1 if use_residual else 2,
                                capacity_factor_train=capacity_factor_train,
                                capacity_factor_eval=capacity_factor_eval,
                                noisy_policy='Jitter' if use_residual else 'Gaussian',
                                drop_tks=drop_tks,
                                use_residual=use_residual,
                                expert_instance=experts,
                                expert_cls=VanillaFFN,
                                **moe_mlp_args(d_model=d_model, d_ff=d_ff, drop_rate=drop_rate))

            layer = TransformerLayer(att=sa,
                                     ffn=ffn,
                                     norm1=nn.LayerNorm(d_model, eps=1e-6),
                                     norm2=nn.LayerNorm(d_model, eps=1e-6),
                                     droppath=DropPath(p=dpr[i], mode=ParallelMode.TENSOR))
            blocks.append(layer)

        norm = nn.LayerNorm(d_model, eps=1e-6)
        self.linear = VanillaClassifier(in_features=d_model, num_classes=num_classes)
        nn.init.zeros_(self.linear.weight)
        nn.init.zeros_(self.linear.bias)
        self.vitmoe = nn.Sequential(embedding, embed_dropout, *blocks, norm)

    def forward(self, x):
        MOE_CONTEXT.reset_loss()
        x = self.vitmoe(x)
        x = torch.mean(x, dim=1)
        x = self.linear(x)
        return x