diff --git a/colossalai/auto_parallel/solver/_utils.py b/colossalai/auto_parallel/solver/_utils.py index 54c9269a4..c62455cbe 100644 --- a/colossalai/auto_parallel/solver/_utils.py +++ b/colossalai/auto_parallel/solver/_utils.py @@ -1,8 +1,9 @@ +from colossalai.tensor.shape_consistency import ShapeConsistencyManager import torch from torch.fx.node import Node from colossalai.tensor.sharding_spec import ShardingSpec from colossalai.device.device_mesh import DeviceMesh -from typing import Union, Dict, List +from typing import Union, Dict, List, Optional def generate_sharding_spec(input_: Union[Node, torch.Tensor], device_mesh: DeviceMesh, @@ -31,3 +32,45 @@ def generate_sharding_spec(input_: Union[Node, torch.Tensor], device_mesh: Devic sharding_spec = ShardingSpec(device_mesh=device_mesh, entire_shape=shape, dim_partition_dict=dim_partition_dict) return sharding_spec + + +def generate_resharding_costs(nodes: List[Node], + sharding_specs: List[ShardingSpec], + count_backward: Optional[bool] = True, + dtype: Optional[torch.dtype] = None): + ''' + Compute the resharding costs with this specific strategy. + + Argument: + nodes (List[Node]): a list of nodes + sharding_spec_for_input(ShardingSpec): a list of ShardingSpec for the nodes. + count_backward (Optional[bool]): whether to include the cost of resharding in the backward pass, default is True. False can be used for inference. + dtype (Optional[torch.dtype]): the data type for cost calculation, default is None. + ''' + # The resharding_cost of weight is counted due to sharing weight cases. + resharding_costs = {} + size_per_elem_bytes = torch.tensor([], dtype=dtype).element_size() + + # shape consistency manager is a singleton class + shape_consistency_manager = ShapeConsistencyManager() + + for input_node, input_spec in zip(nodes, sharding_specs): + resharding_costs[input_node] = [] + for strategy in input_node.strategies_vector: + input_sharding_spec = strategy.output_sharding_spec + assert isinstance(input_sharding_spec, ShardingSpec), f'The input node should NOT be a tuple of tensor.' + # compute the resharding cost during forward phase + _, _, resharding_cost_forward = shape_consistency_manager.shape_consistency(input_sharding_spec, input_spec) + + if count_backward: + # In backward phase, we should convert grad with target_spec into input_sharding_spec + _, _, resharding_cost_backward = shape_consistency_manager.shape_consistency( + input_spec, input_sharding_spec) + total_resharding_cost = resharding_cost_forward + resharding_cost_backward + else: + total_resharding_cost = resharding_cost_forward + + # we need multiply the size of elem dtype to get correct communication cost + resharding_cost = total_resharding_cost * size_per_elem_bytes + resharding_costs[input_node].append(resharding_cost) + return resharding_costs diff --git a/colossalai/auto_parallel/solver/op_handler/batch_norm_handler.py b/colossalai/auto_parallel/solver/op_handler/batch_norm_handler.py index 8e6b1a7c0..ae343b03a 100644 --- a/colossalai/auto_parallel/solver/op_handler/batch_norm_handler.py +++ b/colossalai/auto_parallel/solver/op_handler/batch_norm_handler.py @@ -4,7 +4,6 @@ import warnings import torch from colossalai.auto_parallel.solver.sharding_strategy import ShardingStrategy, StrategiesVector from .operator_handler import OperatorHandler -from .._utils import generate_sharding_spec __all__ = ['BatchNormHandler'] @@ -115,15 +114,13 @@ class BatchNormHandler(OperatorHandler): name = f'RS{mesh_dim_0} = RS{mesh_dim_0} x S{mesh_dim_0}' dim_partition_dict_for_input = {1: [mesh_dim_0]} - sharding_spec_for_input = generate_sharding_spec(self.input_data, self.device_mesh, - dim_partition_dict_for_input) + sharding_spec_for_input = self._generate_sharding_spec(self.input_data, dim_partition_dict_for_input) dim_partition_dict_for_weight = {0: [mesh_dim_0]} - sharding_spec_for_weight = generate_sharding_spec(self.weight, self.device_mesh, dim_partition_dict_for_weight) + sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight) dim_partition_dict_for_output = {1: [mesh_dim_0]} - sharding_spec_for_output = generate_sharding_spec(self.output_data, self.device_mesh, - dim_partition_dict_for_output) + sharding_spec_for_output = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output) # generate resharding cost for this strategy resharding_costs = self._generate_resharding_costs([sharding_spec_for_input]) @@ -156,8 +153,7 @@ class BatchNormHandler(OperatorHandler): new_name = f'S{mesh_dim_1}S{mesh_dim_0} = RS{mesh_dim_0} x S{mesh_dim_0}' dim_partition_dict_for_output = {0: [mesh_dim_1], 1: [mesh_dim_0]} - new_sharding_spec_for_output = generate_sharding_spec(self.output_data, self.device_mesh, - dim_partition_dict_for_output) + new_sharding_spec_for_output = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output) # the computation cost is all the same new_compute_cost = compute_cost @@ -192,15 +188,13 @@ class BatchNormHandler(OperatorHandler): name = f'RS{mesh_dim_0}{mesh_dim_1} = RS{mesh_dim_0}{mesh_dim_1} x S{mesh_dim_0}{mesh_dim_1}' dim_partition_dict_for_input = {1: [mesh_dim_0, mesh_dim_1]} - sharding_spec_for_input = generate_sharding_spec(self.input_data, self.device_mesh, - dim_partition_dict_for_input) + sharding_spec_for_input = self._generate_sharding_spec(self.input_data, dim_partition_dict_for_input) dim_partition_dict_for_weight = {0: [mesh_dim_0, mesh_dim_1]} - sharding_spec_for_weight = generate_sharding_spec(self.weight, self.device_mesh, dim_partition_dict_for_weight) + sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight) dim_partition_dict_for_output = {1: [mesh_dim_0, mesh_dim_1]} - sharding_spec_for_output = generate_sharding_spec(self.output_data, self.device_mesh, - dim_partition_dict_for_output) + sharding_spec_for_output = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output) # generate resharding cost for this strategy resharding_costs = self._generate_resharding_costs([sharding_spec_for_input]) @@ -234,15 +228,13 @@ class BatchNormHandler(OperatorHandler): name = f'RR = RR x R' dim_partition_dict_for_input = {} - sharding_spec_for_input = generate_sharding_spec(self.input_data, self.device_mesh, - dim_partition_dict_for_input) + sharding_spec_for_input = self._generate_sharding_spec(self.input_data, dim_partition_dict_for_input) dim_partition_dict_for_weight = {} - sharding_spec_for_weight = generate_sharding_spec(self.weight, self.device_mesh, dim_partition_dict_for_weight) + sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight) dim_partition_dict_for_output = {} - sharding_spec_for_output = generate_sharding_spec(self.output_data, self.device_mesh, - dim_partition_dict_for_output) + sharding_spec_for_output = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output) # generate resharding cost for this strategy resharding_costs = self._generate_resharding_costs([sharding_spec_for_input]) @@ -273,8 +265,7 @@ class BatchNormHandler(OperatorHandler): def _construct_batch_sharding_strategies(mesh_dim_list, new_name): dim_partition_dict_for_output = {0: mesh_dim_list} - new_sharding_spec_for_output = generate_sharding_spec(self.output_data, self.device_mesh, - dim_partition_dict_for_output) + new_sharding_spec_for_output = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output) # the computation cost is all the same new_compute_cost = compute_cost @@ -332,15 +323,13 @@ class BatchNormHandler(OperatorHandler): name = f'S{mesh_dim_0}R = S{mesh_dim_0}R x R WITH SYNC_BN' dim_partition_dict_for_input = {0: [mesh_dim_0]} - sharding_spec_for_input = generate_sharding_spec(self.input_data, self.device_mesh, - dim_partition_dict_for_input) + sharding_spec_for_input = self._generate_sharding_spec(self.input_data, dim_partition_dict_for_input) dim_partition_dict_for_weight = {} - sharding_spec_for_weight = generate_sharding_spec(self.weight, self.device_mesh, dim_partition_dict_for_weight) + sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight) dim_partition_dict_for_output = {0: [mesh_dim_0]} - sharding_spec_for_output = generate_sharding_spec(self.output_data, self.device_mesh, - dim_partition_dict_for_output) + sharding_spec_for_output = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output) # generate resharding cost for this strategy resharding_costs = self._generate_resharding_costs([sharding_spec_for_input]) @@ -374,15 +363,13 @@ class BatchNormHandler(OperatorHandler): name = f'S{mesh_dim_0}{mesh_dim_1}R = S{mesh_dim_0}{mesh_dim_1}R x R WITH SYNC_BN' dim_partition_dict_for_input = {0: [mesh_dim_0, mesh_dim_1]} - sharding_spec_for_input = generate_sharding_spec(self.input_data, self.device_mesh, - dim_partition_dict_for_input) + sharding_spec_for_input = self._generate_sharding_spec(self.input_data, dim_partition_dict_for_input) dim_partition_dict_for_weight = {} - sharding_spec_for_weight = generate_sharding_spec(self.weight, self.device_mesh, dim_partition_dict_for_weight) + sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight) dim_partition_dict_for_output = {0: [mesh_dim_0, mesh_dim_1]} - sharding_spec_for_output = generate_sharding_spec(self.output_data, self.device_mesh, - dim_partition_dict_for_output) + sharding_spec_for_output = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output) # generate resharding cost for this strategy resharding_costs = self._generate_resharding_costs([sharding_spec_for_input]) @@ -416,15 +403,13 @@ class BatchNormHandler(OperatorHandler): name = f'S{mesh_dim_0}S{mesh_dim_1} = S{mesh_dim_0}S{mesh_dim_1} x S{mesh_dim_1} WITH SYNC_BN' dim_partition_dict_for_input = {0: [mesh_dim_0], 1: [mesh_dim_1]} - sharding_spec_for_input = generate_sharding_spec(self.input_data, self.device_mesh, - dim_partition_dict_for_input) + sharding_spec_for_input = self._generate_sharding_spec(self.input_data, dim_partition_dict_for_input) dim_partition_dict_for_weight = {0: [mesh_dim_1]} - sharding_spec_for_weight = generate_sharding_spec(self.weight, self.device_mesh, dim_partition_dict_for_weight) + sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight) dim_partition_dict_for_output = {0: [mesh_dim_0], 1: [mesh_dim_1]} - sharding_spec_for_output = generate_sharding_spec(self.output_data, self.device_mesh, - dim_partition_dict_for_output) + sharding_spec_for_output = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output) # generate resharding cost for this strategy resharding_costs = self._generate_resharding_costs([sharding_spec_for_input]) @@ -459,7 +444,7 @@ class BatchNormHandler(OperatorHandler): Generate every possible strategies for a BatchNorm node, and record all strategies into the strategies_vector. Example: - norm_handler = BatchNormHandler(node, self.device_mesh, strategies_vector, + norm_handler = BatchNormHandler(node, strategies_vector, self.shape_consistency_manager) norm_handler.register_strategy() for strategy in norm_handler.strategies_vector: diff --git a/colossalai/auto_parallel/solver/op_handler/conv_handler.py b/colossalai/auto_parallel/solver/op_handler/conv_handler.py index d41817652..8f062e7fe 100644 --- a/colossalai/auto_parallel/solver/op_handler/conv_handler.py +++ b/colossalai/auto_parallel/solver/op_handler/conv_handler.py @@ -4,7 +4,6 @@ import warnings import torch from colossalai.auto_parallel.solver.sharding_strategy import ShardingStrategy, StrategiesVector from .operator_handler import OperatorHandler -from .._utils import generate_sharding_spec __all__ = ['ConvHandler'] @@ -109,15 +108,13 @@ class ConvHandler(OperatorHandler): name = f'S{mesh_dim_0}S{mesh_dim_1} = S{mesh_dim_0}R x RS{mesh_dim_1}' dim_partition_dict_for_input = {0: [mesh_dim_0]} - sharding_spec_for_input = generate_sharding_spec(self.input_data, self.device_mesh, - dim_partition_dict_for_input) + sharding_spec_for_input = self._generate_sharding_spec(self.input_data, dim_partition_dict_for_input) dim_partition_dict_for_weight = {1: [mesh_dim_1]} - sharding_spec_for_weight = generate_sharding_spec(self.weight, self.device_mesh, dim_partition_dict_for_weight) + sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight) dim_partition_dict_for_output = {0: [mesh_dim_0], 1: [mesh_dim_1]} - sharding_spec_for_output = generate_sharding_spec(self.output_data, self.device_mesh, - dim_partition_dict_for_output) + sharding_spec_for_output = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output) # generate resharding cost for this strategy resharding_costs = self._generate_resharding_costs([sharding_spec_for_input, sharding_spec_for_weight]) @@ -158,15 +155,13 @@ class ConvHandler(OperatorHandler): name = f'S{mesh_dim_0}R = S{mesh_dim_0}R x RR' dim_partition_dict_for_input = {0: [mesh_dim_0]} - sharding_spec_for_input = generate_sharding_spec(self.input_data, self.device_mesh, - dim_partition_dict_for_input) + sharding_spec_for_input = self._generate_sharding_spec(self.input_data, dim_partition_dict_for_input) dim_partition_dict_for_weight = {} - sharding_spec_for_weight = generate_sharding_spec(self.weight, self.device_mesh, dim_partition_dict_for_weight) + sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight) dim_partition_dict_for_output = {0: [mesh_dim_0]} - sharding_spec_for_output = generate_sharding_spec(self.output_data, self.device_mesh, - dim_partition_dict_for_output) + sharding_spec_for_output = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output) # generate resharding cost for this strategy resharding_costs = self._generate_resharding_costs([sharding_spec_for_input, sharding_spec_for_weight]) @@ -205,15 +200,13 @@ class ConvHandler(OperatorHandler): name = f'S{mesh_dim_0}R = S{mesh_dim_0}S{mesh_dim_1} x S{mesh_dim_1}R' dim_partition_dict_for_input = {0: [mesh_dim_0], 1: [mesh_dim_1]} - sharding_spec_for_input = generate_sharding_spec(self.input_data, self.device_mesh, - dim_partition_dict_for_input) + sharding_spec_for_input = self._generate_sharding_spec(self.input_data, dim_partition_dict_for_input) dim_partition_dict_for_weight = {0: [mesh_dim_0]} - sharding_spec_for_weight = generate_sharding_spec(self.weight, self.device_mesh, dim_partition_dict_for_weight) + sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight) dim_partition_dict_for_output = {0: [mesh_dim_0]} - sharding_spec_for_output = generate_sharding_spec(self.output_data, self.device_mesh, - dim_partition_dict_for_output) + sharding_spec_for_output = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output) # generate resharding cost for this strategy resharding_costs = self._generate_resharding_costs([sharding_spec_for_input, sharding_spec_for_weight]) @@ -252,15 +245,13 @@ class ConvHandler(OperatorHandler): name = f'RS{mesh_dim_1} = RS{mesh_dim_0} x S{mesh_dim_0}S{mesh_dim_1}' dim_partition_dict_for_input = {1: [mesh_dim_0]} - sharding_spec_for_input = generate_sharding_spec(self.input_data, self.device_mesh, - dim_partition_dict_for_input) + sharding_spec_for_input = self._generate_sharding_spec(self.input_data, dim_partition_dict_for_input) dim_partition_dict_for_weight = {0: [mesh_dim_0], 1: [mesh_dim_1]} - sharding_spec_for_weight = generate_sharding_spec(self.weight, self.device_mesh, dim_partition_dict_for_weight) + sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight) dim_partition_dict_for_output = {1: [mesh_dim_1]} - sharding_spec_for_output = generate_sharding_spec(self.output_data, self.device_mesh, - dim_partition_dict_for_output) + sharding_spec_for_output = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output) # generate resharding cost for this strategy resharding_costs = self._generate_resharding_costs([sharding_spec_for_input, sharding_spec_for_weight]) @@ -296,15 +287,13 @@ class ConvHandler(OperatorHandler): name = f'RR = RS{mesh_dim_0} x S{mesh_dim_0}R' dim_partition_dict_for_input = {1: [mesh_dim_0]} - sharding_spec_for_input = generate_sharding_spec(self.input_data, self.device_mesh, - dim_partition_dict_for_input) + sharding_spec_for_input = self._generate_sharding_spec(self.input_data, dim_partition_dict_for_input) dim_partition_dict_for_weight = {0: [mesh_dim_0]} - sharding_spec_for_weight = generate_sharding_spec(self.weight, self.device_mesh, dim_partition_dict_for_weight) + sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight) dim_partition_dict_for_output = {} - sharding_spec_for_output = generate_sharding_spec(self.output_data, self.device_mesh, - dim_partition_dict_for_output) + sharding_spec_for_output = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output) # generate resharding cost for this strategy resharding_costs = self._generate_resharding_costs([sharding_spec_for_input, sharding_spec_for_weight]) @@ -340,15 +329,13 @@ class ConvHandler(OperatorHandler): name = f'RS{mesh_dim_0} = RR x RS{mesh_dim_0}' dim_partition_dict_for_input = {} - sharding_spec_for_input = generate_sharding_spec(self.input_data, self.device_mesh, - dim_partition_dict_for_input) + sharding_spec_for_input = self._generate_sharding_spec(self.input_data, dim_partition_dict_for_input) dim_partition_dict_for_weight = {1: [mesh_dim_0]} - sharding_spec_for_weight = generate_sharding_spec(self.weight, self.device_mesh, dim_partition_dict_for_weight) + sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight) dim_partition_dict_for_output = {1: [mesh_dim_0]} - sharding_spec_for_output = generate_sharding_spec(self.output_data, self.device_mesh, - dim_partition_dict_for_output) + sharding_spec_for_output = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output) # generate resharding cost for this strategy resharding_costs = self._generate_resharding_costs([sharding_spec_for_input, sharding_spec_for_weight]) @@ -384,15 +371,13 @@ class ConvHandler(OperatorHandler): name = f'RR = RR x RR' dim_partition_dict_for_input = {} - sharding_spec_for_input = generate_sharding_spec(self.input_data, self.device_mesh, - dim_partition_dict_for_input) + sharding_spec_for_input = self._generate_sharding_spec(self.input_data, dim_partition_dict_for_input) dim_partition_dict_for_weight = {} - sharding_spec_for_weight = generate_sharding_spec(self.weight, self.device_mesh, dim_partition_dict_for_weight) + sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight) dim_partition_dict_for_output = {} - sharding_spec_for_output = generate_sharding_spec(self.output_data, self.device_mesh, - dim_partition_dict_for_output) + sharding_spec_for_output = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output) # generate resharding cost for this strategy resharding_costs = self._generate_resharding_costs([sharding_spec_for_input, sharding_spec_for_weight]) @@ -426,15 +411,13 @@ class ConvHandler(OperatorHandler): name = f'S{mesh_dim_0}{mesh_dim_1}R = S{mesh_dim_0}{mesh_dim_1}R x RR' dim_partition_dict_for_input = {0: [mesh_dim_0, mesh_dim_1]} - sharding_spec_for_input = generate_sharding_spec(self.input_data, self.device_mesh, - dim_partition_dict_for_input) + sharding_spec_for_input = self._generate_sharding_spec(self.input_data, dim_partition_dict_for_input) dim_partition_dict_for_weight = {} - sharding_spec_for_weight = generate_sharding_spec(self.weight, self.device_mesh, dim_partition_dict_for_weight) + sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight) dim_partition_dict_for_output = {0: [mesh_dim_0, mesh_dim_1]} - sharding_spec_for_output = generate_sharding_spec(self.output_data, self.device_mesh, - dim_partition_dict_for_output) + sharding_spec_for_output = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output) # generate resharding cost for this strategy resharding_costs = self._generate_resharding_costs([sharding_spec_for_input, sharding_spec_for_weight]) @@ -475,15 +458,13 @@ class ConvHandler(OperatorHandler): name = f'RR = RS{mesh_dim_0}{mesh_dim_1} x S{mesh_dim_0}{mesh_dim_1}R' dim_partition_dict_for_input = {1: [mesh_dim_0, mesh_dim_1]} - sharding_spec_for_input = generate_sharding_spec(self.input_data, self.device_mesh, - dim_partition_dict_for_input) + sharding_spec_for_input = self._generate_sharding_spec(self.input_data, dim_partition_dict_for_input) dim_partition_dict_for_weight = {0: [mesh_dim_0, mesh_dim_1]} - sharding_spec_for_weight = generate_sharding_spec(self.weight, self.device_mesh, dim_partition_dict_for_weight) + sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight) dim_partition_dict_for_output = {} - sharding_spec_for_output = generate_sharding_spec(self.output_data, self.device_mesh, - dim_partition_dict_for_output) + sharding_spec_for_output = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output) # generate resharding cost for this strategy resharding_costs = self._generate_resharding_costs([sharding_spec_for_input, sharding_spec_for_weight]) diff --git a/colossalai/auto_parallel/solver/op_handler/dot_handler.py b/colossalai/auto_parallel/solver/op_handler/dot_handler.py index 9fa99f748..26791df46 100644 --- a/colossalai/auto_parallel/solver/op_handler/dot_handler.py +++ b/colossalai/auto_parallel/solver/op_handler/dot_handler.py @@ -3,7 +3,6 @@ import torch from colossalai.auto_parallel.solver.sharding_strategy import ShardingStrategy, StrategiesVector from .operator_handler import OperatorHandler from functools import reduce -from .._utils import generate_sharding_spec __all__ = ['DotHandler'] @@ -29,16 +28,14 @@ class DotHandler(OperatorHandler): name = f'S{mesh_dim_0}S{mesh_dim_1} = S{mesh_dim_0}R x RS{mesh_dim_1}' dim_partition_dict_for_input = {0: [mesh_dim_0]} - sharding_spec_for_input = generate_sharding_spec(self.input_data, self.device_mesh, - dim_partition_dict_for_input) + sharding_spec_for_input = self._generate_sharding_spec(self.input_data, dim_partition_dict_for_input) # linear layer weight is transposed during init dim_partition_dict_for_weight = {0: [mesh_dim_1]} - sharding_spec_for_weight = generate_sharding_spec(self.weight, self.device_mesh, dim_partition_dict_for_weight) + sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight) dim_partition_dict_for_output = {0: [mesh_dim_0], 1: [mesh_dim_1]} - sharding_spec_for_ouput = generate_sharding_spec(self.output_data, self.device_mesh, - dim_partition_dict_for_input) + sharding_spec_for_ouput = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_input) # generate resharding cost for this strategy resharding_costs = self._generate_resharding_costs([sharding_spec_for_input]) @@ -69,17 +66,15 @@ class DotHandler(OperatorHandler): name = f'S{mesh_dim_0}R = S{mesh_dim_0}S{mesh_dim_1} x S{mesh_dim_1}R' dim_partition_dict_for_input = {0: [mesh_dim_0], 1: [mesh_dim_1]} - sharding_spec_for_input = generate_sharding_spec(self.input_data, self.device_mesh, - dim_partition_dict_for_input) + sharding_spec_for_input = self._generate_sharding_spec(self.input_data, dim_partition_dict_for_input) # since weight of the linear layer is transposed # the actual dim to be sharded is 1 dim_partition_dict_for_weight = {1: [mesh_dim_0]} - sharding_spec_for_weight = generate_sharding_spec(self.weight, self.device_mesh, dim_partition_dict_for_weight) + sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight) dim_partition_dict_for_output = {0: [mesh_dim_0]} - sharding_spec_for_ouput = generate_sharding_spec(self.output_data, self.device_mesh, - dim_partition_dict_for_output) + sharding_spec_for_ouput = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output) # generate resharding cost for this strategy resharding_costs = self._generate_resharding_costs([sharding_spec_for_input]) @@ -106,15 +101,13 @@ class DotHandler(OperatorHandler): name = f'RS{mesh_dim_1} = RS{mesh_dim_0} x S{mesh_dim_0}S{mesh_dim_1}' dim_partition_dict_for_input = {1: [mesh_dim_0]} - sharding_spec_for_input = generate_sharding_spec(self.input_data, self.device_mesh, - dim_partition_dict_for_input) + sharding_spec_for_input = self._generate_sharding_spec(self.input_data, dim_partition_dict_for_input) dim_partition_dict_for_weight = {0: [mesh_dim_0], 1: [mesh_dim_1]} - sharding_spec_for_weight = generate_sharding_spec(self.weight, self.device_mesh, dim_partition_dict_for_weight) + sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight) dim_partition_dict_for_output = {1: [mesh_dim_1]} - sharding_spec_for_ouput = generate_sharding_spec(self.output_data, self.device_mesh, - dim_partition_dict_for_input) + sharding_spec_for_ouput = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_input) # generate resharding cost for this strategy resharding_costs = self._generate_resharding_costs([sharding_spec_for_input]) @@ -141,15 +134,13 @@ class DotHandler(OperatorHandler): name = f'RR = RS{mesh_dim} x S{mesh_dim}R' dim_partition_dict_for_input = {1: [mesh_dim]} - sharding_spec_for_input = generate_sharding_spec(self.input_data, self.device_mesh, - dim_partition_dict_for_input) + sharding_spec_for_input = self._generate_sharding_spec(self.input_data, dim_partition_dict_for_input) dim_partition_dict_for_weight = {1: [mesh_dim]} - sharding_spec_for_weight = generate_sharding_spec(self.weight, self.device_mesh, dim_partition_dict_for_weight) + sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight) dim_partition_dict_for_output = {} - sharding_spec_for_ouput = generate_sharding_spec(self.output_data, self.device_mesh, - dim_partition_dict_for_output) + sharding_spec_for_ouput = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output) # generate resharding cost for this strategy resharding_costs = self._generate_resharding_costs([sharding_spec_for_input]) @@ -176,15 +167,13 @@ class DotHandler(OperatorHandler): name = f'RS{mesh_dim} = RR x RS{mesh_dim}' dim_partition_dict_for_input = {} - sharding_spec_for_input = generate_sharding_spec(self.input_data, self.device_mesh, - dim_partition_dict_for_input) + sharding_spec_for_input = self._generate_sharding_spec(self.input_data, dim_partition_dict_for_input) dim_partition_dict_for_weight = {0: [mesh_dim]} - sharding_spec_for_weight = generate_sharding_spec(self.weight, self.device_mesh, dim_partition_dict_for_weight) + sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight) dim_partition_dict_for_output = {1: [mesh_dim]} - sharding_spec_for_ouput = generate_sharding_spec(self.output_data, self.device_mesh, - dim_partition_dict_for_output) + sharding_spec_for_ouput = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output) # generate resharding cost for this strategy resharding_costs = self._generate_resharding_costs([sharding_spec_for_input]) @@ -211,15 +200,13 @@ class DotHandler(OperatorHandler): name = f'S{mesh_dim_0}{mesh_dim_1}R = S{mesh_dim_0}{mesh_dim_1}R x RR' dim_partition_dict_for_input = {0: [mesh_dim_0, mesh_dim_1]} - sharding_spec_for_input = generate_sharding_spec(self.input_data, self.device_mesh, - dim_partition_dict_for_input) + sharding_spec_for_input = self._generate_sharding_spec(self.input_data, dim_partition_dict_for_input) dim_partition_dict_for_weight = {} - sharding_spec_for_weight = generate_sharding_spec(self.weight, self.device_mesh, dim_partition_dict_for_weight) + sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight) dim_partition_dict_for_output = {0: [mesh_dim_0, mesh_dim_1]} - sharding_spec_for_ouput = generate_sharding_spec(self.output_data, self.device_mesh, - dim_partition_dict_for_output) + sharding_spec_for_ouput = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output) # generate resharding cost for this strategy resharding_costs = self._generate_resharding_costs([sharding_spec_for_input]) @@ -246,15 +233,13 @@ class DotHandler(OperatorHandler): name = f'RR = RS{mesh_dim_0}{mesh_dim_1} x S{mesh_dim_0}{mesh_dim_1}R' dim_partition_dict_for_input = {1: [mesh_dim_0, mesh_dim_1]} - sharding_spec_for_input = generate_sharding_spec(self.input_data, self.device_mesh, - dim_partition_dict_for_input) + sharding_spec_for_input = self._generate_sharding_spec(self.input_data, dim_partition_dict_for_input) dim_partition_dict_for_weight = {0: [mesh_dim_0, mesh_dim_1]} - sharding_spec_for_weight = generate_sharding_spec(self.weight, self.device_mesh, dim_partition_dict_for_weight) + sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight) dim_partition_dict_for_output = {} - sharding_spec_for_ouput = generate_sharding_spec(self.output_data, self.device_mesh, - dim_partition_dict_for_output) + sharding_spec_for_ouput = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output) # generate resharding cost for this strategy resharding_costs = self._generate_resharding_costs([sharding_spec_for_input]) @@ -281,15 +266,13 @@ class DotHandler(OperatorHandler): name = f'RS{mesh_dim_0}{mesh_dim_1} = RR x RS{mesh_dim_0}{mesh_dim_1}' dim_partition_dict_for_input = {} - sharding_spec_for_input = generate_sharding_spec(self.input_data, self.device_mesh, - dim_partition_dict_for_input) + sharding_spec_for_input = self._generate_sharding_spec(self.input_data, dim_partition_dict_for_input) dim_partition_dict_for_weight = {1: [mesh_dim_0, mesh_dim_1]} - sharding_spec_for_weight = generate_sharding_spec(self.weight, self.device_mesh, dim_partition_dict_for_weight) + sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight) dim_partition_dict_for_output = {1: [mesh_dim_0, mesh_dim_1]} - sharding_spec_for_ouput = generate_sharding_spec(self.output_data, self.device_mesh, - dim_partition_dict_for_output) + sharding_spec_for_ouput = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output) # generate resharding cost for this strategy resharding_costs = self._generate_resharding_costs([sharding_spec_for_input]) diff --git a/colossalai/auto_parallel/solver/op_handler/operator_handler.py b/colossalai/auto_parallel/solver/op_handler/operator_handler.py index dc397514e..44b4d8217 100644 --- a/colossalai/auto_parallel/solver/op_handler/operator_handler.py +++ b/colossalai/auto_parallel/solver/op_handler/operator_handler.py @@ -7,6 +7,7 @@ from typing import Dict, List from colossalai.device.device_mesh import DeviceMesh from colossalai.tensor.shape_consistency import ShapeConsistencyManager from colossalai.tensor.sharding_spec import ShardingSpec +from .._utils import generate_resharding_costs, generate_sharding_spec from ..sharding_strategy import StrategiesVector @@ -17,24 +18,24 @@ class OperatorHandler(ABC): ''' The OperatorHandler is an abstract class used to generate every possible strategies for an operator node. - Argument: - input_node(Node): the input node in node argument list. - input_index(int): the index of input node in the node argument list. - weight(torch.Tensor): Weight of the node. - output_node(Node): Output_node is the output of the node. - device_mesh(DeviceMesh): A logical view of a physical mesh. - strategies_vector(StrategiesVector): all the strategies generated in this handler will be recorded into the strategies_vector. - shape_consistency_manager(ShapeConsistencyManager): ShapeConsistencyManager will give the resharding costs of the different sharding specs. + Args: + node (Node): the input node in node argument list. + device_mesh (DeviceMesh): A logical view of a physical mesh. + strategies_vector (StrategiesVector): all the strategies generated in this handler will be recorded into the strategies_vector. + handle_backward (Optional[bool]): whether to consider the backward pass. The default value is True. False can be used for inference. ''' - def __init__(self, node: Node, device_mesh: DeviceMesh, strategies_vector: StrategiesVector, - shape_consistency_manager: ShapeConsistencyManager): + def __init__(self, + node: Node, + device_mesh: DeviceMesh, + strategies_vector: StrategiesVector, + handle_backward: bool = True): self.node = node self.predecessor_node = list(node._input_nodes.keys()) self.successor_node = list(node.users.keys()) self.device_mesh = device_mesh self.strategies_vector = strategies_vector - self.shape_consistency_manager = shape_consistency_manager + self.handle_backward = handle_backward # find the module and its parameters associated with this node # this can be used to compute the compute/communication/sharding cost @@ -102,35 +103,23 @@ class OperatorHandler(ABC): return total_memory_cost, activation_memory_cost, weight_memory_cost - def _generate_resharding_costs(self, sharding_spec_for_input): - ''' - Compute the resharding costs with this specific strategy. - - Note: The resharding_cost of weight is NOT counted. - - Argument: - resharding_costs(Dict[int, List[float]]): The resharding cost generated in this method will be appended into this dictionary. - Resharding_cost[i][j] means the cost of i-th argument in the output node argument list - with j-th strategy in its strategies_vector transforms to sharding spec wanted in this - strategy. - sharding_spec_for_input(ShardingSpec): ShardingSpec of the input node. - ''' + def _generate_resharding_costs(self, sharding_specs): # The resharding_cost of weight is counted due to sharing weight cases. - resharding_costs = {} dtype = self.node._meta_data.dtype - size_per_elem_bytes = torch.tensor([], dtype=dtype).element_size() - for input_node, input_spec in zip(self.predecessor_node, sharding_spec_for_input): - resharding_costs[input_node] = [] - for strategy in input_node.strategies_vector: - input_sharding_spec = strategy.output_sharding_spec - assert isinstance(input_sharding_spec, ShardingSpec), f'The input node should NOT be a tuple of tensor.' - # compute the resharding cost during forward phase - _, _, resharding_cost_forward = self.shape_consistency_manager.shape_consistency( - input_sharding_spec, input_spec) - # In backward phase, we should convert grad with target_spec into input_sharding_spec - _, _, resharding_cost_backward = self.shape_consistency_manager.shape_consistency( - input_spec, input_sharding_spec) - # we need multiply the size of elem dtype to get correct communication cost - resharding_cost = (resharding_cost_forward + resharding_cost_backward) * size_per_elem_bytes - resharding_costs[input_node].append(resharding_cost) - return resharding_costs + nodes = self.predecessor_node + return generate_resharding_costs(nodes=nodes, + sharding_specs=sharding_specs, + count_backward=self.handle_backward, + dtype=dtype) + + def _generate_sharding_spec(self, input_: torch.Tensor, dim_partition_dict: Dict[int, List[int]]) -> ShardingSpec: + return generate_sharding_spec(input_=input_, + device_mesh=self.device_mesh, + dim_partition_dict=dim_partition_dict) + + @abstractmethod + def _generate_compute_cost(self, *args, **kwargs): + """ + Compute the flops involved in the node. + """ + pass diff --git a/colossalai/auto_parallel/solver/strategies_constructor.py b/colossalai/auto_parallel/solver/strategies_constructor.py index 6343e201c..101be664e 100644 --- a/colossalai/auto_parallel/solver/strategies_constructor.py +++ b/colossalai/auto_parallel/solver/strategies_constructor.py @@ -11,7 +11,7 @@ import math import torch import operator from typing import Dict, List -from ._utils import generate_sharding_spec +from ._utils import generate_sharding_spec, generate_resharding_costs class StrategiesConstructor: @@ -21,12 +21,10 @@ class StrategiesConstructor: Args: graph (Graph): a Graph object used for analysis and strategy generation. device_mesh (DeviceMesh): a DeviceMesh object which contains the meta information about the cluster. - shape_consistency_manager (ShapeConsistencyManager): a ShapeConsistencyManager object to make sure the sharding specs are consistent. solver_options (SolverOptions): a SolverOptions object which specifies the preferences for plan searching. """ - def __init__(self, graph: Graph, device_mesh: DeviceMesh, shape_consistency_manager: ShapeConsistencyManager, - solver_options: SolverOptions): + def __init__(self, graph: Graph, device_mesh: DeviceMesh, solver_options: SolverOptions): self.graph = graph assert graph.owning_module is not None, 'The given graph is not associated with a owning_module' self.root_module = self.graph.owning_module @@ -34,27 +32,8 @@ class StrategiesConstructor: self.device_mesh = device_mesh self.leaf_strategies = [] self.strategy_map = {} - self.shape_consistency_manager = shape_consistency_manager self.solver_options = solver_options - def _generate_resharding_costs(self, input_nodes, target_sharding_specs): - ''' - Compute the resharding costs with this specific strategy. - - Argument: - sharding_spec_for_input(ShardingSpec): ShardingSpec of the input node. - ''' - resharding_costs = {} - for input_node, target_sharding_spec in zip(input_nodes, target_sharding_specs): - resharding_costs[input_node] = [] - for strategy in input_node.strategies_vector: - input_sharding_spec = strategy.output_sharding_spec - assert isinstance(input_sharding_spec, ShardingSpec), f'The input node should NOT be a tuple of tensor.' - _, _, resharding_cost = self.shape_consistency_manager.shape_consistency( - input_sharding_spec, target_sharding_spec) - resharding_costs[input_node].append(resharding_cost) - return resharding_costs - def remove_duplicated_strategy(self, strategies_vector): ''' In build_strategies_and_cost method, we may produce some duplicated strategies. @@ -120,14 +99,13 @@ class StrategiesConstructor: # conv module if submod_type in CONV_MODULE_OP: # use ConvHandler to create sharding strategies for conv module node - conv_handler = ConvHandler(node, self.device_mesh, strategies_vector, - self.shape_consistency_manager) + conv_handler = ConvHandler(node, self.device_mesh, strategies_vector) conv_handler.register_strategy() # linear module elif submod_type in LINEAR_MODULE_OP: # use DotHandler to create sharding strategies for linear module node - dot_handler = DotHandler(node, self.device_mesh, strategies_vector, self.shape_consistency_manager) + dot_handler = DotHandler(node, self.device_mesh, strategies_vector) dot_handler.register_strategy() # element-wise module @@ -158,8 +136,8 @@ class StrategiesConstructor: # TODO: use meta_info_prop to profile memory cost and compute cost compute_cost = node._meta_data.numel() memory_cost = 0 - resharding_costs = self._generate_resharding_costs(strategies_vector.predecessor_nodes, - [input_sharding_spec]) + resharding_costs = generate_resharding_costs(strategies_vector.predecessor_nodes, + [input_sharding_spec]) # to prevent the resharding happening, set their resharding cost to inf. resharding_costs[input_node] = [ @@ -214,8 +192,8 @@ class StrategiesConstructor: # TODO: use meta_info_prop to profile memory cost and compute cost compute_cost = node._meta_data.numel() memory_cost = 0 - resharding_costs = self._generate_resharding_costs(strategies_vector.predecessor_nodes, - [input_sharding_spec]) + resharding_costs = generate_resharding_costs(strategies_vector.predecessor_nodes, + [input_sharding_spec]) sharding_strategy = ShardingStrategy(name, output_sharding_spec, @@ -275,8 +253,8 @@ class StrategiesConstructor: compute_cost = node._meta_data.numel() memory_cost = 0 - resharding_costs = self._generate_resharding_costs(strategies_vector.predecessor_nodes, - [input_sharding_spec]) + resharding_costs = generate_resharding_costs(strategies_vector.predecessor_nodes, + [input_sharding_spec]) # to prevent the resharding happening, set their resharding cost to inf. resharding_costs[input_node] = [ @@ -317,8 +295,8 @@ class StrategiesConstructor: # TODO: use meta_info_prop to profile origin memory cost and compute cost, then divide them depending on sharding spec. compute_cost = 0 memory_cost = 0 - resharding_costs = self._generate_resharding_costs(strategies_vector.predecessor_nodes, - [new_input_sharding_spec]) + resharding_costs = generate_resharding_costs(strategies_vector.predecessor_nodes, + [new_input_sharding_spec]) sharding_strategy = ShardingStrategy(name, (output_sharding_spec, output_sharding_spec), compute_cost=compute_cost, memory_cost=memory_cost, @@ -335,8 +313,8 @@ class StrategiesConstructor: # TODO: use meta_info_prop to profile origin memory cost and compute cost, then divide them depending on sharding spec. compute_cost = 0 memory_cost = 0 - resharding_costs = self._generate_resharding_costs(strategies_vector.predecessor_nodes, - [input_sharding_spec]) + resharding_costs = generate_resharding_costs(strategies_vector.predecessor_nodes, + [input_sharding_spec]) sharding_strategy = ShardingStrategy(name, (output_sharding_spec, output_sharding_spec), compute_cost=compute_cost, memory_cost=memory_cost, @@ -360,8 +338,8 @@ class StrategiesConstructor: # TODO: use meta_info_prop to profile origin memory cost and compute cost, then divide them depending on sharding spec. compute_cost = 0 memory_cost = 0 - resharding_costs = self._generate_resharding_costs(strategies_vector.predecessor_nodes, - [input_sharding_spec]) + resharding_costs = generate_resharding_costs(strategies_vector.predecessor_nodes, + [input_sharding_spec]) # to prevent the resharding happening, set their resharding cost to inf. resharding_costs[input_tensor_node] = [ cost if cost == 0 else math.inf for cost in resharding_costs[input_tensor_node] @@ -397,8 +375,8 @@ class StrategiesConstructor: output_sharding_spec = input_sharding_specs # TODO: use meta_info_prop to profile memory cost memory_cost = 0 - resharding_costs = self._generate_resharding_costs(strategies_vector.predecessor_nodes, - input_sharding_specs) + resharding_costs = generate_resharding_costs(strategies_vector.predecessor_nodes, + input_sharding_specs) # clear the resharding cost for the output node # TODO: we may remove this in final version diff --git a/colossalai/context/singleton_meta.py b/colossalai/context/singleton_meta.py index f4d3276e2..8ca335119 100644 --- a/colossalai/context/singleton_meta.py +++ b/colossalai/context/singleton_meta.py @@ -15,4 +15,7 @@ class SingletonMeta(type): if cls not in cls._instances: instance = super().__call__(*args, **kwargs) cls._instances[cls] = instance + else: + assert len(args) == 0 and len( + kwargs) == 0, f'{cls.__name__} is a singleton class and a instance has been created.' return cls._instances[cls] diff --git a/colossalai/tensor/shape_consistency.py b/colossalai/tensor/shape_consistency.py index d411918e1..9da935cd9 100644 --- a/colossalai/tensor/shape_consistency.py +++ b/colossalai/tensor/shape_consistency.py @@ -1,15 +1,22 @@ import torch +from dataclasses import dataclass from colossalai.tensor.sharding_spec import ShardingSpec, _DimSpec from colossalai.tensor.utils import all_gather_simulator, all_to_all_simulator, shard_simulator from enum import Enum from copy import deepcopy from typing import Dict, List, Optional, Tuple, Union +from colossalai.context.singleton_meta import SingletonMeta import torch.distributed as dist import math from functools import reduce import operator from torch.distributed import ReduceOp +__all__ = [ + 'CollectiveCommPattern', 'CommSpec', 'ShapeConsistencyManager', 'ShapeConsistencyOptions', + 'set_shape_consistency_options' +] + class CollectiveCommPattern(Enum): ALLGATHER = 'all_gather' @@ -152,14 +159,40 @@ class CommSpec: tensor.data = tensor -class ShapeConsistencyManager: +@dataclass +class ShapeConsistencyOptions: + """ + ShapeConsistencyOptions is a dataclass which specifies the preferences for shape consistency. + """ + # TODO: shape consistency option is not implemented yet + pass + + +def set_shape_consistency_options(options: ShapeConsistencyOptions): + """ + Configure the shape consistency manager via function call. + """ + manager = ShapeConsistencyManager() + manager.options = options + - def __init__(self, consistency_option=None): - self.consistency_option = consistency_option +class ShapeConsistencyManager(metaclass=SingletonMeta): + + def __init__(self): + self._options = None self.total_communication_cost = 0 self.total_transform_steps = 0 self.cached_spec_pairs_transform_path = {} + @property + def options(self): + return self._options + + @options.setter + def options(self, options_: ShapeConsistencyOptions): + assert isinstance(options_, ShapeConsistencyOptions) + self._options = options_ + def get_all_all_gather_spec(self, source_spec, orig_cost): ''' Get all valid sharding specs from source_spec with single all-gather operation, and diff --git a/tests/test_auto_parallel/test_batch_norm_handler.py b/tests/test_auto_parallel/test_batch_norm_handler.py index 8174680b3..4869ecbfa 100644 --- a/tests/test_auto_parallel/test_batch_norm_handler.py +++ b/tests/test_auto_parallel/test_batch_norm_handler.py @@ -8,7 +8,6 @@ from colossalai.fx.tracer.tracer import ColoTracer from colossalai.tensor.sharding_spec import ShardingSpec, _DimSpec from colossalai.auto_parallel.solver.op_handler.batch_norm_handler import BatchNormHandler from colossalai.auto_parallel.solver.sharding_strategy import ShardingStrategy, StrategiesVector -from colossalai.tensor.shape_consistency import ShapeConsistencyManager from colossalai.device.device_mesh import DeviceMesh @@ -31,7 +30,6 @@ def test_bn_handler(): # [2, 3]] device_mesh = DeviceMesh(physical_mesh_id, mesh_shape) entire_shape = torch.Size((4, 16, 64, 64)) - shape_consistency_manager = ShapeConsistencyManager() tracer = ColoTracer() model = BNModel(16) @@ -77,10 +75,11 @@ def test_bn_handler(): # generate bn strategy strategies_vector = StrategiesVector(node=nodes[2]) - bn_handler = BatchNormHandler(node=nodes[2], - device_mesh=device_mesh, - strategies_vector=strategies_vector, - shape_consistency_manager=shape_consistency_manager) + bn_handler = BatchNormHandler( + node=nodes[2], + device_mesh=device_mesh, + strategies_vector=strategies_vector, + ) bn_handler.register_strategy() # ['RS0 = RS0 x S0', 'S1S0 = RS0 x S0', 'RS1 = RS1 x S1', 'S0S1 = RS1 x S1', 'RR = RR x R', 'S0R = RR x R', 'S1R = RR x R', 'S01R = RR x R', 'RS01 = RS01 x S01', # 'S0R = S0R x R WITH SYNC_BN', 'S1R = S1R x R WITH SYNC_BN', 'S0S1 = S0S1 x S1 WITH SYNC_BN', 'S1S0 = S1S0 x S0 WITH SYNC_BN', 'S01R = S01R x R WITH SYNC_BN'] diff --git a/tests/test_auto_parallel/test_conv_handler.py b/tests/test_auto_parallel/test_conv_handler.py index 50b9cfc46..c66e85883 100644 --- a/tests/test_auto_parallel/test_conv_handler.py +++ b/tests/test_auto_parallel/test_conv_handler.py @@ -8,7 +8,6 @@ from colossalai.fx.tracer.tracer import ColoTracer from colossalai.tensor.sharding_spec import ShardingSpec, _DimSpec from colossalai.auto_parallel.solver.op_handler.conv_handler import ConvHandler from colossalai.auto_parallel.solver.sharding_strategy import ShardingStrategy, StrategiesVector -from colossalai.tensor.shape_consistency import ShapeConsistencyManager from colossalai.device.device_mesh import DeviceMesh @@ -31,7 +30,6 @@ def test_conv_handler(): # [2, 3]] device_mesh = DeviceMesh(physical_mesh_id, mesh_shape) entire_shape = torch.Size((4, 16, 64, 64)) - shape_consistency_manager = ShapeConsistencyManager() tracer = ColoTracer() model = ConvModel(16, 32) @@ -77,10 +75,11 @@ def test_conv_handler(): # generate conv strategy strategies_vector = StrategiesVector(node=nodes[2]) - conv_handler = ConvHandler(node=nodes[2], - device_mesh=device_mesh, - strategies_vector=strategies_vector, - shape_consistency_manager=shape_consistency_manager) + conv_handler = ConvHandler( + node=nodes[2], + device_mesh=device_mesh, + strategies_vector=strategies_vector, + ) conv_handler.register_strategy() # ['S0S1 = S0R x RS1', 'S1S0 = S1R x RS0', 'S0R = S0R x RR', 'S1R = S1R x RR', 'S0R = S0S1 x S1R', 'S1R = S1S0 x S0R', 'RS1 = RS0 x S0S1', 'RS0 = RS1 x S1S0', 'RR = RS0 x S0R', 'RR = RS1 x S1R', 'RS0 = RR x RS0', 'RS1 = RR x RS1', 'RR = RR x RR', 'S01R = S01R x RR', 'RR = RS01 x S01R'] strategy_name_list = [strategy.name for strategy in conv_handler.strategies_vector] diff --git a/tests/test_auto_parallel/test_cost_graph.py b/tests/test_auto_parallel/test_cost_graph.py index 1bee5e35f..5b5bcb5d9 100644 --- a/tests/test_auto_parallel/test_cost_graph.py +++ b/tests/test_auto_parallel/test_cost_graph.py @@ -4,10 +4,7 @@ from torch.fx import GraphModule import torch.nn as nn import pytest -from colossalai.fx.proxy import ColoProxy from colossalai.fx.tracer.tracer import ColoTracer -from colossalai.tensor.sharding_spec import ShardingSpec, _DimSpec -from colossalai.tensor.shape_consistency import ShapeConsistencyManager from colossalai.device.device_mesh import DeviceMesh from colossalai.auto_parallel.solver.strategies_constructor import StrategiesConstructor from colossalai.auto_parallel.solver.cost_graph import CostGraph @@ -37,7 +34,6 @@ def test_cost_graph(): # [2, 3]] device_mesh = DeviceMesh(physical_mesh_id, mesh_shape) entire_shape = torch.Size((4, 16, 64, 64)) - shape_consistency_manager = ShapeConsistencyManager() tracer = ColoTracer() model = ConvModel(16, 32) @@ -55,7 +51,7 @@ def test_cost_graph(): gm.recompile() solver_options = SolverOptions(fast=True) - strategies_constructor = StrategiesConstructor(graph, device_mesh, shape_consistency_manager, solver_options) + strategies_constructor = StrategiesConstructor(graph, device_mesh, solver_options) strategies_constructor.build_strategies_and_cost() # (x, mul):{(0, 0): 0} diff --git a/tests/test_auto_parallel/test_dot_handler.py b/tests/test_auto_parallel/test_dot_handler.py index df503646e..856e462de 100644 --- a/tests/test_auto_parallel/test_dot_handler.py +++ b/tests/test_auto_parallel/test_dot_handler.py @@ -8,7 +8,6 @@ from colossalai.fx.tracer.tracer import ColoTracer from colossalai.tensor.sharding_spec import ShardingSpec, _DimSpec from colossalai.auto_parallel.solver.op_handler.dot_handler import DotHandler from colossalai.auto_parallel.solver.sharding_strategy import ShardingStrategy, StrategiesVector -from colossalai.tensor.shape_consistency import ShapeConsistencyManager from colossalai.device.device_mesh import DeviceMesh @@ -31,7 +30,6 @@ def test_dot_handler(): # [2, 3]] device_mesh = DeviceMesh(physical_mesh_id, mesh_shape) entire_shape = torch.Size((4, 8)) - shape_consistency_manager = ShapeConsistencyManager() tracer = ColoTracer() model = LinearModel(8, 16) @@ -76,10 +74,11 @@ def test_dot_handler(): # generate dot strategy strategies_vector = StrategiesVector(node=nodes[2]) - dot_handler = DotHandler(node=nodes[2], - device_mesh=device_mesh, - strategies_vector=strategies_vector, - shape_consistency_manager=shape_consistency_manager) + dot_handler = DotHandler( + node=nodes[2], + device_mesh=device_mesh, + strategies_vector=strategies_vector, + ) strategies_vector = dot_handler.register_strategy() # ['S0S1 = S0R x RS1', 'S1S0 = S1R x RS0', 'S0R = S0S1 x S1R', 'S1R = S1S0 x S0R', 'RS1 = RS0 x S0S1', 'RS0 = RS1 x S1S0', 'RS0 = RR x RS0', 'RS1 = RR x RS1', 'RR = RR x RR'] diff --git a/tests/test_auto_parallel/test_strategies_constructor.py b/tests/test_auto_parallel/test_strategies_constructor.py index 955bf43dd..ce263829a 100644 --- a/tests/test_auto_parallel/test_strategies_constructor.py +++ b/tests/test_auto_parallel/test_strategies_constructor.py @@ -8,7 +8,6 @@ from colossalai.fx.tracer.tracer import ColoTracer from colossalai.tensor.sharding_spec import ShardingSpec, _DimSpec from colossalai.auto_parallel.solver.op_handler.conv_handler import CONV_STRATEGIES_LIST from colossalai.auto_parallel.solver.sharding_strategy import ShardingStrategy, StrategiesVector -from colossalai.tensor.shape_consistency import ShapeConsistencyManager from colossalai.device.device_mesh import DeviceMesh from colossalai.auto_parallel.solver.strategies_constructor import StrategiesConstructor from colossalai.auto_parallel.solver.options import SolverOptions @@ -34,7 +33,6 @@ def test_strategies_constructor(): # [2, 3]] device_mesh = DeviceMesh(physical_mesh_id, mesh_shape) entire_shape = torch.Size((4, 16, 64, 64)) - shape_consistency_manager = ShapeConsistencyManager() tracer = ColoTracer() model = ConvModel(16, 32) @@ -49,7 +47,7 @@ def test_strategies_constructor(): gm.recompile() solver_options = SolverOptions(fast=True) - strategies_constructor = StrategiesConstructor(graph, device_mesh, shape_consistency_manager, solver_options) + strategies_constructor = StrategiesConstructor(graph, device_mesh, solver_options) assert strategies_constructor.leaf_strategies == [] assert strategies_constructor.strategy_map == {}