Browse Source

[hotfix] update test for latest version (#2060)

pull/2071/head
YuliangLiu0306 2 years ago committed by GitHub
parent
commit
e4293e5077
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
  1. 5
      colossalai/auto_parallel/passes/runtime_preparation_pass.py
  2. 8
      tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_bias_linear_module_node.py
  3. 16
      tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_linear_handler.py

5
colossalai/auto_parallel/passes/runtime_preparation_pass.py

@ -126,12 +126,13 @@ def _node_args_converting(gm: torch.fx.GraphModule, device_mesh: DeviceMesh):
if method in (torch.Tensor.view, torch.Tensor.reshape):
for arg in node.args:
if isinstance(arg, Node):
if isinstance(arg._meta_data, int):
if isinstance(arg._meta_data, (int, tuple, list)):
new_args.append(arg._meta_data)
else:
new_args.append(arg)
else:
assert isinstance(arg, int), 'The argument in view node should be either type of Node or int.'
assert isinstance(
arg, (int, tuple, list)), 'The argument in view node should be either type of Node or int.'
new_args.append(arg)
for dim, shard_dims in output_dim_partition_dict.items():

8
tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_bias_linear_module_node.py

@ -102,12 +102,12 @@ def check_linear_module_handler(rank, bias, world_size, port):
assert len(strategy_name_list) > 8
# SS = SR x RS
assert 'S0S1 = S0R x RS1' in strategy_name_list
assert 'S1S0 = S1R x RS0' in strategy_name_list
assert 'S0S1 = S0R x RS1_0' in strategy_name_list
assert 'S1S0 = S1R x RS0_0' in strategy_name_list
# SR = SS x SR
assert 'S0R = S0S1 x S1R' in strategy_name_list
assert 'S1R = S1S0 x S0R' in strategy_name_list
assert 'S0R = S0S1 x S1R_0' in strategy_name_list
assert 'S1R = S1S0 x S0R_0' in strategy_name_list
# RS = RS x SS
assert 'RS0 = RS1 x S1S0' in strategy_name_list

16
tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_linear_handler.py

@ -95,12 +95,12 @@ def check_linear_module_handler(rank, bias, world_size, port):
assert len(strategy_name_list) > 8
# SS = SR x RS
assert 'S0S1 = S0R x RS1' in strategy_name_list
assert 'S1S0 = S1R x RS0' in strategy_name_list
assert 'S0S1 = S0R x RS1_0' in strategy_name_list
assert 'S1S0 = S1R x RS0_0' in strategy_name_list
# SR = SS x SR
assert 'S0R = S0S1 x S1R' in strategy_name_list
assert 'S1R = S1S0 x S0R' in strategy_name_list
assert 'S0R = S0S1 x S1R_0' in strategy_name_list
assert 'S1R = S1S0 x S0R_0' in strategy_name_list
# RS = RS x SS
assert 'RS0 = RS1 x S1S0' in strategy_name_list
@ -212,12 +212,12 @@ def check_linear_function_handler(rank, bias, world_size, port):
assert len(strategy_name_list) > 8
# SS = SR x RS
assert 'S0S1 = S0R x RS1' in strategy_name_list
assert 'S1S0 = S1R x RS0' in strategy_name_list
assert 'S0S1 = S0R x RS1_0' in strategy_name_list
assert 'S1S0 = S1R x RS0_0' in strategy_name_list
# SR = SS x SR
assert 'S0R = S0S1 x S1R' in strategy_name_list
assert 'S1R = S1S0 x S0R' in strategy_name_list
assert 'S0R = S0S1 x S1R_0' in strategy_name_list
assert 'S1R = S1S0 x S0R_0' in strategy_name_list
# RS = RS x SS
assert 'RS0 = RS1 x S1S0' in strategy_name_list

Loading…
Cancel
Save