ColossalAI/model_zoo/helper.py

27 lines
798 B
Python
Raw Normal View History

2022-01-07 07:08:36 +00:00
import torch
import torch.nn as nn
from colossalai.nn.layer import WrappedDropPath as DropPath
class TransformerLayer(nn.Module):
"""Transformer layer builder.
"""
def __init__(self,
att: nn.Module,
ffn: nn.Module,
norm1: nn.Module,
norm2: nn.Module,
droppath=None,
droppath_rate: float = 0):
super().__init__()
self.att = att
self.ffn = ffn
self.norm1 = norm1
self.norm2 = norm2
self.droppath = DropPath(droppath_rate) if droppath is None else droppath
def forward(self, x):
x = x + self.droppath(self.att(self.norm1(x)))
x = x + self.droppath(self.ffn(self.norm2(x)))
return x