import torch import colossalai import colossalai.nn as col_nn class MLP(torch.nn.Module): def __init__(self, dim: int = 256): super().__init__() intermediate_dim = dim * 4 self.dense_1 = col_nn.Linear(dim, intermediate_dim) self.activation = torch.nn.GELU() self.dense_2 = col_nn.Linear(intermediate_dim, dim) self.dropout = col_nn.Dropout(0.1) def forward(self, x): x = self.dense_1(x) x = self.activation(x) x = self.dense_2(x) x = self.dropout(x) return x