import torch
import torch.nn as nn
from colossalai.nn.layer import WrappedDropPath as DropPath


class TransformerLayer(nn.Module):
    """Transformer layer builder.
    """
    def __init__(self,
                 att: nn.Module,
                 ffn: nn.Module,
                 norm1: nn.Module,
                 norm2: nn.Module,
                 droppath=None,
                 droppath_rate: float = 0):
        super().__init__()
        self.att = att
        self.ffn = ffn
        self.norm1 = norm1
        self.norm2 = norm2
        self.droppath = DropPath(droppath_rate) if droppath is None else droppath

    def forward(self, x):
        x = x + self.droppath(self.att(self.norm1(x)))
        x = x + self.droppath(self.ffn(self.norm2(x)))
        return x