diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/__init__.py b/colossalai/auto_parallel/tensor_shard/node_handler/__init__.py index 4b676d153..05e7615d8 100644 --- a/colossalai/auto_parallel/tensor_shard/node_handler/__init__.py +++ b/colossalai/auto_parallel/tensor_shard/node_handler/__init__.py @@ -1,3 +1,4 @@ +from .addmm_handler import ADDMMFunctionHandler from .batch_norm_handler import BatchNormModuleHandler from .binary_elementwise_handler import BinaryElementwiseHandler from .bmm_handler import AddBMMFunctionHandler, BMMFunctionHandler @@ -18,5 +19,5 @@ __all__ = [ 'LinearFunctionHandler', 'LinearModuleHandler', 'BMMFunctionHandler', 'AddBMMFunctionHandler', 'LayerNormModuleHandler', 'BatchNormModuleHandler', 'ConvModuleHandler', 'ConvFunctionHandler', 'UnaryElementwiseHandler', 'ReshapeHandler', 'PlacehodlerHandler', 'OuputHandler', 'WhereHandler', - 'NormPoolingHandler', 'BinaryElementwiseHandler', 'MatMulHandler', 'operator_registry' + 'NormPoolingHandler', 'BinaryElementwiseHandler', 'MatMulHandler', 'operator_registry', 'ADDMMFunctionHandler' ] diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/addmm_handler.py b/colossalai/auto_parallel/tensor_shard/node_handler/addmm_handler.py new file mode 100644 index 000000000..da0d199c5 --- /dev/null +++ b/colossalai/auto_parallel/tensor_shard/node_handler/addmm_handler.py @@ -0,0 +1,91 @@ +from typing import Dict, List, Union + +import torch + +from colossalai.tensor.shape_consistency import CollectiveCommPattern, CommSpec, ShapeConsistencyManager + +from ..sharding_strategy import CommAction, CommType, OperationData, OperationDataType, ShardingStrategy +from ..utils import comm_actions_for_oprands, recover_sharding_spec_for_broadcast_shape +from .node_handler import NodeHandler +from .registry import operator_registry +from .strategy import LinearProjectionStrategyGenerator, StrategyGenerator + +__all__ = ['ADDMMFunctionHandler'] + + +@operator_registry.register(torch.addmm) +@operator_registry.register(torch.Tensor.addmm) +class ADDMMFunctionHandler(NodeHandler): + """ + This is a NodeHandler class which deals with the batched matrix multiplication operation in PyTorch. + Such operations including `torch.bmm` and `torch.Tensor.bmm` require the tensor to be 3D, thus, there is + no logical-physical shape conversion in this handler. + """ + + def _infer_op_data_type(self, tensor: torch.Tensor) -> OperationDataType: + if isinstance(tensor, torch.nn.parameter.Parameter): + data_type = OperationDataType.PARAM + else: + data_type = OperationDataType.ARG + return data_type + + def get_operation_data_mapping(self) -> Dict[str, OperationData]: + + # input operand + input_data = self.node.args[1]._meta_data + physical_input_operand = OperationData(name=str(self.node.args[1]), + type=self._infer_op_data_type(input_data), + data=input_data) + + # other operand + other_data = self.node.args[2]._meta_data + physical_other_operand = OperationData(name=str(self.node.args[2]), + type=self._infer_op_data_type(other_data), + data=other_data) + # bias physical shape + bias_logical_shape = self.node._meta_data.shape + bias_data = self.node.args[0]._meta_data + physical_bias_operand = OperationData(name=str(self.node.args[0]), + type=self._infer_op_data_type(bias_data), + data=bias_data, + logical_shape=bias_logical_shape) + + # output + physical_output = OperationData(name=str(self.node), type=OperationDataType.OUTPUT, data=self.node._meta_data) + + mapping = { + "input": physical_input_operand, + "other": physical_other_operand, + "output": physical_output, + 'bias': physical_bias_operand + } + + return mapping + + def get_strategy_generator(self) -> List[StrategyGenerator]: + op_data_mapping = self.get_operation_data_mapping() + generators = [] + generators.append( + LinearProjectionStrategyGenerator(op_data_mapping, self.device_mesh, linear_projection_type='addmm')) + return generators + + def post_process(self, strategy: ShardingStrategy) -> Union[ShardingStrategy, List[ShardingStrategy]]: + # convert bias from its logical sharding spec to its physical sharding spec + op_data_mapping = self.get_operation_data_mapping() + + bias_op_data = op_data_mapping['bias'] + bias_physical_shape = bias_op_data.data.shape + bias_logical_shape = bias_op_data.logical_shape + bias_sharding_spec = strategy.get_sharding_spec_by_name(bias_op_data.name) + bias_sharding_spec, removed_dims = recover_sharding_spec_for_broadcast_shape( + bias_sharding_spec, bias_logical_shape, bias_physical_shape) + strategy.sharding_specs[bias_op_data] = bias_sharding_spec + + if len(removed_dims) > 0: + comm_action = comm_actions_for_oprands(node=self.node, + removed_dims=removed_dims, + op_data=bias_op_data, + sharding_spec=bias_sharding_spec) + strategy.communication_actions[bias_op_data] = comm_action + + return strategy diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/linear_handler.py b/colossalai/auto_parallel/tensor_shard/node_handler/linear_handler.py index 5aa769981..942f6d31b 100644 --- a/colossalai/auto_parallel/tensor_shard/node_handler/linear_handler.py +++ b/colossalai/auto_parallel/tensor_shard/node_handler/linear_handler.py @@ -140,7 +140,8 @@ class LinearModuleHandler(ModuleHandler): def get_strategy_generator(self) -> List[StrategyGenerator]: op_data_mapping = self.get_operation_data_mapping() generators = [] - generators.append(LinearProjectionStrategyGenerator(op_data_mapping, self.device_mesh)) + generators.append( + LinearProjectionStrategyGenerator(op_data_mapping, self.device_mesh, linear_projection_type='linear')) return generators def get_operation_data_mapping(self) -> Dict[str, OperationData]: @@ -199,7 +200,8 @@ class LinearFunctionHandler(NodeHandler): def get_strategy_generator(self) -> List[StrategyGenerator]: op_data_mapping = self.get_operation_data_mapping() generators = [] - generators.append(LinearProjectionStrategyGenerator(op_data_mapping, self.device_mesh)) + generators.append( + LinearProjectionStrategyGenerator(op_data_mapping, self.device_mesh, linear_projection_type='linear')) return generators def get_operation_data_mapping(self) -> Dict[str, OperationData]: diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/matmul_handler.py b/colossalai/auto_parallel/tensor_shard/node_handler/matmul_handler.py index ba3e03976..d3f9fd01d 100644 --- a/colossalai/auto_parallel/tensor_shard/node_handler/matmul_handler.py +++ b/colossalai/auto_parallel/tensor_shard/node_handler/matmul_handler.py @@ -363,7 +363,8 @@ class MatMulHandler(NodeHandler): elif self.matmul_type == MatMulType.MV: generators.append(MatVecStrategyGenerator(op_data_mapping, self.device_mesh)) elif self.matmul_type == MatMulType.MM: - generators.append(LinearProjectionStrategyGenerator(op_data_mapping, self.device_mesh)) + generators.append( + LinearProjectionStrategyGenerator(op_data_mapping, self.device_mesh, linear_projection_type='linear')) return generators def get_operation_data_mapping(self) -> Dict[str, OperationData]: diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/strategy/matmul_strategy_generator.py b/colossalai/auto_parallel/tensor_shard/node_handler/strategy/matmul_strategy_generator.py index b12e9c08d..043bb8654 100644 --- a/colossalai/auto_parallel/tensor_shard/node_handler/strategy/matmul_strategy_generator.py +++ b/colossalai/auto_parallel/tensor_shard/node_handler/strategy/matmul_strategy_generator.py @@ -209,6 +209,10 @@ class MatVecStrategyGenerator(MatMulStrategyGenerator): class LinearProjectionStrategyGenerator(MatMulStrategyGenerator): + def __init__(self, operation_data_mapping, device_mesh, linear_projection_type='linear'): + super().__init__(operation_data_mapping, device_mesh) + self.linear_projection_type = linear_projection_type + def update_compute_cost(self, strategy: ShardingStrategy) -> ShardingStrategy: # C = AB # C: [M, N], A: [M, P], B: [P, N] @@ -272,14 +276,21 @@ class LinearProjectionStrategyGenerator(MatMulStrategyGenerator): "other": { -1: [mesh_dim_1] }, - "bias": { - -1: [mesh_dim_1] - }, "output": { 0: [mesh_dim_0], -1: [mesh_dim_1] }, } + + # linear bias only has one dimension, but addmm bias has same dimensions + # as the output logically. + if self.linear_projection_type == 'linear': + dim_partition_dict_mapping['bias'] = {-1: [mesh_dim_1]} + elif self.linear_projection_type == 'addmm': + dim_partition_dict_mapping['bias'] = {0: [mesh_dim_0], -1: [mesh_dim_1]} + else: + raise ('Unsupported linear projection type') + sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping) # set communication action @@ -293,13 +304,13 @@ class LinearProjectionStrategyGenerator(MatMulStrategyGenerator): if self.is_param('other'): other_comm_action = self.get_communication_action( - sharding_spec_mapping["output"], + sharding_spec_mapping["other"], communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD, logical_process_axis=mesh_dim_0, comm_type=CommType.HOOK) else: other_comm_action = self.get_communication_action( - sharding_spec_mapping["output"], + sharding_spec_mapping["other"], communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD, logical_process_axis=mesh_dim_0, comm_type=CommType.BEFORE, @@ -308,7 +319,9 @@ class LinearProjectionStrategyGenerator(MatMulStrategyGenerator): communication_action_mapping['input'] = input_comm_action communication_action_mapping['other'] = other_comm_action - if self.has_bias: + # we only add allreduce comm action for linear bias, because + # allreduce comm action for addmm bias will be considered in post processing + if self.has_bias and self.linear_projection_type == 'linear': if self.is_param('bias'): bias_comm_action = self.get_communication_action( sharding_spec_mapping["bias"], @@ -347,6 +360,16 @@ class LinearProjectionStrategyGenerator(MatMulStrategyGenerator): 0: [mesh_dim_0] }, } + + # linear bias only has one dimension, but addmm bias has same dimensions + # as the output logically. + if self.linear_projection_type == 'linear': + dim_partition_dict_mapping['bias'] = {} + elif self.linear_projection_type == 'addmm': + dim_partition_dict_mapping['bias'] = {0: [mesh_dim_0]} + else: + raise ('Unsupported linear projection type') + sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping) # get communication action mapping @@ -360,13 +383,13 @@ class LinearProjectionStrategyGenerator(MatMulStrategyGenerator): if self.is_param('other'): other_comm_action = self.get_communication_action( - sharding_spec_mapping["output"], + sharding_spec_mapping["other"], communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD, logical_process_axis=mesh_dim_0, comm_type=CommType.HOOK) else: other_comm_action = self.get_communication_action( - sharding_spec_mapping["output"], + sharding_spec_mapping["other"], communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD, logical_process_axis=mesh_dim_0, comm_type=CommType.BEFORE, @@ -375,7 +398,9 @@ class LinearProjectionStrategyGenerator(MatMulStrategyGenerator): communication_action_mapping['other'] = other_comm_action communication_action_mapping['output'] = output_comm_action - if self.has_bias: + # we only add allreduce comm action for linear bias, because + # allreduce comm action for addmm bias will be considered in post processing + if self.has_bias and self.linear_projection_type == 'linear': if self.is_param('bias'): bias_comm_action = self.get_communication_action( sharding_spec_mapping["bias"], @@ -415,6 +440,10 @@ class LinearProjectionStrategyGenerator(MatMulStrategyGenerator): -1: [mesh_dim_1] }, } + + # We don't have to do anything special for bias here, because + # the bias is already the same sharding spec as the output. + sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping) # get communication actions @@ -451,7 +480,8 @@ class LinearProjectionStrategyGenerator(MatMulStrategyGenerator): "bias": {}, "output": {}, } - + # We don't have to do anything special for bias here, because + # the bias is already the same sharding spec as the output. sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping) # get communication action @@ -484,7 +514,8 @@ class LinearProjectionStrategyGenerator(MatMulStrategyGenerator): -1: [mesh_dim] }, } - + # We don't have to do anything special for bias here, because + # the bias is already the same sharding spec as the output. sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping) # get communication actions @@ -515,6 +546,16 @@ class LinearProjectionStrategyGenerator(MatMulStrategyGenerator): 0: [mesh_dim_0, mesh_dim_1] }, } + + # linear bias only has one dimension, but addmm bias has same dimensions + # as the output logically. + if self.linear_projection_type == 'linear': + dim_partition_dict_mapping['bias'] = {} + elif self.linear_projection_type == 'addmm': + dim_partition_dict_mapping['bias'] = {0: [mesh_dim_0, mesh_dim_1]} + else: + raise ('Unsupported linear projection type') + sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping) # get communication action @@ -534,7 +575,9 @@ class LinearProjectionStrategyGenerator(MatMulStrategyGenerator): arg_index=1) communication_action_mapping['other'] = other_comm_action - if self.has_bias: + # we only add allreduce comm action for linear bias, because + # allreduce comm action for addmm bias will be considered in post processing + if self.has_bias and self.linear_projection_type == 'linear': if self.is_param('bias'): bias_comm_action = self.get_communication_action( sharding_spec=sharding_spec_mapping['bias'], @@ -568,6 +611,9 @@ class LinearProjectionStrategyGenerator(MatMulStrategyGenerator): "bias": {}, "output": {}, } + + # We don't have to do anything special for bias here, because + # the bias is already the same sharding spec as the output. sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping) # get communication action @@ -600,6 +646,9 @@ class LinearProjectionStrategyGenerator(MatMulStrategyGenerator): -1: [mesh_dim_0, mesh_dim_1] }, } + + # We don't have to do anything special for bias here, because + # the bias is already the same sharding spec as the output. sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping) # get communication action @@ -626,10 +675,7 @@ class LinearProjectionStrategyGenerator(MatMulStrategyGenerator): assert input_data.data.dim() > 0 and other_data.data.dim() == 2 assert other_data.logical_shape[0] == input_data.logical_shape[-1] - # check if bias has the same a valid dim - has_bias = "bias" in self.op_data - - if has_bias: + if self.has_bias: bias_data = self.op_data['bias'] assert bias_data.logical_shape[-1] == other_data.logical_shape[-1] diff --git a/colossalai/fx/tracer/meta_patch/patched_function/arithmetic.py b/colossalai/fx/tracer/meta_patch/patched_function/arithmetic.py index aba254a80..042b92c58 100644 --- a/colossalai/fx/tracer/meta_patch/patched_function/arithmetic.py +++ b/colossalai/fx/tracer/meta_patch/patched_function/arithmetic.py @@ -72,11 +72,21 @@ def torch_linear(input, mat2, bias=None, *, out=None): def torch_addbmm(input, mat1, mat2, *, beta=1, alpha=1, out=None): if out is not None: raise ValueError("Don't support in-place abs for MetaTensor analysis") - batch_size, n, m = mat1.shape + _, n, _ = mat1.shape _, _, p = mat2.shape return torch.empty(n, p, device="meta") +@meta_patched_function.register(torch.addmm) +@meta_patched_function.register(torch.Tensor.addmm) +def torch_addmm(input, mat1, mat2, *, beta=1, alpha=1, out=None): + if out is not None: + raise ValueError("Don't support in-place abs for MetaTensor analysis") + n, _ = mat1.shape + _, p = mat2.shape + return torch.empty(n, p, device="meta") + + @meta_patched_function.register(torch.var_mean) def torch_var_mean(input, dim, unbiased=True, keepdim=False, *, out=None): assert out is None, 'saving to out is not supported yet' diff --git a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_addmm_handler.py b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_addmm_handler.py new file mode 100644 index 000000000..e8d3a95a7 --- /dev/null +++ b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_addmm_handler.py @@ -0,0 +1,156 @@ +from faulthandler import disable +from functools import partial +from xml.dom import WrongDocumentErr + +import pytest +import torch +import torch.multiprocessing as mp +import torch.nn as nn +from typing_extensions import Self + +from colossalai.auto_parallel.tensor_shard.node_handler import ADDMMFunctionHandler +from colossalai.auto_parallel.tensor_shard.sharding_strategy import ( + OperationData, + OperationDataType, + ShardingStrategy, + StrategiesVector, +) +from colossalai.device.device_mesh import DeviceMesh +from colossalai.fx import ColoGraphModule, ColoTracer +from colossalai.initialize import launch +from colossalai.logging import disable_existing_loggers +from colossalai.testing import assert_close, parameterize, rerun_if_address_is_in_use +from colossalai.testing.pytest_wrapper import run_on_environment_flag +from colossalai.utils import free_port +from tests.test_auto_parallel.test_tensor_shard.test_node_handler.utils import numerical_test_for_node_strategy + + +class AddmmModel(nn.Module): + + def __init__(self): + super().__init__() + + def forward(self, input, m1, m2): + x = torch.addmm(input, m1, m2) + return x + + +def check_linear_function_handler(rank, input_shape, world_size, port): + disable_existing_loggers() + launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + model = AddmmModel().cuda() + physical_mesh_id = torch.arange(0, 4) + mesh_shape = (2, 2) + device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True) + + input = torch.rand(input_shape).cuda() + m1 = torch.rand(4, 8).cuda() + m2 = torch.rand(8, 16).cuda() + # the index of addmm node in computation graph + node_index = 3 + # strategy number of linear node + strategy_number = 10 + # construct input args + input_args = [input, m1, m2] + # construct meta arg names + meta_arg_names = ['input', 'm1', 'm2'] + numerical_test_for_node_strategy(model=model, + device_mesh=device_mesh, + node_index=node_index, + strategy_number=strategy_number, + input_args=input_args, + meta_arg_names=meta_arg_names) + + tracer = ColoTracer() + graph = tracer.trace(model, + meta_args={ + "input": torch.rand(input_shape).to('meta'), + 'm1': torch.rand(4, 8).to('meta'), + 'm2': torch.rand(8, 16).to('meta'), + }) + gm = ColoGraphModule(model, graph) + # [input_1, m1, m2, addmm, output] + node_list = list(graph.nodes) + addmm_node = node_list[3] + strategies_vector = StrategiesVector(addmm_node) + + # build handler + handler = ADDMMFunctionHandler(node=addmm_node, device_mesh=device_mesh, strategies_vector=strategies_vector) + + handler.register_strategy(compute_resharding_cost=False) + strategy_name_list = [val.name for val in strategies_vector] + + # check operation data mapping + mapping = handler.get_operation_data_mapping() + + assert mapping['input'].name == "m1" + assert mapping['input'].data.shape == torch.Size([4, 8]) + assert mapping['input'].type == OperationDataType.ARG + assert mapping['input'].logical_shape == torch.Size([4, 8]) + + assert mapping['other'].name == "m2" + assert mapping['other'].data.shape == torch.Size([8, 16]) + assert mapping['other'].type == OperationDataType.ARG + assert mapping['other'].logical_shape == torch.Size([8, 16]) + + assert mapping['bias'].name == "input_1" + assert mapping['bias'].data.shape == torch.Size(input_shape) + assert mapping['bias'].type == OperationDataType.ARG + assert mapping['bias'].logical_shape == torch.Size([4, 16]) + + assert mapping['output'].name == "addmm" + assert mapping['output'].data.shape == torch.Size([4, 16]) + assert mapping['output'].type == OperationDataType.OUTPUT + + # one strategy will be converted to different physical sharding spec + assert len(strategy_name_list) > 8 + + # SS = SR x RS + assert 'S0S1 = S0R x RS1' in strategy_name_list + assert 'S1S0 = S1R x RS0' in strategy_name_list + + # SR = SS x SR + assert 'S0R = S0S1 x S1R' in strategy_name_list + assert 'S1R = S1S0 x S0R' in strategy_name_list + + # RS = RS x SS + assert 'RS0 = RS1 x S1S0' in strategy_name_list + assert 'RS1 = RS0 x S0S1' in strategy_name_list + + # RR = RS x SR + assert 'RR = RS0 x S0R' in strategy_name_list + assert 'RR = RS1 x S1R' in strategy_name_list + + # RS= RR x RS + assert 'RS0 = RR x RS0' in strategy_name_list + assert 'RS1 = RR x RS1' in strategy_name_list + + for strategy in strategies_vector: + strategy: ShardingStrategy + input_sharding_spec = strategy.get_sharding_spec_by_name('m1') + weight_sharding_spec = strategy.get_sharding_spec_by_name('m2') + output_sharding_spec = strategy.get_sharding_spec_by_name('addmm') + bias_sharding_spec = strategy.get_sharding_spec_by_name('input_1') + + # make sure the sharding matches across different operation data + assert input_sharding_spec.sharding_sequence[:-1] == output_sharding_spec.sharding_sequence[:-1] + assert weight_sharding_spec.sharding_sequence[0] == input_sharding_spec.sharding_sequence[1] + assert weight_sharding_spec.sharding_sequence[1] == output_sharding_spec.sharding_sequence[1] + assert bias_sharding_spec.sharding_sequence[-1] == output_sharding_spec.sharding_sequence[-1] + + +@parameterize('input_shape', [(16,), (4, 16)]) +@run_on_environment_flag(name='AUTO_PARALLEL') +@pytest.mark.dist +@rerun_if_address_is_in_use() +def test_addmm_handler(input_shape): + world_size = 4 + run_func_function = partial(check_linear_function_handler, + input_shape=input_shape, + world_size=world_size, + port=free_port()) + mp.spawn(run_func_function, nprocs=world_size) + + +if __name__ == '__main__': + test_addmm_handler()