[autoparallel] add sum handler (#2101)

pull/2103/head
YuliangLiu0306 2022-12-08 17:02:54 +08:00 committed by GitHub
parent e4705ba4e2
commit d3d4630495
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 433 additions and 2 deletions

View File

@ -15,6 +15,7 @@ from .output_handler import OuputHandler
from .placeholder_handler import PlacehodlerHandler from .placeholder_handler import PlacehodlerHandler
from .registry import operator_registry from .registry import operator_registry
from .reshape_handler import ReshapeHandler from .reshape_handler import ReshapeHandler
from .sum_handler import SumHandler
from .tensor_constructor_handler import TensorConstructorHandler from .tensor_constructor_handler import TensorConstructorHandler
from .unary_elementwise_handler import UnaryElementwiseHandler from .unary_elementwise_handler import UnaryElementwiseHandler
from .where_handler import WhereHandler from .where_handler import WhereHandler
@ -25,5 +26,5 @@ __all__ = [
'UnaryElementwiseHandler', 'ReshapeHandler', 'PlacehodlerHandler', 'OuputHandler', 'WhereHandler', 'UnaryElementwiseHandler', 'ReshapeHandler', 'PlacehodlerHandler', 'OuputHandler', 'WhereHandler',
'NormPoolingHandler', 'BinaryElementwiseHandler', 'MatMulHandler', 'operator_registry', 'ADDMMFunctionHandler', 'NormPoolingHandler', 'BinaryElementwiseHandler', 'MatMulHandler', 'operator_registry', 'ADDMMFunctionHandler',
'GetItemHandler', 'GetattrHandler', 'ViewHandler', 'PermuteHandler', 'TensorConstructorHandler', 'GetItemHandler', 'GetattrHandler', 'ViewHandler', 'PermuteHandler', 'TensorConstructorHandler',
'EmbeddingModuleHandler', 'EmbeddingFunctionHandler' 'EmbeddingModuleHandler', 'EmbeddingFunctionHandler', 'SumHandler'
] ]

View File

@ -16,6 +16,7 @@ from .output_generator import OutputGenerator
from .placeholder_generator import PlaceholderGenerator from .placeholder_generator import PlaceholderGenerator
from .reshape_generator import ReshapeGenerator from .reshape_generator import ReshapeGenerator
from .strategy_generator import StrategyGenerator from .strategy_generator import StrategyGenerator
from .sum_generator import SumGenerator
from .tensor_constructor_generator import TensorConstructorGenerator from .tensor_constructor_generator import TensorConstructorGenerator
from .unary_elementwise_generator import UnaryElementwiseGenerator from .unary_elementwise_generator import UnaryElementwiseGenerator
from .where_generator import WhereGenerator from .where_generator import WhereGenerator
@ -26,5 +27,5 @@ __all__ = [
'BatchNormStrategyGenerator', 'GetItemStrategyGenerator', 'TensorStrategyGenerator', 'TensorTupleStrategyGenerator', 'BatchNormStrategyGenerator', 'GetItemStrategyGenerator', 'TensorStrategyGenerator', 'TensorTupleStrategyGenerator',
'LayerNormGenerator', 'ReshapeGenerator', 'PlaceholderGenerator', 'OutputGenerator', 'WhereGenerator', 'LayerNormGenerator', 'ReshapeGenerator', 'PlaceholderGenerator', 'OutputGenerator', 'WhereGenerator',
'ReshapeGenerator', 'NormalPoolStrategyGenerator', 'BinaryElementwiseStrategyGenerator', 'GetattrGenerator', 'ReshapeGenerator', 'NormalPoolStrategyGenerator', 'BinaryElementwiseStrategyGenerator', 'GetattrGenerator',
'TensorConstructorGenerator', 'EmbeddingStrategyGenerator' 'TensorConstructorGenerator', 'EmbeddingStrategyGenerator', 'SumGenerator'
] ]

View File

