From cdf537a648df68d65c5427ecf3ecce58d5fe0ed2 Mon Sep 17 00:00:00 2001 From: YuliangLiu0306 <72588413+YuliangLiu0306@users.noreply.github.com> Date: Tue, 6 Dec 2022 10:19:33 +0800 Subject: [PATCH] [autoparallel] add non_split linear strategy (#2078) * [autoparallel] add non_split linear stategy * polish --- .../strategy/matmul_strategy_generator.py | 26 +++++++ .../tensor_shard/sharding_strategy.py | 12 +++- .../test_bias_linear_module_node.py | 36 +++++++--- .../test_node_handler/test_linear_handler.py | 70 +++++++++++++++---- 4 files changed, 120 insertions(+), 24 deletions(-) diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/strategy/matmul_strategy_generator.py b/colossalai/auto_parallel/tensor_shard/node_handler/strategy/matmul_strategy_generator.py index 043bb8654..fa2246f95 100644 --- a/colossalai/auto_parallel/tensor_shard/node_handler/strategy/matmul_strategy_generator.py +++ b/colossalai/auto_parallel/tensor_shard/node_handler/strategy/matmul_strategy_generator.py @@ -263,6 +263,9 @@ class LinearProjectionStrategyGenerator(MatMulStrategyGenerator): # RS01 = RR x RS01 strategies.append(self.split_rhs_2nd_dim_1d(0, 1)) + # RR = RR x RR + strategies.append(self.non_split()) + return strategies @ignore_sharding_exception @@ -665,6 +668,29 @@ class LinearProjectionStrategyGenerator(MatMulStrategyGenerator): sharding_spec_mapping=sharding_spec_mapping, communication_action_mapping=communication_action_mapping) + @ignore_sharding_exception + def non_split(self): + name = f'RR = RR x RR' + + # get sharding spec + dim_partition_dict_mapping = { + "input": {}, + "other": {}, + "bias": {}, + "output": {}, + } + + # We don't have to do anything special for bias here, because + # the bias is already the same sharding spec as the output. + sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping) + + # get communication action + communication_action_mapping = {} + + return self.get_sharding_strategy(name=name, + sharding_spec_mapping=sharding_spec_mapping, + communication_action_mapping=communication_action_mapping) + def validate(self) -> bool: assert "input" in self.op_data assert "other" in self.op_data diff --git a/colossalai/auto_parallel/tensor_shard/sharding_strategy.py b/colossalai/auto_parallel/tensor_shard/sharding_strategy.py index d40988250..4929e09ad 100644 --- a/colossalai/auto_parallel/tensor_shard/sharding_strategy.py +++ b/colossalai/auto_parallel/tensor_shard/sharding_strategy.py @@ -204,9 +204,15 @@ class ShardingStrategy: def _deepcopy_dict_vals(data: Dict): return {k: deepcopy(v) for k, v in data.items()} - sharding_specs = _deepcopy_dict_vals(self.sharding_specs) if self.sharding_specs else None - communication_actions = _deepcopy_dict_vals(self.communication_actions) if self.communication_actions else None - resharding_costs = _deepcopy_dict_vals(self.resharding_costs) if self.resharding_costs else None + sharding_specs = _deepcopy_dict_vals(self.sharding_specs) if self.sharding_specs is not None else None + # We need to deepcopy it when self.communication_actions is not None, instead of checking its __bool__ value. + # Consider the examples below: + # If self.communication_actions is an empty dictionary {}, then self.communication_actions is not None, but its __bool__ value is False. + # In this case, if we set None to the new object, program will crash when we try to access the communication_actions.items. + communication_actions = _deepcopy_dict_vals( + self.communication_actions) if self.communication_actions is not None else None + # same reason as communication_actions + resharding_costs = _deepcopy_dict_vals(self.resharding_costs) if self.resharding_costs is not None else None compute_cost = deepcopy(self.compute_cost) communication_cost = deepcopy(self.communication_cost) memory_cost = deepcopy(self.memory_cost) diff --git a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_bias_linear_module_node.py b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_bias_linear_module_node.py index 6c788b60e..c5c3f3781 100644 --- a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_bias_linear_module_node.py +++ b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_bias_linear_module_node.py @@ -45,11 +45,11 @@ def check_linear_module_handler(rank, bias, world_size, port): physical_mesh_id = torch.arange(0, 4) mesh_shape = (2, 2) device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True) - input = torch.rand(2, 2, 4, 16).cuda() + input = torch.rand(4, 4, 4, 16).cuda() # the index of linear node in computation graph node_index = 3 # strategy number of linear node - strategy_number = 10 + strategy_number = 24 # construct input args input_args = [input] # construct meta arg names @@ -63,7 +63,7 @@ def check_linear_module_handler(rank, bias, world_size, port): node_type='bias_module') tracer = ColoTracer() - graph = tracer.trace(model, meta_args={"x": torch.rand(2, 2, 4, 16).to('meta')}) + graph = tracer.trace(model, meta_args={"x": torch.rand(4, 4, 4, 16).to('meta')}) gm = ColoGraphModule(model, graph) linear_mod_node = list(graph.nodes)[3] @@ -81,9 +81,9 @@ def check_linear_module_handler(rank, bias, world_size, port): assert op_data.data is not None assert mapping['input'].name == "x" - assert mapping['input'].data.shape == torch.Size([2, 2, 4, 16]) + assert mapping['input'].data.shape == torch.Size([4, 4, 4, 16]) assert mapping['input'].type == OperationDataType.ARG - assert mapping['input'].logical_shape == torch.Size([16, 16]) + assert mapping['input'].logical_shape == torch.Size([64, 16]) assert mapping['other'].name == "linear_weight" assert mapping['other'].data.shape == torch.Size([32, 16]) @@ -93,21 +93,27 @@ def check_linear_module_handler(rank, bias, world_size, port): assert 'bias' not in mapping assert mapping['output'].name == "linear" - assert mapping['output'].data.shape == torch.Size([2, 2, 4, 32]) + assert mapping['output'].data.shape == torch.Size([4, 4, 4, 32]) assert mapping['output'].type == OperationDataType.OUTPUT strategies_vector = handler.register_strategy(compute_resharding_cost=False) strategy_name_list = [val.name for val in strategies_vector] - # one strategy will be converted to different physical sharding spec - assert len(strategy_name_list) > 8 # SS = SR x RS assert 'S0S1 = S0R x RS1_0' in strategy_name_list + assert 'S0S1 = S0R x RS1_1' in strategy_name_list + assert 'S0S1 = S0R x RS1_2' in strategy_name_list assert 'S1S0 = S1R x RS0_0' in strategy_name_list + assert 'S1S0 = S1R x RS0_1' in strategy_name_list + assert 'S1S0 = S1R x RS0_2' in strategy_name_list # SR = SS x SR assert 'S0R = S0S1 x S1R_0' in strategy_name_list + assert 'S0R = S0S1 x S1R_1' in strategy_name_list + assert 'S0R = S0S1 x S1R_2' in strategy_name_list assert 'S1R = S1S0 x S0R_0' in strategy_name_list + assert 'S1R = S1S0 x S0R_1' in strategy_name_list + assert 'S1R = S1S0 x S0R_2' in strategy_name_list # RS = RS x SS assert 'RS0 = RS1 x S1S0' in strategy_name_list @@ -121,6 +127,20 @@ def check_linear_module_handler(rank, bias, world_size, port): assert 'RS0 = RR x RS0' in strategy_name_list assert 'RS1 = RR x RS1' in strategy_name_list + # S01R = S01R x RR + assert 'S01R = S01R x RR_0' in strategy_name_list + assert 'S01R = S01R x RR_1' in strategy_name_list + assert 'S01R = S01R x RR_2' in strategy_name_list + + # RR = RS01 x S01R + assert 'RR = RS01 x S01R' in strategy_name_list + + # RS01 = RR x RS01 + assert 'RS01 = RR x RS01' in strategy_name_list + + # RR = RR x RR + assert 'RR = RR x RR' in strategy_name_list + for strategy in strategies_vector: strategy: ShardingStrategy input_sharding_spec = strategy.get_sharding_spec_by_name('x') diff --git a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_linear_handler.py b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_linear_handler.py index 5e9061568..e0130936d 100644 --- a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_linear_handler.py +++ b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_linear_handler.py @@ -33,11 +33,11 @@ def check_linear_module_handler(rank, bias, world_size, port): physical_mesh_id = torch.arange(0, 4) mesh_shape = (2, 2) device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True) - input = torch.rand(2, 2, 4, 16).cuda() + input = torch.rand(4, 4, 4, 16).cuda() # the index of linear node in computation graph node_index = 1 # strategy number of linear node - strategy_number = 10 + strategy_number = 24 # construct input args input_args = [input] # construct meta arg names @@ -50,7 +50,7 @@ def check_linear_module_handler(rank, bias, world_size, port): meta_arg_names=meta_arg_names) tracer = ColoTracer() - graph = tracer.trace(model, meta_args={"input": torch.rand(2, 2, 4, 16).to('meta')}) + graph = tracer.trace(model, meta_args={"input": torch.rand(4, 4, 4, 16).to('meta')}) gm = ColoGraphModule(model, graph) linear_mod_node = list(graph.nodes)[1] @@ -69,9 +69,9 @@ def check_linear_module_handler(rank, bias, world_size, port): assert op_data.data is not None assert mapping['input'].name == "input_1" - assert mapping['input'].data.shape == torch.Size([2, 2, 4, 16]) + assert mapping['input'].data.shape == torch.Size([4, 4, 4, 16]) assert mapping['input'].type == OperationDataType.ARG - assert mapping['input'].logical_shape == torch.Size([16, 16]) + assert mapping['input'].logical_shape == torch.Size([64, 16]) assert mapping['other'].name == "weight" assert mapping['other'].data.shape == torch.Size([32, 16]) @@ -85,9 +85,9 @@ def check_linear_module_handler(rank, bias, world_size, port): assert mapping['bias'].logical_shape == torch.Size([32]) assert mapping['output'].name == "_0" - assert mapping['output'].data.shape == torch.Size([2, 2, 4, 32]) + assert mapping['output'].data.shape == torch.Size([4, 4, 4, 32]) assert mapping['output'].type == OperationDataType.OUTPUT - assert mapping['output'].logical_shape == torch.Size([16, 32]) + assert mapping['output'].logical_shape == torch.Size([64, 32]) strategies_vector = handler.register_strategy(compute_resharding_cost=False) strategy_name_list = [val.name for val in strategies_vector] @@ -96,11 +96,19 @@ def check_linear_module_handler(rank, bias, world_size, port): # SS = SR x RS assert 'S0S1 = S0R x RS1_0' in strategy_name_list + assert 'S0S1 = S0R x RS1_1' in strategy_name_list + assert 'S0S1 = S0R x RS1_2' in strategy_name_list assert 'S1S0 = S1R x RS0_0' in strategy_name_list + assert 'S1S0 = S1R x RS0_1' in strategy_name_list + assert 'S1S0 = S1R x RS0_2' in strategy_name_list # SR = SS x SR assert 'S0R = S0S1 x S1R_0' in strategy_name_list + assert 'S0R = S0S1 x S1R_1' in strategy_name_list + assert 'S0R = S0S1 x S1R_2' in strategy_name_list assert 'S1R = S1S0 x S0R_0' in strategy_name_list + assert 'S1R = S1S0 x S0R_1' in strategy_name_list + assert 'S1R = S1S0 x S0R_2' in strategy_name_list # RS = RS x SS assert 'RS0 = RS1 x S1S0' in strategy_name_list @@ -114,6 +122,20 @@ def check_linear_module_handler(rank, bias, world_size, port): assert 'RS0 = RR x RS0' in strategy_name_list assert 'RS1 = RR x RS1' in strategy_name_list + # S01R = S01R x RR + assert 'S01R = S01R x RR_0' in strategy_name_list + assert 'S01R = S01R x RR_1' in strategy_name_list + assert 'S01R = S01R x RR_2' in strategy_name_list + + # RR = RS01 x S01R + assert 'RR = RS01 x S01R' in strategy_name_list + + # RS01 = RR x RS01 + assert 'RS01 = RR x RS01' in strategy_name_list + + # RR = RR x RR + assert 'RR = RR x RR' in strategy_name_list + for strategy in strategies_vector: strategy: ShardingStrategy input_sharding_spec = strategy.get_sharding_spec_by_name('input_1') @@ -150,12 +172,12 @@ def check_linear_function_handler(rank, bias, world_size, port): mesh_shape = (2, 2) device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True) - input = torch.rand(2, 2, 4, 16).cuda() + input = torch.rand(4, 4, 4, 16).cuda() other = torch.rand(32, 16).cuda() # the index of linear node in computation graph node_index = 2 # strategy number of linear node - strategy_number = 10 + strategy_number = 24 # construct input args input_args = [input, other] # construct meta arg names @@ -170,7 +192,7 @@ def check_linear_function_handler(rank, bias, world_size, port): tracer = ColoTracer() graph = tracer.trace(model, meta_args={ - "input": torch.rand(2, 2, 4, 16).to('meta'), + "input": torch.rand(4, 4, 4, 16).to('meta'), 'others': torch.rand(32, 16).to('meta') }) gm = ColoGraphModule(model, graph) @@ -187,9 +209,9 @@ def check_linear_function_handler(rank, bias, world_size, port): mapping = handler.get_operation_data_mapping() assert mapping['input'].name == "input_1" - assert mapping['input'].data.shape == torch.Size([2, 2, 4, 16]) + assert mapping['input'].data.shape == torch.Size([4, 4, 4, 16]) assert mapping['input'].type == OperationDataType.ARG - assert mapping['input'].logical_shape == torch.Size([16, 16]) + assert mapping['input'].logical_shape == torch.Size([64, 16]) assert mapping['other'].name == "others" assert mapping['other'].data.shape == torch.Size([32, 16]) @@ -203,7 +225,7 @@ def check_linear_function_handler(rank, bias, world_size, port): assert mapping['other'].logical_shape == torch.Size([16, 32]) assert mapping['output'].name == "linear" - assert mapping['output'].data.shape == torch.Size([2, 2, 4, 32]) + assert mapping['output'].data.shape == torch.Size([4, 4, 4, 32]) assert mapping['output'].type == OperationDataType.OUTPUT strategies_vector = handler.register_strategy(compute_resharding_cost=False) @@ -213,11 +235,19 @@ def check_linear_function_handler(rank, bias, world_size, port): # SS = SR x RS assert 'S0S1 = S0R x RS1_0' in strategy_name_list + assert 'S0S1 = S0R x RS1_1' in strategy_name_list + assert 'S0S1 = S0R x RS1_2' in strategy_name_list assert 'S1S0 = S1R x RS0_0' in strategy_name_list + assert 'S1S0 = S1R x RS0_1' in strategy_name_list + assert 'S1S0 = S1R x RS0_2' in strategy_name_list # SR = SS x SR assert 'S0R = S0S1 x S1R_0' in strategy_name_list + assert 'S0R = S0S1 x S1R_1' in strategy_name_list + assert 'S0R = S0S1 x S1R_2' in strategy_name_list assert 'S1R = S1S0 x S0R_0' in strategy_name_list + assert 'S1R = S1S0 x S0R_1' in strategy_name_list + assert 'S1R = S1S0 x S0R_2' in strategy_name_list # RS = RS x SS assert 'RS0 = RS1 x S1S0' in strategy_name_list @@ -231,6 +261,20 @@ def check_linear_function_handler(rank, bias, world_size, port): assert 'RS0 = RR x RS0' in strategy_name_list assert 'RS1 = RR x RS1' in strategy_name_list + # S01R = S01R x RR + assert 'S01R = S01R x RR_0' in strategy_name_list + assert 'S01R = S01R x RR_1' in strategy_name_list + assert 'S01R = S01R x RR_2' in strategy_name_list + + # RR = RS01 x S01R + assert 'RR = RS01 x S01R' in strategy_name_list + + # RS01 = RR x RS01 + assert 'RS01 = RR x RS01' in strategy_name_list + + # RR = RR x RR + assert 'RR = RR x RR' in strategy_name_list + for strategy in strategies_vector: strategy: ShardingStrategy input_sharding_spec = strategy.get_sharding_spec_by_name('input_1')