mirror of https://github.com/hpcaitech/ColossalAI
You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
27 lines
798 B
27 lines
798 B
3 years ago
|
import torch
|
||
|
import torch.nn as nn
|
||
|
from colossalai.nn.layer import WrappedDropPath as DropPath
|
||
|
|
||
|
|
||
|
class TransformerLayer(nn.Module):
|
||
|
"""Transformer layer builder.
|
||
|
"""
|
||
|
def __init__(self,
|
||
|
att: nn.Module,
|
||
|
ffn: nn.Module,
|
||
|
norm1: nn.Module,
|
||
|
norm2: nn.Module,
|
||
|
droppath=None,
|
||
|
droppath_rate: float = 0):
|
||
|
super().__init__()
|
||
|
self.att = att
|
||
|
self.ffn = ffn
|
||
|
self.norm1 = norm1
|
||
|
self.norm2 = norm2
|
||
|
self.droppath = DropPath(droppath_rate) if droppath is None else droppath
|
||
|
|
||
|
def forward(self, x):
|
||
|
x = x + self.droppath(self.att(self.norm1(x)))
|
||
|
x = x + self.droppath(self.ffn(self.norm2(x)))
|
||
|
return x
|