@ -0,0 +1,113 @@
import copy
import operator
from functools import reduce
from typing import List
from colossalai.auto_parallel.tensor_shard.node_handler.strategy.strategy_generator import FollowingStrategyGenerator
from colossalai.auto_parallel.tensor_shard.sharding_strategy import (
CommAction,
CommType,
MemoryCost,
ShardingStrategy,
TrainCycleItem,
)
from colossalai.auto_parallel.tensor_shard.utils import (
check_keep_sharding_status,
detect_reshape_mapping,
infer_output_dim_partition_dict,
)
from colossalai.tensor.shape_consistency import CollectiveCommPattern
from colossalai.tensor.sharding_spec import ShardingSpec
__all__ = ['SumGenerator']
class SumGenerator(FollowingStrategyGenerator):
"""
SumGenerator deals with the sharding strategies of torch.sum op.
"""
def validate(self) -> bool:
return super().validate()
def update_compute_cost(self, strategy: ShardingStrategy):
sharded_input_shape = strategy.sharding_specs[self.op_data['input']].get_sharded_shape_per_device()
sharded_output_shape = strategy.sharding_specs[self.op_data['output']].get_sharded_shape_per_device()
input_size_product = reduce(operator.mul, sharded_input_shape)
output_size_product = reduce(operator.mul, sharded_output_shape)
compute_cost = TrainCycleItem(fwd=input_size_product,
bwd=output_size_product,
total=input_size_product + output_size_product)
strategy.compute_cost = compute_cost
def update_memory_cost(self, strategy: ShardingStrategy):
'''
Compute the memory cost per device with this specific strategy.
'''
forward_size_mapping = {
'input': self._compute_size_in_bytes(strategy, "input"),
'output': self._compute_size_in_bytes(strategy, "output")
}
backward_size_mapping = copy.deepcopy(forward_size_mapping)
backward_size_mapping.pop("output")
# compute fwd cost incurred
# fwd_cost = input + output
fwd_activation_cost = sum([v for k, v in forward_size_mapping.items() if not self.is_param(k)])
fwd_parameter_cost = sum([v for k, v in forward_size_mapping.items() if self.is_param(k)])
fwd_mem_cost = MemoryCost(activation=fwd_activation_cost, parameter=fwd_parameter_cost)
# compute bwd cost incurred
# bwd_cost = input_grad
bwd_activation_cost = sum([v for k, v in backward_size_mapping.items() if not self.is_param(k)])
bwd_parameter_cost = sum([v for k, v in backward_size_mapping.items() if self.is_param(k)])
bwd_mem_cost = MemoryCost(activation=bwd_activation_cost, parameter=bwd_parameter_cost)
# compute total cost
total_mem_cost = MemoryCost(activation=fwd_activation_cost + bwd_activation_cost,
parameter=fwd_parameter_cost + bwd_parameter_cost)
memory_cost = TrainCycleItem(fwd=fwd_mem_cost, bwd=bwd_mem_cost, total=total_mem_cost)
strategy.memory_cost = memory_cost
def collate_strategies(self) -> List[ShardingStrategy]:
strategy_list = []
for index, strategy in enumerate(self.predecessor_node.strategies_vector):
dim_partition_dict_mapping = {}
communication_action_mapping = {}
input_sharding_spec = strategy.output_sharding_specs[self.op_data["input"]]
dim_partition_dict_for_input = copy.deepcopy(input_sharding_spec.dim_partition_dict)
sum_dims, sum_mapping_dict = self.op_data['sum_info'].data
# TODO: a better way to handle the distributed sum is sum all the data on chip and then do all reduce
# among all the shard groups
recover_dims = []
dim_partition_dict_for_output = {}
for dim in dim_partition_dict_for_input:
if dim in sum_dims:
recover_dims.append(dim)
elif dim in sum_mapping_dict:
dim_partition_dict_for_output[sum_mapping_dict[dim]] = dim_partition_dict_for_input[dim]
else:
raise RuntimeError(f'dim {dim} is not in sum_mapping_dict or sum_dims')
for dim in recover_dims:
dim_partition_dict_for_input.pop(dim)
dim_partition_dict_mapping = {
"input": dim_partition_dict_for_input,
"output": dim_partition_dict_for_output,
}
sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping)
# add index into name to pass the duplicated check
# we keep same strategies with different name for node merging, and it will not increase the searching space,
# because in solver, this node will be merged into other nodes, and solver will not create a new variable for this node.
name = f'{sharding_spec_mapping["input"].sharding_sequence} -> {sharding_spec_mapping["output"].sharding_sequence}_{index}'
strategy = self.get_sharding_strategy(name=name,
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping)
strategy_list.append(strategy)
return strategy_list

