import operator from functools import reduce import warnings import torch from colossalai.auto_parallel.solver.sharding_strategy import ShardingStrategy, StrategiesVector from .operator_handler import OperatorHandler from .._utils import generate_sharding_spec __all__ = ['ConvHandler'] class ConvHandler(OperatorHandler): """ A OperatorHandler which deals with the sharding strategies of Convolution. """ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.input_data = self.predecessor_node[0]._meta_data self.weight = self.module_named_parameters['weight'] self.output_data = self.node._meta_data self._sanity_check() def _sanity_check(self): ''' In sanity check, we need make sure the input data having correct dimension size. For Conv1d, the dim of input data should be 3([N, C, L]). For Conv2d, the dim of input data should be 4([N, C, H, W]). For Conv3d, the dim of input data should be 5([N, C, H, W, D]). ''' assert self.input_data.dim() in (3, 4, 5), f'We suppose the dim of input fed into conv op should in range of [3, 5].' def _generate_compute_cost(self, bs, channel_in, channel_out): ''' Compute the computation cost per device with this specific strategy. Note: compute_cost need to be devided by TFLOPS, now it just shows the computation size. Argument: bs(int): Batch size of the input data. channel_in(int): The channel dimension of input data. channel_out(int): The out channel of the conv weight. Return: compute_cost(float): Computation cost per device with this specific strategy ''' # TODO: compute_cost need to be devided by TFLOPS, now it just shows the computation size. # 1D: (L) * N * Cout * Cin * kernel # 2D: (H * W) * N * Cout * Cin * kernel # 3D: (H * W * D) * N * Cout * Cin * kernel output_size = self.output_data.shape[2:] output_size_product = reduce(operator.mul, output_size, 1) input_size = self.input_data.shape[2:] input_size_product = reduce(operator.mul, input_size, 1) kernel_size = self.weight.shape[2:] kernel_size_product = reduce(operator.mul, kernel_size, 1) forward_compute_cost = output_size_product * bs * channel_in * channel_out * kernel_size_product backward_activation_cost = input_size_product * bs * channel_in * channel_out * kernel_size_product backward_weight_cost = output_size_product * bs * channel_in * channel_out * kernel_size_product compute_cost = forward_compute_cost + backward_activation_cost + backward_weight_cost return compute_cost def _generate_memory_cost(self, sharding_size_forward, sharding_size_backward_activation, sharding_size_weight): ''' Compute the memory cost per device with this specific strategy. Argument: sharding_size_forward(int): The forward activation will be divided into sharding_size_forward number partions. sharding_size_backward_activation(int): The backward activation will be divided into sharding_size_backward_activation number partions. sharding_size_weight(int): The backward weight will be divided into sharding_size_weight number partions. Return: memory_cost(Tuple[float]): Memory cost per device with this specific strategy, the first element of this tuple is forward memory cost, and the second element of this tuple is backward memory cost. memory_cost_forward(float): Memory cost of forward activation per device with this specific strategy. memory_cost_backward_activation(float): Memory cost of backward activation per device with this specific strategy. ''' # compute the memory cost of this strategy dtype = self.input_data.dtype numel_output = self.output_data.numel() numel_input = self.input_data.numel() numel_weight = self.weight.numel() size_per_elem_bytes = torch.tensor([], dtype=dtype).element_size() # forward memory_cost memory_cost_forward_activation = numel_output * size_per_elem_bytes / sharding_size_forward memory_cost_forward_weight = numel_weight * size_per_elem_bytes / sharding_size_weight memory_cost_forward = memory_cost_forward_activation + memory_cost_forward_weight # backward memory_cost memory_cost_backward_activation = numel_input * size_per_elem_bytes / sharding_size_backward_activation memory_cost_backward_weight = numel_weight * size_per_elem_bytes / sharding_size_weight memory_cost_backward = memory_cost_backward_activation + memory_cost_backward_weight # memory_cost pair memory_cost = (memory_cost_forward, memory_cost_backward) return memory_cost, memory_cost_forward_activation, memory_cost_backward_activation, memory_cost_backward_weight def split_input_batch_weight_out_channel(self, mesh_dim_0, mesh_dim_1): name = f'S{mesh_dim_0}S{mesh_dim_1} = S{mesh_dim_0}R x RS{mesh_dim_1}' dim_partition_dict_for_input = {0: [mesh_dim_0]} sharding_spec_for_input = generate_sharding_spec(self.input_data, self.device_mesh, dim_partition_dict_for_input) dim_partition_dict_for_weight = {1: [mesh_dim_1]} sharding_spec_for_weight = generate_sharding_spec(self.weight, self.device_mesh, dim_partition_dict_for_weight) dim_partition_dict_for_output = {0: [mesh_dim_0], 1: [mesh_dim_1]} sharding_spec_for_output = generate_sharding_spec(self.output_data, self.device_mesh, dim_partition_dict_for_output) # generate resharding cost for this strategy resharding_costs = self._generate_resharding_costs([sharding_spec_for_input, sharding_spec_for_weight]) # compute the computation cost of this strategy bs = self.input_data.shape[0] // self.device_mesh.shape[mesh_dim_0] channel_in = self.input_data.shape[1] channel_out = self.weight.shape[1] // self.device_mesh.shape[mesh_dim_1] compute_cost = self._generate_compute_cost(bs, channel_in, channel_out) # compute the memory cost of this strategy sharding_size_forward = self.device_mesh.shape[mesh_dim_0] * self.device_mesh.shape[mesh_dim_1] sharding_size_backward_activation = self.device_mesh.shape[mesh_dim_0] sharding_size_weight = self.device_mesh.shape[mesh_dim_1] memory_cost, _, memory_cost_backward_activation, memory_cost_backward_weight = self._generate_memory_cost( sharding_size_forward, sharding_size_backward_activation, sharding_size_weight) # This strategy do not need to do all_reduce operation during forward communication_cost_forward = 0 # compute the backward communication cost to all reduce the input activation grad communication_cost_backward_activation = self.device_mesh.all_reduce_cost(memory_cost_backward_activation, mesh_dim_1) # compute the backward communication cost to all reduce the weight due to data parallel communication_cost_backward_weight = self.device_mesh.all_reduce_cost(memory_cost_backward_weight, mesh_dim_0) # total communication cost communication_cost = communication_cost_forward + communication_cost_backward_activation + communication_cost_backward_weight sharding_strategies = ShardingStrategy(name, output_sharding_spec=sharding_spec_for_output, compute_cost=compute_cost, communication_cost=communication_cost, memory_cost=memory_cost, resharding_costs=resharding_costs, input_shardings=(sharding_spec_for_input, sharding_spec_for_weight)) self.strategies_vector.append(sharding_strategies) def split_input_batch(self, mesh_dim_0): name = f'S{mesh_dim_0}R = S{mesh_dim_0}R x RR' dim_partition_dict_for_input = {0: [mesh_dim_0]} sharding_spec_for_input = generate_sharding_spec(self.input_data, self.device_mesh, dim_partition_dict_for_input) dim_partition_dict_for_weight = {} sharding_spec_for_weight = generate_sharding_spec(self.weight, self.device_mesh, dim_partition_dict_for_weight) dim_partition_dict_for_output = {0: [mesh_dim_0]} sharding_spec_for_output = generate_sharding_spec(self.output_data, self.device_mesh, dim_partition_dict_for_output) # generate resharding cost for this strategy resharding_costs = self._generate_resharding_costs([sharding_spec_for_input, sharding_spec_for_weight]) # compute the computation cost of this strategy bs = self.input_data.shape[0] // self.device_mesh.shape[mesh_dim_0] channel_in = self.input_data.shape[1] channel_out = self.weight.shape[1] compute_cost = self._generate_compute_cost(bs, channel_in, channel_out) # compute the memory cost of this strategy sharding_size_forward = self.device_mesh.shape[mesh_dim_0] sharding_size_backward_activation = self.device_mesh.shape[mesh_dim_0] sharding_size_weight = 1 memory_cost, _, _, memory_cost_backward_weight = self._generate_memory_cost(sharding_size_forward, sharding_size_backward_activation, sharding_size_weight) # This strategy do not need to do all_reduce operation in forward phase. communication_cost_forward = 0 # compute the backward communication cost to all reduce the weight due to data parallel communication_cost_backward_weight = self.device_mesh.all_reduce_cost(memory_cost_backward_weight, mesh_dim_0) # compute the total cost communication_cost = communication_cost_forward + communication_cost_backward_weight sharding_strategies = ShardingStrategy(name, output_sharding_spec=sharding_spec_for_output, compute_cost=compute_cost, communication_cost=communication_cost, memory_cost=memory_cost, resharding_costs=resharding_costs, input_shardings=(sharding_spec_for_input, sharding_spec_for_weight)) self.strategies_vector.append(sharding_strategies) def split_input_both_dim_weight_in_channel(self, mesh_dim_0, mesh_dim_1): name = f'S{mesh_dim_0}R = S{mesh_dim_0}S{mesh_dim_1} x S{mesh_dim_1}R' dim_partition_dict_for_input = {0: [mesh_dim_0], 1: [mesh_dim_1]} sharding_spec_for_input = generate_sharding_spec(self.input_data, self.device_mesh, dim_partition_dict_for_input) dim_partition_dict_for_weight = {0: [mesh_dim_0]} sharding_spec_for_weight = generate_sharding_spec(self.weight, self.device_mesh, dim_partition_dict_for_weight) dim_partition_dict_for_output = {0: [mesh_dim_0]} sharding_spec_for_output = generate_sharding_spec(self.output_data, self.device_mesh, dim_partition_dict_for_output) # generate resharding cost for this strategy resharding_costs = self._generate_resharding_costs([sharding_spec_for_input, sharding_spec_for_weight]) # compute the computation cost of this strategy bs = self.input_data.shape[0] // self.device_mesh.shape[mesh_dim_0] channel_in = self.input_data.shape[1] // self.device_mesh.shape[mesh_dim_1] channel_out = self.weight.shape[1] compute_cost = self._generate_compute_cost(bs, channel_in, channel_out) # compute the memory cost of this strategy sharding_size_forward = self.device_mesh.shape[mesh_dim_0] sharding_size_backward_activation = self.device_mesh.shape[mesh_dim_0] * self.device_mesh.shape[mesh_dim_1] sharding_size_weight = self.device_mesh.shape[mesh_dim_1] memory_cost, memory_cost_forward_activation, _, memory_cost_backward_weight = self._generate_memory_cost( sharding_size_forward, sharding_size_backward_activation, sharding_size_weight) # compute the communication cost of this strategy during forward phase communication_cost_forward = self.device_mesh.all_reduce_cost(memory_cost_forward_activation, mesh_dim_1) # This strategy do not need to do all_reduce operation to compute the input activation grad communication_cost_backward_activation = 0 # compute the backward communication cost to all reduce the weight due to data parallel communication_cost_backward_weight = self.device_mesh.all_reduce_cost(memory_cost_backward_weight, mesh_dim_0) # compute total cost communication_cost = communication_cost_forward + communication_cost_backward_activation + communication_cost_backward_weight sharding_strategies = ShardingStrategy(name, output_sharding_spec=sharding_spec_for_output, compute_cost=compute_cost, communication_cost=communication_cost, memory_cost=memory_cost, resharding_costs=resharding_costs, input_shardings=(sharding_spec_for_input, sharding_spec_for_weight)) self.strategies_vector.append(sharding_strategies) def split_input_in_channel_weight_both_channel(self, mesh_dim_0, mesh_dim_1): name = f'RS{mesh_dim_1} = RS{mesh_dim_0} x S{mesh_dim_0}S{mesh_dim_1}' dim_partition_dict_for_input = {1: [mesh_dim_0]} sharding_spec_for_input = generate_sharding_spec(self.input_data, self.device_mesh, dim_partition_dict_for_input) dim_partition_dict_for_weight = {0: [mesh_dim_0], 1: [mesh_dim_1]} sharding_spec_for_weight = generate_sharding_spec(self.weight, self.device_mesh, dim_partition_dict_for_weight) dim_partition_dict_for_output = {1: [mesh_dim_1]} sharding_spec_for_output = generate_sharding_spec(self.output_data, self.device_mesh, dim_partition_dict_for_output) # generate resharding cost for this strategy resharding_costs = self._generate_resharding_costs([sharding_spec_for_input, sharding_spec_for_weight]) # compute the computation cost of this strategy bs = self.input_data.shape[0] channel_in = self.input_data.shape[1] // self.device_mesh.shape[mesh_dim_0] channel_out = self.weight.shape[1] // self.device_mesh.shape[mesh_dim_1] compute_cost = self._generate_compute_cost(bs, channel_in, channel_out) # compute the memory cost of this strategy sharding_size_forward = self.device_mesh.shape[mesh_dim_1] sharding_size_backward_activation = self.device_mesh.shape[mesh_dim_0] sharding_size_weight = self.device_mesh.shape[mesh_dim_0] * self.device_mesh.shape[mesh_dim_1] memory_cost, memory_cost_forward_activation, memory_cost_backward_activation, _ = self._generate_memory_cost( sharding_size_forward, sharding_size_backward_activation, sharding_size_weight) # compute the communication cost of this strategy during forward phase communication_cost_forward = self.device_mesh.all_reduce_cost(memory_cost_forward_activation, mesh_dim_0) # compute the communication cost of this strategy during backward phase communication_cost_backward = self.device_mesh.all_reduce_cost(memory_cost_backward_activation, mesh_dim_1) communication_cost = communication_cost_forward + communication_cost_backward sharding_strategies = ShardingStrategy(name, output_sharding_spec=sharding_spec_for_output, compute_cost=compute_cost, communication_cost=communication_cost, memory_cost=memory_cost, resharding_costs=resharding_costs, input_shardings=(sharding_spec_for_input, sharding_spec_for_weight)) self.strategies_vector.append(sharding_strategies) def split_input_in_channel_weight_in_channel(self, mesh_dim_0): name = f'RR = RS{mesh_dim_0} x S{mesh_dim_0}R' dim_partition_dict_for_input = {1: [mesh_dim_0]} sharding_spec_for_input = generate_sharding_spec(self.input_data, self.device_mesh, dim_partition_dict_for_input) dim_partition_dict_for_weight = {0: [mesh_dim_0]} sharding_spec_for_weight = generate_sharding_spec(self.weight, self.device_mesh, dim_partition_dict_for_weight) dim_partition_dict_for_output = {} sharding_spec_for_output = generate_sharding_spec(self.output_data, self.device_mesh, dim_partition_dict_for_output) # generate resharding cost for this strategy resharding_costs = self._generate_resharding_costs([sharding_spec_for_input, sharding_spec_for_weight]) # compute the computation cost of this strategy bs = self.input_data.shape[0] channel_in = self.input_data.shape[1] // self.device_mesh.shape[mesh_dim_0] channel_out = self.weight.shape[1] compute_cost = self._generate_compute_cost(bs, channel_in, channel_out) # compute the memory cost of this strategy sharding_size_forward = 1 sharding_size_backward_activation = self.device_mesh.shape[mesh_dim_0] sharding_size_weight = self.device_mesh.shape[mesh_dim_0] memory_cost, memory_cost_forward_activation, _, _ = self._generate_memory_cost( sharding_size_forward, sharding_size_backward_activation, sharding_size_weight) # compute the communication cost of this strategy during forward phase communication_cost_forward = self.device_mesh.all_reduce_cost(memory_cost_forward_activation, mesh_dim_0) # This strategy do NOT need all_reduce during forward phase communication_cost_backward = 0 communication_cost = communication_cost_forward + communication_cost_backward sharding_strategies = ShardingStrategy(name, output_sharding_spec=sharding_spec_for_output, compute_cost=compute_cost, communication_cost=communication_cost, memory_cost=memory_cost, resharding_costs=resharding_costs, input_shardings=(sharding_spec_for_input, sharding_spec_for_weight)) self.strategies_vector.append(sharding_strategies) def split_weight_out_channel(self, mesh_dim_0): name = f'RS{mesh_dim_0} = RR x RS{mesh_dim_0}' dim_partition_dict_for_input = {} sharding_spec_for_input = generate_sharding_spec(self.input_data, self.device_mesh, dim_partition_dict_for_input) dim_partition_dict_for_weight = {1: [mesh_dim_0]} sharding_spec_for_weight = generate_sharding_spec(self.weight, self.device_mesh, dim_partition_dict_for_weight) dim_partition_dict_for_output = {1: [mesh_dim_0]} sharding_spec_for_output = generate_sharding_spec(self.output_data, self.device_mesh, dim_partition_dict_for_output) # generate resharding cost for this strategy resharding_costs = self._generate_resharding_costs([sharding_spec_for_input, sharding_spec_for_weight]) # compute the computation cost of this strategy bs = self.input_data.shape[0] channel_in = self.input_data.shape[1] channel_out = self.weight.shape[1] // self.device_mesh.shape[mesh_dim_0] compute_cost = self._generate_compute_cost(bs, channel_in, channel_out) # compute the memory cost of this strategy sharding_size_forward = self.device_mesh.shape[mesh_dim_0] sharding_size_backward_activation = 1 sharding_size_weight = self.device_mesh.shape[mesh_dim_0] memory_cost, _, memory_cost_backward_activation, _ = self._generate_memory_cost( sharding_size_forward, sharding_size_backward_activation, sharding_size_weight) # This strategy do not need to do all_reduce during forward phase communication_cost_forward = 0 # compute the communication cost of this strategy during backward phase communication_cost_backward = self.device_mesh.all_reduce_cost(memory_cost_backward_activation, mesh_dim_0) communication_cost = communication_cost_forward + communication_cost_backward sharding_strategies = ShardingStrategy(name, output_sharding_spec=sharding_spec_for_output, compute_cost=compute_cost, communication_cost=communication_cost, memory_cost=memory_cost, resharding_costs=resharding_costs, input_shardings=(sharding_spec_for_input, sharding_spec_for_weight)) self.strategies_vector.append(sharding_strategies) def non_split(self): name = f'RR = RR x RR' dim_partition_dict_for_input = {} sharding_spec_for_input = generate_sharding_spec(self.input_data, self.device_mesh, dim_partition_dict_for_input) dim_partition_dict_for_weight = {} sharding_spec_for_weight = generate_sharding_spec(self.weight, self.device_mesh, dim_partition_dict_for_weight) dim_partition_dict_for_output = {} sharding_spec_for_output = generate_sharding_spec(self.output_data, self.device_mesh, dim_partition_dict_for_output) # generate resharding cost for this strategy resharding_costs = self._generate_resharding_costs([sharding_spec_for_input, sharding_spec_for_weight]) # compute the computation cost of this strategy bs = self.input_data.shape[0] channel_in = self.input_data.shape[1] channel_out = self.weight.shape[1] compute_cost = self._generate_compute_cost(bs, channel_in, channel_out) # compute the memory cost of this strategy sharding_size_forward = 1 sharding_size_backward_activation = 1 sharding_size_weight = 1 memory_cost, _, _, _ = self._generate_memory_cost(sharding_size_forward, sharding_size_backward_activation, sharding_size_weight) # This strategy do not need to do all_reduce in both forward and backward phase communication_cost = 0 sharding_strategies = ShardingStrategy(name, output_sharding_spec=sharding_spec_for_output, compute_cost=compute_cost, communication_cost=communication_cost, memory_cost=memory_cost, resharding_costs=resharding_costs, input_shardings=(sharding_spec_for_input, sharding_spec_for_weight)) self.strategies_vector.append(sharding_strategies) def split_1d_parallel_on_input_batch(self, mesh_dim_0, mesh_dim_1): name = f'S{mesh_dim_0}{mesh_dim_1}R = S{mesh_dim_0}{mesh_dim_1}R x RR' dim_partition_dict_for_input = {0: [mesh_dim_0, mesh_dim_1]} sharding_spec_for_input = generate_sharding_spec(self.input_data, self.device_mesh, dim_partition_dict_for_input) dim_partition_dict_for_weight = {} sharding_spec_for_weight = generate_sharding_spec(self.weight, self.device_mesh, dim_partition_dict_for_weight) dim_partition_dict_for_output = {0: [mesh_dim_0, mesh_dim_1]} sharding_spec_for_output = generate_sharding_spec(self.output_data, self.device_mesh, dim_partition_dict_for_output) # generate resharding cost for this strategy resharding_costs = self._generate_resharding_costs([sharding_spec_for_input, sharding_spec_for_weight]) # compute the computation cost of this strategy bs = self.input_data.shape[0] // (self.device_mesh.shape[mesh_dim_0] * self.device_mesh.shape[mesh_dim_1]) channel_in = self.input_data.shape[1] channel_out = self.weight.shape[1] compute_cost = self._generate_compute_cost(bs, channel_in, channel_out) # compute the memory cost of this strategy sharding_size_forward = self.device_mesh.mesh_shape[mesh_dim_0] * self.device_mesh.mesh_shape[mesh_dim_1] sharding_size_backward_activation = self.device_mesh.mesh_shape[mesh_dim_0] * self.device_mesh.mesh_shape[ mesh_dim_1] sharding_size_weight = 1 memory_cost, _, _, memory_cost_backward_weight = self._generate_memory_cost(sharding_size_forward, sharding_size_backward_activation, sharding_size_weight) # This strategy do not need to do all_reduce in forward phase communication_cost_forward = 0 # compute the backward communication cost to all reduce the weight due to data parallel communication_cost_backward_weight = self.device_mesh.flatten_device_mesh.all_reduce_cost( memory_cost_backward_weight, 0) # compute the total communication cost communication_cost = communication_cost_backward_weight + communication_cost_forward sharding_strategies = ShardingStrategy(name, output_sharding_spec=sharding_spec_for_output, compute_cost=compute_cost, communication_cost=communication_cost, memory_cost=memory_cost, resharding_costs=resharding_costs, input_shardings=(sharding_spec_for_input, sharding_spec_for_weight)) self.strategies_vector.append(sharding_strategies) def split_1d_parallel_on_in_channel(self, mesh_dim_0, mesh_dim_1): name = f'RR = RS{mesh_dim_0}{mesh_dim_1} x S{mesh_dim_0}{mesh_dim_1}R' dim_partition_dict_for_input = {1: [mesh_dim_0, mesh_dim_1]} sharding_spec_for_input = generate_sharding_spec(self.input_data, self.device_mesh, dim_partition_dict_for_input) dim_partition_dict_for_weight = {0: [mesh_dim_0, mesh_dim_1]} sharding_spec_for_weight = generate_sharding_spec(self.weight, self.device_mesh, dim_partition_dict_for_weight) dim_partition_dict_for_output = {} sharding_spec_for_output = generate_sharding_spec(self.output_data, self.device_mesh, dim_partition_dict_for_output) # generate resharding cost for this strategy resharding_costs = self._generate_resharding_costs([sharding_spec_for_input, sharding_spec_for_weight]) # compute the computation cost of this strategy bs = self.input_data.shape[0] channel_in = self.input_data.shape[1] // (self.device_mesh.shape[mesh_dim_0] * self.device_mesh.shape[mesh_dim_1]) channel_out = self.weight.shape[1] compute_cost = self._generate_compute_cost(bs, channel_in, channel_out) # compute the memory cost of this strategy sharding_size_forward = 1 sharding_size_backward_activation = self.device_mesh.mesh_shape[mesh_dim_0] * self.device_mesh.mesh_shape[ mesh_dim_1] sharding_size_weight = self.device_mesh.mesh_shape[mesh_dim_0] * self.device_mesh.mesh_shape[mesh_dim_1] memory_cost, memory_cost_forward_activation, _, _ = self._generate_memory_cost( sharding_size_forward, sharding_size_backward_activation, sharding_size_weight) # compute communication cost during forward phase communication_cost_forward = self.device_mesh.flatten_device_mesh.all_reduce_cost( memory_cost_forward_activation, 0) # This strategy do NOT need do all_reduce during backward phase communication_cost_backward = 0 communication_cost = communication_cost_forward + communication_cost_backward sharding_strategies = ShardingStrategy(name, output_sharding_spec=sharding_spec_for_output, compute_cost=compute_cost, communication_cost=communication_cost, memory_cost=memory_cost, resharding_costs=resharding_costs, input_shardings=(sharding_spec_for_input, sharding_spec_for_weight)) self.strategies_vector.append(sharding_strategies) def register_strategy(self) -> StrategiesVector: ''' Generate every possible strategies for a Conv node, and record all strategies into the strategies_vector. Example: physical_mesh_id = torch.arange(0, 4) mesh_shape = (2, 2) # [[0, 1] # [2, 3]] device_mesh = DeviceMesh(physical_mesh_id, mesh_shape) shape_consistency_manager = ShapeConsistencyManager() model = ConvModel(16, 32) input_sample = {'x': torch.rand(4, 16, 64, 64).to('meta')} # graph(): # %x : torch.Tensor [#users=1] = placeholder[target=x] # %mul : [#users=1] = call_function[target=operator.mul](args = (%x, 2), kwargs = {}) # %conv : [#users=1] = call_module[target=conv](args = (%mul,), kwargs = {}) # return conv graph = tracer.trace(root=model, meta_args=input_sample) gm = GraphModule(model, graph, model.__class__.__name__) gm.recompile() # [x, mul, conv, output] nodes = [node for node in gm.graph.nodes] # strategies_for_input = [[R, R, R, R], [R, S0, R, R], [R, S1, R, R], [S0, R, R, R], [S0, S1, R, R], [S1, R, R, R], [S1, S0, R, R]] strategies_vector_for_input = StrategiesVector(node=nodes[0], in_nodes=[nodes[1], 2], strategies=strategies_for_input) setattr(nodes[1], 'strategies_vector', strategies_vector_for_input) strategies_vector = StrategiesVector(node=nodes[2], in_nodes=[nodes[1], ]) conv_handler = ConvHandler(input_node=nodes[1], input_index=0, weight=dict(gm.named_modules())[nodes[2].name].weight, output_node=nodes[2], device_mesh=device_mesh, strategies_vector=strategies_vector, shape_consistency_manager=shape_consistency_manager) conv_handler.register_strategy_into_strategies_vector() for strategy in conv_handler.strategies_vector: print(f'{strategy.name}: compute_cost is {strategy.compute_cost}, communication_cost is {strategy.communication_cost}, memory_cost is {strategy.memory_cost}, resharding_costs is {strategy.resharding_costs}') Output: S0S1 = S0R x RS1: compute_cost is 8856576, communication_cost is 0, memory_cost is 492032.0, resharding_costs is {mul: [0, 32769.001, 131074.2, 0, 32769.1, 131074.2, 98307.201]} S1S0 = S1R x RS0: compute_cost is 8856576, communication_cost is 0, memory_cost is 492032.0, resharding_costs is {mul: [0, 131074.2, 32769.001, 131074.2, 98307.201, 0, 32769.1]} S0R = S0R x RR: compute_cost is 17713152, communication_cost is 0, memory_cost is 984064.0, resharding_costs is {mul: [0, 32769.001, 131074.2, 0, 32769.1, 131074.2, 98307.201]} S1R = S1R x RR: compute_cost is 17713152, communication_cost is 0, memory_cost is 984064.0, resharding_costs is {mul: [0, 131074.2, 32769.001, 131074.2, 98307.201, 0, 32769.1]} S0R = S0S1 x S1R: compute_cost is 8856576, communication_cost is 984065.01, memory_cost is 984064.0, resharding_costs is {mul: [0, 65538.002, 0, 0, 0, 65538.002, 196614.402]} S1R = S1S0 x S0R: compute_cost is 8856576, communication_cost is 984065.01, memory_cost is 984064.0, resharding_costs is {mul: [0, 0, 65538.002, 65538.002, 196614.402, 0, 0]} RS1 = RS0 x S0S1: compute_cost is 8856576, communication_cost is 984065.01, memory_cost is 984064.0, resharding_costs is {mul: [0, 0, 131074.2, 32769.001, 98307.201, 131074.2, 32769.1]} RS0 = RS1 x S1S0: compute_cost is 8856576, communication_cost is 984065.01, memory_cost is 984064.0, resharding_costs is {mul: [0, 131074.2, 0, 131074.2, 32769.1, 32769.001, 98307.201]} RR = RS0 x S0R: compute_cost is 17713152, communication_cost is 1968129.01, memory_cost is 1968128, resharding_costs is {mul: [0, 0, 131074.2, 32769.001, 98307.201, 131074.2, 32769.1]} RR = RS1 x S1R: compute_cost is 17713152, communication_cost is 1968129.01, memory_cost is 1968128, resharding_costs is {mul: [0, 131074.2, 0, 131074.2, 32769.1, 32769.001, 98307.201]} RS0 = RR x RS0: compute_cost is 17713152, communication_cost is 0, memory_cost is 984064.0, resharding_costs is {mul: [0, 65537.1, 65537.1, 65537.1, 131075.30000000002, 65537.1, 131075.30000000002]} RS1 = RR x RS1: compute_cost is 17713152, communication_cost is 0, memory_cost is 984064.0, resharding_costs is {mul: [0, 65537.1, 65537.1, 65537.1, 131075.30000000002, 65537.1, 131075.30000000002]} RR = RR x RR: compute_cost is 35426304, communication_cost is 0, memory_cost is 1968128, resharding_costs is {mul: [0, 65537.1, 65537.1, 65537.1, 131075.30000000002, 65537.1, 131075.30000000002]} S01R = S01R x RR: compute_cost is 8856576, communication_cost is 0, memory_cost is 492032.0, resharding_costs is {mul: [0, 65538.002, 262148.4, 0, 16385.001, 262148.4, 196614.402]} RR = RS01 x S01R: compute_cost is 8856576, communication_cost is 0, memory_cost is 1968128, resharding_costs is {mul: [0, 0, 262148.4, 65538.002, 196614.402, 262148.4, 65538.2]} ''' # SS = SR x RS try: self.split_input_batch_weight_out_channel(0, 1) except Exception as e: warnings.warn(f'{e}') try: self.split_input_batch_weight_out_channel(1, 0) except Exception as e: warnings.warn(f'{e}') # SR = SR x RR self.split_input_batch(0) self.split_input_batch(1) # SR = SS x SR try: self.split_input_both_dim_weight_in_channel(0, 1) except Exception as e: warnings.warn(f'{e}') try: self.split_input_both_dim_weight_in_channel(1, 0) except Exception as e: warnings.warn(f'{e}') # RS = RS x SS try: self.split_input_in_channel_weight_both_channel(0, 1) except Exception as e: warnings.warn(f'{e}') try: self.split_input_in_channel_weight_both_channel(1, 0) except Exception as e: warnings.warn(f'{e}') # RR = RS x SR try: self.split_input_in_channel_weight_in_channel(0) except Exception as e: warnings.warn(f'{e}') try: self.split_input_in_channel_weight_in_channel(1) except Exception as e: warnings.warn(f'{e}') # RS = RR x RS self.split_weight_out_channel(0) self.split_weight_out_channel(1) # RR= RR x RR self.non_split() # S01R = S01R x RR self.split_1d_parallel_on_input_batch(0, 1) # RR = RS01 x S01R try: self.split_1d_parallel_on_in_channel(0, 1) except Exception as e: warnings.warn(f'{e}') # print(f'strategies num is :{len(self.strategies_vector)}') return self.strategies_vector CONV_STRATEGIES_LIST = [ 'S0S1 = S0R x RS1', 'S1S0 = S1R x RS0', 'S0R = S0R x RR', 'S1R = S1R x RR', 'S0R = S0S1 x S1R', 'S1R = S1S0 x S0R', 'RS1 = RS0 x S0S1', 'RS0 = RS1 x S1S0', 'RR = RS0 x S0R', 'RR = RS1 x S1R', 'RS0 = RR x RS0', 'RS1 = RR x RS1', 'RR = RR x RR', 'S01R = S01R x RR', 'RR = RS01 x S01R' ]