mirror of https://github.com/hpcaitech/ColossalAI
20 lines
557 B
Python
20 lines
557 B
Python
import torch
|
|
import colossalai
|
|
import colossalai.nn as col_nn
|
|
|
|
class MLP(torch.nn.Module):
|
|
def __init__(self, dim: int = 256):
|
|
super().__init__()
|
|
intermediate_dim = dim * 4
|
|
self.dense_1 = col_nn.Linear(dim, intermediate_dim)
|
|
self.activation = torch.nn.GELU()
|
|
self.dense_2 = col_nn.Linear(intermediate_dim, dim)
|
|
self.dropout = col_nn.Dropout(0.1)
|
|
|
|
def forward(self, x):
|
|
x = self.dense_1(x)
|
|
x = self.activation(x)
|
|
x = self.dense_2(x)
|
|
x = self.dropout(x)
|
|
return x
|