mirror of https://github.com/hpcaitech/ColossalAI
You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
62 lines
1.7 KiB
62 lines
1.7 KiB
from copy import deepcopy
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
|
|
from colossalai.shardformer.layer import Linear1D_Col, Linear1D_Row
|
|
|
|
from ..registry import model_zoo
|
|
|
|
_BS = 16
|
|
_IN_DIM = 32
|
|
_HID_DIM = 128
|
|
|
|
|
|
class Net(nn.Module):
|
|
def __init__(self, in_dim=_IN_DIM, hid_dim=_HID_DIM, identity=False, dtype=torch.float32):
|
|
super().__init__()
|
|
if identity:
|
|
self.fc0 = nn.Identity()
|
|
else:
|
|
self.fc0 = nn.Linear(in_dim, in_dim).to(dtype=dtype)
|
|
|
|
self.fc1 = nn.Linear(in_dim, hid_dim).to(dtype=dtype)
|
|
self.fc2 = nn.Linear(hid_dim, in_dim).to(dtype=dtype)
|
|
|
|
def forward(self, x):
|
|
return self.fc2(self.fc1(self.fc0(x)))
|
|
|
|
|
|
class TPNet(nn.Module):
|
|
def __init__(
|
|
self,
|
|
fc0=nn.Linear(_IN_DIM, _IN_DIM),
|
|
fc1=nn.Linear(_IN_DIM, _HID_DIM),
|
|
fc2=nn.Linear(_HID_DIM, _IN_DIM),
|
|
tp_group=None,
|
|
dtype=torch.float32,
|
|
):
|
|
super().__init__()
|
|
self.fc0 = deepcopy(fc0)
|
|
self.fc1 = Linear1D_Col.from_native_module(
|
|
deepcopy(fc1), process_group=tp_group, gather_output=False, overlap=True, dtype=dtype
|
|
)
|
|
self.fc2 = Linear1D_Row.from_native_module(
|
|
deepcopy(fc2), process_group=tp_group, parallel_input=True, dtype=dtype
|
|
)
|
|
|
|
def forward(self, x):
|
|
return self.fc2(self.fc1(self.fc0(x)))
|
|
|
|
|
|
def data_gen():
|
|
return torch.randn(_BS, _IN_DIM)
|
|
|
|
|
|
def output_transform(x: torch.Tensor):
|
|
return x
|
|
|
|
|
|
model_zoo.register(name="simple_mlp", model_fn=Net, data_gen_fn=data_gen, output_transform_fn=output_transform)
|
|
model_zoo.register(name="simple_tp_mlp", model_fn=TPNet, data_gen_fn=data_gen, output_transform_fn=output_transform)
|