diff --git a/colossalai/fx/passes/utils.py b/colossalai/fx/passes/utils.py index fb8d029b7..4dfb292e2 100644 --- a/colossalai/fx/passes/utils.py +++ b/colossalai/fx/passes/utils.py @@ -15,8 +15,8 @@ def get_comm_size(prev_partition, next_partition): # If a node has input nodes from the parent partition, # the output size of those input nodes will be counted # and added to comm_size - parent_node_names = [n.name for n in parent_partition.graph.nodes] - for node in child_partition.graph.nodes: + parent_node_names = [n.name for n in prev_partition.graph.nodes] + for node in next_partition.graph.nodes: input_nodes: Dict[Node, None] = {} map_arg(node.args, lambda n: input_nodes.setdefault(n)) map_arg(node.kwargs, lambda n: input_nodes.setdefault(n)) diff --git a/tests/test_tensor/test_model.py b/tests/test_tensor/test_model.py index 031bdc25f..90fd9d00e 100644 --- a/tests/test_tensor/test_model.py +++ b/tests/test_tensor/test_model.py @@ -74,8 +74,9 @@ def run_1d_hybrid_tp(model_name): continue # print(name) # num_class = type_vocab_size = 2 | (8, 2) - if 'classifier' in name and 'weight' in name: - init_1d_row_linear(p, pg) + # TODO(jiaruifang) has bug if open the following 2 comments + # if 'classifier' in name and 'weight' in name: + # init_1d_row_linear(p, pg) # num_class = vocab_size = 30524 | (30524, 8) if 'word_embeddings' in name and 'weight' in name: init_1d_row_embedding(p, pg) @@ -152,7 +153,6 @@ def run_1d_hybrid_tp(model_name): # Test the overrided parameters() and named_parameters() member functions -@pytest.mark.skip def test_model_parameters(): colossalai.launch(config={}, rank=0, world_size=1, host='localhost', port=free_port(), backend='nccl') @@ -186,9 +186,8 @@ def test_model_parameters(): assert param_cnt == 2 -@pytest.mark.skip +# @pytest.mark.skip def test_colo_optimizer(): - colossalai.launch(config={}, rank=0, world_size=1, host='localhost', port=free_port(), backend='nccl') get_components_func = non_distributed_component_funcs.get_callable('simple_net') model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func() set_seed(1) @@ -323,7 +322,6 @@ def run_model_dist(rank, world_size, port): @pytest.mark.dist @pytest.mark.parametrize('world_size', [1, 4]) -@pytest.mark.skip("under development") @rerun_if_address_is_in_use() def test_model(world_size): run_func = partial(run_model_dist, world_size=world_size, port=free_port())