|
|
|
@ -33,11 +33,11 @@ def check_linear_module_handler(rank, bias, world_size, port):
|
|
|
|
|
physical_mesh_id = torch.arange(0, 4) |
|
|
|
|
mesh_shape = (2, 2) |
|
|
|
|
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True) |
|
|
|
|
input = torch.rand(2, 2, 4, 16).cuda() |
|
|
|
|
input = torch.rand(4, 4, 4, 16).cuda() |
|
|
|
|
# the index of linear node in computation graph |
|
|
|
|
node_index = 1 |
|
|
|
|
# strategy number of linear node |
|
|
|
|
strategy_number = 10 |
|
|
|
|
strategy_number = 24 |
|
|
|
|
# construct input args |
|
|
|
|
input_args = [input] |
|
|
|
|
# construct meta arg names |
|
|
|
@ -50,7 +50,7 @@ def check_linear_module_handler(rank, bias, world_size, port):
|
|
|
|
|
meta_arg_names=meta_arg_names) |
|
|
|
|
|
|
|
|
|
tracer = ColoTracer() |
|
|
|
|
graph = tracer.trace(model, meta_args={"input": torch.rand(2, 2, 4, 16).to('meta')}) |
|
|
|
|
graph = tracer.trace(model, meta_args={"input": torch.rand(4, 4, 4, 16).to('meta')}) |
|
|
|
|
gm = ColoGraphModule(model, graph) |
|
|
|
|
|
|
|
|
|
linear_mod_node = list(graph.nodes)[1] |
|
|
|
@ -69,9 +69,9 @@ def check_linear_module_handler(rank, bias, world_size, port):
|
|
|
|
|
assert op_data.data is not None |
|
|
|
|
|
|
|
|
|
assert mapping['input'].name == "input_1" |
|
|
|
|
assert mapping['input'].data.shape == torch.Size([2, 2, 4, 16]) |
|
|
|
|
assert mapping['input'].data.shape == torch.Size([4, 4, 4, 16]) |
|
|
|
|
assert mapping['input'].type == OperationDataType.ARG |
|
|
|
|
assert mapping['input'].logical_shape == torch.Size([16, 16]) |
|
|
|
|
assert mapping['input'].logical_shape == torch.Size([64, 16]) |
|
|
|
|
|
|
|
|
|
assert mapping['other'].name == "weight" |
|
|
|
|
assert mapping['other'].data.shape == torch.Size([32, 16]) |
|
|
|
@ -85,9 +85,9 @@ def check_linear_module_handler(rank, bias, world_size, port):
|
|
|
|
|
assert mapping['bias'].logical_shape == torch.Size([32]) |
|
|
|
|
|
|
|
|
|
assert mapping['output'].name == "_0" |
|
|
|
|
assert mapping['output'].data.shape == torch.Size([2, 2, 4, 32]) |
|
|
|
|
assert mapping['output'].data.shape == torch.Size([4, 4, 4, 32]) |
|
|
|
|
assert mapping['output'].type == OperationDataType.OUTPUT |
|
|
|
|
assert mapping['output'].logical_shape == torch.Size([16, 32]) |
|
|
|
|
assert mapping['output'].logical_shape == torch.Size([64, 32]) |
|
|
|
|
|
|
|
|
|
strategies_vector = handler.register_strategy(compute_resharding_cost=False) |
|
|
|
|
strategy_name_list = [val.name for val in strategies_vector] |
|
|
|
@ -96,11 +96,19 @@ def check_linear_module_handler(rank, bias, world_size, port):
|
|
|
|
|
|
|
|
|
|
# SS = SR x RS |
|
|
|
|
assert 'S0S1 = S0R x RS1_0' in strategy_name_list |
|
|
|
|
assert 'S0S1 = S0R x RS1_1' in strategy_name_list |
|
|
|
|
assert 'S0S1 = S0R x RS1_2' in strategy_name_list |
|
|
|
|
assert 'S1S0 = S1R x RS0_0' in strategy_name_list |
|
|
|
|
assert 'S1S0 = S1R x RS0_1' in strategy_name_list |
|
|
|
|
assert 'S1S0 = S1R x RS0_2' in strategy_name_list |
|
|
|
|
|
|
|
|
|
# SR = SS x SR |
|
|
|
|
assert 'S0R = S0S1 x S1R_0' in strategy_name_list |
|
|
|
|
assert 'S0R = S0S1 x S1R_1' in strategy_name_list |
|
|
|
|
assert 'S0R = S0S1 x S1R_2' in strategy_name_list |
|
|
|
|
assert 'S1R = S1S0 x S0R_0' in strategy_name_list |
|
|
|
|
assert 'S1R = S1S0 x S0R_1' in strategy_name_list |
|
|
|
|
assert 'S1R = S1S0 x S0R_2' in strategy_name_list |
|
|
|
|
|
|
|
|
|
# RS = RS x SS |
|
|
|
|
assert 'RS0 = RS1 x S1S0' in strategy_name_list |
|
|
|
@ -114,6 +122,20 @@ def check_linear_module_handler(rank, bias, world_size, port):
|
|
|
|
|
assert 'RS0 = RR x RS0' in strategy_name_list |
|
|
|
|
assert 'RS1 = RR x RS1' in strategy_name_list |
|
|
|
|
|
|
|
|
|
# S01R = S01R x RR |
|
|
|
|
assert 'S01R = S01R x RR_0' in strategy_name_list |
|
|
|
|
assert 'S01R = S01R x RR_1' in strategy_name_list |
|
|
|
|
assert 'S01R = S01R x RR_2' in strategy_name_list |
|
|
|
|
|
|
|
|
|
# RR = RS01 x S01R |
|
|
|
|
assert 'RR = RS01 x S01R' in strategy_name_list |
|
|
|
|
|
|
|
|
|
# RS01 = RR x RS01 |
|
|
|
|
assert 'RS01 = RR x RS01' in strategy_name_list |
|
|
|
|
|
|
|
|
|
# RR = RR x RR |
|
|
|
|
assert 'RR = RR x RR' in strategy_name_list |
|
|
|
|
|
|
|
|
|
for strategy in strategies_vector: |
|
|
|
|
strategy: ShardingStrategy |
|
|
|
|
input_sharding_spec = strategy.get_sharding_spec_by_name('input_1') |
|
|
|
@ -150,12 +172,12 @@ def check_linear_function_handler(rank, bias, world_size, port):
|
|
|
|
|
mesh_shape = (2, 2) |
|
|
|
|
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True) |
|
|
|
|
|
|
|
|
|
input = torch.rand(2, 2, 4, 16).cuda() |
|
|
|
|
input = torch.rand(4, 4, 4, 16).cuda() |
|
|
|
|
other = torch.rand(32, 16).cuda() |
|
|
|
|
# the index of linear node in computation graph |
|
|
|
|
node_index = 2 |
|
|
|
|
# strategy number of linear node |
|
|
|
|
strategy_number = 10 |
|
|
|
|
strategy_number = 24 |
|
|
|
|
# construct input args |
|
|
|
|
input_args = [input, other] |
|
|
|
|
# construct meta arg names |
|
|
|
@ -170,7 +192,7 @@ def check_linear_function_handler(rank, bias, world_size, port):
|
|
|
|
|
tracer = ColoTracer() |
|
|
|
|
graph = tracer.trace(model, |
|
|
|
|
meta_args={ |
|
|
|
|
"input": torch.rand(2, 2, 4, 16).to('meta'), |
|
|
|
|
"input": torch.rand(4, 4, 4, 16).to('meta'), |
|
|
|
|
'others': torch.rand(32, 16).to('meta') |
|
|
|
|
}) |
|
|
|
|
gm = ColoGraphModule(model, graph) |
|
|
|
@ -187,9 +209,9 @@ def check_linear_function_handler(rank, bias, world_size, port):
|
|
|
|
|
mapping = handler.get_operation_data_mapping() |
|
|
|
|
|
|
|
|
|
assert mapping['input'].name == "input_1" |
|
|
|
|
assert mapping['input'].data.shape == torch.Size([2, 2, 4, 16]) |
|
|
|
|
assert mapping['input'].data.shape == torch.Size([4, 4, 4, 16]) |
|
|
|
|
assert mapping['input'].type == OperationDataType.ARG |
|
|
|
|
assert mapping['input'].logical_shape == torch.Size([16, 16]) |
|
|
|
|
assert mapping['input'].logical_shape == torch.Size([64, 16]) |
|
|
|
|
|
|
|
|
|
assert mapping['other'].name == "others" |
|
|
|
|
assert mapping['other'].data.shape == torch.Size([32, 16]) |
|
|
|
@ -203,7 +225,7 @@ def check_linear_function_handler(rank, bias, world_size, port):
|
|
|
|
|
assert mapping['other'].logical_shape == torch.Size([16, 32]) |
|
|
|
|
|
|
|
|
|
assert mapping['output'].name == "linear" |
|
|
|
|
assert mapping['output'].data.shape == torch.Size([2, 2, 4, 32]) |
|
|
|
|
assert mapping['output'].data.shape == torch.Size([4, 4, 4, 32]) |
|
|
|
|
assert mapping['output'].type == OperationDataType.OUTPUT |
|
|
|
|
|
|
|
|
|
strategies_vector = handler.register_strategy(compute_resharding_cost=False) |
|
|
|
@ -213,11 +235,19 @@ def check_linear_function_handler(rank, bias, world_size, port):
|
|
|
|
|
|
|
|
|
|
# SS = SR x RS |
|
|
|
|
assert 'S0S1 = S0R x RS1_0' in strategy_name_list |
|
|
|
|
assert 'S0S1 = S0R x RS1_1' in strategy_name_list |
|
|
|
|
assert 'S0S1 = S0R x RS1_2' in strategy_name_list |
|
|
|
|
assert 'S1S0 = S1R x RS0_0' in strategy_name_list |
|
|
|
|
assert 'S1S0 = S1R x RS0_1' in strategy_name_list |
|
|
|
|
assert 'S1S0 = S1R x RS0_2' in strategy_name_list |
|
|
|
|
|
|
|
|
|
# SR = SS x SR |
|
|
|
|
assert 'S0R = S0S1 x S1R_0' in strategy_name_list |
|
|
|
|
assert 'S0R = S0S1 x S1R_1' in strategy_name_list |
|
|
|
|
assert 'S0R = S0S1 x S1R_2' in strategy_name_list |
|
|
|
|
assert 'S1R = S1S0 x S0R_0' in strategy_name_list |
|
|
|
|
assert 'S1R = S1S0 x S0R_1' in strategy_name_list |
|
|
|
|
assert 'S1R = S1S0 x S0R_2' in strategy_name_list |
|
|
|
|
|
|
|
|
|
# RS = RS x SS |
|
|
|
|
assert 'RS0 = RS1 x S1S0' in strategy_name_list |
|
|
|
@ -231,6 +261,20 @@ def check_linear_function_handler(rank, bias, world_size, port):
|
|
|
|
|
assert 'RS0 = RR x RS0' in strategy_name_list |
|
|
|
|
assert 'RS1 = RR x RS1' in strategy_name_list |
|
|
|
|
|
|
|
|
|
# S01R = S01R x RR |
|
|
|
|
assert 'S01R = S01R x RR_0' in strategy_name_list |
|
|
|
|
assert 'S01R = S01R x RR_1' in strategy_name_list |
|
|
|
|
assert 'S01R = S01R x RR_2' in strategy_name_list |
|
|
|
|
|
|
|
|
|
# RR = RS01 x S01R |
|
|
|
|
assert 'RR = RS01 x S01R' in strategy_name_list |
|
|
|
|
|
|
|
|
|
# RS01 = RR x RS01 |
|
|
|
|
assert 'RS01 = RR x RS01' in strategy_name_list |
|
|
|
|
|
|
|
|
|
# RR = RR x RR |
|
|
|
|
assert 'RR = RR x RR' in strategy_name_list |
|
|
|
|
|
|
|
|
|
for strategy in strategies_vector: |
|
|
|
|
strategy: ShardingStrategy |
|
|
|
|
input_sharding_spec = strategy.get_sharding_spec_by_name('input_1') |
|
|
|
|