diff --git a/colossalai/auto_parallel/passes/runtime_preparation_pass.py b/colossalai/auto_parallel/passes/runtime_preparation_pass.py index b6c1fc5c5..29b6a6db6 100644 --- a/colossalai/auto_parallel/passes/runtime_preparation_pass.py +++ b/colossalai/auto_parallel/passes/runtime_preparation_pass.py @@ -126,12 +126,13 @@ def _node_args_converting(gm: torch.fx.GraphModule, device_mesh: DeviceMesh): if method in (torch.Tensor.view, torch.Tensor.reshape): for arg in node.args: if isinstance(arg, Node): - if isinstance(arg._meta_data, int): + if isinstance(arg._meta_data, (int, tuple, list)): new_args.append(arg._meta_data) else: new_args.append(arg) else: - assert isinstance(arg, int), 'The argument in view node should be either type of Node or int.' + assert isinstance( + arg, (int, tuple, list)), 'The argument in view node should be either type of Node or int.' new_args.append(arg) for dim, shard_dims in output_dim_partition_dict.items(): diff --git a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_bias_linear_module_node.py b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_bias_linear_module_node.py index 1bc556209..6c788b60e 100644 --- a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_bias_linear_module_node.py +++ b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_bias_linear_module_node.py @@ -102,12 +102,12 @@ def check_linear_module_handler(rank, bias, world_size, port): assert len(strategy_name_list) > 8 # SS = SR x RS - assert 'S0S1 = S0R x RS1' in strategy_name_list - assert 'S1S0 = S1R x RS0' in strategy_name_list + assert 'S0S1 = S0R x RS1_0' in strategy_name_list + assert 'S1S0 = S1R x RS0_0' in strategy_name_list # SR = SS x SR - assert 'S0R = S0S1 x S1R' in strategy_name_list - assert 'S1R = S1S0 x S0R' in strategy_name_list + assert 'S0R = S0S1 x S1R_0' in strategy_name_list + assert 'S1R = S1S0 x S0R_0' in strategy_name_list # RS = RS x SS assert 'RS0 = RS1 x S1S0' in strategy_name_list diff --git a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_linear_handler.py b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_linear_handler.py index acb12eec0..5e9061568 100644 --- a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_linear_handler.py +++ b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_linear_handler.py @@ -95,12 +95,12 @@ def check_linear_module_handler(rank, bias, world_size, port): assert len(strategy_name_list) > 8 # SS = SR x RS - assert 'S0S1 = S0R x RS1' in strategy_name_list - assert 'S1S0 = S1R x RS0' in strategy_name_list + assert 'S0S1 = S0R x RS1_0' in strategy_name_list + assert 'S1S0 = S1R x RS0_0' in strategy_name_list # SR = SS x SR - assert 'S0R = S0S1 x S1R' in strategy_name_list - assert 'S1R = S1S0 x S0R' in strategy_name_list + assert 'S0R = S0S1 x S1R_0' in strategy_name_list + assert 'S1R = S1S0 x S0R_0' in strategy_name_list # RS = RS x SS assert 'RS0 = RS1 x S1S0' in strategy_name_list @@ -212,12 +212,12 @@ def check_linear_function_handler(rank, bias, world_size, port): assert len(strategy_name_list) > 8 # SS = SR x RS - assert 'S0S1 = S0R x RS1' in strategy_name_list - assert 'S1S0 = S1R x RS0' in strategy_name_list + assert 'S0S1 = S0R x RS1_0' in strategy_name_list + assert 'S1S0 = S1R x RS0_0' in strategy_name_list # SR = SS x SR - assert 'S0R = S0S1 x S1R' in strategy_name_list - assert 'S1R = S1S0 x S0R' in strategy_name_list + assert 'S0R = S0S1 x S1R_0' in strategy_name_list + assert 'S1R = S1S0 x S0R_0' in strategy_name_list # RS = RS x SS assert 'RS0 = RS1 x S1S0' in strategy_name_list