View File

@ -0,0 +1,81 @@
from typing import Dict, List
import torch
from ..sharding_strategy import OperationData, OperationDataType
from .node_handler import NodeHandler
from .registry import operator_registry
from .strategy import StrategyGenerator, SumGenerator
__all__ = ['SumHandler']
@operator_registry.register(torch.Tensor.sum)
@operator_registry.register(torch.sum)
class SumHandler(NodeHandler):
"""
A SumHandler which deals with the sharding strategies for torch.sum or torch.Tensor.sum.
"""
def get_strategy_generator(self) -> List[StrategyGenerator]:
op_data_mapping = self.get_operation_data_mapping()
generators = []
generators.append(SumGenerator(op_data_mapping, self.device_mesh, self.node.args[0]))
return generators
def get_operation_data_mapping(self) -> Dict[str, OperationData]:
# check if the input operand is a parameter
if isinstance(self.node.args[0]._meta_data, torch.nn.parameter.Parameter):
data_type = OperationDataType.PARAM
else:
data_type = OperationDataType.ARG
input_data = self.node.args[0]._meta_data
physical_input_operand = OperationData(name=str(self.node.args[0]), type=data_type, data=input_data)
if len(self.node.args) > 1:
sum_dims = self.node.args[1]
else:
sum_dims = tuple(range(self.node.args[0]._meta_data.dim()))
if isinstance(sum_dims, int):
sum_dims = (sum_dims,)
# recover negative value to positive
num_dims = self.node.args[0]._meta_data.dim()
for i in range(len(sum_dims)):
if sum_dims[i] < 0:
sum_dims[i] += num_dims
# mapping the input dims to output dims
# For examples:
# input: torch.rand(2, 3, 4, 5)
# output: torch.sum(input, (0, 2))
# sum_mapping_dict = {1: 0, 3: 1}
# sum_mapping_dict[1] = 0 means the 0th dim of output is the 1st dim of input
# sum_mapping_dict[3] = 1 means the 1st dim of output is the 3rd dim of input
sum_mapping_dict = {}
if 'keepdim' in self.node.kwargs and self.node.kwargs['keepdim']:
for i in range(num_dims):
sum_mapping_dict.update({i: i})
else:
output_index = 0
for i in range(num_dims):
if i not in sum_dims:
sum_mapping_dict.update({i: output_index})
output_index += 1
assert output_index == self.node._meta_data.dim()
sum_info = (sum_dims, sum_mapping_dict)
physical_shape_operand = OperationData(name='sum_info', type=OperationDataType.ARG, data=sum_info)
output_data = self.node._meta_data
physical_output_operand = OperationData(name=str(self.node), type=OperationDataType.OUTPUT, data=output_data)
mapping = {
"input": physical_input_operand,
"sum_info": physical_shape_operand,
"output": physical_output_operand
}
return mapping

View File

