ColossalAI/tests/test_elixir/utils/small.py

32 lines
927 B
Python

import torch
import torch.nn as nn
from tests.test_elixir.utils.mlp import MlpModule
from tests.test_elixir.utils.registry import TEST_MODELS
def small_data_fn():
return dict(x=torch.randint(low=0, high=20, size=(4, 8)))
class SmallModel(nn.Module):
def __init__(self, num_embeddings: int = 20, hidden_dim: int = 16) -> None:
super().__init__()
self.embed = nn.Embedding(num_embeddings, hidden_dim)
self.norm1 = nn.LayerNorm(hidden_dim)
self.mlp = MlpModule(hidden_dim=hidden_dim)
self.norm2 = nn.LayerNorm(hidden_dim)
self.proj = nn.Linear(hidden_dim, num_embeddings, bias=False)
self.proj.weight = self.embed.weight
def forward(self, x):
x = self.embed(x)
x = x + self.norm1(self.mlp(x))
x = self.proj(self.norm2(x))
x = x.mean(dim=-2)
return x.sum()
TEST_MODELS.register('small', SmallModel, small_data_fn)