[Tensor] init a tp network training unittest (#849)

pull/854/head
Jiarui Fang 2022-04-24 16:43:44 +08:00 committed by GitHub
parent 0dea140760
commit 62f059251b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 113 additions and 6 deletions

View File

@ -1,7 +1,9 @@
from numpy import product
from .op_wrapper import _COLOSSAL_OPS
import torch
from typing import Tuple, Optional
from .op_wrapper import _COLOSSAL_OPS
from numpy import product
class ColoTensor(object):
""" Data Structure for Tensor in Colossal-AI
@ -52,7 +54,6 @@ class ColoTensor(object):
return product(self._size)
@staticmethod
def init_from_torch_tensor(tensor: torch.Tensor, save_payload=True) -> 'ColoTensor':
colo_t = ColoTensor(*tensor.size(),
dtype=tensor.dtype,

View File

@ -26,4 +26,4 @@ class ColoInitContext(InsertPostInitMethodToModuleSubClasses):
save_torch_payload = True if not self._lazy_memory_allocate else False
for name, param in name_list:
delattr(module, name)
setattr(module, name, ColoTensor.init_from_torch_tensor(tensor=param.data, save_payload=save_torch_payload))
setattr(module, name, ColoTensor.init_from_torch_tensor(tensor=param, save_payload=save_torch_payload))

View File

@ -1 +1 @@
from . import repeated_computed_layer, resnet, nested_model, bert, no_leaf_module
from . import repeated_computed_layer, resnet, nested_model, bert, no_leaf_module, simple_net

View File

@ -0,0 +1,44 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from colossalai.nn import CheckpointModule
from .utils.dummy_data_generator import DummyDataGenerator
from .registry import non_distributed_component_funcs
class SimpleNet(CheckpointModule):
"""
In this no-leaf module, it has subordinate nn.modules and a nn.Parameter.
"""
def __init__(self, checkpoint=False) -> None:
super().__init__(checkpoint=checkpoint)
self.proj1 = nn.Linear(4, 8)
self.proj2 = nn.Linear(8, 4)
def forward(self, x):
x = self.proj1(x)
x = self.proj2(x)
return x
class DummyDataLoader(DummyDataGenerator):
def generate(self):
data = torch.rand(16, 4)
label = torch.randint(low=0, high=2, size=(16,))
return data, label
@non_distributed_component_funcs.register(name='simple_net')
def get_training_components():
def model_builder(checkpoint=True):
return SimpleNet(checkpoint)
trainloader = DummyDataLoader()
testloader = DummyDataLoader()
criterion = torch.nn.CrossEntropyLoss()
from colossalai.nn.optimizer import HybridAdam
return model_builder, trainloader, testloader, HybridAdam, criterion

View File

@ -12,10 +12,10 @@ from colossalai.testing import parameterize, rerun_if_address_is_in_use
from colossalai.utils.cuda import get_current_device
from colossalai.utils import free_port
from colossalai.core import global_context as gpc
import torch.distributed as dist
from _utils import check_equal, replace_parameter_add_grad, broadcast_tensor_chunk
def run_linear_tp1d_row_test():
device = get_current_device()
dtype = torch.float32
@ -73,6 +73,7 @@ def run_linear_tp1d_row_test():
B_grad = B_master.grad
check_equal(B_grad, layer.bias.grad)
def run_dist(rank, world_size, port):
config = dict(parallel=dict(tensor=dict(mode="1d", size=world_size),))
colossalai.launch(config=config, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')

View File

@ -0,0 +1,61 @@
from cProfile import label
from statistics import mode
from tests.components_to_test.registry import non_distributed_component_funcs
import colossalai
import pytest
import torch
import torch.multiprocessing as mp
from colossalai.testing import parameterize, rerun_if_address_is_in_use
from colossalai.utils.cuda import get_current_device
from colossalai.utils import free_port
from colossalai.core import global_context as gpc
from colossalai.utils import ColoInitContext
import torch.distributed as dist
from functools import partial
def run_simple_net():
# A simple net with two stacked nn.Linear
get_components_func = non_distributed_component_funcs.get_callable('simple_net')
model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func()
with ColoInitContext():
model = model_builder(checkpoint=True)
# TODO(jzy) we set the Specs for weight of each linear.
# model.proj1.weight.set_spec('1Drow')
# model.proj2.weight.set_spec('1Drow')
for i, (data, label) in enumerate(train_dataloader):
output = model(data)
print(output)
if criterion:
loss = criterion(output, label)
else:
loss = output
loss.backward()
if i > 5:
break
# TODO(jzy) check the results with col.nn.Linear?
def run_dist(rank, world_size, port):
config = dict(parallel=dict(tensor=dict(mode="1d", size=world_size),))
colossalai.launch(config=config, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
run_simple_net()
@pytest.mark.dist
@parameterize('world_size', [1, 4])
@rerun_if_address_is_in_use()
def test_simple_net(world_size):
run_func = partial(run_dist, world_size=world_size, port=free_port())
mp.spawn(run_func, nprocs=world_size)
if __name__ == '__main__':
test_simple_net()