You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
ColossalAI/colossalai/cli/benchmark/models.py

18 lines
385 B

import torch
import colossalai.nn as col_nn
class MLP(torch.nn.Module):
def __init__(self, dim: int, layers: int):
super().__init__()
self.layers = torch.nn.ModuleList()
for _ in range(layers):
self.layers.append(col_nn.Linear(dim, dim))
def forward(self, x):
for layer in self.layers:
x = layer(x)
return x