From 62f059251bfcfffcc2f1d89303e27628e5ea8d04 Mon Sep 17 00:00:00 2001 From: Jiarui Fang <fangjiarui123@gmail.com> Date: Sun, 24 Apr 2022 16:43:44 +0800 Subject: [PATCH] [Tensor] init a tp network training unittest (#849) --- colossalai/tensor/colo_tensor.py | 7 ++- colossalai/utils/model/colo_init_context.py | 2 +- tests/components_to_test/__init__.py | 2 +- tests/components_to_test/simple_net.py | 44 +++++++++++++++ tests/test_tensor/test_linear_tp.py | 3 +- tests/test_tensor/test_net_tp.py | 61 +++++++++++++++++++++ 6 files changed, 113 insertions(+), 6 deletions(-) create mode 100644 tests/components_to_test/simple_net.py create mode 100644 tests/test_tensor/test_net_tp.py diff --git a/colossalai/tensor/colo_tensor.py b/colossalai/tensor/colo_tensor.py index 3a567f223..ad2b28e7f 100644 --- a/colossalai/tensor/colo_tensor.py +++ b/colossalai/tensor/colo_tensor.py @@ -1,7 +1,9 @@ -from numpy import product +from .op_wrapper import _COLOSSAL_OPS + import torch from typing import Tuple, Optional -from .op_wrapper import _COLOSSAL_OPS +from numpy import product + class ColoTensor(object): """ Data Structure for Tensor in Colossal-AI @@ -52,7 +54,6 @@ class ColoTensor(object): return product(self._size) @staticmethod - def init_from_torch_tensor(tensor: torch.Tensor, save_payload=True) -> 'ColoTensor': colo_t = ColoTensor(*tensor.size(), dtype=tensor.dtype, diff --git a/colossalai/utils/model/colo_init_context.py b/colossalai/utils/model/colo_init_context.py index 1e9efec0a..d6cb197eb 100644 --- a/colossalai/utils/model/colo_init_context.py +++ b/colossalai/utils/model/colo_init_context.py @@ -26,4 +26,4 @@ class ColoInitContext(InsertPostInitMethodToModuleSubClasses): save_torch_payload = True if not self._lazy_memory_allocate else False for name, param in name_list: delattr(module, name) - setattr(module, name, ColoTensor.init_from_torch_tensor(tensor=param.data, save_payload=save_torch_payload)) + setattr(module, name, ColoTensor.init_from_torch_tensor(tensor=param, save_payload=save_torch_payload)) diff --git a/tests/components_to_test/__init__.py b/tests/components_to_test/__init__.py index 590314de8..099bbe813 100644 --- a/tests/components_to_test/__init__.py +++ b/tests/components_to_test/__init__.py @@ -1 +1 @@ -from . import repeated_computed_layer, resnet, nested_model, bert, no_leaf_module +from . import repeated_computed_layer, resnet, nested_model, bert, no_leaf_module, simple_net diff --git a/tests/components_to_test/simple_net.py b/tests/components_to_test/simple_net.py new file mode 100644 index 000000000..487de2062 --- /dev/null +++ b/tests/components_to_test/simple_net.py @@ -0,0 +1,44 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +from colossalai.nn import CheckpointModule +from .utils.dummy_data_generator import DummyDataGenerator +from .registry import non_distributed_component_funcs + + +class SimpleNet(CheckpointModule): + """ + In this no-leaf module, it has subordinate nn.modules and a nn.Parameter. + """ + + def __init__(self, checkpoint=False) -> None: + super().__init__(checkpoint=checkpoint) + self.proj1 = nn.Linear(4, 8) + self.proj2 = nn.Linear(8, 4) + + def forward(self, x): + x = self.proj1(x) + x = self.proj2(x) + return x + + +class DummyDataLoader(DummyDataGenerator): + + def generate(self): + data = torch.rand(16, 4) + label = torch.randint(low=0, high=2, size=(16,)) + return data, label + + +@non_distributed_component_funcs.register(name='simple_net') +def get_training_components(): + + def model_builder(checkpoint=True): + return SimpleNet(checkpoint) + + trainloader = DummyDataLoader() + testloader = DummyDataLoader() + + criterion = torch.nn.CrossEntropyLoss() + from colossalai.nn.optimizer import HybridAdam + return model_builder, trainloader, testloader, HybridAdam, criterion diff --git a/tests/test_tensor/test_linear_tp.py b/tests/test_tensor/test_linear_tp.py index bd3adcf8f..4119d60b3 100644 --- a/tests/test_tensor/test_linear_tp.py +++ b/tests/test_tensor/test_linear_tp.py @@ -12,10 +12,10 @@ from colossalai.testing import parameterize, rerun_if_address_is_in_use from colossalai.utils.cuda import get_current_device from colossalai.utils import free_port from colossalai.core import global_context as gpc -import torch.distributed as dist from _utils import check_equal, replace_parameter_add_grad, broadcast_tensor_chunk + def run_linear_tp1d_row_test(): device = get_current_device() dtype = torch.float32 @@ -73,6 +73,7 @@ def run_linear_tp1d_row_test(): B_grad = B_master.grad check_equal(B_grad, layer.bias.grad) + def run_dist(rank, world_size, port): config = dict(parallel=dict(tensor=dict(mode="1d", size=world_size),)) colossalai.launch(config=config, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') diff --git a/tests/test_tensor/test_net_tp.py b/tests/test_tensor/test_net_tp.py new file mode 100644 index 000000000..c39fa34c5 --- /dev/null +++ b/tests/test_tensor/test_net_tp.py @@ -0,0 +1,61 @@ +from cProfile import label +from statistics import mode +from tests.components_to_test.registry import non_distributed_component_funcs + +import colossalai +import pytest +import torch +import torch.multiprocessing as mp +from colossalai.testing import parameterize, rerun_if_address_is_in_use +from colossalai.utils.cuda import get_current_device +from colossalai.utils import free_port +from colossalai.core import global_context as gpc +from colossalai.utils import ColoInitContext + +import torch.distributed as dist +from functools import partial + + +def run_simple_net(): + # A simple net with two stacked nn.Linear + get_components_func = non_distributed_component_funcs.get_callable('simple_net') + model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func() + with ColoInitContext(): + model = model_builder(checkpoint=True) + + # TODO(jzy) we set the Specs for weight of each linear. + # model.proj1.weight.set_spec('1Drow') + # model.proj2.weight.set_spec('1Drow') + + for i, (data, label) in enumerate(train_dataloader): + output = model(data) + print(output) + if criterion: + loss = criterion(output, label) + else: + loss = output + + loss.backward() + + if i > 5: + break + + # TODO(jzy) check the results with col.nn.Linear? + + +def run_dist(rank, world_size, port): + config = dict(parallel=dict(tensor=dict(mode="1d", size=world_size),)) + colossalai.launch(config=config, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + run_simple_net() + + +@pytest.mark.dist +@parameterize('world_size', [1, 4]) +@rerun_if_address_is_in_use() +def test_simple_net(world_size): + run_func = partial(run_dist, world_size=world_size, port=free_port()) + mp.spawn(run_func, nprocs=world_size) + + +if __name__ == '__main__': + test_simple_net()