mirror of https://github.com/hpcaitech/ColossalAI
32 lines
927 B
Python
32 lines
927 B
Python
import torch
|
|
import torch.nn as nn
|
|
|
|
from tests.test_elixir.utils.mlp import MlpModule
|
|
from tests.test_elixir.utils.registry import TEST_MODELS
|
|
|
|
|
|
def small_data_fn():
|
|
return dict(x=torch.randint(low=0, high=20, size=(4, 8)))
|
|
|
|
|
|
class SmallModel(nn.Module):
|
|
|
|
def __init__(self, num_embeddings: int = 20, hidden_dim: int = 16) -> None:
|
|
super().__init__()
|
|
self.embed = nn.Embedding(num_embeddings, hidden_dim)
|
|
self.norm1 = nn.LayerNorm(hidden_dim)
|
|
self.mlp = MlpModule(hidden_dim=hidden_dim)
|
|
self.norm2 = nn.LayerNorm(hidden_dim)
|
|
self.proj = nn.Linear(hidden_dim, num_embeddings, bias=False)
|
|
self.proj.weight = self.embed.weight
|
|
|
|
def forward(self, x):
|
|
x = self.embed(x)
|
|
x = x + self.norm1(self.mlp(x))
|
|
x = self.proj(self.norm2(x))
|
|
x = x.mean(dim=-2)
|
|
return x.sum()
|
|
|
|
|
|
TEST_MODELS.register('small', SmallModel, small_data_fn)
|