diff --git a/colossalai/auto_parallel/solver/__init__.py b/colossalai/auto_parallel/solver/__init__.py index c20ed18ca..ec7817dfb 100644 --- a/colossalai/auto_parallel/solver/__init__.py +++ b/colossalai/auto_parallel/solver/__init__.py @@ -1,7 +1,8 @@ -from .operator_handler import OperatorHandler -from .dot_handler import DotHandler -from .conv_handler import ConvHandler from .sharding_strategy import ShardingStrategy, StrategiesVector from .graph_analysis import GraphAnalyser +from .solver import Solver +from .cost_graph import CostGraph +from .strategies_constructor import StrategiesConstructor +from .constants import * -__all__ = ['OperatorHandler', 'DotHandler', 'ConvHandler', 'StrategiesVector', 'ShardingStrategy', 'GraphAnalyser'] +__all__ = ['StrategiesVector', 'ShardingStrategy', 'GraphAnalyser', 'Solver', 'StrategiesConstructor', 'CostGraph'] diff --git a/colossalai/auto_parallel/solver/constants.py b/colossalai/auto_parallel/solver/constants.py index a65b2b173..773a5a566 100644 --- a/colossalai/auto_parallel/solver/constants.py +++ b/colossalai/auto_parallel/solver/constants.py @@ -3,13 +3,14 @@ import operator __all__ = [ 'ELEMENTWISE_MODULE_OP', 'ELEMENTWISE_FUNC_OP', 'CONV_MODULE_OP', 'CONV_FUNC_OP', 'LINEAR_MODULE_OP', - 'LINEAR_FUNC_OP' + 'LINEAR_FUNC_OP', 'BATCHNORM_MODULE_OP', 'POOL_MODULE_OP' ] ELEMENTWISE_MODULE_OP = [torch.nn.Dropout, torch.nn.ReLU] +# TODO: flatten should not be added into this group ELEMENTWISE_FUNC_OP = [ torch.add, operator.add, torch.abs, torch.cos, torch.exp, torch.mul, operator.mul, operator.floordiv, - operator.truediv, operator.neg, torch.multiply, torch.nn.functional.relu, torch.nn.functional.dropout + operator.truediv, operator.neg, torch.multiply, torch.nn.functional.relu, torch.nn.functional.dropout, torch.flatten ] CONV_MODULE_OP = [ torch.nn.Conv1d, torch.nn.Conv2d, torch.nn.Conv3d, torch.nn.ConvTranspose1d, torch.nn.ConvTranspose2d, @@ -20,3 +21,7 @@ CONV_FUNC_OP = [ ] LINEAR_MODULE_OP = [torch.nn.Linear] LINEAR_FUNC_OP = [torch.nn.functional.linear, torch.matmul, torch.bmm] +BATCHNORM_MODULE_OP = [torch.nn.BatchNorm1d, torch.nn.BatchNorm2d, torch.nn.BatchNorm3d, torch.nn.SyncBatchNorm] +POOL_MODULE_OP = [torch.nn.MaxPool1d, torch.nn.MaxPool2d, torch.nn.MaxPool3d, torch.nn.AdaptiveAvgPool2d] + +INFINITY_COST = 1e13 diff --git a/colossalai/auto_parallel/solver/cost_graph.py b/colossalai/auto_parallel/solver/cost_graph.py index 220ab54a3..a4ec6c485 100644 --- a/colossalai/auto_parallel/solver/cost_graph.py +++ b/colossalai/auto_parallel/solver/cost_graph.py @@ -1,4 +1,3 @@ -from colossalai.auto_parallel.solver.sharding_strategy import ShardingStrategy, StrategiesVector from typing import List import math from torch.fx.node import Node diff --git a/colossalai/auto_parallel/solver/graph_analysis.py b/colossalai/auto_parallel/solver/graph_analysis.py index 53469c246..831e7eadd 100644 --- a/colossalai/auto_parallel/solver/graph_analysis.py +++ b/colossalai/auto_parallel/solver/graph_analysis.py @@ -15,7 +15,7 @@ class LiveVariable: LiveVariable is a data structure to store the meta information of a variable for liveness analysis. """ name: str - meta: Union[Any, List[Any]] + node: Node is_inplace: bool @@ -80,13 +80,13 @@ class GraphAnalyser: """ return self._graph - def liveness_analysis(self) -> OrderedDict[int, LiveStage]: + def liveness_analysis(self) -> List[LiveStage]: """ Analyse the graph to obtain the variable liveness information. This function returns an ordered dictionary where the key is the compute stage ID and the value is a LivenessStage object. """ compute_nodes = self.graph.nodes - liveness_dict = ODict() + liveness_list = [] # checked: record all variables created since the first stage # all: record the live variables only exist until the current stage. @@ -97,25 +97,6 @@ class GraphAnalyser: all_live_variables = LiveVariableVector() unique_live_vars = LiveVariableVector() - def _add_param_or_buf(node, tensor_type): - module = get_node_module(node) - - if tensor_type == 'param': - iterator = module.named_parameters() - elif tensor_type == 'buffer': - iterator = module.named_buffers() - else: - raise ValueError(f"Expected tensor_type to be param or buffer, but got {tensor_type}") - - for name, tensor in iterator: - tensor_name = f'{node.name}.{name}' - - if not checked_variables.exists(tensor_name): - live_tensor = LiveVariable(name=tensor_name, meta=tensor.to('meta'), is_inplace=False) - unique_live_vars.append(live_tensor) - checked_variables.append(live_tensor) - all_live_variables.append(live_tensor) - for idx, node in enumerate(compute_nodes): ############################# # find new living variables # @@ -135,26 +116,19 @@ class GraphAnalyser: # add the output var meta = getattr(node, '_meta_data', None) - live_var = LiveVariable(name=node.name, meta=meta, is_inplace=is_inplace) + live_var = LiveVariable(name=node.name, node=node, is_inplace=is_inplace) if not is_inplace: unique_live_vars.append(live_var) checked_variables.append(live_var) all_live_variables.append(live_var) - # add the model parameters - if node.op == 'call_module': - _add_param_or_buf(node, tensor_type='param') - _add_param_or_buf(node, tensor_type='buffer') - - # add this output variable to the checked list - checked_variables.append(live_var) - # check if any input is not checked yet for arg in node.args: - arg_name = str(arg) + if not isinstance(arg, Node): + continue + arg_name = arg.name if not checked_variables.exists(arg_name): - meta = getattr(node, '_meta_data', None) - live_var_from_arg = LiveVariable(name=arg_name, meta=meta, is_inplace=False) + live_var_from_arg = LiveVariable(name=arg_name, node=node, is_inplace=False) all_live_variables.append(live_var_from_arg) checked_variables.append(live_var_from_arg) unique_live_vars.append(live_var_from_arg) @@ -167,8 +141,23 @@ class GraphAnalyser: node=node, all_live_vars=all_live_variables.copy(), unique_live_vars=unique_live_vars.copy()) - liveness_dict[idx] = stage - return liveness_dict + # if a LiveStage is covered by another LiveStage, we just keep the larger one. + replace = False + for index, prev_stage in enumerate(liveness_list): + all_covered = True + for ele in prev_stage.unique_live_vars: + if ele not in stage.unique_live_vars: + all_covered = False + break + if all_covered: + replace = True + break + if replace: + liveness_list[index] = stage + else: + liveness_list.append(stage) + + return liveness_list def get_alias_set(self): pass diff --git a/colossalai/auto_parallel/solver/op_handler/__init__.py b/colossalai/auto_parallel/solver/op_handler/__init__.py new file mode 100644 index 000000000..012acffe4 --- /dev/null +++ b/colossalai/auto_parallel/solver/op_handler/__init__.py @@ -0,0 +1,6 @@ +from .operator_handler import OperatorHandler +from .dot_handler import DotHandler +from .conv_handler import ConvHandler +from .batch_norm_handler import BatchNormHandler + +__all__ = ['OperatorHandler', 'DotHandler', 'ConvHandler', 'BatchNormHandler'] \ No newline at end of file diff --git a/colossalai/auto_parallel/solver/op_handler/batch_norm_handler.py b/colossalai/auto_parallel/solver/op_handler/batch_norm_handler.py new file mode 100644 index 000000000..eac2f62cc --- /dev/null +++ b/colossalai/auto_parallel/solver/op_handler/batch_norm_handler.py @@ -0,0 +1,483 @@ +import operator +from functools import reduce +import warnings +import torch +from colossalai.auto_parallel.solver.sharding_strategy import ShardingStrategy, StrategiesVector +from .operator_handler import OperatorHandler + +__all__ = ['BatchNormHandler'] + + +class BatchNormHandler(OperatorHandler): + """ + A OperatorHandler which deals with the sharding strategies of normalization. + + To keep the math consistency, there are two way to do BatchNorm if the input + shards on batch dimension: + 1. We gather the input partitions through batch dimension, then do the normal BatchNorm. + 2. We do the SyncBatchNorm on the each input partition seperately, the SyncBN op will help + us to keep the computing correctness. + In this handler, both methods will be considered. + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.input_data = self.predecessor_node[0]._meta_data + self.weight = self.module_named_parameters['weight'] + self.bias = self.module_named_parameters['bias'] + self.output_data = self.node._meta_data + self._sanity_check() + + def _sanity_check(self): + ''' + In sanity check, we need make sure the input data having correct dimension size. + For BatchNorm1d, the dim of input data should be 3([N, C, L]). + For BatchNorm2d, the dim of input data should be 4([N, C, H, W]). + For BatchNorm3d, the dim of input data should be 5([N, C, H, W, D]). + ''' + assert self.input_data.dim() in (3, 4, + 5), f'We suppose the dim of input fed into conv op should in range of [3, 5].' + + def _generate_compute_cost(self, bs, channel_in): + ''' + Compute the computation cost per device with this specific strategy. + + Note: compute_cost need to be devided by TFLOPS, now it just shows the computation size. + + Argument: + bs(int): Batch size of the input data. + channel_in(int): The channel dimension of input data. + + Return: + compute_cost(float): Computation cost per device with this specific strategy + ''' + # TODO: compute_cost need to be devided by TFLOPS, now it just shows the computation size. + # TODO: a constant coefficient need to be added. + # 1D: (L) * N * Cin + # 2D: (H * W) * N * Cin + # 3D: (H * W * D) * N * Cin + + input_size = self.input_data.shape[2:] + input_size_product = reduce(operator.mul, input_size, 1) + forward_compute_cost = input_size_product * bs * channel_in + backward_activation_compute_cost = input_size_product * bs * channel_in + backward_weight_compute_cost = input_size_product * bs * channel_in + backward_compute_cost = backward_activation_compute_cost + backward_weight_compute_cost + compute_cost = forward_compute_cost + backward_compute_cost + return compute_cost + + def _generate_memory_cost(self, sharding_size_forward, sharding_size_backward_activation, sharding_size_weight): + ''' + Compute the memory cost per device with this specific strategy. + + Argument: + sharding_size_forward(int): The forward activation will be divided + into sharding_size_forward number partions. + sharding_size_backward_activation(int): The backward activation will + be divided into sharding_size_backward_activation number partions. + sharding_size_weight(int): The backward weight will be divided + into sharding_size_weight number partions. + + Return: + memory_cost(Tuple[float]): Memory cost per device with this + specific strategy, the first element of this tuple is forward + memory cost, and the second element of this tuple is backward + memory cost. + memory_cost_forward(float): Memory cost of forward activation per + device with this specific strategy. + memory_cost_backward_activation(float): Memory cost of backward activation + per device with this specific strategy. + ''' + # compute the memory cost of this strategy + dtype = self.input_data.dtype + numel_output = self.output_data.numel() + numel_input = numel_output + numel_weight = self.weight.numel() + size_per_elem_bytes = torch.tensor([], dtype=dtype).element_size() + + # forward memory_cost + memory_cost_forward_activation = numel_output * size_per_elem_bytes / sharding_size_forward + memory_cost_forward_weight = numel_weight * size_per_elem_bytes / sharding_size_weight + memory_cost_forward = memory_cost_forward_activation + memory_cost_forward_weight + + # backward memory_cost + memory_cost_backward_activation = numel_input * size_per_elem_bytes / sharding_size_backward_activation + memory_cost_backward_weight = numel_weight * size_per_elem_bytes / sharding_size_weight + memory_cost_backward = memory_cost_backward_activation + memory_cost_backward_weight + + # memory_cost pair + memory_cost = (memory_cost_forward, memory_cost_backward) + + return memory_cost, memory_cost_forward_activation, memory_cost_backward_activation + + def split_input_channel(self, mesh_dim_0, mesh_dim_1): + name = f'RS{mesh_dim_0} = RS{mesh_dim_0} x S{mesh_dim_0}' + + dim_partition_dict_for_input = {1: [mesh_dim_0]} + sharding_spec_for_input = self._generate_sharding_spec(self.input_data, dim_partition_dict_for_input) + + dim_partition_dict_for_weight = {0: [mesh_dim_0]} + sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight) + + dim_partition_dict_for_output = {1: [mesh_dim_0]} + sharding_spec_for_output = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output) + + # generate resharding cost for this strategy + resharding_costs = self._generate_resharding_costs([sharding_spec_for_input]) + + # compute the computation cost of this strategy + bs = self.input_data.shape[0] + channel_in = self.input_data.shape[1] // self.device_mesh.shape[mesh_dim_0] + compute_cost = self._generate_compute_cost(bs, channel_in) + + # compute the memory cost of this strategy + sharding_size_forward = self.device_mesh.shape[mesh_dim_0] + sharding_size_backward_activation = self.device_mesh.shape[mesh_dim_0] + sharding_size_weight = self.device_mesh.shape[mesh_dim_0] + memory_cost, _, _ = self._generate_memory_cost(sharding_size_forward, sharding_size_backward_activation, + sharding_size_weight) + + # This strategy do not need to do all_reduce operation + communication_cost = 0 + sharding_strategies = ShardingStrategy(name, + output_sharding_spec=sharding_spec_for_output, + compute_cost=compute_cost, + communication_cost=communication_cost, + memory_cost=memory_cost, + resharding_costs=resharding_costs, + input_shardings=(sharding_spec_for_input, sharding_spec_for_weight)) + + self.strategies_vector.append(sharding_strategies) + + # shard the output batch dimension to get all possible sharding strategy from this basic strategy + new_name = f'S{mesh_dim_1}S{mesh_dim_0} = RS{mesh_dim_0} x S{mesh_dim_0}' + + dim_partition_dict_for_output = {0: [mesh_dim_1], 1: [mesh_dim_0]} + new_sharding_spec_for_output = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output) + # the computation cost is all the same + new_compute_cost = compute_cost + + # the memory cost need to be recomputed + # compute the memroy cost of new strategy + new_sharding_size_forward = self.device_mesh.shape[mesh_dim_0] * self.device_mesh.shape[mesh_dim_1] + sharding_size_backward_activation = self.device_mesh.shape[mesh_dim_0] + sharding_size_weight = self.device_mesh.shape[mesh_dim_0] + new_memory_cost, _, memory_cost_backward_activation = self._generate_memory_cost( + new_sharding_size_forward, sharding_size_backward_activation, sharding_size_weight) + + # the communication cost need to count the sharding cost into this strategy + # compute the communication cost of new strategy + origin_communication_cost = communication_cost + tiny_shard_cost = 10 + new_forward_communication_cost = tiny_shard_cost + # we need to all gather the batch dimension for the basic strategy + new_backward_communication_cost = self.device_mesh.all_gather_cost(memory_cost_backward_activation, mesh_dim_1) + new_communication_cost = origin_communication_cost + new_forward_communication_cost + new_backward_communication_cost + + sharding_strategies = ShardingStrategy(new_name, + output_sharding_spec=new_sharding_spec_for_output, + compute_cost=new_compute_cost, + communication_cost=new_communication_cost, + memory_cost=new_memory_cost, + resharding_costs=resharding_costs, + input_shardings=(sharding_spec_for_input, sharding_spec_for_weight)) + + self.strategies_vector.append(sharding_strategies) + + def split_input_channel_1d(self, mesh_dim_0, mesh_dim_1): + name = f'RS{mesh_dim_0}{mesh_dim_1} = RS{mesh_dim_0}{mesh_dim_1} x S{mesh_dim_0}{mesh_dim_1}' + + dim_partition_dict_for_input = {1: [mesh_dim_0, mesh_dim_1]} + sharding_spec_for_input = self._generate_sharding_spec(self.input_data, dim_partition_dict_for_input) + + dim_partition_dict_for_weight = {0: [mesh_dim_0, mesh_dim_1]} + sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight) + + dim_partition_dict_for_output = {1: [mesh_dim_0, mesh_dim_1]} + sharding_spec_for_output = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output) + + # generate resharding cost for this strategy + resharding_costs = self._generate_resharding_costs([sharding_spec_for_input]) + + # compute the computation cost of this strategy + bs = self.input_data.shape[0] + channel_in = self.input_data.shape[1] // (self.device_mesh.shape[mesh_dim_0] * + self.device_mesh.shape[mesh_dim_1]) + compute_cost = self._generate_compute_cost(bs, channel_in) + + # compute the memory cost of this strategy + sharding_size_forward = self.device_mesh.shape[mesh_dim_0] * self.device_mesh.shape[mesh_dim_1] + sharding_size_backward_activation = self.device_mesh.shape[mesh_dim_0] * self.device_mesh.shape[mesh_dim_1] + sharding_size_weight = self.device_mesh.shape[mesh_dim_0] * self.device_mesh.shape[mesh_dim_1] + memory_cost, _, _ = self._generate_memory_cost(sharding_size_forward, sharding_size_backward_activation, + sharding_size_weight) + + # This strategy do not need to do all_reduce operation + communication_cost = 0 + sharding_strategies = ShardingStrategy(name, + output_sharding_spec=sharding_spec_for_output, + compute_cost=compute_cost, + communication_cost=communication_cost, + memory_cost=memory_cost, + resharding_costs=resharding_costs, + input_shardings=(sharding_spec_for_input, sharding_spec_for_weight)) + + self.strategies_vector.append(sharding_strategies) + + def non_split(self, mesh_dim_0, mesh_dim_1): + name = f'RR = RR x R' + + dim_partition_dict_for_input = {} + sharding_spec_for_input = self._generate_sharding_spec(self.input_data, dim_partition_dict_for_input) + + dim_partition_dict_for_weight = {} + sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight) + + dim_partition_dict_for_output = {} + sharding_spec_for_output = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output) + + # generate resharding cost for this strategy + resharding_costs = self._generate_resharding_costs([sharding_spec_for_input]) + + # compute the computation cost of this strategy + bs = self.input_data.shape[0] + channel_in = self.input_data.shape[1] + compute_cost = self._generate_compute_cost(bs, channel_in) + + # compute the memory cost of this strategy + sharding_size_forward = 1 + sharding_size_backward_activation = 1 + sharding_size_weight = 1 + memory_cost, _, _ = self._generate_memory_cost(sharding_size_forward, sharding_size_backward_activation, + sharding_size_weight) + + # This strategy do not need to do all_reduce operation + communication_cost = 0 + sharding_strategies = ShardingStrategy(name, + output_sharding_spec=sharding_spec_for_output, + compute_cost=compute_cost, + communication_cost=communication_cost, + memory_cost=memory_cost, + resharding_costs=resharding_costs, + input_shardings=(sharding_spec_for_input, sharding_spec_for_weight)) + + self.strategies_vector.append(sharding_strategies) + + def _construct_batch_sharding_strategies(mesh_dim_list, new_name): + dim_partition_dict_for_output = {0: mesh_dim_list} + new_sharding_spec_for_output = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output) + + # the computation cost is all the same + new_compute_cost = compute_cost + + # the memory cost need to be recomputed + new_sharding_size_input = 1 + for mesh_dim in mesh_dim_list: + new_sharding_size_input = new_sharding_size_input * self.device_mesh.shape[mesh_dim] + new_memory_cost, _, memory_cost_backward_activation = self._generate_memory_cost( + new_sharding_size_input, sharding_size_backward_activation, sharding_size_weight) + + # the communication cost need to count the sharding cost into this strategy + origin_communication_cost = communication_cost + tiny_shard_cost = 10 + new_forward_communication_cost = tiny_shard_cost + if len(mesh_dim_list) == 1: + new_backward_communication_cost = self.device_mesh.all_gather_cost(memory_cost_backward_activation, + mesh_dim_list[0]) + else: + new_backward_communication_cost = self.device_mesh.flatten_device_mesh.all_gather_cost( + memory_cost_backward_activation, 0) + new_communication_cost = origin_communication_cost + new_forward_communication_cost + new_backward_communication_cost + + new_sharding_strategy = ShardingStrategy(new_name, + output_sharding_spec=new_sharding_spec_for_output, + compute_cost=new_compute_cost, + communication_cost=new_communication_cost, + memory_cost=new_memory_cost, + resharding_costs=resharding_costs, + input_shardings=(sharding_spec_for_input, + sharding_spec_for_weight)) + + return new_sharding_strategy + + # shard the output batch dimension to get all possible sharding strategy from this basic strategy + # shard on mesh_dim_0 + new_name = f'S{mesh_dim_0}R = RR x R' + mesh_dim_list = [mesh_dim_0] + new_sharding_strategy = _construct_batch_sharding_strategies(mesh_dim_list, new_name) + self.strategies_vector.append(new_sharding_strategy) + + # shard on mesh_dim_1 + new_name = f'S{mesh_dim_1}R = RR x R' + mesh_dim_list = [mesh_dim_1] + new_sharding_strategy = _construct_batch_sharding_strategies(mesh_dim_list, new_name) + self.strategies_vector.append(new_sharding_strategy) + + # shard on mesh_dim_0, mesh_dim_1 + new_name = f'S{mesh_dim_0}{mesh_dim_1}R = RR x R' + mesh_dim_list = [mesh_dim_0, mesh_dim_1] + new_sharding_strategy = _construct_batch_sharding_strategies(mesh_dim_list, new_name) + self.strategies_vector.append(new_sharding_strategy) + + def split_input_batch(self, mesh_dim_0): + name = f'S{mesh_dim_0}R = S{mesh_dim_0}R x R WITH SYNC_BN' + + dim_partition_dict_for_input = {0: [mesh_dim_0]} + sharding_spec_for_input = self._generate_sharding_spec(self.input_data, dim_partition_dict_for_input) + + dim_partition_dict_for_weight = {} + sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight) + + dim_partition_dict_for_output = {0: [mesh_dim_0]} + sharding_spec_for_output = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output) + + # generate resharding cost for this strategy + resharding_costs = self._generate_resharding_costs([sharding_spec_for_input]) + + # compute the computation cost of this strategy + bs = self.input_data.shape[0] // self.device_mesh.shape[mesh_dim_0] + channel_in = self.input_data.shape[1] + compute_cost = self._generate_compute_cost(bs, channel_in) + + # compute the memory cost of this strategy + sharding_size_forward = self.device_mesh.shape[mesh_dim_0] + sharding_size_backward_activation = self.device_mesh.shape[mesh_dim_0] + sharding_size_weight = 1 + memory_cost, memory_cost_forward_activation, _ = self._generate_memory_cost(sharding_size_forward, + sharding_size_backward_activation, + sharding_size_weight) + + # the all reduce communication will happen during the sync bn computing. + communication_cost = self.device_mesh.all_reduce_cost(memory_cost_forward_activation, mesh_dim_0) + sharding_strategies = ShardingStrategy(name, + output_sharding_spec=sharding_spec_for_output, + compute_cost=compute_cost, + communication_cost=communication_cost, + memory_cost=memory_cost, + resharding_costs=resharding_costs, + input_shardings=(sharding_spec_for_input, sharding_spec_for_weight)) + + self.strategies_vector.append(sharding_strategies) + + def split_input_batch_1d(self, mesh_dim_0, mesh_dim_1): + name = f'S{mesh_dim_0}{mesh_dim_1}R = S{mesh_dim_0}{mesh_dim_1}R x R WITH SYNC_BN' + + dim_partition_dict_for_input = {0: [mesh_dim_0, mesh_dim_1]} + sharding_spec_for_input = self._generate_sharding_spec(self.input_data, dim_partition_dict_for_input) + + dim_partition_dict_for_weight = {} + sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight) + + dim_partition_dict_for_output = {0: [mesh_dim_0, mesh_dim_1]} + sharding_spec_for_output = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output) + + # generate resharding cost for this strategy + resharding_costs = self._generate_resharding_costs([sharding_spec_for_input]) + + # compute the computation cost of this strategy + bs = self.input_data.shape[0] // (self.device_mesh.shape[mesh_dim_0] * self.device_mesh.shape[mesh_dim_1]) + channel_in = self.input_data.shape[1] + compute_cost = self._generate_compute_cost(bs, channel_in) + + # compute the memory cost of this strategy + sharding_size_forward = self.device_mesh.shape[mesh_dim_0] * self.device_mesh.shape[mesh_dim_1] + sharding_size_backward_activation = self.device_mesh.shape[mesh_dim_0] * self.device_mesh.shape[mesh_dim_1] + sharding_size_weight = 1 + memory_cost, memory_cost_forward_activation, _ = self._generate_memory_cost(sharding_size_forward, + sharding_size_backward_activation, + sharding_size_weight) + + # the all reduce communication will happen during the sync bn computing. + communication_cost = self.device_mesh.flatten_device_mesh.all_reduce_cost(memory_cost_forward_activation, 0) + sharding_strategies = ShardingStrategy(name, + output_sharding_spec=sharding_spec_for_output, + compute_cost=compute_cost, + communication_cost=communication_cost, + memory_cost=memory_cost, + resharding_costs=resharding_costs, + input_shardings=(sharding_spec_for_input, sharding_spec_for_weight)) + + self.strategies_vector.append(sharding_strategies) + + def split_input_both_dim(self, mesh_dim_0, mesh_dim_1): + name = f'S{mesh_dim_0}S{mesh_dim_1} = S{mesh_dim_0}S{mesh_dim_1} x S{mesh_dim_1} WITH SYNC_BN' + + dim_partition_dict_for_input = {0: [mesh_dim_0], 1: [mesh_dim_1]} + sharding_spec_for_input = self._generate_sharding_spec(self.input_data, dim_partition_dict_for_input) + + dim_partition_dict_for_weight = {0: [mesh_dim_1]} + sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight) + + dim_partition_dict_for_output = {0: [mesh_dim_0], 1: [mesh_dim_1]} + sharding_spec_for_output = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output) + + # generate resharding cost for this strategy + resharding_costs = self._generate_resharding_costs([sharding_spec_for_input]) + + # compute the computation cost of this strategy + bs = self.input_data.shape[0] // self.device_mesh.shape[mesh_dim_0] + channel_in = self.input_data.shape[1] // self.device_mesh.shape[mesh_dim_1] + compute_cost = self._generate_compute_cost(bs, channel_in) + + # compute the memory cost of this strategy + sharding_size_forward = self.device_mesh.shape[mesh_dim_0] * self.device_mesh.shape[mesh_dim_1] + sharding_size_backward_activation = self.device_mesh.shape[mesh_dim_0] * self.device_mesh.shape[mesh_dim_1] + sharding_size_weight = self.device_mesh.shape[mesh_dim_1] + memory_cost, memory_cost_forward_activation, _ = self._generate_memory_cost(sharding_size_forward, + sharding_size_backward_activation, + sharding_size_weight) + + # the all reduce communication will happen during the sync bn computing. + communication_cost = self.device_mesh.all_reduce_cost(memory_cost_forward_activation, mesh_dim_0) + sharding_strategies = ShardingStrategy(name, + output_sharding_spec=sharding_spec_for_output, + compute_cost=compute_cost, + communication_cost=communication_cost, + memory_cost=memory_cost, + resharding_costs=resharding_costs, + input_shardings=(sharding_spec_for_input, sharding_spec_for_weight)) + + self.strategies_vector.append(sharding_strategies) + + def register_strategy(self) -> StrategiesVector: + ''' + Generate every possible strategies for a BatchNorm node, and record all strategies into the strategies_vector. + + Example: + norm_handler = BatchNormHandler(node, self.device_mesh, strategies_vector, + self.shape_consistency_manager) + norm_handler.register_strategy() + for strategy in norm_handler.strategies_vector: + print(f'{strategy.name}, computation_cost: {strategy.compute_cost}, memory_cost: {strategy.memory_cost}') + + Output: + RS0 = RS0 x S0, computation_cost: 131072, memory_cost: 524288.0 + RS1 = RS1 x S1, computation_cost: 131072, memory_cost: 524288.0 + RR = RR x R, computation_cost: 262144, memory_cost: 1048576 + RS01 = RS01 x S01, computation_cost: 65536, memory_cost: 262144.0 + ''' + + # RS = RS x S and strategies based on it, such as + # SS = RS x S + self.split_input_channel(0, 1) + self.split_input_channel(1, 0) + + # RR = RR x R and strategies based on it, such as + # SR = SR x R + self.non_split(0, 1) + + # RS01 = RS01 x S01 + self.split_input_channel_1d(0, 1) + + # SR = SR x R WITH SYNC_BN + self.split_input_batch(0) + self.split_input_batch(1) + + # SS = SS x S WITH SYNC_BN + self.split_input_both_dim(0, 1) + self.split_input_both_dim(1, 0) + + # S01R = S01R x R WITH SYNC_BN + self.split_input_batch_1d(0, 1) + + return self.strategies_vector diff --git a/colossalai/auto_parallel/solver/conv_handler.py b/colossalai/auto_parallel/solver/op_handler/conv_handler.py similarity index 86% rename from colossalai/auto_parallel/solver/conv_handler.py rename to colossalai/auto_parallel/solver/op_handler/conv_handler.py index 3cbe43926..e3f8a6a21 100644 --- a/colossalai/auto_parallel/solver/conv_handler.py +++ b/colossalai/auto_parallel/solver/op_handler/conv_handler.py @@ -1,5 +1,6 @@ import operator from functools import reduce +import warnings import torch from colossalai.auto_parallel.solver.sharding_strategy import ShardingStrategy, StrategiesVector from .operator_handler import OperatorHandler @@ -9,7 +10,7 @@ __all__ = ['ConvHandler'] class ConvHandler(OperatorHandler): """ - An OperatorHandler which deals with the sharding strategies of linear matrix multiplication. + A OperatorHandler which deals with the sharding strategies of Convolution. """ def __init__(self, *args, **kwargs): @@ -59,8 +60,7 @@ class ConvHandler(OperatorHandler): compute_cost = forward_compute_cost + backward_activation_cost + backward_weight_cost return compute_cost - def _generate_memory_cost(self, sharding_size_forward, sharding_size_backward_activation, - sharding_size_backward_weight): + def _generate_memory_cost(self, sharding_size_forward, sharding_size_backward_activation, sharding_size_weight): ''' Compute the memory cost per device with this specific strategy. @@ -69,8 +69,8 @@ class ConvHandler(OperatorHandler): into sharding_size_forward number partions. sharding_size_backward_activation(int): The backward activation will be divided into sharding_size_backward_activation number partions. - sharding_size_backward_weight(int): The backward weight will be divided - into sharding_size_backward_weight number partions. + sharding_size_weight(int): The backward weight will be divided + into sharding_size_weight number partions. Return: memory_cost(Tuple[float]): Memory cost per device with this @@ -90,17 +90,19 @@ class ConvHandler(OperatorHandler): size_per_elem_bytes = torch.tensor([], dtype=dtype).element_size() # forward memory_cost - memory_cost_forward = numel_output * size_per_elem_bytes / sharding_size_forward + memory_cost_forward_activation = numel_output * size_per_elem_bytes / sharding_size_forward + memory_cost_forward_weight = numel_weight * size_per_elem_bytes / sharding_size_weight + memory_cost_forward = memory_cost_forward_activation + memory_cost_forward_weight # backward memory_cost memory_cost_backward_activation = numel_input * size_per_elem_bytes / sharding_size_backward_activation - memory_cost_backward_weight = numel_weight * size_per_elem_bytes / sharding_size_backward_weight + memory_cost_backward_weight = numel_weight * size_per_elem_bytes / sharding_size_weight memory_cost_backward = memory_cost_backward_activation + memory_cost_backward_weight # memory_cost pair memory_cost = (memory_cost_forward, memory_cost_backward) - return memory_cost, memory_cost_forward, memory_cost_backward_activation + return memory_cost, memory_cost_forward_activation, memory_cost_backward_activation def split_input_batch_weight_out_channel(self, mesh_dim_0, mesh_dim_1): name = f'S{mesh_dim_0}S{mesh_dim_1} = S{mesh_dim_0}R x RS{mesh_dim_1}' @@ -112,7 +114,7 @@ class ConvHandler(OperatorHandler): sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight) dim_partition_dict_for_output = {0: [mesh_dim_0], 1: [mesh_dim_1]} - sharding_spec_for_ouput = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output) + sharding_spec_for_output = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output) # generate resharding cost for this strategy resharding_costs = self._generate_resharding_costs([sharding_spec_for_input, sharding_spec_for_weight]) @@ -126,9 +128,9 @@ class ConvHandler(OperatorHandler): # compute the memory cost of this strategy sharding_size_forward = self.device_mesh.shape[mesh_dim_0] * self.device_mesh.shape[mesh_dim_1] sharding_size_backward_activation = self.device_mesh.shape[mesh_dim_0] - sharding_size_backward_weight = self.device_mesh.shape[mesh_dim_1] + sharding_size_weight = self.device_mesh.shape[mesh_dim_1] memory_cost, _, memory_cost_backward_activation = self._generate_memory_cost( - sharding_size_forward, sharding_size_backward_activation, sharding_size_backward_weight) + sharding_size_forward, sharding_size_backward_activation, sharding_size_weight) # This strategy do not need to do all_reduce operation during forward communication_cost_forward = 0 @@ -138,7 +140,7 @@ class ConvHandler(OperatorHandler): communication_cost = communication_cost_forward + communication_cost_backward sharding_strategies = ShardingStrategy(name, - output_sharding_spec=sharding_spec_for_ouput, + output_sharding_spec=sharding_spec_for_output, compute_cost=compute_cost, communication_cost=communication_cost, memory_cost=memory_cost, @@ -156,7 +158,7 @@ class ConvHandler(OperatorHandler): sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight) dim_partition_dict_for_output = {0: [mesh_dim_0]} - sharding_spec_for_ouput = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output) + sharding_spec_for_output = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output) # generate resharding cost for this strategy resharding_costs = self._generate_resharding_costs([sharding_spec_for_input, sharding_spec_for_weight]) @@ -170,19 +172,20 @@ class ConvHandler(OperatorHandler): # compute the memory cost of this strategy sharding_size_forward = self.device_mesh.shape[mesh_dim_0] sharding_size_backward_activation = self.device_mesh.shape[mesh_dim_0] - sharding_size_backward_weight = 1 + sharding_size_weight = 1 memory_cost, _, _ = self._generate_memory_cost(sharding_size_forward, sharding_size_backward_activation, - sharding_size_backward_weight) + sharding_size_weight) # This strategy do not need to do all_reduce operation in both forward and backward phase. communication_cost = 0 sharding_strategies = ShardingStrategy(name, - output_sharding_spec=sharding_spec_for_ouput, + output_sharding_spec=sharding_spec_for_output, compute_cost=compute_cost, communication_cost=communication_cost, memory_cost=memory_cost, resharding_costs=resharding_costs, input_shardings=(sharding_spec_for_input, sharding_spec_for_weight)) + self.strategies_vector.append(sharding_strategies) def split_input_both_dim_weight_in_channel(self, mesh_dim_0, mesh_dim_1): @@ -195,7 +198,7 @@ class ConvHandler(OperatorHandler): sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight) dim_partition_dict_for_output = {0: [mesh_dim_0]} - sharding_spec_for_ouput = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_input) + sharding_spec_for_output = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output) # generate resharding cost for this strategy resharding_costs = self._generate_resharding_costs([sharding_spec_for_input, sharding_spec_for_weight]) @@ -209,18 +212,18 @@ class ConvHandler(OperatorHandler): # compute the memory cost of this strategy sharding_size_forward = self.device_mesh.shape[mesh_dim_0] sharding_size_backward_activation = self.device_mesh.shape[mesh_dim_0] * self.device_mesh.shape[mesh_dim_1] - sharding_size_backward_weight = self.device_mesh.shape[mesh_dim_1] - memory_cost, memory_cost_forward, _ = self._generate_memory_cost(sharding_size_forward, - sharding_size_backward_activation, - sharding_size_backward_weight) + sharding_size_weight = self.device_mesh.shape[mesh_dim_1] + memory_cost, memory_cost_forward_activation, _ = self._generate_memory_cost(sharding_size_forward, + sharding_size_backward_activation, + sharding_size_weight) # compute the communication cost of this strategy during forward phase - communication_cost_forward = self.device_mesh.all_reduce_cost(memory_cost_forward, mesh_dim_1) + communication_cost_forward = self.device_mesh.all_reduce_cost(memory_cost_forward_activation, mesh_dim_1) # This strategy do not need to do all_reduce operation during backward phase communication_cost_backward = 0 communication_cost = communication_cost_forward + communication_cost_backward sharding_strategies = ShardingStrategy(name, - output_sharding_spec=sharding_spec_for_ouput, + output_sharding_spec=sharding_spec_for_output, compute_cost=compute_cost, communication_cost=communication_cost, memory_cost=memory_cost, @@ -238,7 +241,7 @@ class ConvHandler(OperatorHandler): sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight) dim_partition_dict_for_output = {1: [mesh_dim_1]} - sharding_spec_for_ouput = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_input) + sharding_spec_for_output = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output) # generate resharding cost for this strategy resharding_costs = self._generate_resharding_costs([sharding_spec_for_input, sharding_spec_for_weight]) @@ -252,17 +255,17 @@ class ConvHandler(OperatorHandler): # compute the memory cost of this strategy sharding_size_forward = self.device_mesh.shape[mesh_dim_1] sharding_size_backward_activation = self.device_mesh.shape[mesh_dim_0] - sharding_size_backward_weight = self.device_mesh.shape[mesh_dim_0] * self.device_mesh.shape[mesh_dim_1] - memory_cost, memory_cost_forward, memory_cost_backward_activation = self._generate_memory_cost( - sharding_size_forward, sharding_size_backward_activation, sharding_size_backward_weight) + sharding_size_weight = self.device_mesh.shape[mesh_dim_0] * self.device_mesh.shape[mesh_dim_1] + memory_cost, memory_cost_forward_activation, memory_cost_backward_activation = self._generate_memory_cost( + sharding_size_forward, sharding_size_backward_activation, sharding_size_weight) # compute the communication cost of this strategy during forward phase - communication_cost_forward = self.device_mesh.all_reduce_cost(memory_cost_forward, mesh_dim_0) + communication_cost_forward = self.device_mesh.all_reduce_cost(memory_cost_forward_activation, mesh_dim_0) # compute the communication cost of this strategy during backward phase communication_cost_backward = self.device_mesh.all_reduce_cost(memory_cost_backward_activation, mesh_dim_1) communication_cost = communication_cost_forward + communication_cost_backward sharding_strategies = ShardingStrategy(name, - output_sharding_spec=sharding_spec_for_ouput, + output_sharding_spec=sharding_spec_for_output, compute_cost=compute_cost, communication_cost=communication_cost, memory_cost=memory_cost, @@ -280,7 +283,7 @@ class ConvHandler(OperatorHandler): sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight) dim_partition_dict_for_output = {} - sharding_spec_for_ouput = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_input) + sharding_spec_for_output = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output) # generate resharding cost for this strategy resharding_costs = self._generate_resharding_costs([sharding_spec_for_input, sharding_spec_for_weight]) @@ -294,18 +297,18 @@ class ConvHandler(OperatorHandler): # compute the memory cost of this strategy sharding_size_forward = 1 sharding_size_backward_activation = self.device_mesh.shape[mesh_dim_0] - sharding_size_backward_weight = self.device_mesh.shape[mesh_dim_0] - memory_cost, memory_cost_forward, _ = self._generate_memory_cost(sharding_size_forward, - sharding_size_backward_activation, - sharding_size_backward_weight) + sharding_size_weight = self.device_mesh.shape[mesh_dim_0] + memory_cost, memory_cost_forward_activation, _ = self._generate_memory_cost(sharding_size_forward, + sharding_size_backward_activation, + sharding_size_weight) # compute the communication cost of this strategy during forward phase - communication_cost_forward = self.device_mesh.all_reduce_cost(memory_cost_forward, mesh_dim_0) + communication_cost_forward = self.device_mesh.all_reduce_cost(memory_cost_forward_activation, mesh_dim_0) # This strategy do NOT need all_reduce during forward phase communication_cost_backward = 0 communication_cost = communication_cost_forward + communication_cost_backward sharding_strategies = ShardingStrategy(name, - output_sharding_spec=sharding_spec_for_ouput, + output_sharding_spec=sharding_spec_for_output, compute_cost=compute_cost, communication_cost=communication_cost, memory_cost=memory_cost, @@ -323,7 +326,7 @@ class ConvHandler(OperatorHandler): sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight) dim_partition_dict_for_output = {1: [mesh_dim_0]} - sharding_spec_for_ouput = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_input) + sharding_spec_for_output = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output) # generate resharding cost for this strategy resharding_costs = self._generate_resharding_costs([sharding_spec_for_input, sharding_spec_for_weight]) @@ -337,9 +340,9 @@ class ConvHandler(OperatorHandler): # compute the memory cost of this strategy sharding_size_forward = self.device_mesh.shape[mesh_dim_0] sharding_size_backward_activation = 1 - sharding_size_backward_weight = self.device_mesh.shape[mesh_dim_0] + sharding_size_weight = self.device_mesh.shape[mesh_dim_0] memory_cost, _, memory_cost_backward_activation = self._generate_memory_cost( - sharding_size_forward, sharding_size_backward_activation, sharding_size_backward_weight) + sharding_size_forward, sharding_size_backward_activation, sharding_size_weight) # This strategy do not need to do all_reduce during forward phase communication_cost_forward = 0 @@ -347,7 +350,7 @@ class ConvHandler(OperatorHandler): communication_cost_backward = self.device_mesh.all_reduce_cost(memory_cost_backward_activation, mesh_dim_0) communication_cost = communication_cost_forward + communication_cost_backward sharding_strategies = ShardingStrategy(name, - output_sharding_spec=sharding_spec_for_ouput, + output_sharding_spec=sharding_spec_for_output, compute_cost=compute_cost, communication_cost=communication_cost, memory_cost=memory_cost, @@ -365,7 +368,7 @@ class ConvHandler(OperatorHandler): sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight) dim_partition_dict_for_output = {} - sharding_spec_for_ouput = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_input) + sharding_spec_for_output = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output) # generate resharding cost for this strategy resharding_costs = self._generate_resharding_costs([sharding_spec_for_input, sharding_spec_for_weight]) @@ -379,15 +382,15 @@ class ConvHandler(OperatorHandler): # compute the memory cost of this strategy sharding_size_forward = 1 sharding_size_backward_activation = 1 - sharding_size_backward_weight = 1 + sharding_size_weight = 1 memory_cost, _, _ = self._generate_memory_cost(sharding_size_forward, sharding_size_backward_activation, - sharding_size_backward_weight) + sharding_size_weight) # This strategy do not need to do all_reduce in both forward and backward phase communication_cost = 0 sharding_strategies = ShardingStrategy(name, - output_sharding_spec=sharding_spec_for_ouput, + output_sharding_spec=sharding_spec_for_output, compute_cost=compute_cost, communication_cost=communication_cost, memory_cost=memory_cost, @@ -405,7 +408,7 @@ class ConvHandler(OperatorHandler): sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight) dim_partition_dict_for_output = {0: [mesh_dim_0, mesh_dim_1]} - sharding_spec_for_ouput = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_input) + sharding_spec_for_output = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output) # generate resharding cost for this strategy resharding_costs = self._generate_resharding_costs([sharding_spec_for_input, sharding_spec_for_weight]) @@ -420,15 +423,15 @@ class ConvHandler(OperatorHandler): sharding_size_forward = self.device_mesh.mesh_shape[mesh_dim_0] * self.device_mesh.mesh_shape[mesh_dim_1] sharding_size_backward_activation = self.device_mesh.mesh_shape[mesh_dim_0] * self.device_mesh.mesh_shape[ mesh_dim_1] - sharding_size_backward_weight = 1 + sharding_size_weight = 1 memory_cost, _, _ = self._generate_memory_cost(sharding_size_forward, sharding_size_backward_activation, - sharding_size_backward_weight) + sharding_size_weight) # This strategy do not need to do all_reduce in both forward and backward phase communication_cost = 0 sharding_strategies = ShardingStrategy(name, - output_sharding_spec=sharding_spec_for_ouput, + output_sharding_spec=sharding_spec_for_output, compute_cost=compute_cost, communication_cost=communication_cost, memory_cost=memory_cost, @@ -446,7 +449,7 @@ class ConvHandler(OperatorHandler): sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight) dim_partition_dict_for_output = {} - sharding_spec_for_ouput = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_input) + sharding_spec_for_output = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output) # generate resharding cost for this strategy resharding_costs = self._generate_resharding_costs([sharding_spec_for_input, sharding_spec_for_weight]) @@ -462,19 +465,20 @@ class ConvHandler(OperatorHandler): sharding_size_forward = 1 sharding_size_backward_activation = self.device_mesh.mesh_shape[mesh_dim_0] * self.device_mesh.mesh_shape[ mesh_dim_1] - sharding_size_backward_weight = self.device_mesh.mesh_shape[mesh_dim_0] * self.device_mesh.mesh_shape[mesh_dim_1] - memory_cost, memory_cost_forward, _ = self._generate_memory_cost(sharding_size_forward, - sharding_size_backward_activation, - sharding_size_backward_weight) + sharding_size_weight = self.device_mesh.mesh_shape[mesh_dim_0] * self.device_mesh.mesh_shape[mesh_dim_1] + memory_cost, memory_cost_forward_activation, _ = self._generate_memory_cost(sharding_size_forward, + sharding_size_backward_activation, + sharding_size_weight) # compute communication cost during forward phase - communication_cost_forward = self.device_mesh.flatten_device_mesh.all_reduce_cost(memory_cost_forward, 0) + communication_cost_forward = self.device_mesh.flatten_device_mesh.all_reduce_cost( + memory_cost_forward_activation, 0) # This strategy do NOT need do all_reduce during backward phase communication_cost_backward = 0 communication_cost = communication_cost_forward + communication_cost_backward sharding_strategies = ShardingStrategy(name, - output_sharding_spec=sharding_spec_for_ouput, + output_sharding_spec=sharding_spec_for_output, compute_cost=compute_cost, communication_cost=communication_cost, memory_cost=memory_cost, @@ -536,24 +540,48 @@ class ConvHandler(OperatorHandler): RR = RS01 x S01R: compute_cost is 8856576, communication_cost is 0, memory_cost is 1968128, resharding_costs is {mul: [0, 0, 262148.4, 65538.002, 196614.402, 262148.4, 65538.2]} ''' # SS = SR x RS - self.split_input_batch_weight_out_channel(0, 1) - self.split_input_batch_weight_out_channel(1, 0) + try: + self.split_input_batch_weight_out_channel(0, 1) + except Exception as e: + warnings.warn(f'{e}') + try: + self.split_input_batch_weight_out_channel(1, 0) + except Exception as e: + warnings.warn(f'{e}') # SR = SR x RR self.split_input_batch(0) self.split_input_batch(1) # SR = SS x SR - self.split_input_both_dim_weight_in_channel(0, 1) - self.split_input_both_dim_weight_in_channel(1, 0) + try: + self.split_input_both_dim_weight_in_channel(0, 1) + except Exception as e: + warnings.warn(f'{e}') + try: + self.split_input_both_dim_weight_in_channel(1, 0) + except Exception as e: + warnings.warn(f'{e}') # RS = RS x SS - self.split_input_in_channel_weight_both_channel(0, 1) - self.split_input_in_channel_weight_both_channel(1, 0) + try: + self.split_input_in_channel_weight_both_channel(0, 1) + except Exception as e: + warnings.warn(f'{e}') + try: + self.split_input_in_channel_weight_both_channel(1, 0) + except Exception as e: + warnings.warn(f'{e}') # RR = RS x SR - self.split_input_in_channel_weight_in_channel(0) - self.split_input_in_channel_weight_in_channel(1) + try: + self.split_input_in_channel_weight_in_channel(0) + except Exception as e: + warnings.warn(f'{e}') + try: + self.split_input_in_channel_weight_in_channel(1) + except Exception as e: + warnings.warn(f'{e}') # RS = RR x RS self.split_weight_out_channel(0) @@ -566,7 +594,12 @@ class ConvHandler(OperatorHandler): self.split_1d_parallel_on_input_batch(0, 1) # RR = RS01 x S01R - self.split_1d_parallel_on_in_channel(0, 1) + try: + self.split_1d_parallel_on_in_channel(0, 1) + except Exception as e: + warnings.warn(f'{e}') + + # print(f'strategies num is :{len(self.strategies_vector)}') return self.strategies_vector diff --git a/colossalai/auto_parallel/solver/dot_handler.py b/colossalai/auto_parallel/solver/op_handler/dot_handler.py similarity index 56% rename from colossalai/auto_parallel/solver/dot_handler.py rename to colossalai/auto_parallel/solver/op_handler/dot_handler.py index 3ce2fedbd..26791df46 100644 --- a/colossalai/auto_parallel/solver/dot_handler.py +++ b/colossalai/auto_parallel/solver/op_handler/dot_handler.py @@ -44,11 +44,8 @@ class DotHandler(OperatorHandler): compute_cost = self._generate_compute_cost(self.input_data.shape, self.weight.shape) # compute the memory cost of this strategy - dtype = self.input_data.dtype - numel = self.output_data.numel() - size_per_elem_bytes = torch.tensor([], dtype=dtype).element_size() - sharding_size = self.device_mesh.shape[mesh_dim_0] * self.device_mesh.shape[mesh_dim_1] - memory_cost = numel * size_per_elem_bytes / sharding_size + toatl_memory_cost, activation_memory_cost, weight_memory_cost = self._generate_memory_cost( + dim_partition_dict_for_output, dim_partition_dict_for_weight) # compute the communication cost # no all-reduce required for this case @@ -59,7 +56,7 @@ class DotHandler(OperatorHandler): output_sharding_spec=sharding_spec_for_ouput, compute_cost=compute_cost, communication_cost=communication_cost, - memory_cost=memory_cost, + memory_cost=toatl_memory_cost, resharding_costs=resharding_costs, input_shardings=(sharding_spec_for_input, sharding_spec_for_weight)) self.strategies_vector.append(sharding_strategies) @@ -86,19 +83,16 @@ class DotHandler(OperatorHandler): compute_cost = self._generate_compute_cost(self.input_data.shape, self.weight.shape) # compute the memory cost of this strategy - dtype = self.input_data.dtype - numel = self.output_data.numel() - size_per_elem_bytes = torch.tensor([], dtype=dtype).element_size() - sharding_size = self.device_mesh.shape[mesh_dim_0] - memory_cost = numel * size_per_elem_bytes / sharding_size + toatl_memory_cost, activation_memory_cost, weight_memory_cost = self._generate_memory_cost( + dim_partition_dict_for_output, dim_partition_dict_for_weight) # compute the communication cost of this strategy - communication_cost = self.device_mesh.all_reduce_cost(memory_cost, mesh_dim_1) + communication_cost = self.device_mesh.all_reduce_cost(activation_memory_cost, mesh_dim_1) sharding_strategies = ShardingStrategy(name, output_sharding_spec=sharding_spec_for_ouput, compute_cost=compute_cost, communication_cost=communication_cost, - memory_cost=memory_cost, + memory_cost=toatl_memory_cost, resharding_costs=resharding_costs, input_shardings=(sharding_spec_for_input, sharding_spec_for_weight)) self.strategies_vector.append(sharding_strategies) @@ -122,19 +116,16 @@ class DotHandler(OperatorHandler): compute_cost = self._generate_compute_cost(self.input_data.shape, self.weight.shape) # compute the memory cost of this strategy - dtype = self.input_data.dtype - numel = self.output_data.numel() - size_per_elem_bytes = torch.tensor([], dtype=dtype).element_size() - sharding_size = self.device_mesh.shape[mesh_dim_0] - memory_cost = numel * size_per_elem_bytes / sharding_size + toatl_memory_cost, activation_memory_cost, weight_memory_cost = self._generate_memory_cost( + dim_partition_dict_for_output, dim_partition_dict_for_weight) # compute the communication cost of this strategy - communication_cost = self.device_mesh.all_reduce_cost(memory_cost, mesh_dim_1) + communication_cost = self.device_mesh.all_reduce_cost(activation_memory_cost, mesh_dim_1) sharding_strategies = ShardingStrategy(name, output_sharding_spec=sharding_spec_for_ouput, compute_cost=compute_cost, communication_cost=communication_cost, - memory_cost=memory_cost, + memory_cost=toatl_memory_cost, resharding_costs=resharding_costs, input_shardings=(sharding_spec_for_input, sharding_spec_for_weight)) self.strategies_vector.append(sharding_strategies) @@ -158,18 +149,16 @@ class DotHandler(OperatorHandler): compute_cost = self._generate_compute_cost(self.input_data.shape, self.weight.shape) # compute the memory cost of this strategy - dtype = self.input_data.dtype - numel = self.output_data.numel() - size_per_elem_bytes = torch.tensor([], dtype=dtype).element_size() - memory_cost = numel * size_per_elem_bytes + toatl_memory_cost, activation_memory_cost, weight_memory_cost = self._generate_memory_cost( + dim_partition_dict_for_output, dim_partition_dict_for_weight) # compute the communication cost of this strategy - communication_cost = self.device_mesh.all_reduce_cost(memory_cost, mesh_dim) + communication_cost = self.device_mesh.all_reduce_cost(activation_memory_cost, mesh_dim) sharding_strategies = ShardingStrategy(name, output_sharding_spec=sharding_spec_for_ouput, compute_cost=compute_cost, communication_cost=communication_cost, - memory_cost=memory_cost, + memory_cost=toatl_memory_cost, resharding_costs=resharding_costs, input_shardings=(sharding_spec_for_input, sharding_spec_for_weight)) self.strategies_vector.append(sharding_strategies) @@ -193,19 +182,115 @@ class DotHandler(OperatorHandler): compute_cost = self._generate_compute_cost(self.input_data.shape, self.weight.shape) # compute the memory cost of this strategy - dtype = self.input_data.dtype - numel = self.output_data.numel() - size_per_elem_bytes = torch.tensor([], dtype=dtype).element_size() - sharding_size = self.device_mesh.shape[mesh_dim] - memory_cost = numel * size_per_elem_bytes / sharding_size + toatl_memory_cost, activation_memory_cost, weight_memory_cost = self._generate_memory_cost( + dim_partition_dict_for_output, dim_partition_dict_for_weight) # compute the communication cost of this strategy - communication_cost = self.device_mesh.all_reduce_cost(memory_cost, mesh_dim) + communication_cost = self.device_mesh.all_reduce_cost(activation_memory_cost, mesh_dim) sharding_strategies = ShardingStrategy(name, output_sharding_spec=sharding_spec_for_ouput, compute_cost=compute_cost, communication_cost=communication_cost, - memory_cost=memory_cost, + memory_cost=toatl_memory_cost, + resharding_costs=resharding_costs, + input_shardings=(sharding_spec_for_input, sharding_spec_for_weight)) + self.strategies_vector.append(sharding_strategies) + + def split_lhs_1st_dim_1d(self, mesh_dim_0, mesh_dim_1): + name = f'S{mesh_dim_0}{mesh_dim_1}R = S{mesh_dim_0}{mesh_dim_1}R x RR' + + dim_partition_dict_for_input = {0: [mesh_dim_0, mesh_dim_1]} + sharding_spec_for_input = self._generate_sharding_spec(self.input_data, dim_partition_dict_for_input) + + dim_partition_dict_for_weight = {} + sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight) + + dim_partition_dict_for_output = {0: [mesh_dim_0, mesh_dim_1]} + sharding_spec_for_ouput = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output) + + # generate resharding cost for this strategy + resharding_costs = self._generate_resharding_costs([sharding_spec_for_input]) + + # compute the computation cost of this strategy + compute_cost = self._generate_compute_cost(self.input_data.shape, self.weight.shape) + + # compute the memory cost of this strategy + toatl_memory_cost, activation_memory_cost, weight_memory_cost = self._generate_memory_cost( + dim_partition_dict_for_output, dim_partition_dict_for_weight) + + # compute the communication cost of this strategy + communication_cost = 0 + sharding_strategies = ShardingStrategy(name, + output_sharding_spec=sharding_spec_for_ouput, + compute_cost=compute_cost, + communication_cost=communication_cost, + memory_cost=toatl_memory_cost, + resharding_costs=resharding_costs, + input_shardings=(sharding_spec_for_input, sharding_spec_for_weight)) + self.strategies_vector.append(sharding_strategies) + + def split_lhs_2nd_dim_1d(self, mesh_dim_0, mesh_dim_1): + name = f'RR = RS{mesh_dim_0}{mesh_dim_1} x S{mesh_dim_0}{mesh_dim_1}R' + + dim_partition_dict_for_input = {1: [mesh_dim_0, mesh_dim_1]} + sharding_spec_for_input = self._generate_sharding_spec(self.input_data, dim_partition_dict_for_input) + + dim_partition_dict_for_weight = {0: [mesh_dim_0, mesh_dim_1]} + sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight) + + dim_partition_dict_for_output = {} + sharding_spec_for_ouput = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output) + + # generate resharding cost for this strategy + resharding_costs = self._generate_resharding_costs([sharding_spec_for_input]) + + # compute the computation cost of this strategy + compute_cost = self._generate_compute_cost(self.input_data.shape, self.weight.shape) + + # compute the memory cost of this strategy + toatl_memory_cost, activation_memory_cost, weight_memory_cost = self._generate_memory_cost( + dim_partition_dict_for_output, dim_partition_dict_for_weight) + + # compute the communication cost of this strategy + communication_cost = self.device_mesh.flatten_device_mesh.all_reduce_cost(activation_memory_cost, 0) + sharding_strategies = ShardingStrategy(name, + output_sharding_spec=sharding_spec_for_ouput, + compute_cost=compute_cost, + communication_cost=communication_cost, + memory_cost=toatl_memory_cost, + resharding_costs=resharding_costs, + input_shardings=(sharding_spec_for_input, sharding_spec_for_weight)) + self.strategies_vector.append(sharding_strategies) + + def split_rhs_2nd_dim_1d(self, mesh_dim_0, mesh_dim_1): + name = f'RS{mesh_dim_0}{mesh_dim_1} = RR x RS{mesh_dim_0}{mesh_dim_1}' + + dim_partition_dict_for_input = {} + sharding_spec_for_input = self._generate_sharding_spec(self.input_data, dim_partition_dict_for_input) + + dim_partition_dict_for_weight = {1: [mesh_dim_0, mesh_dim_1]} + sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight) + + dim_partition_dict_for_output = {1: [mesh_dim_0, mesh_dim_1]} + sharding_spec_for_ouput = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output) + + # generate resharding cost for this strategy + resharding_costs = self._generate_resharding_costs([sharding_spec_for_input]) + + # compute the computation cost of this strategy + compute_cost = self._generate_compute_cost(self.input_data.shape, self.weight.shape) + + # compute the memory cost of this strategy + toatl_memory_cost, activation_memory_cost, weight_memory_cost = self._generate_memory_cost( + dim_partition_dict_for_output, dim_partition_dict_for_weight) + + # compute the communication cost of this strategy + communication_cost = 0 + sharding_strategies = ShardingStrategy(name, + output_sharding_spec=sharding_spec_for_ouput, + compute_cost=compute_cost, + communication_cost=communication_cost, + memory_cost=toatl_memory_cost, resharding_costs=resharding_costs, input_shardings=(sharding_spec_for_input, sharding_spec_for_weight)) self.strategies_vector.append(sharding_strategies) @@ -236,4 +321,14 @@ class DotHandler(OperatorHandler): # RS = RR x RS self.split_rhs_space_only(0) self.split_rhs_space_only(1) + + # S01R = S01R x RR + self.split_lhs_1st_dim_1d(0, 1) + + # RR = RS01 x S01R + self.split_lhs_2nd_dim_1d(0, 1) + + # RS01 = RR x RS01 + self.split_rhs_2nd_dim_1d(0, 1) + return self.strategies_vector diff --git a/colossalai/auto_parallel/solver/operator_handler.py b/colossalai/auto_parallel/solver/op_handler/operator_handler.py similarity index 62% rename from colossalai/auto_parallel/solver/operator_handler.py rename to colossalai/auto_parallel/solver/op_handler/operator_handler.py index 62289b5ce..3c0e98cf4 100644 --- a/colossalai/auto_parallel/solver/operator_handler.py +++ b/colossalai/auto_parallel/solver/op_handler/operator_handler.py @@ -8,7 +8,7 @@ from colossalai.device.device_mesh import DeviceMesh from colossalai.tensor.shape_consistency import ShapeConsistencyManager from colossalai.tensor.sharding_spec import ShardingSpec -from .sharding_strategy import StrategiesVector +from ..sharding_strategy import StrategiesVector __all__ = ['OperatorHandler'] @@ -70,6 +70,48 @@ class OperatorHandler(ABC): dim_partition_dict=dim_partition_dict) return sharding_spec + def _generate_memory_cost(self, dim_partition_dict_for_output, dim_partition_dict_for_weight): + ''' + Compute the memory cost per device with this specific strategy. + + Argument: + dim_partition_dict_for_output(List[int]): The key is the dimension of output to be sharded, + and the value of the key decribe which logical axis will be sharded in that dimension. + dim_partition_dict_for_weight(List[int]): The key is the dimension of weight to be sharded, + and the value of the key decribe which logical axis will be sharded in that dimension. + Return: + total_memory_cost(float): total memory cost per device with this specific strategy + activation_cost(float): the memory cost of activation per device with this specific strategy + weight_memory_cost(float): the memory cost of weight per device with this specific strategy + ''' + # compute the size of one element with specific dtype + dtype = self.input_data.dtype + size_per_elem_bytes = torch.tensor([], dtype=dtype).element_size() + + # compute the memory cost of activation + activation_numel = self.output_data.numel() + output_mesh_dims = [] + for sharding_dim, mesh_dims in dim_partition_dict_for_output.items(): + output_mesh_dims.extend(mesh_dims) + activation_sharding_size = 1 + for mesh_dim in output_mesh_dims: + activation_sharding_size *= self.device_mesh.shape[mesh_dim] + activation_memory_cost = activation_numel / activation_sharding_size * size_per_elem_bytes + + # compute the memory cost of weight + weight_numel = self.weight.numel() + weight_sharding_size = 1 + weight_mesh_dims = [] + for sharding_dim, mesh_dims in dim_partition_dict_for_weight.items(): + weight_mesh_dims.extend(mesh_dims) + for mesh_dim in weight_mesh_dims: + weight_sharding_size *= self.device_mesh.shape[mesh_dim] + weight_memory_cost = weight_numel / weight_sharding_size * size_per_elem_bytes + + total_memory_cost = activation_memory_cost + weight_memory_cost + + return total_memory_cost, activation_memory_cost, weight_memory_cost + def _generate_resharding_costs(self, sharding_spec_for_input): ''' Compute the resharding costs with this specific strategy. @@ -85,17 +127,20 @@ class OperatorHandler(ABC): ''' # The resharding_cost of weight is counted due to sharing weight cases. resharding_costs = {} - for input_node, target_spec in zip(self.predecessor_node, sharding_spec_for_input): + dtype = self.node._meta_data.dtype + size_per_elem_bytes = torch.tensor([], dtype=dtype).element_size() + for input_node, input_spec in zip(self.predecessor_node, sharding_spec_for_input): resharding_costs[input_node] = [] for strategy in input_node.strategies_vector: input_sharding_spec = strategy.output_sharding_spec assert isinstance(input_sharding_spec, ShardingSpec), f'The input node should NOT be a tuple of tensor.' # compute the resharding cost during forward phase _, _, resharding_cost_forward = self.shape_consistency_manager.shape_consistency( - input_sharding_spec, target_spec) + input_sharding_spec, input_spec) # In backward phase, we should convert grad with target_spec into input_sharding_spec _, _, resharding_cost_backward = self.shape_consistency_manager.shape_consistency( - target_spec, input_sharding_spec) - resharding_cost = resharding_cost_forward + resharding_cost_backward + input_spec, input_sharding_spec) + # we need multiply the size of elem dtype to get correct communication cost + resharding_cost = (resharding_cost_forward + resharding_cost_backward) * size_per_elem_bytes resharding_costs[input_node].append(resharding_cost) return resharding_costs diff --git a/colossalai/auto_parallel/solver/solver.py b/colossalai/auto_parallel/solver/solver.py new file mode 100644 index 000000000..63c35c2fc --- /dev/null +++ b/colossalai/auto_parallel/solver/solver.py @@ -0,0 +1,444 @@ +import warnings + +import time +import numpy as np +import multiprocessing +from torch.fx.node import Node +from torch.fx.graph import Graph +from . import GraphAnalyser +from colossalai.auto_parallel.solver.cost_graph import CostGraph +from colossalai.auto_parallel.solver.strategies_constructor import StrategiesConstructor +from typing import Dict +from .constants import INFINITY_COST +try: + import pulp + from pulp import LpVariable, LpProblem, LpMinimize, lpSum, lpDot, LpStatus +except: + warnings.warn(f'please install the pulp') + +__all___ = ['Solver'] + + +class Solver: + + def __init__(self, + graph: Graph, + strategies_constructor: StrategiesConstructor, + cost_graph: CostGraph, + graph_analyser: GraphAnalyser, + memory_budget: float = -1.0, + solution_numbers: int = 1, + memory_increasing_coefficient: float = 1.3): + ''' + Solver class will integrate information provided by the components and use ILP solver to find a possible optimal strategies combination for target computing graph. + + Argument: + graph: The computing graph to be optimized. + strategies_constructor: It will provide all the possible strategies for each node in the computing graph. + cost_graph: A graph data structure to simplify the edge cost graph. + graph_analyser: graph_analyser will analyse the graph to obtain the variable liveness information, which will be used to generate memory constraints. + memory_budget: Memory constraint for the solution. + solution_numbers: If solution_numbers is larger than one, solver will us a serious of solutions based on different memory budget. + memory_increasing_coefficient: If solution_numbers is larger than one, we will use this coefficient to generate new memory budget. + ''' + self.graph = graph + self.strategies_constructor = strategies_constructor + self.cost_graph = cost_graph + self.graph_analyser = graph_analyser + self.nodes = list(self.graph.nodes) + self.leaf_strategies = self.strategies_constructor.leaf_strategies + self.strategy_map = self.strategies_constructor.strategy_map + self.memory_budget = memory_budget + self.solution_numbers = solution_numbers + if self.solution_numbers > 1: + self.memory_increasing_coefficient = memory_increasing_coefficient + else: + self.memory_increasing_coefficient = 1 + self.liveness_list = self.graph_analyser.liveness_analysis() + self.node_index_dict = self._generate_node_index_dict() + # The last solution vector of auto sharding. + self.last_s_val = None + # The last objective value of the best ILP solution. + self.last_objective = None + + def _generate_node_index_dict(self) -> Dict[Node, int]: + node_index_dict = {} + for index, strategies_vector in enumerate(self.leaf_strategies): + node_index_dict[strategies_vector.node] = index + return node_index_dict + + def _prepare_data_for_solver(self): + ''' + Extract information from components for solver. + ''' + node_nums = len(self.leaf_strategies) + memory_budget = self.memory_budget + + # prepare strategies_len + strategies_len = [] + for node in self.nodes: + strategies_len.append(self.cost_graph.node_lens[node]) + strategies_len = np.array(strategies_len) + + # prepare following_nodes + following_nodes = self.cost_graph.following_dict + index_following_nodes = {} + for src, target in following_nodes.items(): + src_index = self.node_index_dict[src] + target_index = self.node_index_dict[target] + index_following_nodes[src_index] = target_index + following_nodes = index_following_nodes + for index in range(node_nums): + if index not in following_nodes: + following_nodes[index] = -1 + + # prepare edge_pairs and resharding costs + edge_pairs = [] + resharding_costs = [] + for pairs, edge_cost in self.cost_graph.edge_costs.items(): + src_node = pairs[0] + dst_node = pairs[1] + src_node_index = self.node_index_dict[src_node] + dst_node_index = self.node_index_dict[dst_node] + edge_pairs.append(src_node_index) + edge_pairs.append(dst_node_index) + + for i in range(strategies_len[src_node_index]): + for j in range(strategies_len[dst_node_index]): + resharding_costs.append(edge_cost[(i, j)]) + edge_pairs = np.array(edge_pairs) + resharding_costs = np.array(resharding_costs) + + # prepare liveness_set + liveness_set = self.liveness_list + + # omit alias_set now + alias_set = None + alias_convert_costs = None + + # prepare compute_costs, communication_costs and memory_costs + compute_costs = [] + communication_costs = [] + memory_costs = [] + extra_node_costs = self.cost_graph.extra_node_costs + for strategies_vector in self.leaf_strategies: + node = strategies_vector.node + for index, strategy in enumerate(strategies_vector): + compute_costs.append(strategy.compute_cost) + # node in extra_node_costs means it has some extra communication + # cost from node merging, so we need to add those extra communication + # cost into + if node in extra_node_costs: + origin_communication_cost = strategy.communication_cost + extra_node_cost = extra_node_costs[node][index] + communication_cost = origin_communication_cost + extra_node_cost + communication_costs.append(communication_cost) + else: + communication_costs.append(strategy.communication_cost) + # temporarily we just consider the forward memory cost + memory_cost = strategy.memory_cost + if isinstance(memory_cost, tuple): + memory_costs.append(memory_cost[0]) + else: + memory_costs.append(memory_cost) + compute_costs = np.array(compute_costs) + communication_costs = np.array(communication_costs) + memory_costs = np.array(memory_costs) + + # omit initial value for nodes + s_init_np = None + + return node_nums, memory_budget, strategies_len, following_nodes, edge_pairs, alias_set, liveness_set, compute_costs, communication_costs, memory_costs, resharding_costs, alias_convert_costs, s_init_np + + def _call_solver_serialized_args(self, + node_nums, + memory_budget, + strategies_len, + following_nodes, + edge_pairs, + alias_set, + liveness_set, + compute_costs, + communication_costs, + memory_costs, + resharding_costs, + alias_convert_costs, + s_init_np=None): + """ + Call the solver with serialized arguments. + """ + + tic = time.time() + + for x in [strategies_len, edge_pairs, compute_costs, communication_costs, memory_costs, resharding_costs]: + assert isinstance(x, np.ndarray) + assert len(strategies_len) == node_nums, "strategies_len" + + def get_non_zero_index(binary_vector): + """ + Get the index of non-zero item in a vector. + """ + ct = 0 + ret = None + for i, elem in enumerate(binary_vector): + if pulp.value(elem): + ret = i + ct += 1 + + assert ct == 1 + return ret + + # 0. Unpack flatten numpy arrays + s_follow = following_nodes + + E = edge_pairs.reshape((-1, 2)) # noqa + r = [] + pt = 0 + edge_set = set() + for (i, j) in E: + prod_length = strategies_len[i] * strategies_len[j] + + if (i, j) in edge_set: + raise ValueError(f"Duplicated edges: {(i, j)}") + + edge_set.add((i, j)) + r.append(resharding_costs[pt:pt + prod_length]) + pt += prod_length + assert pt == len(resharding_costs) + + ###################### + # omit alias set now # + ###################### + + # A = alias_set.reshape((-1, 2)) # noqa + # for (i, j) in A: + # prod_length = strategies_len[i] * strategies_len[j] + # v.append(alias_convert_costs[pt:pt + prod_length]) + # pt += prod_length + # assert pt == len(alias_convert_costs) + + # L = [] # noqa + # pt = node_nums + # for i in range(node_nums): + # length = liveness_set[i] + # L.append(liveness_set[pt:pt + length]) + # pt += length + # assert pt == len(liveness_set) + v = [] + pt = 0 + + c = [] + d = [] + m = [] + pt = 0 + for i in range(node_nums): + length = strategies_len[i] + c.append(compute_costs[pt:pt + length]) + d.append(communication_costs[pt:pt + length]) + m.append(memory_costs[pt:pt + length]) + pt += length + assert pt == len(compute_costs), f"{pt} == {len(compute_costs)}" + assert pt == len(communication_costs), f"{pt} == {len(communication_costs)}" + assert pt == len(memory_costs), f"{pt} == {len(memory_costs)}" + + # 1. Create variables + + ############################# + # create variables for node # + ############################# + s = [] + num_nodes = 0 + reverse_follow_backpatch = [] + for i in range(node_nums): + if s_follow[i] < 0: + if strategies_len[i] == 1: + s.append([1]) + else: + num_nodes += 1 + s.append(LpVariable.matrix(f"s[{i}]", (range(strategies_len[i]),), cat="Binary")) + else: + if s_follow[i] < len(s): + s.append(s[s_follow[i]]) + else: + s.append(None) + reverse_follow_backpatch.append(i) + + for i in reverse_follow_backpatch: + s[i] = s[s_follow[i]] + + ############################# + # create variables for edge # + ############################# + e = [] + num_edges = 0 + for (idx, (i, j)) in enumerate(E): + if len(s[i]) == 1: + e.append(s[j]) + elif len(s[j]) == 1: + e.append(s[i]) + else: + num_edges += 1 + e.append(LpVariable.matrix(f"e[{i},{j}]", (range(len(s[i]) * len(s[j])),), cat="Binary")) + assert len(e[idx]) == len(r[idx]) + + # 2. Set initial value + ###################################### + # set a initial value for warm start # + ###################################### + if s_init_np is not None: + s_init = s_init_np.reshape((-1, 3)) + for (idx, value, fix) in s_init: + for i in range(len(s[idx])): + s[idx][i].setInitialValue(i == value) + if fix: + s[idx][i].fixValue() + + # 3. Objective + prob = LpProblem("myProblem", LpMinimize) + ################################################################### + # computing the node cost(computing cost and communication cost) # + ################################################################### + obj = 0 + for i in range(node_nums): + obj += lpDot(s[i], c[i]) + lpDot(s[i], d[i]) + + ############################################# + # computing the edge cost(resharding cost) # + ############################################# + for i in range(len(E)): + obj += lpDot(e[i], r[i]) + + prob += obj + + # 4. Constraints + # (a). specified by `cat="Binary"` + + # (b) + ################################################# + # make sure each node only choose one strategy # + ################################################# + for i in range(node_nums): + if s_follow[i] < 0: + prob += lpSum(s[i]) == 1 + + # (c) + ################################################# + # compute memory consumption with liveness set # + ################################################# + if memory_budget > 0: + for liveness_stage in liveness_set: + mem = 0 + for live_variable in liveness_stage.unique_live_vars: + node_index = self.node_index_dict[live_variable.node] + mem += lpSum(s[node_index][j] * m[node_index][j] for j in range(len(s[node_index]))) + prob += mem <= memory_budget + + # (d). specified by `cat="Binary"` + + for (idx, (i, j)) in enumerate(E): + if strategies_len[i] == 1 or strategies_len[j] == 1: + continue + + # (e) + prob += lpSum(e[idx]) == 1 + + # (f) + for row in range(len(s[i])): + C = len(s[j]) # noqa + prob += lpSum(e[idx][row * C + col] for col in range(0, C)) <= s[i][row] + + # (g) + for col in range(len(s[j])): + R = len(s[i]) # noqa + C = len(s[j]) # noqa + prob += lpSum(e[idx][row * C + col] for row in range(0, R)) <= s[j][col] + + # (h) + ###################### + # omit alias set now # + ###################### + + # alias_set = set() + # for (idx, (i, j)) in enumerate(A): + # R = len(s[i]) # noqa + # C = len(s[j]) # noqa + # if (i, j) in alias_set: + # raise ValueError(f"Duplicated edges: {(i, j)}") + + # alias_set.add((i, j)) + # alias_set.add((j, i)) + + # for row in range(len(s[i])): + # for col in range(len(s[j])): + # if v[idx][row * C + col] > 0.5: + # prob += s[i][row] + s[j][col] <= 1 + + verbose = True + + msg = verbose + time_limit = 600 + assert "COIN_CMD" in pulp.listSolvers( + onlyAvailable=True), ("Please install ILP solvers by 'sudo apt install coinor-cbc'") + + solver = pulp.COIN_CMD(mip=True, msg=msg, timeLimit=time_limit, threads=multiprocessing.cpu_count()) + # solver = pulp.GLPK_CMD(mip=True, msg=msg, timeLimit=time_limit) + prob.solve(solver) + + status = prob.status + objective = pulp.value(prob.objective) + objective = float(objective) if objective is not None else -1.0 + if verbose: + print(f"ILP Status: {LpStatus[status]}\tObjective: {objective}\t" + f"Time: {time.time() - tic}") + print(f"#nodes: {num_nodes}, #edges: {num_edges}") + + if prob.status in [pulp.LpStatusInfeasible]: + raise RuntimeError("Cannot run the function under the given memory budget. " + "Please increase the memory budget.") + + # Get and check results + s_val = np.full((node_nums,), -1, dtype=np.int32) + for i in range(node_nums): + s_val[i] = get_non_zero_index(s[i]) + + e_val = np.full((len(E),), -1, dtype=np.int32) + for (idx, (i, j)) in enumerate(E): + e_val[idx] = get_non_zero_index(e[idx]) + i_spec_index = e_val[idx] // len(s[j]) + j_spec_index = e_val[idx] % len(s[j]) + assert i_spec_index == s_val[i], f"e_val[{i}][{j}]" + assert j_spec_index == s_val[j], f"e_val[{i}][{j}]" + if verbose and r[idx][e_val[idx]] > 0: + print(f"Edge cost {(i, j)} : {r[idx][e_val[idx]]}") + + self.last_s_val = s_val + self.last_objective = objective + + if objective > INFINITY_COST: + warnings.warn("Detect unexpected behaviors in the auto-sharding pass.") + + return s_val, e_val, objective, status + + def call_solver_serialized_args(self): + """ + Call the solver with serialized arguments and handle python errors. Additionally, + we could give a serious of solutions with different memory budget. + """ + if self.solution_numbers == 1: + args = self._prepare_data_for_solver() + ret = self._call_solver_serialized_args(*args) + + return ret + + origin_memory_budget = self.memory_budget + memory_budget_list = [ + origin_memory_budget * self.memory_increasing_coefficient**i for i in range(self.solution_numbers) + ] + ret_list = [] + for memory_budget in memory_budget_list: + self.memory_budget = memory_budget + args = self._prepare_data_for_solver() + ret = self._call_solver_serialized_args(*args) + ret_list.append(ret) + + return ret_list diff --git a/colossalai/auto_parallel/solver/strategies_constructor.py b/colossalai/auto_parallel/solver/strategies_constructor.py index 98cc43976..546a30978 100644 --- a/colossalai/auto_parallel/solver/strategies_constructor.py +++ b/colossalai/auto_parallel/solver/strategies_constructor.py @@ -1,7 +1,7 @@ from torch.fx import Graph, Node from colossalai.tensor.sharding_spec import ShardingSpec -from .sharding_strategy import ShardingStrategy, StrategiesVector -from .conv_handler import ConvHandler +from . import ShardingStrategy, StrategiesVector +from .op_handler import * from .constants import * from copy import deepcopy import math @@ -175,6 +175,58 @@ class StrategiesConstructor: input_shardings=[input_sharding_spec]) strategies_vector.append(sharding_strategy) + # BatchNormNd module + elif submod_type in BATCHNORM_MODULE_OP: + # bn1 call_module bn1 (conv1,) + # print(node, node.op, node.target, node.args) + # create sharding strategy for element-wise module + # input_node = strategies_vector.predecessor_nodes[0] + norm_handler = BatchNormHandler(node, self.device_mesh, strategies_vector, + self.shape_consistency_manager) + norm_handler.register_strategy() + # for strategy in norm_handler.strategies_vector: + # print(f'{strategy.name}, computation_cost: {strategy.compute_cost}, memory_cost: {strategy.memory_cost}') + # assert False + + # MaxPool module + elif submod_type in POOL_MODULE_OP: + # create sharding strategy for element-wise module + assert len(strategies_vector.predecessor_nodes + ) == 1, f'Temporally, we just support single input element-wise op.' + input_node = strategies_vector.predecessor_nodes[0] + # For element-wise module, we keep the sharding spec of output node same as + # the input. Therefore, the different strategies of input node with same + # output sharding spec will generate same strategy for element-wise module. + sharding_spec_checklist = [] + for strategy in input_node.strategies_vector: + # It looks a little bit confusing, the input of the processing node + # is the output of the input_node. + input_sharding_spec = strategy.output_sharding_spec + assert isinstance(input_sharding_spec, + ShardingSpec), f'The input node should NOT be a tuple of tensor.' + if input_sharding_spec in sharding_spec_checklist: + continue + + sharding_spec_checklist.append(input_sharding_spec) + dim_partition_dict = deepcopy(input_sharding_spec.dim_partition_dict) + output_sharding_spec = self._generate_sharding_spec(node, dim_partition_dict) + + name = f'{input_sharding_spec.sharding_sequence} -> {output_sharding_spec.sharding_sequence}' + + # TODO: use meta_info_prop to profile memory cost and compute cost + compute_cost = node._meta_data.numel() + memory_cost = 0 + resharding_costs = self._generate_resharding_costs(strategies_vector.predecessor_nodes, + [input_sharding_spec]) + + sharding_strategy = ShardingStrategy(name, + output_sharding_spec, + compute_cost=compute_cost, + memory_cost=memory_cost, + resharding_costs=resharding_costs, + input_shardings=[input_sharding_spec]) + strategies_vector.append(sharding_strategy) + # other module else: raise RuntimeError(f'{submod_type} module is NOT supported now.') @@ -203,7 +255,7 @@ class StrategiesConstructor: # TODO: integrate element-wise func and module together # create sharding strategy for element-wise function assert len(strategies_vector.predecessor_nodes - ) == 1, f'Temporally, we just support single input element-wise op.' + ) == 1, f'Temporally, we just support single input element-wise op, node name is {node}.' input_node = strategies_vector.predecessor_nodes[0] # For element-wise function, we keep the sharding spec of output node same as # the input. Therefore, the different strategies of input node with same @@ -349,6 +401,13 @@ class StrategiesConstructor: memory_cost = 0 resharding_costs = self._generate_resharding_costs(strategies_vector.predecessor_nodes, input_sharding_specs) + + # clear the resharding cost for the output node + # TODO: we may remove this in final version + if True: + for prev_node, resharding_cost_list in resharding_costs.items(): + resharding_costs[prev_node] = [0] * len(resharding_cost_list) + sharding_strategy_attribute = ShardingStrategy(name, output_sharding_spec, memory_cost=memory_cost, diff --git a/tests/test_auto_parallel/test_batch_norm_handler.py b/tests/test_auto_parallel/test_batch_norm_handler.py new file mode 100644 index 000000000..8174680b3 --- /dev/null +++ b/tests/test_auto_parallel/test_batch_norm_handler.py @@ -0,0 +1,119 @@ +import torch +from torch.fx import GraphModule +import torch.nn as nn +import pytest + +from colossalai.fx.proxy import ColoProxy +from colossalai.fx.tracer.tracer import ColoTracer +from colossalai.tensor.sharding_spec import ShardingSpec, _DimSpec +from colossalai.auto_parallel.solver.op_handler.batch_norm_handler import BatchNormHandler +from colossalai.auto_parallel.solver.sharding_strategy import ShardingStrategy, StrategiesVector +from colossalai.tensor.shape_consistency import ShapeConsistencyManager +from colossalai.device.device_mesh import DeviceMesh + + +class BNModel(nn.Module): + + def __init__(self, c): + super().__init__() + self.bn = nn.BatchNorm2d(c) + + def forward(self, x): + x = x * 2 + x = self.bn(x) + return x + + +def test_bn_handler(): + physical_mesh_id = torch.arange(0, 4) + mesh_shape = (2, 2) + # [[0, 1] + # [2, 3]] + device_mesh = DeviceMesh(physical_mesh_id, mesh_shape) + entire_shape = torch.Size((4, 16, 64, 64)) + shape_consistency_manager = ShapeConsistencyManager() + + tracer = ColoTracer() + model = BNModel(16) + input_sample = {'x': torch.rand(4, 16, 64, 64).to('meta')} + # graph(): + # %x : torch.Tensor [#users=1] = placeholder[target=x] + # %mul : [#users=1] = call_function[target=operator.mul](args = (%x, 2), kwargs = {}) + # %bn : [#users=1] = call_module[target=bn](args = (%mul,), kwargs = {}) + # return bn + graph = tracer.trace(root=model, meta_args=input_sample) + gm = GraphModule(model, graph, model.__class__.__name__) + gm.recompile() + # [x, mul, bn, output] + nodes = [node for node in gm.graph.nodes] + + # find the sharding strategies for the input node of the bn node + # strategies_for_input = [[R, R, R, R], [R, S0, R, R], [R, S1, R, R], [S0, R, R, R], [S0, S1, R, R], [S1, R, R, R], [S1, S0, R, R]] + strategies_vector_for_input = StrategiesVector(nodes[1]) + sharding_option = (None, 0, 1) + for first_sharding_index in sharding_option: + for second_sharding_index in sharding_option: + if first_sharding_index is not None and second_sharding_index == first_sharding_index: + continue + if first_sharding_index is None: + first_dim_spec = _DimSpec([]) + else: + first_dim_spec = _DimSpec([first_sharding_index]) + + if second_sharding_index is None: + second_dim_spec = _DimSpec([]) + else: + second_dim_spec = _DimSpec([second_sharding_index]) + + replica_dim_spec = _DimSpec([]) + sharding_sequence = [first_dim_spec, second_dim_spec, replica_dim_spec, replica_dim_spec] + sharding_spec = ShardingSpec(device_mesh=device_mesh, + entire_shape=entire_shape, + sharding_sequence=sharding_sequence) + strategy_name = str(sharding_spec.sharding_sequence) + sharding_strategy = ShardingStrategy(name=strategy_name, output_sharding_spec=sharding_spec) + strategies_vector_for_input.append(sharding_strategy) + setattr(nodes[1], 'strategies_vector', strategies_vector_for_input) + + # generate bn strategy + strategies_vector = StrategiesVector(node=nodes[2]) + bn_handler = BatchNormHandler(node=nodes[2], + device_mesh=device_mesh, + strategies_vector=strategies_vector, + shape_consistency_manager=shape_consistency_manager) + bn_handler.register_strategy() + # ['RS0 = RS0 x S0', 'S1S0 = RS0 x S0', 'RS1 = RS1 x S1', 'S0S1 = RS1 x S1', 'RR = RR x R', 'S0R = RR x R', 'S1R = RR x R', 'S01R = RR x R', 'RS01 = RS01 x S01', + # 'S0R = S0R x R WITH SYNC_BN', 'S1R = S1R x R WITH SYNC_BN', 'S0S1 = S0S1 x S1 WITH SYNC_BN', 'S1S0 = S1S0 x S0 WITH SYNC_BN', 'S01R = S01R x R WITH SYNC_BN'] + strategy_name_list = [strategy.name for strategy in bn_handler.strategies_vector] + + # RS = RS x S and strategies based on it, such as + # SS = RS x S + assert 'RS0 = RS0 x S0' in strategy_name_list + assert 'S1S0 = RS0 x S0' in strategy_name_list + assert 'RS1 = RS1 x S1' in strategy_name_list + assert 'S0S1 = RS1 x S1' in strategy_name_list + + # RR = RR x R and strategies based on it, such as + # SR = SR x R + assert 'RR = RR x R' in strategy_name_list + assert 'S0R = RR x R' in strategy_name_list + assert 'S1R = RR x R' in strategy_name_list + assert 'S01R = RR x R' in strategy_name_list + + # RS01 = RS01 x S01 + assert 'RS01 = RS01 x S01' in strategy_name_list + + # SR = SR x R WITH SYNC_BN + assert 'S0R = S0R x R WITH SYNC_BN' in strategy_name_list + assert 'S1R = S1R x R WITH SYNC_BN' in strategy_name_list + + # SS = SS x S WITH SYNC_BN + assert 'S0S1 = S0S1 x S1 WITH SYNC_BN' in strategy_name_list + assert 'S1S0 = S1S0 x S0 WITH SYNC_BN' in strategy_name_list + + # S01R = S01R x R WITH SYNC_BN + assert 'S01R = S01R x R WITH SYNC_BN' in strategy_name_list + + +if __name__ == '__main__': + test_bn_handler() diff --git a/tests/test_auto_parallel/test_conv_handler.py b/tests/test_auto_parallel/test_conv_handler.py index 45eb87e3b..50b9cfc46 100644 --- a/tests/test_auto_parallel/test_conv_handler.py +++ b/tests/test_auto_parallel/test_conv_handler.py @@ -6,7 +6,7 @@ import pytest from colossalai.fx.proxy import ColoProxy from colossalai.fx.tracer.tracer import ColoTracer from colossalai.tensor.sharding_spec import ShardingSpec, _DimSpec -from colossalai.auto_parallel.solver.conv_handler import ConvHandler +from colossalai.auto_parallel.solver.op_handler.conv_handler import ConvHandler from colossalai.auto_parallel.solver.sharding_strategy import ShardingStrategy, StrategiesVector from colossalai.tensor.shape_consistency import ShapeConsistencyManager from colossalai.device.device_mesh import DeviceMesh diff --git a/tests/test_auto_parallel/test_cost_graph.py b/tests/test_auto_parallel/test_cost_graph.py index 30e3ece3b..7d8232867 100644 --- a/tests/test_auto_parallel/test_cost_graph.py +++ b/tests/test_auto_parallel/test_cost_graph.py @@ -6,8 +6,6 @@ import pytest from colossalai.fx.proxy import ColoProxy from colossalai.fx.tracer.tracer import ColoTracer from colossalai.tensor.sharding_spec import ShardingSpec, _DimSpec -from colossalai.auto_parallel.solver.conv_handler import ConvHandler, CONV_STRATEGIES_LIST -from colossalai.auto_parallel.solver.sharding_strategy import ShardingStrategy, StrategiesVector from colossalai.tensor.shape_consistency import ShapeConsistencyManager from colossalai.device.device_mesh import DeviceMesh from colossalai.auto_parallel.solver.strategies_constructor import StrategiesConstructor diff --git a/tests/test_auto_parallel/test_dot_handler.py b/tests/test_auto_parallel/test_dot_handler.py index 4cc41178d..df503646e 100644 --- a/tests/test_auto_parallel/test_dot_handler.py +++ b/tests/test_auto_parallel/test_dot_handler.py @@ -6,7 +6,7 @@ import pytest from colossalai.fx.proxy import ColoProxy from colossalai.fx.tracer.tracer import ColoTracer from colossalai.tensor.sharding_spec import ShardingSpec, _DimSpec -from colossalai.auto_parallel.solver.dot_handler import DotHandler +from colossalai.auto_parallel.solver.op_handler.dot_handler import DotHandler from colossalai.auto_parallel.solver.sharding_strategy import ShardingStrategy, StrategiesVector from colossalai.tensor.shape_consistency import ShapeConsistencyManager from colossalai.device.device_mesh import DeviceMesh diff --git a/tests/test_auto_parallel/test_liveness_analysis.py b/tests/test_auto_parallel/test_liveness_analysis.py index 36039382f..f54441729 100644 --- a/tests/test_auto_parallel/test_liveness_analysis.py +++ b/tests/test_auto_parallel/test_liveness_analysis.py @@ -32,21 +32,21 @@ def test_liveness_analysis(): gm = ColoGraphModule(root=model, graph=graph, class_name=model.__class__.__name__) graph_analyser = GraphAnalyser(gm) - liveness_dict = graph_analyser.liveness_analysis() - stage_count = len(liveness_dict) + liveness_list = graph_analyser.liveness_analysis() + stage_count = len(liveness_list) - # 8 stages including input and output - assert stage_count == 8 + # if a LiveStage is covered by another LiveStage, we just keep the larger one. + assert stage_count == 1 # a variable named `relu` must exist # and this live var must have inplace = True - assert liveness_dict[5].all_live_vars.exists('relu') - relu_var = liveness_dict[5].all_live_vars.get('relu') + assert liveness_list[0].all_live_vars.exists('relu') + relu_var = liveness_list[0].all_live_vars.get('relu') assert relu_var.is_inplace # the unique vars must be fewer than the all vars since in-place ops exist - all_live_vars = liveness_dict[7].all_live_vars - unique_live_vars = liveness_dict[7].unique_live_vars + all_live_vars = liveness_list[0].all_live_vars + unique_live_vars = liveness_list[0].unique_live_vars assert len(unique_live_vars) + 1 == len(all_live_vars) diff --git a/tests/test_auto_parallel/test_solver.py b/tests/test_auto_parallel/test_solver.py new file mode 100644 index 000000000..56b1052a3 --- /dev/null +++ b/tests/test_auto_parallel/test_solver.py @@ -0,0 +1,78 @@ +import torch +from torch.fx import GraphModule +import torch.nn as nn +import pytest + +from colossalai.fx.tracer.tracer import ColoTracer +from colossalai.tensor.shape_consistency import ShapeConsistencyManager +from colossalai.device.device_mesh import DeviceMesh +from colossalai.auto_parallel.solver.strategies_constructor import StrategiesConstructor +from colossalai.auto_parallel.solver.cost_graph import CostGraph +from colossalai.auto_parallel.solver.graph_analysis import GraphAnalyser +from copy import deepcopy +from colossalai.auto_parallel.solver import Solver + + +class ConvModel(nn.Module): + + def __init__(self, c_in, c_out): + super().__init__() + self.conv1 = nn.Conv2d(c_in, c_out, kernel_size=3) + self.conv2 = nn.Conv2d(c_out, c_out, kernel_size=3) + self.conv3 = nn.Conv2d(c_out, c_out, kernel_size=3) + self.relu = nn.ReLU() + + def forward(self, x): + x = x * 2 + x = self.conv1(x) + x = self.conv2(x) + x = x / 2 + x = self.conv3(x) + x = self.relu(x) + return x + + +@pytest.mark.skip("for higher testing speed") +def test_solver(): + physical_mesh_id = torch.arange(0, 4) + mesh_shape = (2, 2) + # [[0, 1] + # [2, 3]] + device_mesh = DeviceMesh(physical_mesh_id, mesh_shape) + entire_shape = torch.Size((4, 16, 64, 64)) + shape_consistency_manager = ShapeConsistencyManager() + + tracer = ColoTracer() + model = ConvModel(16, 32) + input_sample = {'x': torch.rand(4, 16, 64, 64).to('meta')} + + # graph(): + # %x : torch.Tensor [#users=1] = placeholder[target=x] + # %mul : [#users=1] = call_function[target=operator.mul](args = (%x, 2), kwargs = {}) + # %conv1 : [#users=1] = call_module[target=conv1](args = (%mul,), kwargs = {}) + # %conv2 : [#users=1] = call_module[target=conv2](args = (%conv1,), kwargs = {}) + # %truediv : [#users=1] = call_function[target=operator.truediv](args = (%conv2, 2), kwargs = {}) + # %conv3 : [#users=1] = call_module[target=conv3](args = (%truediv,), kwargs = {}) + # %relu : [#users=1] = call_module[target=relu](args = (%conv3,), kwargs = {}) + # return relu + graph = tracer.trace(root=model, meta_args=input_sample) + gm = GraphModule(model, graph, model.__class__.__name__) + gm.recompile() + + solver_options = {'fast_mode': True} + strategies_constructor = StrategiesConstructor(graph, device_mesh, shape_consistency_manager, solver_options) + strategies_constructor.build_strategies_and_cost() + + cost_graph = CostGraph(strategies_constructor.leaf_strategies) + cost_graph.simplify_graph() + graph_analyser = GraphAnalyser(gm) + solver = Solver(gm.graph, strategies_constructor, cost_graph, graph_analyser) + ret = solver.call_solver_serialized_args() + + # [ 0 0 13 13 13 13 13 0] + strategies_combination_list = ret[0] + assert solver.leaf_strategies[2][13].name == 'S01R = S01R x RR' + + +if __name__ == '__main__': + test_solver() diff --git a/tests/test_auto_parallel/test_strategies_constructor.py b/tests/test_auto_parallel/test_strategies_constructor.py index 41a7e8bd7..37769d3c6 100644 --- a/tests/test_auto_parallel/test_strategies_constructor.py +++ b/tests/test_auto_parallel/test_strategies_constructor.py @@ -6,7 +6,7 @@ import pytest from colossalai.fx.proxy import ColoProxy from colossalai.fx.tracer.tracer import ColoTracer from colossalai.tensor.sharding_spec import ShardingSpec, _DimSpec -from colossalai.auto_parallel.solver.conv_handler import ConvHandler, CONV_STRATEGIES_LIST +from colossalai.auto_parallel.solver.op_handler.conv_handler import CONV_STRATEGIES_LIST from colossalai.auto_parallel.solver.sharding_strategy import ShardingStrategy, StrategiesVector from colossalai.tensor.shape_consistency import ShapeConsistencyManager from colossalai.device.device_mesh import DeviceMesh