From 536560ccc088870f377aaf454b22eed35e942b62 Mon Sep 17 00:00:00 2001 From: YuliangLiu0306 <72588413+YuliangLiu0306@users.noreply.github.com> Date: Wed, 14 Dec 2022 16:09:53 +0800 Subject: [PATCH] [autoparallel] implement softmax handler (#2132) --- .../tensor_shard/node_handler/__init__.py | 3 +- .../node_handler/softmax_handler.py | 55 ++++++ .../node_handler/strategy/__init__.py | 3 +- .../strategy/softmax_generator.py | 104 ++++++++++ .../node_handler/unary_elementwise_handler.py | 2 - .../test_node_handler/test_softmax_handler.py | 186 ++++++++++++++++++ 6 files changed, 349 insertions(+), 4 deletions(-) create mode 100644 colossalai/auto_parallel/tensor_shard/node_handler/softmax_handler.py create mode 100644 colossalai/auto_parallel/tensor_shard/node_handler/strategy/softmax_generator.py create mode 100644 tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_softmax_handler.py diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/__init__.py b/colossalai/auto_parallel/tensor_shard/node_handler/__init__.py index 014f3b50b..b4ba3b7cd 100644 --- a/colossalai/auto_parallel/tensor_shard/node_handler/__init__.py +++ b/colossalai/auto_parallel/tensor_shard/node_handler/__init__.py @@ -15,6 +15,7 @@ from .output_handler import OuputHandler from .placeholder_handler import PlacehodlerHandler from .registry import operator_registry from .reshape_handler import ReshapeHandler +from .softmax_handler import SoftmaxHandler from .sum_handler import SumHandler from .tensor_constructor_handler import TensorConstructorHandler from .unary_elementwise_handler import UnaryElementwiseHandler @@ -26,5 +27,5 @@ __all__ = [ 'UnaryElementwiseHandler', 'ReshapeHandler', 'PlacehodlerHandler', 'OuputHandler', 'WhereHandler', 'NormPoolingHandler', 'BinaryElementwiseHandler', 'MatMulHandler', 'operator_registry', 'ADDMMFunctionHandler', 'GetItemHandler', 'GetattrHandler', 'ViewHandler', 'PermuteHandler', 'TensorConstructorHandler', - 'EmbeddingModuleHandler', 'EmbeddingFunctionHandler', 'SumHandler' + 'EmbeddingModuleHandler', 'EmbeddingFunctionHandler', 'SumHandler', 'SoftmaxHandler' ] diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/softmax_handler.py b/colossalai/auto_parallel/tensor_shard/node_handler/softmax_handler.py new file mode 100644 index 000000000..743a1f90e --- /dev/null +++ b/colossalai/auto_parallel/tensor_shard/node_handler/softmax_handler.py @@ -0,0 +1,55 @@ +from typing import Dict, List + +import torch + +from ..sharding_strategy import OperationData, OperationDataType +from .node_handler import NodeHandler +from .registry import operator_registry +from .strategy import SoftmaxGenerator, StrategyGenerator + +__all__ = ['SoftmaxHandler'] + + +@operator_registry.register(torch.nn.Softmax) +@operator_registry.register(torch.nn.functional.softmax) +class SoftmaxHandler(NodeHandler): + """ + A SoftmaxHandler which deals with the sharding strategies for + torch.nn.Softmax or torch.nn.functional.softmax. + """ + + def get_strategy_generator(self) -> List[StrategyGenerator]: + op_data_mapping = self.get_operation_data_mapping() + generators = [] + generators.append(SoftmaxGenerator(op_data_mapping, self.device_mesh, self.node.args[0])) + return generators + + def get_operation_data_mapping(self) -> Dict[str, OperationData]: + # check if the input operand is a parameter + if isinstance(self.node.args[0]._meta_data, torch.nn.parameter.Parameter): + data_type = OperationDataType.PARAM + else: + data_type = OperationDataType.ARG + + input_data = self.node.args[0]._meta_data + physical_input_operand = OperationData(name=str(self.node.args[0]), type=data_type, data=input_data) + + softmax_dim = self.node.kwargs['dim'] + + num_dims = self.node.args[0]._meta_data.dim() + # recover negative value to positive + if softmax_dim < 0: + softmax_dim += num_dims + + physical_dim_operand = OperationData(name='softmax_dim', type=OperationDataType.ARG, data=softmax_dim) + + output_data = self.node._meta_data + physical_output_operand = OperationData(name=str(self.node), type=OperationDataType.OUTPUT, data=output_data) + + mapping = { + "input": physical_input_operand, + "softmax_dim": physical_dim_operand, + "output": physical_output_operand + } + + return mapping diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/strategy/__init__.py b/colossalai/auto_parallel/tensor_shard/node_handler/strategy/__init__.py index f52b3e1d8..8d25475f9 100644 --- a/colossalai/auto_parallel/tensor_shard/node_handler/strategy/__init__.py +++ b/colossalai/auto_parallel/tensor_shard/node_handler/strategy/__init__.py @@ -15,6 +15,7 @@ from .normal_pooling_generator import NormalPoolStrategyGenerator from .output_generator import OutputGenerator from .placeholder_generator import PlaceholderGenerator from .reshape_generator import ReshapeGenerator +from .softmax_generator import SoftmaxGenerator from .strategy_generator import StrategyGenerator from .sum_generator import SumGenerator from .tensor_constructor_generator import TensorConstructorGenerator @@ -27,5 +28,5 @@ __all__ = [ 'BatchNormStrategyGenerator', 'GetItemStrategyGenerator', 'TensorStrategyGenerator', 'TensorTupleStrategyGenerator', 'LayerNormGenerator', 'ReshapeGenerator', 'PlaceholderGenerator', 'OutputGenerator', 'WhereGenerator', 'ReshapeGenerator', 'NormalPoolStrategyGenerator', 'BinaryElementwiseStrategyGenerator', 'GetattrGenerator', - 'TensorConstructorGenerator', 'EmbeddingStrategyGenerator', 'SumGenerator' + 'TensorConstructorGenerator', 'EmbeddingStrategyGenerator', 'SumGenerator', 'SoftmaxGenerator' ] diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/strategy/softmax_generator.py b/colossalai/auto_parallel/tensor_shard/node_handler/strategy/softmax_generator.py new file mode 100644 index 000000000..a1ebadd04 --- /dev/null +++ b/colossalai/auto_parallel/tensor_shard/node_handler/strategy/softmax_generator.py @@ -0,0 +1,104 @@ +import copy +import operator +from functools import reduce +from typing import List + +from colossalai.auto_parallel.tensor_shard.node_handler.strategy.strategy_generator import FollowingStrategyGenerator +from colossalai.auto_parallel.tensor_shard.sharding_strategy import ( + CommAction, + CommType, + MemoryCost, + ShardingStrategy, + TrainCycleItem, +) +from colossalai.auto_parallel.tensor_shard.utils import ( + check_keep_sharding_status, + detect_reshape_mapping, + infer_output_dim_partition_dict, +) +from colossalai.tensor.shape_consistency import CollectiveCommPattern + +__all__ = ['SoftmaxGenerator'] + + +class SoftmaxGenerator(FollowingStrategyGenerator): + """ + SoftmaxGenerator is used to generate strategies for torch.nn.Softmax or F.softmax. + """ + + def validate(self) -> bool: + return super().validate() + + def update_compute_cost(self, strategy: ShardingStrategy): + ''' + Compute the computation cost per device with this specific strategy. + ''' + sharded_input_shape = strategy.sharding_specs[self.op_data['input']].get_sharded_shape_per_device() + sharded_output_shape = strategy.sharding_specs[self.op_data['output']].get_sharded_shape_per_device() + input_size_product = reduce(operator.mul, sharded_input_shape) + output_size_product = reduce(operator.mul, sharded_output_shape) + + forward_compute_cost = output_size_product * 2 + backward_compute_cost = input_size_product + total_compute_cost = forward_compute_cost + backward_compute_cost + compute_cost = TrainCycleItem(fwd=forward_compute_cost, bwd=backward_compute_cost, total=total_compute_cost) + strategy.compute_cost = compute_cost + + def update_memory_cost(self, strategy: ShardingStrategy): + ''' + Compute the memory cost per device with this specific strategy. + ''' + forward_size_mapping = { + 'input': self._compute_size_in_bytes(strategy, "input"), + 'output': self._compute_size_in_bytes(strategy, "output") + } + + backward_size_mapping = copy.deepcopy(forward_size_mapping) + backward_size_mapping.pop("output") + # compute fwd cost incurred + # fwd_cost = input + output + fwd_activation_cost = sum([v for k, v in forward_size_mapping.items() if not self.is_param(k)]) + fwd_parameter_cost = sum([v for k, v in forward_size_mapping.items() if self.is_param(k)]) + fwd_mem_cost = MemoryCost(activation=fwd_activation_cost, parameter=fwd_parameter_cost) + + # compute bwd cost incurred + # bwd_cost = input_grad + bwd_activation_cost = sum([v for k, v in backward_size_mapping.items() if not self.is_param(k)]) + bwd_parameter_cost = sum([v for k, v in backward_size_mapping.items() if self.is_param(k)]) + bwd_mem_cost = MemoryCost(activation=bwd_activation_cost, parameter=bwd_parameter_cost) + + # compute total cost + total_mem_cost = MemoryCost(activation=fwd_activation_cost + bwd_activation_cost, + parameter=fwd_parameter_cost + bwd_parameter_cost) + memory_cost = TrainCycleItem(fwd=fwd_mem_cost, bwd=bwd_mem_cost, total=total_mem_cost) + strategy.memory_cost = memory_cost + + def collate_strategies(self) -> List[ShardingStrategy]: + strategy_list = [] + for index, strategy in enumerate(self.predecessor_node.strategies_vector): + dim_partition_dict_mapping = {} + communication_action_mapping = {} + input_sharding_spec = strategy.output_sharding_specs[self.op_data["input"]] + dim_partition_dict_for_input = copy.deepcopy(input_sharding_spec.dim_partition_dict) + softmax_dim = self.op_data['softmax_dim'].data + + if softmax_dim in dim_partition_dict_for_input: + recover_dims = dim_partition_dict_for_input.pop(softmax_dim) + + dim_partition_dict_for_output = copy.deepcopy(dim_partition_dict_for_input) + dim_partition_dict_mapping = { + "input": dim_partition_dict_for_input, + "output": dim_partition_dict_for_output, + } + sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping) + # add index into name to pass the duplicated check + # we keep same strategies with different name for node merging, and it will not increase the searching space, + # because in solver, this node will be merged into other nodes, and solver will not create a new variable for this node. + name = f'{sharding_spec_mapping["input"].sharding_sequence} -> {sharding_spec_mapping["output"].sharding_sequence}_{index}' + + strategy = self.get_sharding_strategy(name=name, + sharding_spec_mapping=sharding_spec_mapping, + communication_action_mapping=communication_action_mapping) + strategy_list.append(strategy) + + return strategy_list diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/unary_elementwise_handler.py b/colossalai/auto_parallel/tensor_shard/node_handler/unary_elementwise_handler.py index 4c9d355c3..bda160906 100644 --- a/colossalai/auto_parallel/tensor_shard/node_handler/unary_elementwise_handler.py +++ b/colossalai/auto_parallel/tensor_shard/node_handler/unary_elementwise_handler.py @@ -16,8 +16,6 @@ __all__ = ['UnaryElementwiseHandler'] @operator_registry.register(torch.nn.ReLU) @operator_registry.register(torch.nn.Tanh) @operator_registry.register(torch.tanh) -# TODO: softmax need to be relocated -@operator_registry.register(torch.nn.functional.softmax) @operator_registry.register(torch.nn.modules.dropout.Dropout) @operator_registry.register(torch.Tensor.contiguous) @operator_registry.register(torch.nn.functional.dropout) diff --git a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_softmax_handler.py b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_softmax_handler.py new file mode 100644 index 000000000..b5e8e3277 --- /dev/null +++ b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_softmax_handler.py @@ -0,0 +1,186 @@ +from functools import partial + +import pytest +import torch +import torch.multiprocessing as mp +import torch.nn as nn +import torch.nn.functional as F + +from colossalai.auto_parallel.tensor_shard.node_handler.linear_handler import LinearFunctionHandler +from colossalai.auto_parallel.tensor_shard.node_handler.softmax_handler import SoftmaxHandler +from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationData, OperationDataType, StrategiesVector +from colossalai.device.device_mesh import DeviceMesh +from colossalai.fx import ColoGraphModule, ColoTracer +from colossalai.initialize import launch +from colossalai.logging import disable_existing_loggers +from colossalai.testing import assert_close, parameterize, rerun_if_address_is_in_use +from colossalai.testing.pytest_wrapper import run_on_environment_flag +from colossalai.utils import free_port +from tests.test_auto_parallel.test_tensor_shard.test_node_handler.utils import numerical_test_for_node_strategy + + +class LinearSplitModel(nn.Module): + + def __init__(self, softmax_dim): + super().__init__() + self.softmax_dim = softmax_dim + + def forward(self, input, other): + linear_node = F.linear(input, other, bias=None) + softmax_node = F.softmax(linear_node, self.softmax_dim) + return softmax_node + + +def check_split_handler(rank, softmax_dim, model_cls, world_size, port): + disable_existing_loggers() + launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + model = model_cls(softmax_dim=softmax_dim).cuda() + + input = torch.rand(8, 16, 64, 32).to('cuda') + other = torch.rand(64, 32).to('cuda') + # index of linear node in computation graph + node_index = 2 + # total number of linear strategies + strategy_number = 23 + + physical_mesh_id = torch.arange(0, 4) + mesh_shape = (2, 2) + device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True) + + numerical_test_for_node_strategy(model=model, + device_mesh=device_mesh, + node_index=node_index, + strategy_number=strategy_number, + input_args=[input, other], + meta_arg_names=['input', 'other'], + node_type='following') + tracer = ColoTracer() + + # graph(): + # %input_1 : torch.Tensor [#users=1] = placeholder[target=input] + # %other : torch.Tensor [#users=1] = placeholder[target=other] + # %linear : [#users=1] = call_function[target=torch._C._nn.linear](args = (%input_1, %other), kwargs = {bias: None}) + # %softmax : [#users=1] = call_method[target=split](args = (%linear,), kwargs = {}) + # return split + graph = tracer.trace(model, + meta_args={ + "input": torch.rand(8, 16, 64, 32).to('meta'), + "other": torch.rand(64, 32).to('meta'), + }) + + gm = ColoGraphModule(model, graph) + + previous_mod_node = list(graph.nodes)[2] + split_node = list(graph.nodes)[3] + split_strategies_vector = StrategiesVector(split_node) + previous_strategies_vector = StrategiesVector(previous_mod_node) + + # build handler + assert len(previous_strategies_vector) == 0 + linear_handler = LinearFunctionHandler(node=previous_mod_node, + device_mesh=device_mesh, + strategies_vector=previous_strategies_vector) + linear_handler.register_strategy(compute_resharding_cost=False) + setattr(previous_mod_node, 'strategies_vector', previous_strategies_vector) + + softmax_handler = SoftmaxHandler(node=split_node, + device_mesh=device_mesh, + strategies_vector=split_strategies_vector) + + softmax_handler.register_strategy(compute_resharding_cost=False) + + # check operation data mapping + mapping = softmax_handler.get_operation_data_mapping() + + for name, op_data in mapping.items(): + op_data: OperationData + # make sure they have valid values + assert op_data.data is not None + + assert mapping['input'].name == "linear" + assert mapping['input'].data.is_meta + assert mapping['input'].data.shape == torch.Size([8, 16, 64, 64]) + assert mapping['input'].type == OperationDataType.ARG + assert mapping['input'].logical_shape == torch.Size([8, 16, 64, 64]) + + assert mapping['softmax_dim'].name == "softmax_dim" + assert mapping['softmax_dim'].data == softmax_dim + assert mapping['softmax_dim'].type == OperationDataType.ARG + + assert mapping['output'].name == "softmax" + assert mapping['output'].data.shape == torch.Size([8, 16, 64, 64]) + assert mapping['output'].logical_shape == torch.Size([8, 16, 64, 64]) + assert mapping['output'].type == OperationDataType.OUTPUT + + # reshape handler is a following strategy handler, so the number of strategies is equal to the predecessor node. + assert len(split_strategies_vector) == len(previous_strategies_vector) + strategy_name_list = [strategy.name for strategy in split_strategies_vector] + + if softmax_dim == 0: + assert '[R, R, R, S1] -> [R, R, R, S1]_0' in strategy_name_list + assert '[R, S0, R, S1] -> [R, S0, R, S1]_1' in strategy_name_list + assert '[R, R, S0, S1] -> [R, R, S0, S1]_2' in strategy_name_list + assert '[R, R, R, S0] -> [R, R, R, S0]_3' in strategy_name_list + assert '[R, S1, R, S0] -> [R, S1, R, S0]_4' in strategy_name_list + assert '[R, R, S1, S0] -> [R, R, S1, S0]_5' in strategy_name_list + assert '[R, R, R, R] -> [R, R, R, R]_6' in strategy_name_list + assert '[R, S0, R, R] -> [R, S0, R, R]_7' in strategy_name_list + assert '[R, R, S0, R] -> [R, R, S0, R]_8' in strategy_name_list + assert '[R, R, R, R] -> [R, R, R, R]_9' in strategy_name_list + assert '[R, S1, R, R] -> [R, S1, R, R]_10' in strategy_name_list + assert '[R, R, S1, R] -> [R, R, S1, R]_11' in strategy_name_list + assert '[R, R, R, S1] -> [R, R, R, S1]_12' in strategy_name_list + assert '[R, R, R, S0] -> [R, R, R, S0]_13' in strategy_name_list + assert '[R, R, R, R] -> [R, R, R, R]_14' in strategy_name_list + assert '[R, R, R, R] -> [R, R, R, R]_15' in strategy_name_list + assert '[R, R, R, S0] -> [R, R, R, S0]_16' in strategy_name_list + assert '[R, R, R, S1] -> [R, R, R, S1]_17' in strategy_name_list + assert '[R, R, R, R] -> [R, R, R, R]_18' in strategy_name_list + assert '[R, S01, R, R] -> [R, S01, R, R]_19' in strategy_name_list + assert '[R, R, S01, R] -> [R, R, S01, R]_20' in strategy_name_list + assert '[R, R, R, R] -> [R, R, R, R]_21' in strategy_name_list + assert '[R, R, R, S01] -> [R, R, R, S01]_22' in strategy_name_list + + if softmax_dim == 1: + assert '[S0, R, R, S1] -> [S0, R, R, S1]_0' in strategy_name_list + assert '[R, R, R, S1] -> [R, R, R, S1]_1' in strategy_name_list + assert '[R, R, S0, S1] -> [R, R, S0, S1]_2' in strategy_name_list + assert '[S1, R, R, S0] -> [S1, R, R, S0]_3' in strategy_name_list + assert '[R, R, R, S0] -> [R, R, R, S0]_4' in strategy_name_list + assert '[R, R, S1, S0] -> [R, R, S1, S0]_5' in strategy_name_list + assert '[S0, R, R, R] -> [S0, R, R, R]_6' in strategy_name_list + assert '[R, R, R, R] -> [R, R, R, R]_7' in strategy_name_list + assert '[R, R, S0, R] -> [R, R, S0, R]_8' in strategy_name_list + assert '[S1, R, R, R] -> [S1, R, R, R]_9' in strategy_name_list + assert '[R, R, R, R] -> [R, R, R, R]_10' in strategy_name_list + assert '[R, R, S1, R] -> [R, R, S1, R]_11' in strategy_name_list + assert '[R, R, R, S1] -> [R, R, R, S1]_12' in strategy_name_list + assert '[R, R, R, S0] -> [R, R, R, S0]_13' in strategy_name_list + assert '[R, R, R, R] -> [R, R, R, R]_14' in strategy_name_list + assert '[R, R, R, R] -> [R, R, R, R]_15' in strategy_name_list + assert '[R, R, R, S0] -> [R, R, R, S0]_16' in strategy_name_list + assert '[R, R, R, S1] -> [R, R, R, S1]_17' in strategy_name_list + assert '[S01, R, R, R] -> [S01, R, R, R]_18' in strategy_name_list + assert '[R, R, R, R] -> [R, R, R, R]_19' in strategy_name_list + assert '[R, R, S01, R] -> [R, R, S01, R]_20' in strategy_name_list + assert '[R, R, R, R] -> [R, R, R, R]_21' in strategy_name_list + assert '[R, R, R, S01] -> [R, R, R, S01]_22' in strategy_name_list + + +@run_on_environment_flag(name='AUTO_PARALLEL') +@pytest.mark.dist +@rerun_if_address_is_in_use() +@parameterize('softmax_dim', [0, 1, 2, 3]) +@parameterize('model_cls', [LinearSplitModel]) +def test_split_handler(softmax_dim, model_cls): + world_size = 4 + run_func = partial(check_split_handler, + softmax_dim=softmax_dim, + model_cls=model_cls, + world_size=world_size, + port=free_port()) + mp.spawn(run_func, nprocs=world_size) + + +if __name__ == '__main__': + test_split_handler()