@ -0,0 +1,235 @@
from functools import partial
import pytest
import torch
import torch.multiprocessing as mp
import torch.nn as nn
from colossalai.auto_parallel.tensor_shard.node_handler.conv_handler import ConvFunctionHandler
from colossalai.auto_parallel.tensor_shard.node_handler.linear_handler import LinearFunctionHandler
from colossalai.auto_parallel.tensor_shard.node_handler.sum_handler import SumHandler
from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationData, OperationDataType, StrategiesVector
from colossalai.device.device_mesh import DeviceMesh
from colossalai.fx import ColoGraphModule, ColoTracer
from colossalai.initialize import launch
from colossalai.logging import disable_existing_loggers
from colossalai.testing import assert_close, parameterize, rerun_if_address_is_in_use
from colossalai.testing.pytest_wrapper import run_on_environment_flag
from colossalai.utils import free_port
from tests.test_auto_parallel.test_tensor_shard.test_node_handler.utils import numerical_test_for_node_strategy
class LinearSumModel(nn.Module):
def __init__(self, sum_dims, keepdim):
super().__init__()
self.sum_dims = sum_dims
self.keepdim = keepdim
def forward(self, input, other):
linear_node = nn.functional.linear(input, other, bias=None)
if self.sum_dims is not None:
sum_node = torch.sum(linear_node, self.sum_dims, keepdim=self.keepdim)
else:
sum_node = torch.sum(linear_node, keepdim=self.keepdim)
return sum_node
def check_sum_handler(rank, sum_dims, keepdim, world_size, port):
disable_existing_loggers()
launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
model = LinearSumModel(sum_dims=sum_dims, keepdim=keepdim).cuda()
physical_mesh_id = torch.arange(0, 4)
mesh_shape = (2, 2)
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True)
input = torch.rand(8, 16, 64, 32).to('cuda')
other = torch.rand(64, 32).to('cuda')
# index of linear node in computation graph
node_index = 2
# total number of linear strategies
strategy_number = 24
numerical_test_for_node_strategy(model=model,
device_mesh=device_mesh,
node_index=node_index,
strategy_number=strategy_number,
input_args=[input, other],
meta_arg_names=['input', 'other'],
node_type='following')
tracer = ColoTracer()
# graph():
# %input_1 : torch.Tensor [#users=1] = placeholder[target=input]
# %other : torch.Tensor [#users=1] = placeholder[target=other]
# %linear : [#users=1] = call_function[target=torch._C._nn.linear](args = (%input_1, %other), kwargs = {bias: None})
# %sum_1 : [#users=1] = call_function[target=torch.sum](args = (%linear,), kwargs = {})
# return sum_1
graph = tracer.trace(model,
meta_args={
"input": torch.rand(8, 16, 64, 32).to('meta'),
"other": torch.rand(64, 32).to('meta'),
})
gm = ColoGraphModule(model, graph)
previous_mod_node = list(graph.nodes)[2]
sum_node = list(graph.nodes)[3]
sum_strategies_vector = StrategiesVector(sum_node)
previous_strategies_vector = StrategiesVector(previous_mod_node)
# build handler
assert len(previous_strategies_vector) == 0
linear_handler = LinearFunctionHandler(node=previous_mod_node,
device_mesh=device_mesh,
strategies_vector=previous_strategies_vector)
linear_handler.register_strategy(compute_resharding_cost=False)
setattr(previous_mod_node, 'strategies_vector', previous_strategies_vector)
sum_handler = SumHandler(node=sum_node, device_mesh=device_mesh, strategies_vector=sum_strategies_vector)
sum_handler.register_strategy(compute_resharding_cost=False)
# sum handler is a following strategy handler, so the number of strategies is equal to the predecessor node.
assert len(sum_strategies_vector) == len(previous_strategies_vector)
strategy_name_list = [strategy.name for strategy in sum_strategies_vector]
# check operation data mapping
mapping = sum_handler.get_operation_data_mapping()
for name, op_data in mapping.items():
op_data: OperationData
# make sure they have valid values
assert op_data.data is not None
assert mapping['input'].name == "linear"
assert mapping['input'].data.is_meta
assert mapping['input'].data.shape == torch.Size([8, 16, 64, 64])
assert mapping['input'].type == OperationDataType.ARG
assert mapping['input'].logical_shape == torch.Size([8, 16, 64, 64])
assert mapping['output'].name == "sum_1"
sum_node_shape = torch.empty([8, 16, 64, 64]).sum(sum_dims, keepdim=keepdim).shape
assert mapping['output'].logical_shape == sum_node_shape
assert mapping['output'].type == OperationDataType.OUTPUT
# check strategy name
if sum_dims == (0, 2) and keepdim == False:
assert '[R, R, R, S1] -> [R, S1]_0' in strategy_name_list
assert '[R, S0, R, S1] -> [S0, S1]_1' in strategy_name_list
assert '[R, R, R, S1] -> [R, S1]_2' in strategy_name_list
assert '[R, R, R, S0] -> [R, S0]_3' in strategy_name_list
assert '[R, S1, R, S0] -> [S1, S0]_4' in strategy_name_list
assert '[R, R, R, S0] -> [R, S0]_5' in strategy_name_list
assert '[R, R, R, R] -> [R, R]_6' in strategy_name_list
assert '[R, S0, R, R] -> [S0, R]_7' in strategy_name_list
assert '[R, R, R, R] -> [R, R]_8' in strategy_name_list
assert '[R, R, R, R] -> [R, R]_9' in strategy_name_list
assert '[R, S1, R, R] -> [S1, R]_10' in strategy_name_list
assert '[R, R, R, R] -> [R, R]_11' in strategy_name_list
assert '[R, R, R, S1] -> [R, S1]_12' in strategy_name_list
assert '[R, R, R, S0] -> [R, S0]_13' in strategy_name_list
assert '[R, R, R, R] -> [R, R]_14' in strategy_name_list
assert '[R, R, R, R] -> [R, R]_15' in strategy_name_list
assert '[R, R, R, S0] -> [R, S0]_16' in strategy_name_list
assert '[R, R, R, S1] -> [R, S1]_17' in strategy_name_list
assert '[R, R, R, R] -> [R, R]_18' in strategy_name_list
assert '[R, S01, R, R] -> [S01, R]_19' in strategy_name_list
assert '[R, R, R, R] -> [R, R]_20' in strategy_name_list
assert '[R, R, R, R] -> [R, R]_21' in strategy_name_list
assert '[R, R, R, S01] -> [R, S01]_22' in strategy_name_list
assert '[R, R, R, R] -> [R, R]_23' in strategy_name_list
if sum_dims == (0, 2) and keepdim == True:
assert '[R, R, R, S1] -> [R, R, R, S1]_0' in strategy_name_list
assert '[R, S0, R, S1] -> [R, S0, R, S1]_1' in strategy_name_list
assert '[R, R, R, S1] -> [R, R, R, S1]_2' in strategy_name_list
assert '[R, R, R, S0] -> [R, R, R, S0]_3' in strategy_name_list
assert '[R, S1, R, S0] -> [R, S1, R, S0]_4' in strategy_name_list
assert '[R, R, R, S0] -> [R, R, R, S0]_5' in strategy_name_list
assert '[R, R, R, R] -> [R, R, R, R]_6' in strategy_name_list
assert '[R, S0, R, R] -> [R, S0, R, R]_7' in strategy_name_list
assert '[R, R, R, R] -> [R, R, R, R]_8' in strategy_name_list
assert '[R, R, R, R] -> [R, R, R, R]_9' in strategy_name_list
assert '[R, S1, R, R] -> [R, S1, R, R]_10' in strategy_name_list
assert '[R, R, R, R] -> [R, R, R, R]_11' in strategy_name_list
assert '[R, R, R, S1] -> [R, R, R, S1]_12' in strategy_name_list
assert '[R, R, R, S0] -> [R, R, R, S0]_13' in strategy_name_list
assert '[R, R, R, R] -> [R, R, R, R]_14' in strategy_name_list
assert '[R, R, R, R] -> [R, R, R, R]_15' in strategy_name_list
assert '[R, R, R, S0] -> [R, R, R, S0]_16' in strategy_name_list
assert '[R, R, R, S1] -> [R, R, R, S1]_17' in strategy_name_list
assert '[R, R, R, R] -> [R, R, R, R]_18' in strategy_name_list
assert '[R, S01, R, R] -> [R, S01, R, R]_19' in strategy_name_list
assert '[R, R, R, R] -> [R, R, R, R]_20' in strategy_name_list
assert '[R, R, R, R] -> [R, R, R, R]_21' in strategy_name_list
assert '[R, R, R, S01] -> [R, R, R, S01]_22' in strategy_name_list
assert '[R, R, R, R] -> [R, R, R, R]_23' in strategy_name_list
if sum_dims == 1 and keepdim == False:
assert '[S0, R, R, S1] -> [S0, R, S1]_0' in strategy_name_list
assert '[R, R, R, S1] -> [R, R, S1]_1' in strategy_name_list
assert '[R, R, S0, S1] -> [R, S0, S1]_2' in strategy_name_list
assert '[S1, R, R, S0] -> [S1, R, S0]_3' in strategy_name_list
assert '[R, R, R, S0] -> [R, R, S0]_4' in strategy_name_list
assert '[R, R, S1, S0] -> [R, S1, S0]_5' in strategy_name_list
assert '[S0, R, R, R] -> [S0, R, R]_6' in strategy_name_list
assert '[R, R, R, R] -> [R, R, R]_7' in strategy_name_list
assert '[R, R, S0, R] -> [R, S0, R]_8' in strategy_name_list
assert '[S1, R, R, R] -> [S1, R, R]_9' in strategy_name_list
assert '[R, R, R, R] -> [R, R, R]_10' in strategy_name_list
assert '[R, R, S1, R] -> [R, S1, R]_11' in strategy_name_list
assert '[R, R, R, S1] -> [R, R, S1]_12' in strategy_name_list
assert '[R, R, R, S0] -> [R, R, S0]_13' in strategy_name_list
assert '[R, R, R, R] -> [R, R, R]_14' in strategy_name_list
assert '[R, R, R, R] -> [R, R, R]_15' in strategy_name_list
assert '[R, R, R, S0] -> [R, R, S0]_16' in strategy_name_list
assert '[R, R, R, S1] -> [R, R, S1]_17' in strategy_name_list
assert '[S01, R, R, R] -> [S01, R, R]_18' in strategy_name_list
assert '[R, R, R, R] -> [R, R, R]_19' in strategy_name_list
assert '[R, R, S01, R] -> [R, S01, R]_20' in strategy_name_list
assert '[R, R, R, R] -> [R, R, R]_21' in strategy_name_list
assert '[R, R, R, S01] -> [R, R, S01]_22' in strategy_name_list
assert '[R, R, R, R] -> [R, R, R]_23' in strategy_name_list
if sum_dims == 1 and keepdim == True:
assert '[S0, R, R, S1] -> [S0, R, R, S1]_0' in strategy_name_list
assert '[R, R, R, S1] -> [R, R, R, S1]_1' in strategy_name_list
assert '[R, R, S0, S1] -> [R, R, S0, S1]_2' in strategy_name_list
assert '[S1, R, R, S0] -> [S1, R, R, S0]_3' in strategy_name_list
assert '[R, R, R, S0] -> [R, R, R, S0]_4' in strategy_name_list
assert '[R, R, S1, S0] -> [R, R, S1, S0]_5' in strategy_name_list
assert '[S0, R, R, R] -> [S0, R, R, R]_6' in strategy_name_list
assert '[R, R, R, R] -> [R, R, R, R]_7' in strategy_name_list
assert '[R, R, S0, R] -> [R, R, S0, R]_8' in strategy_name_list
assert '[S1, R, R, R] -> [S1, R, R, R]_9' in strategy_name_list
assert '[R, R, R, R] -> [R, R, R, R]_10' in strategy_name_list
assert '[R, R, S1, R] -> [R, R, S1, R]_11' in strategy_name_list
assert '[R, R, R, S1] -> [R, R, R, S1]_12' in strategy_name_list
assert '[R, R, R, S0] -> [R, R, R, S0]_13' in strategy_name_list
assert '[R, R, R, R] -> [R, R, R, R]_14' in strategy_name_list
assert '[R, R, R, R] -> [R, R, R, R]_15' in strategy_name_list
assert '[R, R, R, S0] -> [R, R, R, S0]_16' in strategy_name_list
assert '[R, R, R, S1] -> [R, R, R, S1]_17' in strategy_name_list
assert '[S01, R, R, R] -> [S01, R, R, R]_18' in strategy_name_list
assert '[R, R, R, R] -> [R, R, R, R]_19' in strategy_name_list
assert '[R, R, S01, R] -> [R, R, S01, R]_20' in strategy_name_list
assert '[R, R, R, R] -> [R, R, R, R]_21' in strategy_name_list
assert '[R, R, R, S01] -> [R, R, R, S01]_22' in strategy_name_list
assert '[R, R, R, R] -> [R, R, R, R]_23' in strategy_name_list
@run_on_environment_flag(name='AUTO_PARALLEL')
@pytest.mark.dist
@rerun_if_address_is_in_use()
@parameterize('sum_dims', [(0, 2), 1])
@parameterize('keepdim', [False, True])
def test_sum_handler(sum_dims, keepdim):
world_size = 4
run_func = partial(check_sum_handler, sum_dims=sum_dims, keepdim=keepdim, world_size=world_size, port=free_port())
mp.spawn(run_func, nprocs=world_size)
if __name__ == '__main__':
test_sum_handler()