|
|
|
@ -95,12 +95,12 @@ def check_linear_module_handler(rank, bias, world_size, port):
|
|
|
|
|
assert len(strategy_name_list) > 8 |
|
|
|
|
|
|
|
|
|
# SS = SR x RS |
|
|
|
|
assert 'S0S1 = S0R x RS1' in strategy_name_list |
|
|
|
|
assert 'S1S0 = S1R x RS0' in strategy_name_list |
|
|
|
|
assert 'S0S1 = S0R x RS1_0' in strategy_name_list |
|
|
|
|
assert 'S1S0 = S1R x RS0_0' in strategy_name_list |
|
|
|
|
|
|
|
|
|
# SR = SS x SR |
|
|
|
|
assert 'S0R = S0S1 x S1R' in strategy_name_list |
|
|
|
|
assert 'S1R = S1S0 x S0R' in strategy_name_list |
|
|
|
|
assert 'S0R = S0S1 x S1R_0' in strategy_name_list |
|
|
|
|
assert 'S1R = S1S0 x S0R_0' in strategy_name_list |
|
|
|
|
|
|
|
|
|
# RS = RS x SS |
|
|
|
|
assert 'RS0 = RS1 x S1S0' in strategy_name_list |
|
|
|
@ -212,12 +212,12 @@ def check_linear_function_handler(rank, bias, world_size, port):
|
|
|
|
|
assert len(strategy_name_list) > 8 |
|
|
|
|
|
|
|
|
|
# SS = SR x RS |
|
|
|
|
assert 'S0S1 = S0R x RS1' in strategy_name_list |
|
|
|
|
assert 'S1S0 = S1R x RS0' in strategy_name_list |
|
|
|
|
assert 'S0S1 = S0R x RS1_0' in strategy_name_list |
|
|
|
|
assert 'S1S0 = S1R x RS0_0' in strategy_name_list |
|
|
|
|
|
|
|
|
|
# SR = SS x SR |
|
|
|
|
assert 'S0R = S0S1 x S1R' in strategy_name_list |
|
|
|
|
assert 'S1R = S1S0 x S0R' in strategy_name_list |
|
|
|
|
assert 'S0R = S0S1 x S1R_0' in strategy_name_list |
|
|
|
|
assert 'S1R = S1S0 x S0R_0' in strategy_name_list |
|
|
|
|
|
|
|
|
|
# RS = RS x SS |
|
|
|
|
assert 'RS0 = RS1 x S1S0' in strategy_name_list |
|
|
|
|