import torch
import torch.nn as nn
import torch.nn.functional as F

from .linear import Linear
from colossalai.kernel.jit import bias_gelu_impl


class TransformerMLP(nn.Module):
    """MLP.
    MLP will take the input with h hidden state, project it to 4*h
    hidden dimension, perform nonlinear transformation, and project the
    state back into h hidden dimension. At the end, dropout is also
    applied.
    """

    def __init__(self, hidden_size, mlp_ratio, fuse_gelu=True):
        super(TransformerMLP, self).__init__()

        # Project to 4h.
        self.dense_h_to_4h = Linear(
            hidden_size,
            int(hidden_size*mlp_ratio),
            skip_bias_add=True)

        self.bias_gelu_fusion = fuse_gelu
        self.activation_func = F.gelu

        # Project back to h.
        self.dense_4h_to_h = Linear(
            int(hidden_size*mlp_ratio),
            hidden_size,
            skip_bias_add=True)

    def forward(self, hidden_states):
        # hidden states should be in the shape of [s, b, h]
        # it will be projects into [s, b, 4h]
        # and projected back to [s, b, h]
        intermediate_parallel, bias_parallel = self.dense_h_to_4h(hidden_states)

        if self.bias_gelu_fusion:
            intermediate_parallel = \
                bias_gelu_impl(intermediate_parallel, bias_parallel)
        else:
            intermediate_parallel = \
                self.activation_func(intermediate_parallel + bias_parallel)

        # [s, b, h]
        output, output_bias = self.dense_4h_to_h(intermediate_parallel)
        return output, output_bias