mirror of https://github.com/hpcaitech/ColossalAI
[hotfix] update test for latest version (#2060)
parent
19438ea0ef
commit
e4293e5077
|
@ -126,12 +126,13 @@ def _node_args_converting(gm: torch.fx.GraphModule, device_mesh: DeviceMesh):
|
|||
if method in (torch.Tensor.view, torch.Tensor.reshape):
|
||||
for arg in node.args:
|
||||
if isinstance(arg, Node):
|
||||
if isinstance(arg._meta_data, int):
|
||||
if isinstance(arg._meta_data, (int, tuple, list)):
|
||||
new_args.append(arg._meta_data)
|
||||
else:
|
||||
new_args.append(arg)
|
||||
else:
|
||||
assert isinstance(arg, int), 'The argument in view node should be either type of Node or int.'
|
||||
assert isinstance(
|
||||
arg, (int, tuple, list)), 'The argument in view node should be either type of Node or int.'
|
||||
new_args.append(arg)
|
||||
|
||||
for dim, shard_dims in output_dim_partition_dict.items():
|
||||
|
|
|
@ -102,12 +102,12 @@ def check_linear_module_handler(rank, bias, world_size, port):
|
|||
assert len(strategy_name_list) > 8
|
||||
|
||||
# SS = SR x RS
|
||||
assert 'S0S1 = S0R x RS1' in strategy_name_list
|
||||
assert 'S1S0 = S1R x RS0' in strategy_name_list
|
||||
assert 'S0S1 = S0R x RS1_0' in strategy_name_list
|
||||
assert 'S1S0 = S1R x RS0_0' in strategy_name_list
|
||||
|
||||
# SR = SS x SR
|
||||
assert 'S0R = S0S1 x S1R' in strategy_name_list
|
||||
assert 'S1R = S1S0 x S0R' in strategy_name_list
|
||||
assert 'S0R = S0S1 x S1R_0' in strategy_name_list
|
||||
assert 'S1R = S1S0 x S0R_0' in strategy_name_list
|
||||
|
||||
# RS = RS x SS
|
||||
assert 'RS0 = RS1 x S1S0' in strategy_name_list
|
||||
|
|
|
@ -95,12 +95,12 @@ def check_linear_module_handler(rank, bias, world_size, port):
|
|||
assert len(strategy_name_list) > 8
|
||||
|
||||
# SS = SR x RS
|
||||
assert 'S0S1 = S0R x RS1' in strategy_name_list
|
||||
assert 'S1S0 = S1R x RS0' in strategy_name_list
|
||||
assert 'S0S1 = S0R x RS1_0' in strategy_name_list
|
||||
assert 'S1S0 = S1R x RS0_0' in strategy_name_list
|
||||
|
||||
# SR = SS x SR
|
||||
assert 'S0R = S0S1 x S1R' in strategy_name_list
|
||||
assert 'S1R = S1S0 x S0R' in strategy_name_list
|
||||
assert 'S0R = S0S1 x S1R_0' in strategy_name_list
|
||||
assert 'S1R = S1S0 x S0R_0' in strategy_name_list
|
||||
|
||||
# RS = RS x SS
|
||||
assert 'RS0 = RS1 x S1S0' in strategy_name_list
|
||||
|
@ -212,12 +212,12 @@ def check_linear_function_handler(rank, bias, world_size, port):
|
|||
assert len(strategy_name_list) > 8
|
||||
|
||||
# SS = SR x RS
|
||||
assert 'S0S1 = S0R x RS1' in strategy_name_list
|
||||
assert 'S1S0 = S1R x RS0' in strategy_name_list
|
||||
assert 'S0S1 = S0R x RS1_0' in strategy_name_list
|
||||
assert 'S1S0 = S1R x RS0_0' in strategy_name_list
|
||||
|
||||
# SR = SS x SR
|
||||
assert 'S0R = S0S1 x S1R' in strategy_name_list
|
||||
assert 'S1R = S1S0 x S0R' in strategy_name_list
|
||||
assert 'S0R = S0S1 x S1R_0' in strategy_name_list
|
||||
assert 'S1R = S1S0 x S0R_0' in strategy_name_list
|
||||
|
||||
# RS = RS x SS
|
||||
assert 'RS0 = RS1 x S1S0' in strategy_name_list
|
||||
|
|
Loading…
Reference in New Issue