From 62f059251bfcfffcc2f1d89303e27628e5ea8d04 Mon Sep 17 00:00:00 2001
From: Jiarui Fang <fangjiarui123@gmail.com>
Date: Sun, 24 Apr 2022 16:43:44 +0800
Subject: [PATCH] [Tensor] init a tp network training unittest (#849)

---
 colossalai/tensor/colo_tensor.py            |  7 ++-
 colossalai/utils/model/colo_init_context.py |  2 +-
 tests/components_to_test/__init__.py        |  2 +-
 tests/components_to_test/simple_net.py      | 44 +++++++++++++++
 tests/test_tensor/test_linear_tp.py         |  3 +-
 tests/test_tensor/test_net_tp.py            | 61 +++++++++++++++++++++
 6 files changed, 113 insertions(+), 6 deletions(-)
 create mode 100644 tests/components_to_test/simple_net.py
 create mode 100644 tests/test_tensor/test_net_tp.py

diff --git a/colossalai/tensor/colo_tensor.py b/colossalai/tensor/colo_tensor.py
index 3a567f223..ad2b28e7f 100644
--- a/colossalai/tensor/colo_tensor.py
+++ b/colossalai/tensor/colo_tensor.py
@@ -1,7 +1,9 @@
-from numpy import product
+from .op_wrapper import _COLOSSAL_OPS
+
 import torch
 from typing import Tuple, Optional
-from .op_wrapper import _COLOSSAL_OPS
+from numpy import product
+
 
 class ColoTensor(object):
     """ Data Structure for Tensor in Colossal-AI
@@ -52,7 +54,6 @@ class ColoTensor(object):
         return product(self._size)
 
     @staticmethod
-
     def init_from_torch_tensor(tensor: torch.Tensor, save_payload=True) -> 'ColoTensor':
         colo_t = ColoTensor(*tensor.size(),
                             dtype=tensor.dtype,
diff --git a/colossalai/utils/model/colo_init_context.py b/colossalai/utils/model/colo_init_context.py
index 1e9efec0a..d6cb197eb 100644
--- a/colossalai/utils/model/colo_init_context.py
+++ b/colossalai/utils/model/colo_init_context.py
@@ -26,4 +26,4 @@ class ColoInitContext(InsertPostInitMethodToModuleSubClasses):
         save_torch_payload = True if not self._lazy_memory_allocate else False
         for name, param in name_list:
             delattr(module, name)
-            setattr(module, name, ColoTensor.init_from_torch_tensor(tensor=param.data, save_payload=save_torch_payload))
+            setattr(module, name, ColoTensor.init_from_torch_tensor(tensor=param, save_payload=save_torch_payload))
diff --git a/tests/components_to_test/__init__.py b/tests/components_to_test/__init__.py
index 590314de8..099bbe813 100644
--- a/tests/components_to_test/__init__.py
+++ b/tests/components_to_test/__init__.py
@@ -1 +1 @@
-from . import repeated_computed_layer, resnet, nested_model, bert, no_leaf_module
+from . import repeated_computed_layer, resnet, nested_model, bert, no_leaf_module, simple_net
diff --git a/tests/components_to_test/simple_net.py b/tests/components_to_test/simple_net.py
new file mode 100644
index 000000000..487de2062
--- /dev/null
+++ b/tests/components_to_test/simple_net.py
@@ -0,0 +1,44 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from colossalai.nn import CheckpointModule
+from .utils.dummy_data_generator import DummyDataGenerator
+from .registry import non_distributed_component_funcs
+
+
+class SimpleNet(CheckpointModule):
+    """
+    In this no-leaf module, it has subordinate nn.modules and a nn.Parameter.
+    """
+
+    def __init__(self, checkpoint=False) -> None:
+        super().__init__(checkpoint=checkpoint)
+        self.proj1 = nn.Linear(4, 8)
+        self.proj2 = nn.Linear(8, 4)
+
+    def forward(self, x):
+        x = self.proj1(x)
+        x = self.proj2(x)
+        return x
+
+
+class DummyDataLoader(DummyDataGenerator):
+
+    def generate(self):
+        data = torch.rand(16, 4)
+        label = torch.randint(low=0, high=2, size=(16,))
+        return data, label
+
+
+@non_distributed_component_funcs.register(name='simple_net')
+def get_training_components():
+
+    def model_builder(checkpoint=True):
+        return SimpleNet(checkpoint)
+
+    trainloader = DummyDataLoader()
+    testloader = DummyDataLoader()
+
+    criterion = torch.nn.CrossEntropyLoss()
+    from colossalai.nn.optimizer import HybridAdam
+    return model_builder, trainloader, testloader, HybridAdam, criterion
diff --git a/tests/test_tensor/test_linear_tp.py b/tests/test_tensor/test_linear_tp.py
index bd3adcf8f..4119d60b3 100644
--- a/tests/test_tensor/test_linear_tp.py
+++ b/tests/test_tensor/test_linear_tp.py
@@ -12,10 +12,10 @@ from colossalai.testing import parameterize, rerun_if_address_is_in_use
 from colossalai.utils.cuda import get_current_device
 from colossalai.utils import free_port
 from colossalai.core import global_context as gpc
-import torch.distributed as dist
 
 from _utils import check_equal, replace_parameter_add_grad, broadcast_tensor_chunk
 
+
 def run_linear_tp1d_row_test():
     device = get_current_device()
     dtype = torch.float32
@@ -73,6 +73,7 @@ def run_linear_tp1d_row_test():
     B_grad = B_master.grad
     check_equal(B_grad, layer.bias.grad)
 
+
 def run_dist(rank, world_size, port):
     config = dict(parallel=dict(tensor=dict(mode="1d", size=world_size),))
     colossalai.launch(config=config, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
diff --git a/tests/test_tensor/test_net_tp.py b/tests/test_tensor/test_net_tp.py
new file mode 100644
index 000000000..c39fa34c5
--- /dev/null
+++ b/tests/test_tensor/test_net_tp.py
@@ -0,0 +1,61 @@
+from cProfile import label
+from statistics import mode
+from tests.components_to_test.registry import non_distributed_component_funcs
+
+import colossalai
+import pytest
+import torch
+import torch.multiprocessing as mp
+from colossalai.testing import parameterize, rerun_if_address_is_in_use
+from colossalai.utils.cuda import get_current_device
+from colossalai.utils import free_port
+from colossalai.core import global_context as gpc
+from colossalai.utils import ColoInitContext
+
+import torch.distributed as dist
+from functools import partial
+
+
+def run_simple_net():
+    # A simple net with two stacked nn.Linear
+    get_components_func = non_distributed_component_funcs.get_callable('simple_net')
+    model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func()
+    with ColoInitContext():
+        model = model_builder(checkpoint=True)
+
+    # TODO(jzy) we set the Specs for weight of each linear.
+    # model.proj1.weight.set_spec('1Drow')
+    # model.proj2.weight.set_spec('1Drow')
+
+    for i, (data, label) in enumerate(train_dataloader):
+        output = model(data)
+        print(output)
+        if criterion:
+            loss = criterion(output, label)
+        else:
+            loss = output
+
+        loss.backward()
+
+        if i > 5:
+            break
+
+    # TODO(jzy) check the results with col.nn.Linear?
+
+
+def run_dist(rank, world_size, port):
+    config = dict(parallel=dict(tensor=dict(mode="1d", size=world_size),))
+    colossalai.launch(config=config, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
+    run_simple_net()
+
+
+@pytest.mark.dist
+@parameterize('world_size', [1, 4])
+@rerun_if_address_is_in_use()
+def test_simple_net(world_size):
+    run_func = partial(run_dist, world_size=world_size, port=free_port())
+    mp.spawn(run_func, nprocs=world_size)
+
+
+if __name__ == '__main__':
+    test_simple_net()