mirror of https://github.com/hpcaitech/ColossalAI
[Gemini] add an inline_op_module to common test models and polish unitests. (#2004)
parent
56a3dcdabd
commit
3d907faede
|
@ -1 +1 @@
|
|||
from . import repeated_computed_layer, resnet, nested_model, bert, no_leaf_module, simple_net, gpt
|
||||
from . import bert, gpt, inline_op_model, nested_model, no_leaf_module, repeated_computed_layer, resnet, simple_net
|
||||
|
|
|
@ -0,0 +1,52 @@
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from colossalai.nn import CheckpointModule
|
||||
|
||||
from .registry import non_distributed_component_funcs
|
||||
from .utils.dummy_data_generator import DummyDataGenerator
|
||||
|
||||
|
||||
class InlineOpModule(CheckpointModule):
|
||||
"""
|
||||
a module with inline Ops
|
||||
"""
|
||||
|
||||
def __init__(self, checkpoint=False) -> None:
|
||||
super().__init__(checkpoint=checkpoint)
|
||||
self.proj1 = nn.Linear(4, 8)
|
||||
self.weight = nn.Parameter(torch.randn(8, 8))
|
||||
self.proj2 = nn.Linear(8, 4)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.proj1(x)
|
||||
# inline add_
|
||||
x.add_(10)
|
||||
x = F.linear(x, self.weight)
|
||||
# inline relu_
|
||||
x = torch.relu_(x)
|
||||
x = self.proj2(x)
|
||||
return x
|
||||
|
||||
|
||||
class DummyDataLoader(DummyDataGenerator):
|
||||
|
||||
def generate(self):
|
||||
data = torch.rand(16, 4)
|
||||
label = torch.randint(low=0, high=2, size=(16,))
|
||||
return data, label
|
||||
|
||||
|
||||
@non_distributed_component_funcs.register(name='inline_op_module')
|
||||
def get_training_components():
|
||||
|
||||
def model_builder(checkpoint=True):
|
||||
return InlineOpModule(checkpoint)
|
||||
|
||||
trainloader = DummyDataLoader()
|
||||
testloader = DummyDataLoader()
|
||||
|
||||
criterion = torch.nn.CrossEntropyLoss()
|
||||
from colossalai.nn.optimizer import HybridAdam
|
||||
return model_builder, trainloader, testloader, HybridAdam, criterion
|
|
@ -1,38 +1,9 @@
|
|||
from colossalai.gemini.paramhooks import BaseParamHookMgr
|
||||
from torch import nn
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import copy
|
||||
|
||||
import torch
|
||||
|
||||
class SubNet(nn.Module):
|
||||
|
||||
def __init__(self, out_features) -> None:
|
||||
super().__init__()
|
||||
self.bias = nn.Parameter(torch.zeros(out_features))
|
||||
|
||||
def forward(self, x, weight):
|
||||
return F.linear(x, weight, self.bias)
|
||||
|
||||
|
||||
class Net(nn.Module):
|
||||
|
||||
def __init__(self, checkpoint=False) -> None:
|
||||
super().__init__()
|
||||
self.fc1 = nn.Linear(5, 5)
|
||||
self.sub_fc = SubNet(5)
|
||||
self.fc2 = nn.Linear(5, 1)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.fc1(x)
|
||||
x = self.sub_fc(x, self.fc1.weight)
|
||||
x = self.fc1(x)
|
||||
x = self.fc2(x)
|
||||
return x
|
||||
|
||||
|
||||
def net_data():
|
||||
return (torch.randn(2, 5, dtype=torch.float, device='cuda'),)
|
||||
from colossalai.gemini.paramhooks import BaseParamHookMgr
|
||||
from tests.components_to_test.registry import non_distributed_component_funcs
|
||||
|
||||
|
||||
def allclose(tensor_a: torch.Tensor, tensor_b: torch.Tensor, loose=False) -> bool:
|
||||
|
@ -41,54 +12,68 @@ def allclose(tensor_a: torch.Tensor, tensor_b: torch.Tensor, loose=False) -> boo
|
|||
return torch.allclose(tensor_a, tensor_b)
|
||||
|
||||
|
||||
def run_model(model, inputs, label, criterion, use_param_hook=False):
|
||||
if use_param_hook:
|
||||
|
||||
class HooKWrapper:
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.hook_triggered_times = 0
|
||||
|
||||
def wrapper_func(self):
|
||||
|
||||
def hook(param, grad) -> torch.Tensor or None:
|
||||
self.hook_triggered_times += 1
|
||||
return grad
|
||||
|
||||
return hook
|
||||
|
||||
hookwrapper = HooKWrapper()
|
||||
param_list = [p for p in model.parameters()]
|
||||
hook_mgr = BaseParamHookMgr(param_list)
|
||||
hook_mgr.register_backward_hooks(hookwrapper.wrapper_func())
|
||||
|
||||
model.zero_grad(set_to_none=True)
|
||||
|
||||
with torch.cuda.amp.autocast():
|
||||
if criterion:
|
||||
y = model(inputs)
|
||||
loss = criterion(y, label)
|
||||
else:
|
||||
loss = model(inputs, label)
|
||||
loss = loss.float()
|
||||
loss.backward()
|
||||
|
||||
if use_param_hook:
|
||||
hook_mgr.remove_hooks()
|
||||
return hookwrapper.hook_triggered_times
|
||||
|
||||
|
||||
def test_base_param_hook():
|
||||
torch.manual_seed(0)
|
||||
model = Net(checkpoint=True).cuda()
|
||||
model.train()
|
||||
inputs = net_data()
|
||||
test_models = ['repeated_computed_layers', 'resnet18', 'no_leaf_module', 'inline_op_module']
|
||||
# test_models = ['bert']
|
||||
|
||||
def run_model(model, inputs, use_param_hook=False):
|
||||
if use_param_hook:
|
||||
for model_name in test_models:
|
||||
get_components_func = non_distributed_component_funcs.get_callable(model_name)
|
||||
model_builder, train_dataloader, _, _, criterion = get_components_func()
|
||||
|
||||
class HooKWrapper:
|
||||
torch.manual_seed(0)
|
||||
model = model_builder(checkpoint=True).cuda()
|
||||
model.train()
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.hook_triggered_times = 0
|
||||
for i, (inputs, label) in enumerate(train_dataloader):
|
||||
if i > 0:
|
||||
break
|
||||
model_copy = copy.deepcopy(model)
|
||||
|
||||
def wrapper_func(self):
|
||||
run_model(model, inputs.cuda(), label.cuda(), criterion, False)
|
||||
ret2 = run_model(model_copy, inputs.cuda(), label.cuda(), criterion, True)
|
||||
|
||||
def hook(param, grad) -> torch.Tensor or None:
|
||||
self.hook_triggered_times += 1
|
||||
return grad
|
||||
# Make sure param hook has only be fired once in case of parameter sharing
|
||||
assert ret2 == len(list(model.parameters()))
|
||||
|
||||
return hook
|
||||
|
||||
hookwrapper = HooKWrapper()
|
||||
param_list = [p for p in model.parameters()]
|
||||
hook_mgr = BaseParamHookMgr(param_list)
|
||||
hook_mgr.register_backward_hooks(hookwrapper.wrapper_func())
|
||||
|
||||
model.zero_grad(set_to_none=True)
|
||||
|
||||
with torch.cuda.amp.autocast():
|
||||
y = model(*inputs)
|
||||
loss = y.sum()
|
||||
loss.backward()
|
||||
|
||||
if use_param_hook:
|
||||
hook_mgr.remove_hooks()
|
||||
return hookwrapper.hook_triggered_times
|
||||
|
||||
model_copy = copy.deepcopy(model)
|
||||
|
||||
run_model(model, inputs, False)
|
||||
ret2 = run_model(model_copy, inputs, True)
|
||||
|
||||
# Make sure param hook has only be fired once in case of parameter sharing
|
||||
assert ret2 == len(list(model.parameters()))
|
||||
|
||||
for p, p_copy in zip(model.parameters(), model_copy.parameters()):
|
||||
assert allclose(p.grad, p_copy.grad), f"{p.grad} vs {p_copy.grad}"
|
||||
for p, p_copy in zip(model.parameters(), model_copy.parameters()):
|
||||
assert allclose(p.grad, p_copy.grad), f"{p.grad} vs {p_copy.grad}"
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
|
Loading…
Reference in New Issue