From ed6426c300a4f3ed343ca9d7b4aad1b6199d1469 Mon Sep 17 00:00:00 2001
From: Jiarui Fang <fangjiarui123@gmail.com>
Date: Fri, 6 May 2022 17:07:56 +0800
Subject: [PATCH] [Tensor] polish model test (#915)

---
 tests/test_tensor/test_model.py | 196 +++++++++++++-------------------
 1 file changed, 77 insertions(+), 119 deletions(-)

diff --git a/tests/test_tensor/test_model.py b/tests/test_tensor/test_model.py
index 06544401c..4a8c8f8d1 100644
--- a/tests/test_tensor/test_model.py
+++ b/tests/test_tensor/test_model.py
@@ -21,7 +21,9 @@ import numpy as np
 # Make it available to our ColoTensor
 from transformers.file_utils import ModelOutput
 from dataclasses import fields
-def post_init_colo(self):
+
+
+def _post_init_colo(self):
     class_fields = fields(self)
     # Safety and consistency checks
     if not len(class_fields):
@@ -38,7 +40,7 @@ def post_init_colo(self):
         """
         if isinstance(x, torch.Tensor):
             return True
-    
+
         return isinstance(x, ColoTensor)
 
     if other_fields_are_none and not is_tensor_with_colo(first_field):
@@ -56,11 +58,7 @@ def post_init_colo(self):
         # set the associated fields
         if first_field_iterator:
             for element in iterator:
-                if (
-                    not isinstance(element, (list, tuple))
-                    or not len(element) == 2
-                    or not isinstance(element[0], str)
-                ):
+                if (not isinstance(element, (list, tuple)) or not len(element) == 2 or not isinstance(element[0], str)):
                     break
                 setattr(self, element[0], element[1])
                 if element[1] is not None:
@@ -73,9 +71,11 @@ def post_init_colo(self):
             if v is not None:
                 self[field.name] = v
 
-ModelOutput.__post_init__ = post_init_colo
+
+ModelOutput.__post_init__ = _post_init_colo
 # complete the hack
 
+
 def set_seed(seed):
     random.seed(seed)
     os.environ['PYTHONHASHSEED'] = str(seed)
@@ -85,9 +85,9 @@ def set_seed(seed):
     torch.backends.cudnn.deterministic = True
 
 
-def run_1d_col_tp():
+def run_1d_col_tp(model_name):
     # A simple net with two stacked nn.Linear
-    get_components_func = non_distributed_component_funcs.get_callable('simple_net')
+    get_components_func = non_distributed_component_funcs.get_callable(model_name)
     model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func()
     rank = gpc.get_local_rank(ParallelMode.PARALLEL_1D)
 
@@ -95,43 +95,66 @@ def run_1d_col_tp():
     with ColoInitContext(device=get_current_device()):
         model = model_builder(checkpoint=True)
 
-    parallel_action_list_row = [
-        ParallelAction(priority=1,
-                       compute_pattern=ComputePattern.TP1DRow_Linear,
-                       parallel_mode=ParallelMode.PARALLEL_1D)
-    ]
-    spec_row = TensorSpec(parallel_action_list_row)
+    if 'bert' == model_name:
+        parallel_action_list_col = [
+            ParallelAction(priority=1,
+                           compute_pattern=ComputePattern.TP1DCol_Linear,
+                           parallel_mode=ParallelMode.PARALLEL_1D)
+        ]
+        spec_col = TensorSpec(parallel_action_list_col)
 
-    parallel_action_list_col = [
-        ParallelAction(priority=1,
-                       compute_pattern=ComputePattern.TP1DCol_Linear,
-                       parallel_mode=ParallelMode.PARALLEL_1D)
-    ]
-    spec_col = TensorSpec(parallel_action_list_col)
+        parallel_action_list_embedding_col = [
+            ParallelAction(priority=1,
+                           compute_pattern=ComputePattern.TP1DCol_Embedding,
+                           parallel_mode=ParallelMode.PARALLEL_1D)
+        ]
+        spec_embedding_col = TensorSpec(parallel_action_list_embedding_col)
 
-    parallel_action_list_embedding_col = [
-        ParallelAction(priority=1,
-                       compute_pattern=ComputePattern.TP1DCol_Embedding,
-                       parallel_mode=ParallelMode.PARALLEL_1D)
-    ]
-    spec_embedding_col = TensorSpec(parallel_action_list_embedding_col)
+        for name, p in model.colo_named_parameters():
+            if not isinstance(p, ColoTensor):
+                continue
+            #print(name)
+            if 'classifier' in name and ('weight' in name or 'bias' in name):
+                p.set_spec(spec_col)
+            if '_embeddings' in name and 'weight' in name:
+                p.set_spec(spec_embedding_col)
+    elif "simple_net" == model_name:
+        parallel_action_list_row = [
+            ParallelAction(priority=1,
+                           compute_pattern=ComputePattern.TP1DRow_Linear,
+                           parallel_mode=ParallelMode.PARALLEL_1D)
+        ]
+        spec_row = TensorSpec(parallel_action_list_row)
+
+        parallel_action_list_col = [
+            ParallelAction(priority=1,
+                           compute_pattern=ComputePattern.TP1DCol_Linear,
+                           parallel_mode=ParallelMode.PARALLEL_1D)
+        ]
+        spec_col = TensorSpec(parallel_action_list_col)
+
+        parallel_action_list_embedding_col = [
+            ParallelAction(priority=1,
+                           compute_pattern=ComputePattern.TP1DCol_Embedding,
+                           parallel_mode=ParallelMode.PARALLEL_1D)
+        ]
+        spec_embedding_col = TensorSpec(parallel_action_list_embedding_col)
+        # A naive way to set spec for all weights in Linear
+        for name, p in model.colo_named_parameters():
+            if not isinstance(p, ColoTensor):
+                continue
+            if 'proj1' in name and ('weight' in name or 'bias' in name):
+                p.set_spec(spec_col)
+            if 'proj2' in name and 'weight' in name:
+                p.set_spec(spec_row)
+            if 'embed' in name and 'weight' in name:
+                p.set_spec(spec_embedding_col)
 
     set_seed(1)
     if rank == 0:
         model_torch = model_builder(checkpoint=True)
         model_torch = model_torch.cuda()
 
-    # A naive way to set spec for all weights in Linear
-    for name, p in model.colo_named_parameters():
-        if not isinstance(p, ColoTensor):
-            continue
-        if 'proj1' in name and ('weight' in name or 'bias' in name):
-            p.set_spec(spec_col)
-        if 'proj2' in name and 'weight' in name:
-            p.set_spec(spec_row)
-        if 'embed' in name and 'weight' in name:
-            p.set_spec(spec_embedding_col)
-
     model = model.cuda()
 
     for i, (data, label) in enumerate(train_dataloader):
@@ -231,9 +254,9 @@ def test_colo_optimizer():
             break
 
 
-def run_1d_row_tp():
+def run_1d_row_tp(model_name: str):
     # A simple net with two stacked nn.Linear
-    get_components_func = non_distributed_component_funcs.get_callable('simple_net')
+    get_components_func = non_distributed_component_funcs.get_callable(model_name)
     model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func()
     rank = gpc.get_local_rank(ParallelMode.PARALLEL_1D)
 
@@ -241,6 +264,11 @@ def run_1d_row_tp():
     with ColoInitContext(device=get_current_device()):
         model = model_builder(checkpoint=True)
 
+    set_seed(1)
+    if rank == 0:
+        model_torch = model_builder(checkpoint=True)
+        model_torch = model_torch.cuda()
+
     parallel_action_list = [
         ParallelAction(priority=1,
                        compute_pattern=ComputePattern.TP1DRow_Linear,
@@ -255,11 +283,6 @@ def run_1d_row_tp():
     ]
     spec_embedding_row = TensorSpec(parallel_action_list_embedding_row)
 
-    set_seed(1)
-    if rank == 0:
-        model_torch = model_builder(checkpoint=True)
-        model_torch = model_torch.cuda()
-
     # A naive way to set spec for all weights in Linear
     for name, p in model.colo_named_parameters():
         if not isinstance(p, ColoTensor):
@@ -307,91 +330,26 @@ def run_1d_row_tp():
         if i > 5:
             break
 
-def run_bert_1d():
-    get_components_func = non_distributed_component_funcs.get_callable('bert')
-    model_builder, train_dataloader, _, optimizer_class, criterion = get_components_func()
-    device = get_current_device()
-    
-    set_seed(1)
-    with ColoInitContext(device=device):
-        model = model_builder(checkpoint=True)
-    
-    # parallel_action_list_row = [
-    #     ParallelAction(priority=1, compute_pattern=ComputePattern.TP1DRow_Linear, parallel_mode=ParallelMode.PARALLEL_1D)
-    # ]
-    # spec_row = TensorSpec(parallel_action_list_row)
-
-    parallel_action_list_col = [
-        ParallelAction(priority=1, compute_pattern=ComputePattern.TP1DCol_Linear, parallel_mode=ParallelMode.PARALLEL_1D)
-    ]
-    spec_col = TensorSpec(parallel_action_list_col)
-
-    parallel_action_list_embedding_col = [
-        ParallelAction(priority=1, compute_pattern=ComputePattern.TP1DCol_Embedding, parallel_mode=ParallelMode.PARALLEL_1D)
-    ]
-    spec_embedding_col = TensorSpec(parallel_action_list_embedding_col)
-
-    for name, p in model.colo_named_parameters():
-        if not isinstance(p, ColoTensor):
-            continue
-        #print(name)
-        if 'classifier' in name and ('weight' in name or 'bias' in name):
-            p.set_spec(spec_col)
-        if '_embeddings' in name and 'weight' in name:
-            p.set_spec(spec_embedding_col)
-    # for name, p in model.colo_named_parameters():
-    #     if not isinstance(p, ColoTensor):
-    #         continue
-    #     print(f"{name}: is_gathered {p.is_gathered()}")
-
-    model = model.cuda()
-
-    for i, (data, label) in enumerate(train_dataloader):
-        if i > 5:
-            break
-        data = data.to(device)
-        label = label.to(device)
-
-        model.train()
-        if criterion:
-            output = model(data)
-            loss = criterion(output, label)
-        else:
-            output = model(data, label)
-            loss = output
-
-        loss.backward()
 
 def run_dist(rank, world_size, port):
     config = dict(parallel=dict(tensor=dict(mode="1d", size=world_size),))
     colossalai.launch(config=config, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
-    run_1d_row_tp()
-    run_1d_col_tp()
+    for name in ['bert', 'simple_net']:
+        run_1d_row_tp(name)
+        run_1d_col_tp(name)
 
-def run_dist_bert(rank, world_size, port):
-    config = dict(parallel=dict(tensor=dict(mode="1d", size=world_size),))
-    colossalai.launch(config=config, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
-    run_bert_1d()
 
 @pytest.mark.dist
-@pytest.mark.parametrize('world_size', [1, 4])
-@rerun_if_address_is_in_use()
-def test_simple_net(world_size):
-    run_func = partial(run_dist, world_size=world_size, port=free_port())
-    mp.spawn(run_func, nprocs=world_size)
-
-@pytest.mark.dist
-#@pytest.mark.parametrize('world_size', [1, 4])
-#Don't really add it to pytest now. After finishing Classifier and Loss, I(jzy) will remove this annotation.
+# FIXME(jzy) world size = 4 will fialed
+# @pytest.mark.parametrize('world_size', [4])
 @parameterize('world_size', [1])
 @rerun_if_address_is_in_use()
-def test_bert(world_size):
-    run_func = partial(run_dist_bert, world_size=world_size, port=free_port())
+def test_model(world_size):
+    run_func = partial(run_dist, world_size=world_size, port=free_port())
     mp.spawn(run_func, nprocs=world_size)
 
 
 if __name__ == '__main__':
-    # test_simple_net()
     # test_model_parameters()
     # test_colo_optimizer()
-    test_bert()
+    test_model()