diff --git a/colossalai/auto_parallel/solver/conv_handler.py b/colossalai/auto_parallel/solver/conv_handler.py index 4f72ca4e0..4c8935809 100644 --- a/colossalai/auto_parallel/solver/conv_handler.py +++ b/colossalai/auto_parallel/solver/conv_handler.py @@ -93,6 +93,45 @@ class ConvHandler(OperatorHandler): input_shardings=(sharding_spec_for_input, sharding_spec_for_weight)) self.strategies_vector.append(sharding_strategies) + def split_input_batch(self, mesh_dim_0): + name = f'S{mesh_dim_0}R = S{mesh_dim_0}R x RR' + + dim_partition_dict_for_input = {0: [mesh_dim_0]} + sharding_spec_for_input = self._generate_sharding_spec(self.input_data, dim_partition_dict_for_input) + + dim_partition_dict_for_weight = {} + sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight) + + dim_partition_dict_for_output = {0: [mesh_dim_0]} + sharding_spec_for_ouput = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output) + + # generate resharding cost for this strategy + resharding_costs = self._generate_resharding_costs([sharding_spec_for_input]) + + # compute the computation cost of this strategy + bs = self.input_data.shape[0] // self.device_mesh.shape[mesh_dim_0] + channel_in = self.input_data.shape[1] + channel_out = self.weight.shape[1] + compute_cost = self._generate_compute_cost(bs, channel_in, channel_out) + + # compute the memory cost of this strategy + dtype = self.input_data.dtype + numel = self.output_data.numel() + size_per_elem_bytes = torch.tensor([], dtype=dtype).element_size() + sharding_size = self.device_mesh.shape[mesh_dim_0] + memory_cost = numel * size_per_elem_bytes / sharding_size + + # This strategy do not need to do all_reduce operation + communication_cost = 0 + sharding_strategies = ShardingStrategy(name, + output_sharding_spec=sharding_spec_for_ouput, + compute_cost=compute_cost, + communication_cost=communication_cost, + memory_cost=memory_cost, + resharding_costs=resharding_costs, + input_shardings=(sharding_spec_for_input, sharding_spec_for_weight)) + self.strategies_vector.append(sharding_strategies) + def split_input_both_dim_weight_in_channel(self, mesh_dim_0, mesh_dim_1): name = f'S{mesh_dim_0}R = S{mesh_dim_0}S{mesh_dim_1} x S{mesh_dim_1}R' @@ -161,7 +200,45 @@ class ConvHandler(OperatorHandler): memory_cost = numel * size_per_elem_bytes / sharding_size # compute the communication cost of this strategy - communication_cost = self.device_mesh.all_reduce_cost(memory_cost, mesh_dim_1) + communication_cost = self.device_mesh.all_reduce_cost(memory_cost, mesh_dim_0) + sharding_strategies = ShardingStrategy(name, + output_sharding_spec=sharding_spec_for_ouput, + compute_cost=compute_cost, + communication_cost=communication_cost, + memory_cost=memory_cost, + resharding_costs=resharding_costs, + input_shardings=(sharding_spec_for_input, sharding_spec_for_weight)) + self.strategies_vector.append(sharding_strategies) + + def split_input_in_channel_weight_in_channel(self, mesh_dim_0): + name = f'RR = RS{mesh_dim_0} x S{mesh_dim_0}R' + + dim_partition_dict_for_input = {1: [mesh_dim_0]} + sharding_spec_for_input = self._generate_sharding_spec(self.input_data, dim_partition_dict_for_input) + + dim_partition_dict_for_weight = {0: [mesh_dim_0]} + sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight) + + dim_partition_dict_for_output = {} + sharding_spec_for_ouput = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_input) + + # generate resharding cost for this strategy + resharding_costs = self._generate_resharding_costs([sharding_spec_for_input]) + + # compute the computation cost of this strategy + bs = self.input_data.shape[0] + channel_in = self.input_data.shape[1] // self.device_mesh.shape[mesh_dim_0] + channel_out = self.weight.shape[1] + compute_cost = self._generate_compute_cost(bs, channel_in, channel_out) + + # compute the memory cost of this strategy + dtype = self.input_data.dtype + numel = self.output_data.numel() + size_per_elem_bytes = torch.tensor([], dtype=dtype).element_size() + memory_cost = numel * size_per_elem_bytes + + # compute the communication cost of this strategy + communication_cost = self.device_mesh.all_reduce_cost(memory_cost, mesh_dim_0) sharding_strategies = ShardingStrategy(name, output_sharding_spec=sharding_spec_for_ouput, compute_cost=compute_cost, @@ -250,6 +327,86 @@ class ConvHandler(OperatorHandler): input_shardings=(sharding_spec_for_input, sharding_spec_for_weight)) self.strategies_vector.append(sharding_strategies) + def split_1d_parallel_on_input_batch(self, mesh_dim_0, mesh_dim_1): + name = f'S{mesh_dim_0}{mesh_dim_1}R = S{mesh_dim_0}{mesh_dim_1}R x RR' + + dim_partition_dict_for_input = {0: [mesh_dim_0, mesh_dim_1]} + sharding_spec_for_input = self._generate_sharding_spec(self.input_data, dim_partition_dict_for_input) + + dim_partition_dict_for_weight = {} + sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight) + + dim_partition_dict_for_output = {0: [mesh_dim_0, mesh_dim_1]} + sharding_spec_for_ouput = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_input) + + # generate resharding cost for this strategy + resharding_costs = self._generate_resharding_costs([sharding_spec_for_input]) + + # compute the computation cost of this strategy + bs = self.input_data.shape[0] // (self.device_mesh.shape[mesh_dim_0] * self.device_mesh.shape[mesh_dim_1]) + channel_in = self.input_data.shape[1] + channel_out = self.weight.shape[1] + compute_cost = self._generate_compute_cost(bs, channel_in, channel_out) + + # compute the memory cost of this strategy + dtype = self.input_data.dtype + numel = self.output_data.numel() + size_per_elem_bytes = torch.tensor([], dtype=dtype).element_size() + sharding_size = self.device_mesh.shape[mesh_dim_0] * self.device_mesh.shape[mesh_dim_1] + memory_cost = numel * size_per_elem_bytes / sharding_size + + # This strategy do not need to do all_reduce operation + communication_cost = 0 + + sharding_strategies = ShardingStrategy(name, + output_sharding_spec=sharding_spec_for_ouput, + compute_cost=compute_cost, + communication_cost=communication_cost, + memory_cost=memory_cost, + resharding_costs=resharding_costs, + input_shardings=(sharding_spec_for_input, sharding_spec_for_weight)) + self.strategies_vector.append(sharding_strategies) + + def split_1d_parallel_on_in_channel(self, mesh_dim_0, mesh_dim_1): + name = f'RR = RS{mesh_dim_0}{mesh_dim_1} x S{mesh_dim_0}{mesh_dim_1}R' + + dim_partition_dict_for_input = {1: [mesh_dim_0, mesh_dim_1]} + sharding_spec_for_input = self._generate_sharding_spec(self.input_data, dim_partition_dict_for_input) + + dim_partition_dict_for_weight = {0: [mesh_dim_0, mesh_dim_1]} + sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight) + + dim_partition_dict_for_output = {} + sharding_spec_for_ouput = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_input) + + # generate resharding cost for this strategy + resharding_costs = self._generate_resharding_costs([sharding_spec_for_input]) + + # compute the computation cost of this strategy + bs = self.input_data.shape[0] + channel_in = self.input_data.shape[1] // (self.device_mesh.shape[mesh_dim_0] * + self.device_mesh.shape[mesh_dim_1]) + channel_out = self.weight.shape[1] + compute_cost = self._generate_compute_cost(bs, channel_in, channel_out) + + # compute the memory cost of this strategy + dtype = self.input_data.dtype + numel = self.output_data.numel() + size_per_elem_bytes = torch.tensor([], dtype=dtype).element_size() + memory_cost = numel * size_per_elem_bytes + + # compute communication cost + communication_cost = self.device_mesh.flatten_device_mesh.all_reduce_cost(memory_cost, 0) + + sharding_strategies = ShardingStrategy(name, + output_sharding_spec=sharding_spec_for_ouput, + compute_cost=compute_cost, + communication_cost=communication_cost, + memory_cost=memory_cost, + resharding_costs=resharding_costs, + input_shardings=(sharding_spec_for_input, sharding_spec_for_weight)) + self.strategies_vector.append(sharding_strategies) + def register_strategy(self) -> StrategiesVector: ''' Generate every possible strategies for a Conv node, and record all strategies into the strategies_vector. @@ -283,24 +440,34 @@ class ConvHandler(OperatorHandler): conv_handler = ConvHandler(input_node=nodes[1], input_index=0, weight=dict(gm.named_modules())[nodes[2].name].weight, output_node=nodes[2], device_mesh=device_mesh, strategies_vector=strategies_vector, shape_consistency_manager=shape_consistency_manager) conv_handler.register_strategy_into_strategies_vector() - for strategy in conv_handler.strategies_vector.strategies: + for strategy in conv_handler.strategies_vector: print(f'{strategy.name}: compute_cost is {strategy.compute_cost}, communication_cost is {strategy.communication_cost}, memory_cost is {strategy.memory_cost}, resharding_costs is {strategy.resharding_costs}') Output: - S0S1 = S0R x RS1: compute_cost is 8856576, communication_cost is 0, memory_cost is 492032.0, resharding_costs is {0: [0, 32769.001, 131074.2, 0, 32769.1, 131074.2, 98307.201]} - S1S0 = S1R x RS0: compute_cost is 8856576, communication_cost is 0, memory_cost is 492032.0, resharding_costs is {0: [0, 131074.2, 32769.001, 131074.2, 98307.201, 0, 32769.1]} - S0R = S0S1 x S1R: compute_cost is 8856576, communication_cost is 984065.01, memory_cost is 984064.0, resharding_costs is {0: [0, 65538.002, 0, 0, 0, 65538.002, 196614.402]} - S1R = S1S0 x S0R: compute_cost is 8856576, communication_cost is 984065.01, memory_cost is 984064.0, resharding_costs is {0: [0, 0, 65538.002, 65538.002, 196614.402, 0, 0]} - RS1 = RS0 x S0S1: compute_cost is 8856576, communication_cost is 984065.01, memory_cost is 984064.0, resharding_costs is {0: [0, 0, 131074.2, 32769.001, 98307.201, 131074.2, 32769.1]} - RS0 = RS1 x S1S0: compute_cost is 8856576, communication_cost is 984065.01, memory_cost is 984064.0, resharding_costs is {0: [0, 131074.2, 0, 131074.2, 32769.1, 32769.001, 98307.201]} - RS0 = RR x RS0: compute_cost is 17713152, communication_cost is 0, memory_cost is 984064.0, resharding_costs is {0: [0, 65537.1, 65537.1, 65537.1, 131075.30000000002, 65537.1, 131075.30000000002]} - RS1 = RR x RS1: compute_cost is 17713152, communication_cost is 0, memory_cost is 984064.0, resharding_costs is {0: [0, 65537.1, 65537.1, 65537.1, 131075.30000000002, 65537.1, 131075.30000000002]} - RR = RR x RR: compute_cost is 35426304, communication_cost is 0, memory_cost is 1968128, resharding_costs is {0: [0, 65537.1, 65537.1, 65537.1, 131075.30000000002, 65537.1, 131075.30000000002]} + S0S1 = S0R x RS1: compute_cost is 8856576, communication_cost is 0, memory_cost is 492032.0, resharding_costs is {mul: [0, 32769.001, 131074.2, 0, 32769.1, 131074.2, 98307.201]} + S1S0 = S1R x RS0: compute_cost is 8856576, communication_cost is 0, memory_cost is 492032.0, resharding_costs is {mul: [0, 131074.2, 32769.001, 131074.2, 98307.201, 0, 32769.1]} + S0R = S0R x RR: compute_cost is 17713152, communication_cost is 0, memory_cost is 984064.0, resharding_costs is {mul: [0, 32769.001, 131074.2, 0, 32769.1, 131074.2, 98307.201]} + S1R = S1R x RR: compute_cost is 17713152, communication_cost is 0, memory_cost is 984064.0, resharding_costs is {mul: [0, 131074.2, 32769.001, 131074.2, 98307.201, 0, 32769.1]} + S0R = S0S1 x S1R: compute_cost is 8856576, communication_cost is 984065.01, memory_cost is 984064.0, resharding_costs is {mul: [0, 65538.002, 0, 0, 0, 65538.002, 196614.402]} + S1R = S1S0 x S0R: compute_cost is 8856576, communication_cost is 984065.01, memory_cost is 984064.0, resharding_costs is {mul: [0, 0, 65538.002, 65538.002, 196614.402, 0, 0]} + RS1 = RS0 x S0S1: compute_cost is 8856576, communication_cost is 984065.01, memory_cost is 984064.0, resharding_costs is {mul: [0, 0, 131074.2, 32769.001, 98307.201, 131074.2, 32769.1]} + RS0 = RS1 x S1S0: compute_cost is 8856576, communication_cost is 984065.01, memory_cost is 984064.0, resharding_costs is {mul: [0, 131074.2, 0, 131074.2, 32769.1, 32769.001, 98307.201]} + RR = RS0 x S0R: compute_cost is 17713152, communication_cost is 1968129.01, memory_cost is 1968128, resharding_costs is {mul: [0, 0, 131074.2, 32769.001, 98307.201, 131074.2, 32769.1]} + RR = RS1 x S1R: compute_cost is 17713152, communication_cost is 1968129.01, memory_cost is 1968128, resharding_costs is {mul: [0, 131074.2, 0, 131074.2, 32769.1, 32769.001, 98307.201]} + RS0 = RR x RS0: compute_cost is 17713152, communication_cost is 0, memory_cost is 984064.0, resharding_costs is {mul: [0, 65537.1, 65537.1, 65537.1, 131075.30000000002, 65537.1, 131075.30000000002]} + RS1 = RR x RS1: compute_cost is 17713152, communication_cost is 0, memory_cost is 984064.0, resharding_costs is {mul: [0, 65537.1, 65537.1, 65537.1, 131075.30000000002, 65537.1, 131075.30000000002]} + RR = RR x RR: compute_cost is 35426304, communication_cost is 0, memory_cost is 1968128, resharding_costs is {mul: [0, 65537.1, 65537.1, 65537.1, 131075.30000000002, 65537.1, 131075.30000000002]} + S01R = S01R x RR: compute_cost is 8856576, communication_cost is 0, memory_cost is 492032.0, resharding_costs is {mul: [0, 65538.002, 262148.4, 0, 16385.001, 262148.4, 196614.402]} + RR = RS01 x S01R: compute_cost is 8856576, communication_cost is 0, memory_cost is 1968128, resharding_costs is {mul: [0, 0, 262148.4, 65538.002, 196614.402, 262148.4, 65538.2]} ''' # SS = SR x RS self.split_input_batch_weight_out_channel(0, 1) self.split_input_batch_weight_out_channel(1, 0) + # SR = SR x RR + self.split_input_batch(0) + self.split_input_batch(1) + # SR = SS x SR self.split_input_both_dim_weight_in_channel(0, 1) self.split_input_both_dim_weight_in_channel(1, 0) @@ -309,6 +476,10 @@ class ConvHandler(OperatorHandler): self.split_input_in_channel_weight_both_channel(0, 1) self.split_input_in_channel_weight_both_channel(1, 0) + # RR = RS x SR + self.split_input_in_channel_weight_in_channel(0) + self.split_input_in_channel_weight_in_channel(1) + # RS = RR x RS self.split_weight_out_channel(0) self.split_weight_out_channel(1) @@ -316,4 +487,10 @@ class ConvHandler(OperatorHandler): # RR= RR x RR self.non_split() + # S01R = S01R x RR + self.split_1d_parallel_on_input_batch(0, 1) + + # RR = RS01 x S01R + self.split_1d_parallel_on_in_channel(0, 1) + return self.strategies_vector diff --git a/colossalai/auto_parallel/solver/operator_handler.py b/colossalai/auto_parallel/solver/operator_handler.py index 1cacc9324..85174d9f4 100644 --- a/colossalai/auto_parallel/solver/operator_handler.py +++ b/colossalai/auto_parallel/solver/operator_handler.py @@ -87,4 +87,4 @@ class OperatorHandler(ABC): _, _, resharding_cost = self.shape_consistency_manager.shape_consistency( strategy.output_sharding_spec, input_spec) resharding_costs[input_node].append(resharding_cost) - return resharding_cost + return resharding_costs diff --git a/tests/test_auto_parallel/test_conv_handler.py b/tests/test_auto_parallel/test_conv_handler.py index 3cda3bd80..52b8ba28a 100644 --- a/tests/test_auto_parallel/test_conv_handler.py +++ b/tests/test_auto_parallel/test_conv_handler.py @@ -83,7 +83,7 @@ def test_conv_handler(): shape_consistency_manager=shape_consistency_manager) conv_handler.register_strategy() - # ['S0S1 = S0R x RS1', 'S1S0 = S1R x RS0', 'S0R = S0S1 x S1R', 'S1R = S1S0 x S0R', 'RS1 = RS0 x S0S1', 'RS0 = RS1 x S1S0', 'RS0 = RR x RS0', 'RS1 = RR x RS1', 'RR = RR x RR'] + # ['S0S1 = S0R x RS1', 'S1S0 = S1R x RS0', 'S0R = S0R x RR', 'S1R = S1R x RR', 'S0R = S0S1 x S1R', 'S1R = S1S0 x S0R', 'RS1 = RS0 x S0S1', 'RS0 = RS1 x S1S0', 'RR = RS0 x S0R', 'RR = RS1 x S1R', 'RS0 = RR x RS0', 'RS1 = RR x RS1', 'RR = RR x RR', 'S01R = S01R x RR', 'RR = RS01 x S01R'] strategy_name_list = [strategy.name for strategy in conv_handler.strategies_vector] # SS = SR x RS @@ -105,6 +105,16 @@ def test_conv_handler(): # RR= RR x RR assert 'RR = RR x RR' in strategy_name_list + # SR = SR x RR + assert 'S0R = S0R x RR' in strategy_name_list + assert 'S1R = S1R x RR' in strategy_name_list + assert 'S01R = S01R x RR' in strategy_name_list + + # RR = RS x SR + assert 'RR = RS0 x S0R' in strategy_name_list + assert 'RR = RS1 x S1R' in strategy_name_list + assert 'RR = RS01 x S01R' in strategy_name_list + if __name__ == '__main__': test_conv_handler()