From eac1b793717e6f117b6da378196f9d271f1c86f8 Mon Sep 17 00:00:00 2001 From: YuliangLiu0306 <72588413+YuliangLiu0306@users.noreply.github.com> Date: Fri, 16 Sep 2022 11:33:01 +0800 Subject: [PATCH] [autoparallel] add bcast op handler (#1600) * [autoparallel] add bcast op handler * polish code * add more BCAST FUNC OP * polish code * add exception handler * polish --- colossalai/auto_parallel/solver/_utils.py | 24 +++ colossalai/auto_parallel/solver/constants.py | 11 +- .../solver/op_handler/__init__.py | 3 +- .../solver/op_handler/batch_norm_handler.py | 8 +- .../solver/op_handler/bcast_op_handler.py | 164 ++++++++++++++++++ .../solver/op_handler/conv_handler.py | 59 +++---- .../solver/op_handler/dot_handler.py | 10 ++ .../solver/op_handler/reshape_handler.py | 1 + .../auto_parallel/solver/sharding_strategy.py | 3 + .../solver/strategies_constructor.py | 17 +- .../test_auto_parallel/test_bcast_handler.py | 71 ++++++++ 11 files changed, 322 insertions(+), 49 deletions(-) create mode 100644 colossalai/auto_parallel/solver/op_handler/bcast_op_handler.py create mode 100644 tests/test_auto_parallel/test_bcast_handler.py diff --git a/colossalai/auto_parallel/solver/_utils.py b/colossalai/auto_parallel/solver/_utils.py index c62455cbe..c9f85fb01 100644 --- a/colossalai/auto_parallel/solver/_utils.py +++ b/colossalai/auto_parallel/solver/_utils.py @@ -4,6 +4,10 @@ from torch.fx.node import Node from colossalai.tensor.sharding_spec import ShardingSpec from colossalai.device.device_mesh import DeviceMesh from typing import Union, Dict, List, Optional +import warnings +from functools import reduce +import functools +import operator def generate_sharding_spec(input_: Union[Node, torch.Tensor], device_mesh: DeviceMesh, @@ -29,6 +33,11 @@ def generate_sharding_spec(input_: Union[Node, torch.Tensor], device_mesh: Devic raise TypeError( f'We cannot generate sharding spec for {type(input_)} type, only torch.fx.Node or torch.Tensor is expected.' ) + for dim_index, sharding_index_list in dim_partition_dict.items(): + sharding_list = [device_mesh.mesh_shape[sharding_index] for sharding_index in sharding_index_list] + sharding_size = reduce(operator.mul, sharding_list, 1) + assert shape[ + dim_index] % sharding_size == 0, f'we cannot shard the {dim_index} dimension of tensor into {sharding_size} partitions.' sharding_spec = ShardingSpec(device_mesh=device_mesh, entire_shape=shape, dim_partition_dict=dim_partition_dict) return sharding_spec @@ -74,3 +83,18 @@ def generate_resharding_costs(nodes: List[Node], resharding_cost = total_resharding_cost * size_per_elem_bytes resharding_costs[input_node].append(resharding_cost) return resharding_costs + + +def exception_handler(func): + """ + A function wrapper which executes the function with a specified seed. + """ + + @functools.wraps(func) + def wrapper(*args, **kwargs): + try: + func(*args, **kwargs) + except Exception as e: + warnings.warn(f'{e}') + + return wrapper diff --git a/colossalai/auto_parallel/solver/constants.py b/colossalai/auto_parallel/solver/constants.py index 3360f9425..487c555e5 100644 --- a/colossalai/auto_parallel/solver/constants.py +++ b/colossalai/auto_parallel/solver/constants.py @@ -3,16 +3,19 @@ import operator __all__ = [ 'ELEMENTWISE_MODULE_OP', 'ELEMENTWISE_FUNC_OP', 'RESHAPE_FUNC_OP', 'CONV_MODULE_OP', 'CONV_FUNC_OP', - 'LINEAR_MODULE_OP', 'LINEAR_FUNC_OP', 'BATCHNORM_MODULE_OP', 'POOL_MODULE_OP', 'NON_PARAM_FUNC_OP' + 'LINEAR_MODULE_OP', 'LINEAR_FUNC_OP', 'BATCHNORM_MODULE_OP', 'POOL_MODULE_OP', 'NON_PARAM_FUNC_OP', 'BCAST_FUNC_OP' ] ELEMENTWISE_MODULE_OP = [torch.nn.Dropout, torch.nn.ReLU] -# TODO: flatten should not be added into this group ELEMENTWISE_FUNC_OP = [ - torch.add, operator.add, torch.abs, torch.cos, torch.exp, torch.mul, operator.mul, operator.floordiv, - operator.truediv, operator.neg, torch.multiply, torch.nn.functional.relu, torch.nn.functional.dropout + torch.abs, torch.cos, torch.exp, operator.neg, torch.multiply, torch.nn.functional.relu, + torch.nn.functional.dropout, torch.flatten ] RESHAPE_FUNC_OP = [torch.flatten, torch.Tensor.view, torch.reshape] +BCAST_FUNC_OP = [ + torch.add, torch.sub, torch.mul, torch.div, torch.floor_divide, torch.true_divide, operator.add, operator.sub, + operator.mul, operator.floordiv, operator.truediv +] CONV_MODULE_OP = [ torch.nn.Conv1d, torch.nn.Conv2d, torch.nn.Conv3d, torch.nn.ConvTranspose1d, torch.nn.ConvTranspose2d, torch.nn.ConvTranspose3d diff --git a/colossalai/auto_parallel/solver/op_handler/__init__.py b/colossalai/auto_parallel/solver/op_handler/__init__.py index 1f31ca45a..f51bfd739 100644 --- a/colossalai/auto_parallel/solver/op_handler/__init__.py +++ b/colossalai/auto_parallel/solver/op_handler/__init__.py @@ -3,5 +3,6 @@ from .dot_handler import DotHandler from .conv_handler import ConvHandler from .batch_norm_handler import BatchNormHandler from .reshape_handler import ReshapeHandler +from .bcast_op_handler import BcastOpHandler -__all__ = ['OperatorHandler', 'DotHandler', 'ConvHandler', 'BatchNormHandler', 'ReshapeHandler'] \ No newline at end of file +__all__ = ['OperatorHandler', 'DotHandler', 'ConvHandler', 'BatchNormHandler', 'ReshapeHandler', 'BcastOpHandler'] \ No newline at end of file diff --git a/colossalai/auto_parallel/solver/op_handler/batch_norm_handler.py b/colossalai/auto_parallel/solver/op_handler/batch_norm_handler.py index ae343b03a..207f66107 100644 --- a/colossalai/auto_parallel/solver/op_handler/batch_norm_handler.py +++ b/colossalai/auto_parallel/solver/op_handler/batch_norm_handler.py @@ -1,9 +1,9 @@ import operator from functools import reduce -import warnings import torch from colossalai.auto_parallel.solver.sharding_strategy import ShardingStrategy, StrategiesVector from .operator_handler import OperatorHandler +from colossalai.auto_parallel.solver._utils import exception_handler __all__ = ['BatchNormHandler'] @@ -110,6 +110,7 @@ class BatchNormHandler(OperatorHandler): return memory_cost, memory_cost_forward_activation, memory_cost_backward_activation + @exception_handler def split_input_channel(self, mesh_dim_0, mesh_dim_1): name = f'RS{mesh_dim_0} = RS{mesh_dim_0} x S{mesh_dim_0}' @@ -184,6 +185,7 @@ class BatchNormHandler(OperatorHandler): self.strategies_vector.append(sharding_strategies) + @exception_handler def split_input_channel_1d(self, mesh_dim_0, mesh_dim_1): name = f'RS{mesh_dim_0}{mesh_dim_1} = RS{mesh_dim_0}{mesh_dim_1} x S{mesh_dim_0}{mesh_dim_1}' @@ -224,6 +226,7 @@ class BatchNormHandler(OperatorHandler): self.strategies_vector.append(sharding_strategies) + @exception_handler def non_split(self, mesh_dim_0, mesh_dim_1): name = f'RR = RR x R' @@ -319,6 +322,7 @@ class BatchNormHandler(OperatorHandler): new_sharding_strategy = _construct_batch_sharding_strategies(mesh_dim_list, new_name) self.strategies_vector.append(new_sharding_strategy) + @exception_handler def split_input_batch(self, mesh_dim_0): name = f'S{mesh_dim_0}R = S{mesh_dim_0}R x R WITH SYNC_BN' @@ -359,6 +363,7 @@ class BatchNormHandler(OperatorHandler): self.strategies_vector.append(sharding_strategies) + @exception_handler def split_input_batch_1d(self, mesh_dim_0, mesh_dim_1): name = f'S{mesh_dim_0}{mesh_dim_1}R = S{mesh_dim_0}{mesh_dim_1}R x R WITH SYNC_BN' @@ -399,6 +404,7 @@ class BatchNormHandler(OperatorHandler): self.strategies_vector.append(sharding_strategies) + @exception_handler def split_input_both_dim(self, mesh_dim_0, mesh_dim_1): name = f'S{mesh_dim_0}S{mesh_dim_1} = S{mesh_dim_0}S{mesh_dim_1} x S{mesh_dim_1} WITH SYNC_BN' diff --git a/colossalai/auto_parallel/solver/op_handler/bcast_op_handler.py b/colossalai/auto_parallel/solver/op_handler/bcast_op_handler.py new file mode 100644 index 000000000..e3e697302 --- /dev/null +++ b/colossalai/auto_parallel/solver/op_handler/bcast_op_handler.py @@ -0,0 +1,164 @@ +import operator +from functools import reduce +import warnings +import torch +from colossalai.auto_parallel.solver.sharding_strategy import ShardingStrategy, StrategiesVector +from .operator_handler import OperatorHandler +from colossalai.tensor.shape_consistency import ShapeConsistencyManager +from colossalai.tensor.sharding_spec import ShardingSpec +from copy import deepcopy +from typing import Dict, List +from colossalai.auto_parallel.solver._utils import exception_handler + +__all__ = ['BcastOpHandler'] + + +class BcastOpHandler(OperatorHandler): + """ + An OperatorHandler which deals with the sharding strategies of broadcast operators(such as operator.add). + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + assert len(self.predecessor_node) == 2 + self.lhs_data = self.predecessor_node[0]._meta_data + self.rhs_data = self.predecessor_node[1]._meta_data + self.lhs = self.predecessor_node[0] + self.rhs = self.predecessor_node[1] + self.output_data = self.node._meta_data + + def _generate_sharding_spec(self, input_: torch.Tensor, dim_partition_dict: Dict[int, List[int]]) -> ShardingSpec: + shape = list(input_.shape) + + # padding the shape to the same length as output_data + while len(shape) < self.output_data.dim(): + shape.insert(0, 1) + shape = torch.Size(shape) + + # if the sharding happens on a size one dimension, we should record it as R. + processed_dim_partition_dict = deepcopy(dim_partition_dict) + for dim_index, _ in dim_partition_dict.items(): + if shape[dim_index] == 1: + processed_dim_partition_dict.pop(dim_index) + sharding_spec = ShardingSpec(device_mesh=self.device_mesh, + entire_shape=shape, + dim_partition_dict=processed_dim_partition_dict) + + return sharding_spec + + def _generate_resharding_costs(self, sharding_specs): + # The resharding_cost of weight is counted due to sharing weight cases. + dtype = self.node._meta_data.dtype + nodes = self.predecessor_node + resharding_costs = {} + size_per_elem_bytes = torch.tensor([], dtype=dtype).element_size() + + # shape consistency manager is a singleton class + shape_consistency_manager = ShapeConsistencyManager() + + for input_node, input_spec in zip(nodes, sharding_specs): + resharding_costs[input_node] = [] + for strategy in input_node.strategies_vector: + input_sharding_spec = strategy.output_sharding_spec + assert isinstance(input_sharding_spec, ShardingSpec), f'The input node should NOT be a tuple of tensor.' + # if the input shape is smaller than the target input, we will fill the input to the same length as target. + # Then, use the padded input sharding spec to compute the resharding cost. + if len(input_sharding_spec.entire_shape) < len(input_spec.entire_shape): + new_entire_shape = list(input_sharding_spec.entire_shape) + while len(new_entire_shape) < len(input_spec.entire_shape): + new_entire_shape.insert(0, 1) + new_entire_shape = torch.Size(new_entire_shape) + new_device_mesh = input_sharding_spec.device_mesh + new_dim_partition_dict = input_sharding_spec.dim_partition_dict + input_sharding_spec = ShardingSpec(device_mesh=new_device_mesh, + entire_shape=new_entire_shape, + dim_partition_dict=new_dim_partition_dict) + + # compute the resharding cost during forward phase + _, _, resharding_cost_forward = shape_consistency_manager.shape_consistency( + input_sharding_spec, input_spec) + + _, _, resharding_cost_backward = shape_consistency_manager.shape_consistency( + input_spec, input_sharding_spec) + total_resharding_cost = resharding_cost_forward + resharding_cost_backward + + # we need multiply the size of elem dtype to get correct communication cost + resharding_cost = total_resharding_cost * size_per_elem_bytes + resharding_costs[input_node].append(resharding_cost) + + return resharding_costs + + def _enumerate_all_possible_output(self, mesh_dim_0, mesh_dim_1): + # use mesh_dim_0, mesh_dim_1 instead of constant 0, 1 in here for N-D device mesh scaliablity. + + output_sharding_spec_list = [] + output_dim_partition_list = [] + + # enumerate all the 2D sharding cases + for i in range(self.output_data.dim()): + for j in range(i + 1, self.output_data.dim()): + dim_partition_dict_0 = {i: [mesh_dim_0], j: [mesh_dim_1]} + dim_partition_dict_1 = {i: [mesh_dim_1], j: [mesh_dim_0]} + output_dim_partition_list.append(dim_partition_dict_0) + output_dim_partition_list.append(dim_partition_dict_1) + + # enumerate all the 1D sharding cases + for i in range(self.output_data.dim()): + dim_partition_dict_0 = {i: [mesh_dim_0]} + dim_partition_dict_1 = {i: [mesh_dim_1]} + dim_partition_dict_flatten = {i: [mesh_dim_0, mesh_dim_1]} + output_dim_partition_list.append(dim_partition_dict_0) + output_dim_partition_list.append(dim_partition_dict_1) + output_dim_partition_list.append(dim_partition_dict_flatten) + + # add empty dict for fully replicated case + output_dim_partition_list.append({}) + check_duplicated_list = [] + for output_dim_partition_dict in output_dim_partition_list: + output_sharding_spec = self._generate_sharding_spec(self.output_data, output_dim_partition_dict) + sharding_seq = output_sharding_spec.sharding_sequence + if sharding_seq not in check_duplicated_list: + check_duplicated_list.append(sharding_seq) + output_sharding_spec_list.append(output_sharding_spec) + + return output_sharding_spec_list + + def _generate_compute_cost(self, *args, **kwargs): + return super()._generate_compute_cost(*args, **kwargs) + + @exception_handler + def _register_strategy(self, output_sharding_spec): + dim_partition_dict_for_input = output_sharding_spec.dim_partition_dict + sharding_spec_for_lhs = self._generate_sharding_spec(self.lhs_data, dim_partition_dict_for_input) + sharding_spec_for_rhs = self._generate_sharding_spec(self.rhs_data, dim_partition_dict_for_input) + + name = f'{output_sharding_spec.sharding_sequence} = {sharding_spec_for_lhs.sharding_sequence} x {sharding_spec_for_rhs.sharding_sequence}' + dim_partition_dict_for_output = output_sharding_spec.dim_partition_dict + + # generate resharding cost for this strategy + resharding_costs = self._generate_resharding_costs([sharding_spec_for_lhs, sharding_spec_for_rhs]) + + # compute the computation cost of this strategy + sharding_dims = [] + for mesh_dims in dim_partition_dict_for_output.values(): + for mesh_dim in mesh_dims: + sharding_dims.append(self.device_mesh.shape[mesh_dim]) + sharding_size = reduce(operator.mul, sharding_dims, 1) + memory_cost = self.output_data.numel() / sharding_size + compute_cost = memory_cost + communication_cost = 0 + + sharding_strategies = ShardingStrategy(name, + output_sharding_spec=output_sharding_spec, + compute_cost=compute_cost, + communication_cost=communication_cost, + memory_cost=memory_cost, + resharding_costs=resharding_costs, + input_shardings=(sharding_spec_for_lhs, sharding_spec_for_rhs)) + + self.strategies_vector.append(sharding_strategies) + + def register_strategy(self) -> StrategiesVector: + output_sharding_specs = self._enumerate_all_possible_output(0, 1) + for output_sharding_spec in output_sharding_specs: + self._register_strategy(output_sharding_spec) diff --git a/colossalai/auto_parallel/solver/op_handler/conv_handler.py b/colossalai/auto_parallel/solver/op_handler/conv_handler.py index 8f062e7fe..028561755 100644 --- a/colossalai/auto_parallel/solver/op_handler/conv_handler.py +++ b/colossalai/auto_parallel/solver/op_handler/conv_handler.py @@ -4,13 +4,14 @@ import warnings import torch from colossalai.auto_parallel.solver.sharding_strategy import ShardingStrategy, StrategiesVector from .operator_handler import OperatorHandler +from colossalai.auto_parallel.solver._utils import exception_handler __all__ = ['ConvHandler'] class ConvHandler(OperatorHandler): """ - A OperatorHandler which deals with the sharding strategies of Convolution. + An OperatorHandler which deals with the sharding strategies of Convolution. """ def __init__(self, *args, **kwargs): @@ -104,6 +105,7 @@ class ConvHandler(OperatorHandler): return memory_cost, memory_cost_forward_activation, memory_cost_backward_activation, memory_cost_backward_weight + @exception_handler def split_input_batch_weight_out_channel(self, mesh_dim_0, mesh_dim_1): name = f'S{mesh_dim_0}S{mesh_dim_1} = S{mesh_dim_0}R x RS{mesh_dim_1}' @@ -151,6 +153,7 @@ class ConvHandler(OperatorHandler): input_shardings=(sharding_spec_for_input, sharding_spec_for_weight)) self.strategies_vector.append(sharding_strategies) + @exception_handler def split_input_batch(self, mesh_dim_0): name = f'S{mesh_dim_0}R = S{mesh_dim_0}R x RR' @@ -196,6 +199,7 @@ class ConvHandler(OperatorHandler): self.strategies_vector.append(sharding_strategies) + @exception_handler def split_input_both_dim_weight_in_channel(self, mesh_dim_0, mesh_dim_1): name = f'S{mesh_dim_0}R = S{mesh_dim_0}S{mesh_dim_1} x S{mesh_dim_1}R' @@ -241,6 +245,7 @@ class ConvHandler(OperatorHandler): input_shardings=(sharding_spec_for_input, sharding_spec_for_weight)) self.strategies_vector.append(sharding_strategies) + @exception_handler def split_input_in_channel_weight_both_channel(self, mesh_dim_0, mesh_dim_1): name = f'RS{mesh_dim_1} = RS{mesh_dim_0} x S{mesh_dim_0}S{mesh_dim_1}' @@ -283,6 +288,7 @@ class ConvHandler(OperatorHandler): input_shardings=(sharding_spec_for_input, sharding_spec_for_weight)) self.strategies_vector.append(sharding_strategies) + @exception_handler def split_input_in_channel_weight_in_channel(self, mesh_dim_0): name = f'RR = RS{mesh_dim_0} x S{mesh_dim_0}R' @@ -325,6 +331,7 @@ class ConvHandler(OperatorHandler): input_shardings=(sharding_spec_for_input, sharding_spec_for_weight)) self.strategies_vector.append(sharding_strategies) + @exception_handler def split_weight_out_channel(self, mesh_dim_0): name = f'RS{mesh_dim_0} = RR x RS{mesh_dim_0}' @@ -367,6 +374,7 @@ class ConvHandler(OperatorHandler): input_shardings=(sharding_spec_for_input, sharding_spec_for_weight)) self.strategies_vector.append(sharding_strategies) + @exception_handler def non_split(self): name = f'RR = RR x RR' @@ -407,6 +415,7 @@ class ConvHandler(OperatorHandler): input_shardings=(sharding_spec_for_input, sharding_spec_for_weight)) self.strategies_vector.append(sharding_strategies) + @exception_handler def split_1d_parallel_on_input_batch(self, mesh_dim_0, mesh_dim_1): name = f'S{mesh_dim_0}{mesh_dim_1}R = S{mesh_dim_0}{mesh_dim_1}R x RR' @@ -454,6 +463,7 @@ class ConvHandler(OperatorHandler): input_shardings=(sharding_spec_for_input, sharding_spec_for_weight)) self.strategies_vector.append(sharding_strategies) + @exception_handler def split_1d_parallel_on_in_channel(self, mesh_dim_0, mesh_dim_1): name = f'RR = RS{mesh_dim_0}{mesh_dim_1} x S{mesh_dim_0}{mesh_dim_1}R' @@ -554,48 +564,24 @@ class ConvHandler(OperatorHandler): RR = RS01 x S01R: compute_cost is 8856576, communication_cost is 0, memory_cost is 1968128, resharding_costs is {mul: [0, 0, 262148.4, 65538.002, 196614.402, 262148.4, 65538.2]} ''' # SS = SR x RS - try: - self.split_input_batch_weight_out_channel(0, 1) - except Exception as e: - warnings.warn(f'{e}') - try: - self.split_input_batch_weight_out_channel(1, 0) - except Exception as e: - warnings.warn(f'{e}') + self.split_input_batch_weight_out_channel(0, 1) + self.split_input_batch_weight_out_channel(1, 0) # SR = SR x RR self.split_input_batch(0) self.split_input_batch(1) # SR = SS x SR - try: - self.split_input_both_dim_weight_in_channel(0, 1) - except Exception as e: - warnings.warn(f'{e}') - try: - self.split_input_both_dim_weight_in_channel(1, 0) - except Exception as e: - warnings.warn(f'{e}') + self.split_input_both_dim_weight_in_channel(0, 1) + self.split_input_both_dim_weight_in_channel(1, 0) # RS = RS x SS - try: - self.split_input_in_channel_weight_both_channel(0, 1) - except Exception as e: - warnings.warn(f'{e}') - try: - self.split_input_in_channel_weight_both_channel(1, 0) - except Exception as e: - warnings.warn(f'{e}') + self.split_input_in_channel_weight_both_channel(0, 1) + self.split_input_in_channel_weight_both_channel(1, 0) # RR = RS x SR - try: - self.split_input_in_channel_weight_in_channel(0) - except Exception as e: - warnings.warn(f'{e}') - try: - self.split_input_in_channel_weight_in_channel(1) - except Exception as e: - warnings.warn(f'{e}') + self.split_input_in_channel_weight_in_channel(0) + self.split_input_in_channel_weight_in_channel(1) # RS = RR x RS self.split_weight_out_channel(0) @@ -608,12 +594,7 @@ class ConvHandler(OperatorHandler): self.split_1d_parallel_on_input_batch(0, 1) # RR = RS01 x S01R - try: - self.split_1d_parallel_on_in_channel(0, 1) - except Exception as e: - warnings.warn(f'{e}') - - # print(f'strategies num is :{len(self.strategies_vector)}') + self.split_1d_parallel_on_in_channel(0, 1) return self.strategies_vector diff --git a/colossalai/auto_parallel/solver/op_handler/dot_handler.py b/colossalai/auto_parallel/solver/op_handler/dot_handler.py index f29772705..8243b6457 100644 --- a/colossalai/auto_parallel/solver/op_handler/dot_handler.py +++ b/colossalai/auto_parallel/solver/op_handler/dot_handler.py @@ -6,10 +6,12 @@ from colossalai.auto_parallel.solver.sharding_strategy import ShardingStrategy, from .operator_handler import OperatorHandler from ..constants import LINEAR_FUNC_OP, LINEAR_MODULE_OP from functools import reduce +from colossalai.auto_parallel.solver._utils import exception_handler from enum import Enum from .strategy_generator import StrategyGenerator, IntermediateStrategy from typing import List + __all__ = ['DotHandler'] @@ -414,6 +416,7 @@ class DotHandler(OperatorHandler): compute_cost = reduce(operator.mul, input_shape) * weight_shape[0] * 2 return compute_cost + @exception_handler def split_lhs_space_rhs_space(self, mesh_dim_0, mesh_dim_1): # handle case SS = SR x RS name = f'S{mesh_dim_0}S{mesh_dim_1} = S{mesh_dim_0}R x RS{mesh_dim_1}' @@ -452,6 +455,7 @@ class DotHandler(OperatorHandler): input_shardings=(sharding_spec_for_input, sharding_spec_for_weight)) self.strategies_vector.append(sharding_strategies) + @exception_handler def split_lhs_space_both_contract(self, mesh_dim_0, mesh_dim_1): # handle the case SR = SS x SR name = f'S{mesh_dim_0}R = S{mesh_dim_0}S{mesh_dim_1} x S{mesh_dim_1}R' @@ -488,6 +492,7 @@ class DotHandler(OperatorHandler): input_shardings=(sharding_spec_for_input, sharding_spec_for_weight)) self.strategies_vector.append(sharding_strategies) + @exception_handler def split_rhs_space_both_contract(self, mesh_dim_0, mesh_dim_1): name = f'RS{mesh_dim_1} = RS{mesh_dim_0} x S{mesh_dim_0}S{mesh_dim_1}' @@ -521,6 +526,7 @@ class DotHandler(OperatorHandler): input_shardings=(sharding_spec_for_input, sharding_spec_for_weight)) self.strategies_vector.append(sharding_strategies) + @exception_handler def recompute_split_both_contract(self, mesh_dim): name = f'RR = RS{mesh_dim} x S{mesh_dim}R' @@ -554,6 +560,7 @@ class DotHandler(OperatorHandler): input_shardings=(sharding_spec_for_input, sharding_spec_for_weight)) self.strategies_vector.append(sharding_strategies) + @exception_handler def split_rhs_space_only(self, mesh_dim): name = f'RS{mesh_dim} = RR x RS{mesh_dim}' @@ -587,6 +594,7 @@ class DotHandler(OperatorHandler): input_shardings=(sharding_spec_for_input, sharding_spec_for_weight)) self.strategies_vector.append(sharding_strategies) + @exception_handler def split_lhs_1st_dim_1d(self, mesh_dim_0, mesh_dim_1): name = f'S{mesh_dim_0}{mesh_dim_1}R = S{mesh_dim_0}{mesh_dim_1}R x RR' @@ -620,6 +628,7 @@ class DotHandler(OperatorHandler): input_shardings=(sharding_spec_for_input, sharding_spec_for_weight)) self.strategies_vector.append(sharding_strategies) + @exception_handler def split_lhs_2nd_dim_1d(self, mesh_dim_0, mesh_dim_1): name = f'RR = RS{mesh_dim_0}{mesh_dim_1} x S{mesh_dim_0}{mesh_dim_1}R' @@ -653,6 +662,7 @@ class DotHandler(OperatorHandler): input_shardings=(sharding_spec_for_input, sharding_spec_for_weight)) self.strategies_vector.append(sharding_strategies) + @exception_handler def split_rhs_2nd_dim_1d(self, mesh_dim_0, mesh_dim_1): name = f'RS{mesh_dim_0}{mesh_dim_1} = RR x RS{mesh_dim_0}{mesh_dim_1}' diff --git a/colossalai/auto_parallel/solver/op_handler/reshape_handler.py b/colossalai/auto_parallel/solver/op_handler/reshape_handler.py index cb9e4e6ea..19b99ad77 100644 --- a/colossalai/auto_parallel/solver/op_handler/reshape_handler.py +++ b/colossalai/auto_parallel/solver/op_handler/reshape_handler.py @@ -20,6 +20,7 @@ class ReshapeHandler(OperatorHandler): return super()._generate_compute_cost(*args, **kwargs) def register_strategy(self): + # TODO: add strategies with more output sharding specs other than only fully replicated. input_node = self.strategies_vector.predecessor_nodes[0] # For reshape function, to keep the computing correctness we keep the sharding # spec of input is fully replicated. In addition, we will keep the output in diff --git a/colossalai/auto_parallel/solver/sharding_strategy.py b/colossalai/auto_parallel/solver/sharding_strategy.py index 725ef6892..8e34f6e18 100644 --- a/colossalai/auto_parallel/solver/sharding_strategy.py +++ b/colossalai/auto_parallel/solver/sharding_strategy.py @@ -70,6 +70,9 @@ class StrategiesVector(list): # we could merge element-wise op, because the output sharding spec is always same as the input sharding spec. if self.node.target in ELEMENTWISE_FUNC_OP: merge_label = True + # we could merge bcast op if the rhs is a scalar, because it will fall back to the element-wise case. + if self.node.target in BCAST_FUNC_OP and len(self.predecessor_nodes) == 1: + merge_label = True # we could merge reshape op, because the output sharding spec of reshape op is always fully replicated. if self.node.target in RESHAPE_FUNC_OP: merge_label = True diff --git a/colossalai/auto_parallel/solver/strategies_constructor.py b/colossalai/auto_parallel/solver/strategies_constructor.py index 3291790da..93d0c636e 100644 --- a/colossalai/auto_parallel/solver/strategies_constructor.py +++ b/colossalai/auto_parallel/solver/strategies_constructor.py @@ -1,4 +1,5 @@ from torch.fx import Graph, Node +from colossalai.auto_parallel.solver.op_handler.bcast_op_handler import BcastOpHandler from colossalai.tensor.sharding_spec import ShardingSpec from colossalai.device.device_mesh import DeviceMesh from colossalai.tensor.shape_consistency import ShapeConsistencyManager @@ -52,6 +53,7 @@ class StrategiesConstructor: def build_strategies_and_cost(self): for node in self.nodes: strategies_vector = StrategiesVector(node) + input_nodes_len = len(strategies_vector.predecessor_nodes) # placeholder node if node.op == 'placeholder': # For placeholder nodes, if solver_options.fast is True, we just let them in @@ -165,6 +167,9 @@ class StrategiesConstructor: # MaxPool module elif submod_type in POOL_MODULE_OP: + # TODO: add sharding constraints on image dimension + # e.g.: for a 2D pooling input NCHW, we should promise no sharding happens on H and W dimension + # create sharding strategy for element-wise module assert len(strategies_vector.predecessor_nodes ) == 1, f'Temporally, we just support single input element-wise op.' @@ -230,7 +235,7 @@ class StrategiesConstructor: reshape_handler.register_strategy() # element-wise function - elif target in ELEMENTWISE_FUNC_OP: + elif target in ELEMENTWISE_FUNC_OP or (target in BCAST_FUNC_OP and input_nodes_len == 1): # TODO: integrate element-wise func and module together # create sharding strategy for element-wise function assert len(strategies_vector.predecessor_nodes @@ -271,6 +276,11 @@ class StrategiesConstructor: input_shardings=[input_sharding_spec]) strategies_vector.append(sharding_strategy) + # bcast op + elif target in BCAST_FUNC_OP: + bcast_op_handler = BcastOpHandler(node, self.device_mesh, strategies_vector) + bcast_op_handler.register_strategy() + # torch.var_mean elif target == torch.var_mean: dim = node.kwargs['dim'] @@ -383,9 +393,8 @@ class StrategiesConstructor: # clear the resharding cost for the output node # TODO: we may remove this in final version - if True: - for prev_node, resharding_cost_list in resharding_costs.items(): - resharding_costs[prev_node] = [0] * len(resharding_cost_list) + for prev_node, resharding_cost_list in resharding_costs.items(): + resharding_costs[prev_node] = [0] * len(resharding_cost_list) sharding_strategy_attribute = ShardingStrategy(name, output_sharding_spec, diff --git a/tests/test_auto_parallel/test_bcast_handler.py b/tests/test_auto_parallel/test_bcast_handler.py new file mode 100644 index 000000000..023d3ac15 --- /dev/null +++ b/tests/test_auto_parallel/test_bcast_handler.py @@ -0,0 +1,71 @@ +import torch +from torch.fx import GraphModule +import torch.nn as nn +import pytest + +from colossalai.auto_parallel.solver.options import SolverOptions +from colossalai.auto_parallel.solver.strategies_constructor import StrategiesConstructor +from colossalai.fx.tracer.tracer import ColoTracer +from colossalai.device.device_mesh import DeviceMesh + + +class ConvModel(nn.Module): + + def __init__(self, c_in, c_out): + super().__init__() + self.conv1 = nn.Conv2d(c_in, c_out, kernel_size=3, padding=1) + self.conv2 = nn.Conv2d(c_in, c_out, kernel_size=3, padding=1, stride=2) + + def forward(self, x): + x1 = self.conv1(x) + x2 = x1 + 1 + x1 = torch.reshape(x1, [1, -1, 64, 1]) + x3 = self.conv2(x1) + x3 = torch.reshape(x3, [4, 1, 64, -1]) + x = x1 + x3 + + return x + + +def test_conv_handler(): + physical_mesh_id = torch.arange(0, 4) + mesh_shape = (2, 2) + # [[0, 1] + # [2, 3]] + device_mesh = DeviceMesh(physical_mesh_id, mesh_shape) + + tracer = ColoTracer() + model = ConvModel(16, 32) + input_sample = {'x': torch.rand(4, 16, 64, 64).to('meta')} + # graph(): + # %x : torch.Tensor [#users=1] = placeholder[target=x] + # %conv1 : [#users=2] = call_module[target=conv1](args = (%x,), kwargs = {}) + # %add : [#users=0] = call_function[target=operator.add](args = (%conv1, 1), kwargs = {}) + # %reshape : [#users=2] = call_function[target=torch.reshape](args = (%conv1, [1, -1, 64, 1]), kwargs = {}) + # %conv2 : [#users=1] = call_module[target=conv2](args = (%reshape,), kwargs = {}) + # %reshape_1 : [#users=1] = call_function[target=torch.reshape](args = (%conv2, [4, 1, 64, -1]), kwargs = {}) + # %add_1 : [#users=1] = call_function[target=operator.add](args = (%reshape, %reshape_1), kwargs = {}) + # return add_1 + graph = tracer.trace(root=model, meta_args=input_sample) + gm = GraphModule(model, graph, model.__class__.__name__) + # [x, conv1, add, reshape, conv2, reshape_1, add_1, output] + nodes = [node for node in gm.graph.nodes] + solver_options = SolverOptions(fast=True) + strategies_constructor = StrategiesConstructor(graph, device_mesh, solver_options) + + strategies_constructor.build_strategies_and_cost() + strategy_map = strategies_constructor.strategy_map + # check a tensor add with a scalar case + conv1_strategies = strategy_map[nodes[1]] + add_strategies = strategy_map[nodes[2]] + add_strategies_cover_list = [strategy.input_shardings[0].sharding_sequence for strategy in add_strategies] + for strategy in conv1_strategies: + assert strategy.output_sharding_spec.sharding_sequence in add_strategies_cover_list + + # check two tensors element-wise add case + add_1_strategies = strategy_map[nodes[6]] + assert len(add_1_strategies) == 25 + + +if __name__ == '__main__': + test_conv_handler()