From 79fe7b027ac543db241b5036f6cb6c0d5c8e1708 Mon Sep 17 00:00:00 2001 From: Jiarui Fang Date: Tue, 12 Jul 2022 23:45:29 +0800 Subject: [PATCH] [hotfix] test model unittest hotfix (#1281) --- tests/test_tensor/test_model.py | 17 +++++++++-------- 1 file changed, 9 insertions(+), 8 deletions(-) diff --git a/tests/test_tensor/test_model.py b/tests/test_tensor/test_model.py index b44e7af01..3431336cb 100644 --- a/tests/test_tensor/test_model.py +++ b/tests/test_tensor/test_model.py @@ -12,7 +12,7 @@ from colossalai.utils.cuda import get_current_device from colossalai.utils import free_port from colossalai.utils.model.colo_init_context import ColoInitContext from colossalai.tensor import ShardSpec, ColoTensorSpec, ComputePattern, \ - ComputeSpec, ColoTensor, DistSpecManager, ProcessGroup + ComputeSpec, ColoTensor, DistSpecManager, ProcessGroup, ReplicaSpec from colossalai.nn.optimizer import ColoOptimizer from tests.components_to_test.registry import non_distributed_component_funcs @@ -76,22 +76,23 @@ def run_1d_hybrid_tp(model_name): for name, p in model.named_parameters(): if not isinstance(p, ColoTensor): continue - # print(name) + # num_class = type_vocab_size = 2 | (8, 2) - # TODO(jiaruifang) has bug if open the following 2 comments if 'classifier' in name and 'weight' in name: init_1d_row_linear(p, pg) # num_class = vocab_size = 30524 | (30524, 8) - if 'word_embeddings' in name and 'weight' in name: + elif 'word_embeddings' in name and 'weight' in name: init_1d_row_embedding(p, pg) # num_class = seq_len = 512 | (512, 8) - if 'position_embeddings' in name and 'weight' in name: + elif 'position_embeddings' in name and 'weight' in name: init_1d_row_embedding(p, pg) # num_class = type_vocab_size = 2 | (2, 8) - if 'token_type_embeddings' in name and 'weight' in name: + elif 'token_type_embeddings' in name and 'weight' in name: init_1d_col_embedding(p, pg) - if p.process_group.tp_world_size() == 1: - p.set_process_group(pg) + elif p.process_group.tp_world_size() == 1: + with DistSpecManager.no_grad(): + p.redistribute(ReplicaSpec(), pg) + elif "simple_net" == model_name: # A naive way to set spec for all weights in Linear for name, p in model.named_parameters():