From d3d4630495a2790dc0960d86ba954ececb5518c5 Mon Sep 17 00:00:00 2001 From: YuliangLiu0306 <72588413+YuliangLiu0306@users.noreply.github.com> Date: Thu, 8 Dec 2022 17:02:54 +0800 Subject: [PATCH] [autoparallel] add sum handler (#2101) --- .../tensor_shard/node_handler/__init__.py | 3 +- .../node_handler/strategy/__init__.py | 3 +- .../node_handler/strategy/sum_generator.py | 113 +++++++++ .../tensor_shard/node_handler/sum_handler.py | 81 ++++++ .../test_node_handler/test_sum_handler.py | 235 ++++++++++++++++++ 5 files changed, 433 insertions(+), 2 deletions(-) create mode 100644 colossalai/auto_parallel/tensor_shard/node_handler/strategy/sum_generator.py create mode 100644 colossalai/auto_parallel/tensor_shard/node_handler/sum_handler.py create mode 100644 tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_sum_handler.py diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/__init__.py b/colossalai/auto_parallel/tensor_shard/node_handler/__init__.py index c69f73c0b..014f3b50b 100644 --- a/colossalai/auto_parallel/tensor_shard/node_handler/__init__.py +++ b/colossalai/auto_parallel/tensor_shard/node_handler/__init__.py @@ -15,6 +15,7 @@ from .output_handler import OuputHandler from .placeholder_handler import PlacehodlerHandler from .registry import operator_registry from .reshape_handler import ReshapeHandler +from .sum_handler import SumHandler from .tensor_constructor_handler import TensorConstructorHandler from .unary_elementwise_handler import UnaryElementwiseHandler from .where_handler import WhereHandler @@ -25,5 +26,5 @@ __all__ = [ 'UnaryElementwiseHandler', 'ReshapeHandler', 'PlacehodlerHandler', 'OuputHandler', 'WhereHandler', 'NormPoolingHandler', 'BinaryElementwiseHandler', 'MatMulHandler', 'operator_registry', 'ADDMMFunctionHandler', 'GetItemHandler', 'GetattrHandler', 'ViewHandler', 'PermuteHandler', 'TensorConstructorHandler', - 'EmbeddingModuleHandler', 'EmbeddingFunctionHandler' + 'EmbeddingModuleHandler', 'EmbeddingFunctionHandler', 'SumHandler' ] diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/strategy/__init__.py b/colossalai/auto_parallel/tensor_shard/node_handler/strategy/__init__.py index cfd552e44..f52b3e1d8 100644 --- a/colossalai/auto_parallel/tensor_shard/node_handler/strategy/__init__.py +++ b/colossalai/auto_parallel/tensor_shard/node_handler/strategy/__init__.py @@ -16,6 +16,7 @@ from .output_generator import OutputGenerator from .placeholder_generator import PlaceholderGenerator from .reshape_generator import ReshapeGenerator from .strategy_generator import StrategyGenerator +from .sum_generator import SumGenerator from .tensor_constructor_generator import TensorConstructorGenerator from .unary_elementwise_generator import UnaryElementwiseGenerator from .where_generator import WhereGenerator @@ -26,5 +27,5 @@ __all__ = [ 'BatchNormStrategyGenerator', 'GetItemStrategyGenerator', 'TensorStrategyGenerator', 'TensorTupleStrategyGenerator', 'LayerNormGenerator', 'ReshapeGenerator', 'PlaceholderGenerator', 'OutputGenerator', 'WhereGenerator', 'ReshapeGenerator', 'NormalPoolStrategyGenerator', 'BinaryElementwiseStrategyGenerator', 'GetattrGenerator', - 'TensorConstructorGenerator', 'EmbeddingStrategyGenerator' + 'TensorConstructorGenerator', 'EmbeddingStrategyGenerator', 'SumGenerator' ] diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/strategy/sum_generator.py b/colossalai/auto_parallel/tensor_shard/node_handler/strategy/sum_generator.py new file mode 100644 index 000000000..a0fbc58d7 --- /dev/null +++ b/colossalai/auto_parallel/tensor_shard/node_handler/strategy/sum_generator.py @@ -0,0 +1,113 @@ +import copy +import operator +from functools import reduce +from typing import List + +from colossalai.auto_parallel.tensor_shard.node_handler.strategy.strategy_generator import FollowingStrategyGenerator +from colossalai.auto_parallel.tensor_shard.sharding_strategy import ( + CommAction, + CommType, + MemoryCost, + ShardingStrategy, + TrainCycleItem, +) +from colossalai.auto_parallel.tensor_shard.utils import ( + check_keep_sharding_status, + detect_reshape_mapping, + infer_output_dim_partition_dict, +) +from colossalai.tensor.shape_consistency import CollectiveCommPattern +from colossalai.tensor.sharding_spec import ShardingSpec + +__all__ = ['SumGenerator'] + + +class SumGenerator(FollowingStrategyGenerator): + """ + SumGenerator deals with the sharding strategies of torch.sum op. + """ + + def validate(self) -> bool: + return super().validate() + + def update_compute_cost(self, strategy: ShardingStrategy): + sharded_input_shape = strategy.sharding_specs[self.op_data['input']].get_sharded_shape_per_device() + sharded_output_shape = strategy.sharding_specs[self.op_data['output']].get_sharded_shape_per_device() + input_size_product = reduce(operator.mul, sharded_input_shape) + output_size_product = reduce(operator.mul, sharded_output_shape) + + compute_cost = TrainCycleItem(fwd=input_size_product, + bwd=output_size_product, + total=input_size_product + output_size_product) + + strategy.compute_cost = compute_cost + + def update_memory_cost(self, strategy: ShardingStrategy): + ''' + Compute the memory cost per device with this specific strategy. + ''' + forward_size_mapping = { + 'input': self._compute_size_in_bytes(strategy, "input"), + 'output': self._compute_size_in_bytes(strategy, "output") + } + + backward_size_mapping = copy.deepcopy(forward_size_mapping) + backward_size_mapping.pop("output") + # compute fwd cost incurred + # fwd_cost = input + output + fwd_activation_cost = sum([v for k, v in forward_size_mapping.items() if not self.is_param(k)]) + fwd_parameter_cost = sum([v for k, v in forward_size_mapping.items() if self.is_param(k)]) + fwd_mem_cost = MemoryCost(activation=fwd_activation_cost, parameter=fwd_parameter_cost) + + # compute bwd cost incurred + # bwd_cost = input_grad + bwd_activation_cost = sum([v for k, v in backward_size_mapping.items() if not self.is_param(k)]) + bwd_parameter_cost = sum([v for k, v in backward_size_mapping.items() if self.is_param(k)]) + bwd_mem_cost = MemoryCost(activation=bwd_activation_cost, parameter=bwd_parameter_cost) + + # compute total cost + total_mem_cost = MemoryCost(activation=fwd_activation_cost + bwd_activation_cost, + parameter=fwd_parameter_cost + bwd_parameter_cost) + memory_cost = TrainCycleItem(fwd=fwd_mem_cost, bwd=bwd_mem_cost, total=total_mem_cost) + strategy.memory_cost = memory_cost + + def collate_strategies(self) -> List[ShardingStrategy]: + strategy_list = [] + for index, strategy in enumerate(self.predecessor_node.strategies_vector): + dim_partition_dict_mapping = {} + communication_action_mapping = {} + input_sharding_spec = strategy.output_sharding_specs[self.op_data["input"]] + dim_partition_dict_for_input = copy.deepcopy(input_sharding_spec.dim_partition_dict) + sum_dims, sum_mapping_dict = self.op_data['sum_info'].data + + # TODO: a better way to handle the distributed sum is sum all the data on chip and then do all reduce + # among all the shard groups + recover_dims = [] + dim_partition_dict_for_output = {} + for dim in dim_partition_dict_for_input: + if dim in sum_dims: + recover_dims.append(dim) + elif dim in sum_mapping_dict: + dim_partition_dict_for_output[sum_mapping_dict[dim]] = dim_partition_dict_for_input[dim] + else: + raise RuntimeError(f'dim {dim} is not in sum_mapping_dict or sum_dims') + + for dim in recover_dims: + dim_partition_dict_for_input.pop(dim) + + dim_partition_dict_mapping = { + "input": dim_partition_dict_for_input, + "output": dim_partition_dict_for_output, + } + sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping) + # add index into name to pass the duplicated check + # we keep same strategies with different name for node merging, and it will not increase the searching space, + # because in solver, this node will be merged into other nodes, and solver will not create a new variable for this node. + name = f'{sharding_spec_mapping["input"].sharding_sequence} -> {sharding_spec_mapping["output"].sharding_sequence}_{index}' + + strategy = self.get_sharding_strategy(name=name, + sharding_spec_mapping=sharding_spec_mapping, + communication_action_mapping=communication_action_mapping) + strategy_list.append(strategy) + + return strategy_list diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/sum_handler.py b/colossalai/auto_parallel/tensor_shard/node_handler/sum_handler.py new file mode 100644 index 000000000..86f90694e --- /dev/null +++ b/colossalai/auto_parallel/tensor_shard/node_handler/sum_handler.py @@ -0,0 +1,81 @@ +from typing import Dict, List + +import torch + +from ..sharding_strategy import OperationData, OperationDataType +from .node_handler import NodeHandler +from .registry import operator_registry +from .strategy import StrategyGenerator, SumGenerator + +__all__ = ['SumHandler'] + + +@operator_registry.register(torch.Tensor.sum) +@operator_registry.register(torch.sum) +class SumHandler(NodeHandler): + """ + A SumHandler which deals with the sharding strategies for torch.sum or torch.Tensor.sum. + """ + + def get_strategy_generator(self) -> List[StrategyGenerator]: + op_data_mapping = self.get_operation_data_mapping() + generators = [] + generators.append(SumGenerator(op_data_mapping, self.device_mesh, self.node.args[0])) + return generators + + def get_operation_data_mapping(self) -> Dict[str, OperationData]: + # check if the input operand is a parameter + if isinstance(self.node.args[0]._meta_data, torch.nn.parameter.Parameter): + data_type = OperationDataType.PARAM + else: + data_type = OperationDataType.ARG + + input_data = self.node.args[0]._meta_data + physical_input_operand = OperationData(name=str(self.node.args[0]), type=data_type, data=input_data) + + if len(self.node.args) > 1: + sum_dims = self.node.args[1] + else: + sum_dims = tuple(range(self.node.args[0]._meta_data.dim())) + + if isinstance(sum_dims, int): + sum_dims = (sum_dims,) + + # recover negative value to positive + num_dims = self.node.args[0]._meta_data.dim() + for i in range(len(sum_dims)): + if sum_dims[i] < 0: + sum_dims[i] += num_dims + + # mapping the input dims to output dims + # For examples: + # input: torch.rand(2, 3, 4, 5) + # output: torch.sum(input, (0, 2)) + # sum_mapping_dict = {1: 0, 3: 1} + # sum_mapping_dict[1] = 0 means the 0th dim of output is the 1st dim of input + # sum_mapping_dict[3] = 1 means the 1st dim of output is the 3rd dim of input + sum_mapping_dict = {} + if 'keepdim' in self.node.kwargs and self.node.kwargs['keepdim']: + for i in range(num_dims): + sum_mapping_dict.update({i: i}) + else: + output_index = 0 + for i in range(num_dims): + if i not in sum_dims: + sum_mapping_dict.update({i: output_index}) + output_index += 1 + assert output_index == self.node._meta_data.dim() + + sum_info = (sum_dims, sum_mapping_dict) + physical_shape_operand = OperationData(name='sum_info', type=OperationDataType.ARG, data=sum_info) + + output_data = self.node._meta_data + physical_output_operand = OperationData(name=str(self.node), type=OperationDataType.OUTPUT, data=output_data) + + mapping = { + "input": physical_input_operand, + "sum_info": physical_shape_operand, + "output": physical_output_operand + } + + return mapping diff --git a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_sum_handler.py b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_sum_handler.py new file mode 100644 index 000000000..5fda4de1a --- /dev/null +++ b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_sum_handler.py @@ -0,0 +1,235 @@ +from functools import partial + +import pytest +import torch +import torch.multiprocessing as mp +import torch.nn as nn + +from colossalai.auto_parallel.tensor_shard.node_handler.conv_handler import ConvFunctionHandler +from colossalai.auto_parallel.tensor_shard.node_handler.linear_handler import LinearFunctionHandler +from colossalai.auto_parallel.tensor_shard.node_handler.sum_handler import SumHandler +from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationData, OperationDataType, StrategiesVector +from colossalai.device.device_mesh import DeviceMesh +from colossalai.fx import ColoGraphModule, ColoTracer +from colossalai.initialize import launch +from colossalai.logging import disable_existing_loggers +from colossalai.testing import assert_close, parameterize, rerun_if_address_is_in_use +from colossalai.testing.pytest_wrapper import run_on_environment_flag +from colossalai.utils import free_port +from tests.test_auto_parallel.test_tensor_shard.test_node_handler.utils import numerical_test_for_node_strategy + + +class LinearSumModel(nn.Module): + + def __init__(self, sum_dims, keepdim): + super().__init__() + self.sum_dims = sum_dims + self.keepdim = keepdim + + def forward(self, input, other): + linear_node = nn.functional.linear(input, other, bias=None) + if self.sum_dims is not None: + sum_node = torch.sum(linear_node, self.sum_dims, keepdim=self.keepdim) + else: + sum_node = torch.sum(linear_node, keepdim=self.keepdim) + return sum_node + + +def check_sum_handler(rank, sum_dims, keepdim, world_size, port): + disable_existing_loggers() + launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + model = LinearSumModel(sum_dims=sum_dims, keepdim=keepdim).cuda() + physical_mesh_id = torch.arange(0, 4) + mesh_shape = (2, 2) + device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True) + + input = torch.rand(8, 16, 64, 32).to('cuda') + other = torch.rand(64, 32).to('cuda') + # index of linear node in computation graph + node_index = 2 + # total number of linear strategies + strategy_number = 24 + + numerical_test_for_node_strategy(model=model, + device_mesh=device_mesh, + node_index=node_index, + strategy_number=strategy_number, + input_args=[input, other], + meta_arg_names=['input', 'other'], + node_type='following') + + tracer = ColoTracer() + + # graph(): + # %input_1 : torch.Tensor [#users=1] = placeholder[target=input] + # %other : torch.Tensor [#users=1] = placeholder[target=other] + # %linear : [#users=1] = call_function[target=torch._C._nn.linear](args = (%input_1, %other), kwargs = {bias: None}) + # %sum_1 : [#users=1] = call_function[target=torch.sum](args = (%linear,), kwargs = {}) + # return sum_1 + graph = tracer.trace(model, + meta_args={ + "input": torch.rand(8, 16, 64, 32).to('meta'), + "other": torch.rand(64, 32).to('meta'), + }) + gm = ColoGraphModule(model, graph) + + previous_mod_node = list(graph.nodes)[2] + sum_node = list(graph.nodes)[3] + sum_strategies_vector = StrategiesVector(sum_node) + previous_strategies_vector = StrategiesVector(previous_mod_node) + + # build handler + + assert len(previous_strategies_vector) == 0 + linear_handler = LinearFunctionHandler(node=previous_mod_node, + device_mesh=device_mesh, + strategies_vector=previous_strategies_vector) + linear_handler.register_strategy(compute_resharding_cost=False) + setattr(previous_mod_node, 'strategies_vector', previous_strategies_vector) + + sum_handler = SumHandler(node=sum_node, device_mesh=device_mesh, strategies_vector=sum_strategies_vector) + + sum_handler.register_strategy(compute_resharding_cost=False) + + # sum handler is a following strategy handler, so the number of strategies is equal to the predecessor node. + assert len(sum_strategies_vector) == len(previous_strategies_vector) + strategy_name_list = [strategy.name for strategy in sum_strategies_vector] + + # check operation data mapping + mapping = sum_handler.get_operation_data_mapping() + + for name, op_data in mapping.items(): + op_data: OperationData + # make sure they have valid values + assert op_data.data is not None + + assert mapping['input'].name == "linear" + assert mapping['input'].data.is_meta + assert mapping['input'].data.shape == torch.Size([8, 16, 64, 64]) + assert mapping['input'].type == OperationDataType.ARG + assert mapping['input'].logical_shape == torch.Size([8, 16, 64, 64]) + + assert mapping['output'].name == "sum_1" + sum_node_shape = torch.empty([8, 16, 64, 64]).sum(sum_dims, keepdim=keepdim).shape + assert mapping['output'].logical_shape == sum_node_shape + assert mapping['output'].type == OperationDataType.OUTPUT + + # check strategy name + if sum_dims == (0, 2) and keepdim == False: + assert '[R, R, R, S1] -> [R, S1]_0' in strategy_name_list + assert '[R, S0, R, S1] -> [S0, S1]_1' in strategy_name_list + assert '[R, R, R, S1] -> [R, S1]_2' in strategy_name_list + assert '[R, R, R, S0] -> [R, S0]_3' in strategy_name_list + assert '[R, S1, R, S0] -> [S1, S0]_4' in strategy_name_list + assert '[R, R, R, S0] -> [R, S0]_5' in strategy_name_list + assert '[R, R, R, R] -> [R, R]_6' in strategy_name_list + assert '[R, S0, R, R] -> [S0, R]_7' in strategy_name_list + assert '[R, R, R, R] -> [R, R]_8' in strategy_name_list + assert '[R, R, R, R] -> [R, R]_9' in strategy_name_list + assert '[R, S1, R, R] -> [S1, R]_10' in strategy_name_list + assert '[R, R, R, R] -> [R, R]_11' in strategy_name_list + assert '[R, R, R, S1] -> [R, S1]_12' in strategy_name_list + assert '[R, R, R, S0] -> [R, S0]_13' in strategy_name_list + assert '[R, R, R, R] -> [R, R]_14' in strategy_name_list + assert '[R, R, R, R] -> [R, R]_15' in strategy_name_list + assert '[R, R, R, S0] -> [R, S0]_16' in strategy_name_list + assert '[R, R, R, S1] -> [R, S1]_17' in strategy_name_list + assert '[R, R, R, R] -> [R, R]_18' in strategy_name_list + assert '[R, S01, R, R] -> [S01, R]_19' in strategy_name_list + assert '[R, R, R, R] -> [R, R]_20' in strategy_name_list + assert '[R, R, R, R] -> [R, R]_21' in strategy_name_list + assert '[R, R, R, S01] -> [R, S01]_22' in strategy_name_list + assert '[R, R, R, R] -> [R, R]_23' in strategy_name_list + + if sum_dims == (0, 2) and keepdim == True: + assert '[R, R, R, S1] -> [R, R, R, S1]_0' in strategy_name_list + assert '[R, S0, R, S1] -> [R, S0, R, S1]_1' in strategy_name_list + assert '[R, R, R, S1] -> [R, R, R, S1]_2' in strategy_name_list + assert '[R, R, R, S0] -> [R, R, R, S0]_3' in strategy_name_list + assert '[R, S1, R, S0] -> [R, S1, R, S0]_4' in strategy_name_list + assert '[R, R, R, S0] -> [R, R, R, S0]_5' in strategy_name_list + assert '[R, R, R, R] -> [R, R, R, R]_6' in strategy_name_list + assert '[R, S0, R, R] -> [R, S0, R, R]_7' in strategy_name_list + assert '[R, R, R, R] -> [R, R, R, R]_8' in strategy_name_list + assert '[R, R, R, R] -> [R, R, R, R]_9' in strategy_name_list + assert '[R, S1, R, R] -> [R, S1, R, R]_10' in strategy_name_list + assert '[R, R, R, R] -> [R, R, R, R]_11' in strategy_name_list + assert '[R, R, R, S1] -> [R, R, R, S1]_12' in strategy_name_list + assert '[R, R, R, S0] -> [R, R, R, S0]_13' in strategy_name_list + assert '[R, R, R, R] -> [R, R, R, R]_14' in strategy_name_list + assert '[R, R, R, R] -> [R, R, R, R]_15' in strategy_name_list + assert '[R, R, R, S0] -> [R, R, R, S0]_16' in strategy_name_list + assert '[R, R, R, S1] -> [R, R, R, S1]_17' in strategy_name_list + assert '[R, R, R, R] -> [R, R, R, R]_18' in strategy_name_list + assert '[R, S01, R, R] -> [R, S01, R, R]_19' in strategy_name_list + assert '[R, R, R, R] -> [R, R, R, R]_20' in strategy_name_list + assert '[R, R, R, R] -> [R, R, R, R]_21' in strategy_name_list + assert '[R, R, R, S01] -> [R, R, R, S01]_22' in strategy_name_list + assert '[R, R, R, R] -> [R, R, R, R]_23' in strategy_name_list + + if sum_dims == 1 and keepdim == False: + assert '[S0, R, R, S1] -> [S0, R, S1]_0' in strategy_name_list + assert '[R, R, R, S1] -> [R, R, S1]_1' in strategy_name_list + assert '[R, R, S0, S1] -> [R, S0, S1]_2' in strategy_name_list + assert '[S1, R, R, S0] -> [S1, R, S0]_3' in strategy_name_list + assert '[R, R, R, S0] -> [R, R, S0]_4' in strategy_name_list + assert '[R, R, S1, S0] -> [R, S1, S0]_5' in strategy_name_list + assert '[S0, R, R, R] -> [S0, R, R]_6' in strategy_name_list + assert '[R, R, R, R] -> [R, R, R]_7' in strategy_name_list + assert '[R, R, S0, R] -> [R, S0, R]_8' in strategy_name_list + assert '[S1, R, R, R] -> [S1, R, R]_9' in strategy_name_list + assert '[R, R, R, R] -> [R, R, R]_10' in strategy_name_list + assert '[R, R, S1, R] -> [R, S1, R]_11' in strategy_name_list + assert '[R, R, R, S1] -> [R, R, S1]_12' in strategy_name_list + assert '[R, R, R, S0] -> [R, R, S0]_13' in strategy_name_list + assert '[R, R, R, R] -> [R, R, R]_14' in strategy_name_list + assert '[R, R, R, R] -> [R, R, R]_15' in strategy_name_list + assert '[R, R, R, S0] -> [R, R, S0]_16' in strategy_name_list + assert '[R, R, R, S1] -> [R, R, S1]_17' in strategy_name_list + assert '[S01, R, R, R] -> [S01, R, R]_18' in strategy_name_list + assert '[R, R, R, R] -> [R, R, R]_19' in strategy_name_list + assert '[R, R, S01, R] -> [R, S01, R]_20' in strategy_name_list + assert '[R, R, R, R] -> [R, R, R]_21' in strategy_name_list + assert '[R, R, R, S01] -> [R, R, S01]_22' in strategy_name_list + assert '[R, R, R, R] -> [R, R, R]_23' in strategy_name_list + + if sum_dims == 1 and keepdim == True: + assert '[S0, R, R, S1] -> [S0, R, R, S1]_0' in strategy_name_list + assert '[R, R, R, S1] -> [R, R, R, S1]_1' in strategy_name_list + assert '[R, R, S0, S1] -> [R, R, S0, S1]_2' in strategy_name_list + assert '[S1, R, R, S0] -> [S1, R, R, S0]_3' in strategy_name_list + assert '[R, R, R, S0] -> [R, R, R, S0]_4' in strategy_name_list + assert '[R, R, S1, S0] -> [R, R, S1, S0]_5' in strategy_name_list + assert '[S0, R, R, R] -> [S0, R, R, R]_6' in strategy_name_list + assert '[R, R, R, R] -> [R, R, R, R]_7' in strategy_name_list + assert '[R, R, S0, R] -> [R, R, S0, R]_8' in strategy_name_list + assert '[S1, R, R, R] -> [S1, R, R, R]_9' in strategy_name_list + assert '[R, R, R, R] -> [R, R, R, R]_10' in strategy_name_list + assert '[R, R, S1, R] -> [R, R, S1, R]_11' in strategy_name_list + assert '[R, R, R, S1] -> [R, R, R, S1]_12' in strategy_name_list + assert '[R, R, R, S0] -> [R, R, R, S0]_13' in strategy_name_list + assert '[R, R, R, R] -> [R, R, R, R]_14' in strategy_name_list + assert '[R, R, R, R] -> [R, R, R, R]_15' in strategy_name_list + assert '[R, R, R, S0] -> [R, R, R, S0]_16' in strategy_name_list + assert '[R, R, R, S1] -> [R, R, R, S1]_17' in strategy_name_list + assert '[S01, R, R, R] -> [S01, R, R, R]_18' in strategy_name_list + assert '[R, R, R, R] -> [R, R, R, R]_19' in strategy_name_list + assert '[R, R, S01, R] -> [R, R, S01, R]_20' in strategy_name_list + assert '[R, R, R, R] -> [R, R, R, R]_21' in strategy_name_list + assert '[R, R, R, S01] -> [R, R, R, S01]_22' in strategy_name_list + assert '[R, R, R, R] -> [R, R, R, R]_23' in strategy_name_list + + +@run_on_environment_flag(name='AUTO_PARALLEL') +@pytest.mark.dist +@rerun_if_address_is_in_use() +@parameterize('sum_dims', [(0, 2), 1]) +@parameterize('keepdim', [False, True]) +def test_sum_handler(sum_dims, keepdim): + world_size = 4 + run_func = partial(check_sum_handler, sum_dims=sum_dims, keepdim=keepdim, world_size=world_size, port=free_port()) + mp.spawn(run_func, nprocs=world_size) + + +if __name__ == '__main__': + test_sum_handler()