[hotfix] update test for latest version (#2060)

pull/2071/head
YuliangLiu0306 2 years ago committed by GitHub
parent 19438ea0ef
commit e4293e5077
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -126,12 +126,13 @@ def _node_args_converting(gm: torch.fx.GraphModule, device_mesh: DeviceMesh):
if method in (torch.Tensor.view, torch.Tensor.reshape): if method in (torch.Tensor.view, torch.Tensor.reshape):
for arg in node.args: for arg in node.args:
if isinstance(arg, Node): if isinstance(arg, Node):
if isinstance(arg._meta_data, int): if isinstance(arg._meta_data, (int, tuple, list)):
new_args.append(arg._meta_data) new_args.append(arg._meta_data)
else: else:
new_args.append(arg) new_args.append(arg)
else: else:
assert isinstance(arg, int), 'The argument in view node should be either type of Node or int.' assert isinstance(
arg, (int, tuple, list)), 'The argument in view node should be either type of Node or int.'
new_args.append(arg) new_args.append(arg)
for dim, shard_dims in output_dim_partition_dict.items(): for dim, shard_dims in output_dim_partition_dict.items():

@ -102,12 +102,12 @@ def check_linear_module_handler(rank, bias, world_size, port):
assert len(strategy_name_list) > 8 assert len(strategy_name_list) > 8
# SS = SR x RS # SS = SR x RS
assert 'S0S1 = S0R x RS1' in strategy_name_list assert 'S0S1 = S0R x RS1_0' in strategy_name_list
assert 'S1S0 = S1R x RS0' in strategy_name_list assert 'S1S0 = S1R x RS0_0' in strategy_name_list
# SR = SS x SR # SR = SS x SR
assert 'S0R = S0S1 x S1R' in strategy_name_list assert 'S0R = S0S1 x S1R_0' in strategy_name_list
assert 'S1R = S1S0 x S0R' in strategy_name_list assert 'S1R = S1S0 x S0R_0' in strategy_name_list
# RS = RS x SS # RS = RS x SS
assert 'RS0 = RS1 x S1S0' in strategy_name_list assert 'RS0 = RS1 x S1S0' in strategy_name_list

@ -95,12 +95,12 @@ def check_linear_module_handler(rank, bias, world_size, port):
assert len(strategy_name_list) > 8 assert len(strategy_name_list) > 8
# SS = SR x RS # SS = SR x RS
assert 'S0S1 = S0R x RS1' in strategy_name_list assert 'S0S1 = S0R x RS1_0' in strategy_name_list
assert 'S1S0 = S1R x RS0' in strategy_name_list assert 'S1S0 = S1R x RS0_0' in strategy_name_list
# SR = SS x SR # SR = SS x SR
assert 'S0R = S0S1 x S1R' in strategy_name_list assert 'S0R = S0S1 x S1R_0' in strategy_name_list
assert 'S1R = S1S0 x S0R' in strategy_name_list assert 'S1R = S1S0 x S0R_0' in strategy_name_list
# RS = RS x SS # RS = RS x SS
assert 'RS0 = RS1 x S1S0' in strategy_name_list assert 'RS0 = RS1 x S1S0' in strategy_name_list
@ -212,12 +212,12 @@ def check_linear_function_handler(rank, bias, world_size, port):
assert len(strategy_name_list) > 8 assert len(strategy_name_list) > 8
# SS = SR x RS # SS = SR x RS
assert 'S0S1 = S0R x RS1' in strategy_name_list assert 'S0S1 = S0R x RS1_0' in strategy_name_list
assert 'S1S0 = S1R x RS0' in strategy_name_list assert 'S1S0 = S1R x RS0_0' in strategy_name_list
# SR = SS x SR # SR = SS x SR
assert 'S0R = S0S1 x S1R' in strategy_name_list assert 'S0R = S0S1 x S1R_0' in strategy_name_list
assert 'S1R = S1S0 x S0R' in strategy_name_list assert 'S1R = S1S0 x S0R_0' in strategy_name_list
# RS = RS x SS # RS = RS x SS
assert 'RS0 = RS1 x S1S0' in strategy_name_list assert 'RS0 = RS1 x S1S0' in strategy_name_list

Loading…
Cancel
Save