2022-03-08 02:19:18 +00:00
|
|
|
import torch
|
|
|
|
import torch.nn as nn
|
|
|
|
import torch.nn.functional as F
|
2022-11-29 05:42:06 +00:00
|
|
|
|
2023-09-11 08:24:28 +00:00
|
|
|
from colossalai.legacy.nn import CheckpointModule
|
2022-11-29 05:42:06 +00:00
|
|
|
|
2022-03-08 02:19:18 +00:00
|
|
|
from .registry import non_distributed_component_funcs
|
2022-11-29 05:42:06 +00:00
|
|
|
from .utils import DummyDataGenerator
|
2022-03-08 02:19:18 +00:00
|
|
|
|
|
|
|
|
|
|
|
class SubNet(nn.Module):
|
|
|
|
|
|
|
|
def __init__(self, out_features) -> None:
|
|
|
|
super().__init__()
|
|
|
|
self.bias = nn.Parameter(torch.zeros(out_features))
|
|
|
|
|
|
|
|
def forward(self, x, weight):
|
|
|
|
return F.linear(x, weight, self.bias)
|
|
|
|
|
|
|
|
|
2022-03-08 06:45:01 +00:00
|
|
|
class NestedNet(CheckpointModule):
|
2022-03-08 02:19:18 +00:00
|
|
|
|
2022-03-08 06:45:01 +00:00
|
|
|
def __init__(self, checkpoint=False) -> None:
|
|
|
|
super().__init__(checkpoint)
|
2022-03-08 02:19:18 +00:00
|
|
|
self.fc1 = nn.Linear(5, 5)
|
|
|
|
self.sub_fc = SubNet(5)
|
|
|
|
self.fc2 = nn.Linear(5, 2)
|
|
|
|
|
|
|
|
def forward(self, x):
|
|
|
|
x = self.fc1(x)
|
|
|
|
x = self.sub_fc(x, self.fc1.weight)
|
|
|
|
x = self.fc1(x)
|
|
|
|
x = self.fc2(x)
|
|
|
|
return x
|
|
|
|
|
|
|
|
|
|
|
|
class DummyDataLoader(DummyDataGenerator):
|
|
|
|
|
|
|
|
def generate(self):
|
|
|
|
data = torch.rand(16, 5)
|
|
|
|
label = torch.randint(low=0, high=2, size=(16,))
|
|
|
|
return data, label
|
|
|
|
|
|
|
|
|
|
|
|
@non_distributed_component_funcs.register(name='nested_model')
|
|
|
|
def get_training_components():
|
2022-03-08 06:45:01 +00:00
|
|
|
|
2022-11-29 05:42:06 +00:00
|
|
|
def model_builder(checkpoint=False):
|
2022-03-08 06:45:01 +00:00
|
|
|
return NestedNet(checkpoint)
|
|
|
|
|
2022-03-08 02:19:18 +00:00
|
|
|
trainloader = DummyDataLoader()
|
|
|
|
testloader = DummyDataLoader()
|
2022-03-08 06:45:01 +00:00
|
|
|
|
2022-03-08 02:19:18 +00:00
|
|
|
criterion = torch.nn.CrossEntropyLoss()
|
2022-03-14 12:48:41 +00:00
|
|
|
return model_builder, trainloader, testloader, torch.optim.Adam, criterion
|