[autoparallel] add pooling handler (#1690)

* [autoparallel] add pooling handler

* polish code
pull/1695/head
YuliangLiu0306 2022-10-13 13:42:13 +08:00 committed by GitHub
parent 319d654f79
commit 56088e6d98
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 213 additions and 1 deletions

View File

@ -0,0 +1,40 @@
import torch
import torch.nn.functional as F
from .node_handler import ModuleHandler, NodeHandler
from ..sharding_strategy import ShardingStrategy_V2, OperationDataType, OperationData
from ..strategy import NormalPoolStrategyGenerator, StrategyGenerator_V2
from typing import List, Dict
from .registry import operator_registry
__all__ = ['LinearModuleHandler', 'LinearFunctionHandler']
@operator_registry.register(torch.nn.MaxPool1d)
@operator_registry.register(torch.nn.MaxPool2d)
@operator_registry.register(torch.nn.MaxPool1d)
@operator_registry.register(torch.nn.AvgPool1d)
@operator_registry.register(torch.nn.AvgPool2d)
@operator_registry.register(torch.nn.AvgPool3d)
class NormPoolingHandler(ModuleHandler):
"""
A NormPoolingHandler which deals with the sharding strategies for nn.MaxPoolxd module.
"""
def get_strategy_generator(self) -> List[StrategyGenerator_V2]:
op_data_mapping = self.get_operation_data_mapping()
generators = []
generators.append(NormalPoolStrategyGenerator(op_data_mapping, self.device_mesh))
return generators
def get_operation_data_mapping(self) -> Dict[str, OperationData]:
# use transposed shape for strategies
# the strategies will be transformed back to its original shape in self.post_process
physical_input_operand = OperationData(name=str(self.node.args[0]),
type=OperationDataType.ARG,
data=self.node.args[0]._meta_data)
physical_weight_operand = OperationData(name="kernel", type=OperationDataType.ARG, data=self.module.kernel_size)
physical_output = OperationData(name=str(self.node), type=OperationDataType.OUTPUT, data=self.node._meta_data)
mapping = {"input": physical_input_operand, "other": physical_weight_operand, "output": physical_output}
return mapping

View File

@ -7,10 +7,11 @@ from .getitem_generator import GetItemStrategyGenerator, TensorStrategyGenerator
from .layer_norm_generator import LayerNormGenerator
from .where_generator import WhereGenerator
from .reshape_generator import ReshapeGenerator
from .normal_pooling_generator import NormalPoolStrategyGenerator
__all__ = [
'StrategyGenerator_V2', 'DotProductStrategyGenerator', 'MatVecStrategyGenerator',
'LinearProjectionStrategyGenerator', 'BatchedMatMulStrategyGenerator', 'ConvStrategyGenerator',
'UnaryElementwiseGenerator', 'BatchNormStrategyGenerator', 'GetItemStrategyGenerator', 'TensorStrategyGenerator',
'TensorTupleStrategyGenerator', 'LayerNormGenerator', "WhereGenerator", 'ReshapeGenerator'
'TensorTupleStrategyGenerator', 'LayerNormGenerator', "WhereGenerator", 'ReshapeGenerator', 'NormalPoolStrategyGenerator'
]

View File

@ -0,0 +1,117 @@
import operator
from functools import reduce
from ..sharding_strategy import ShardingStrategy_V2, TrainCycleItem, MemoryCost
from colossalai.tensor.shape_consistency import CollectiveCommPattern
from .strategy_generator import StrategyGenerator_V2
from typing import List
from .._utils import exception_handler, enumerate_all_possible_1d_sharding, enumerate_all_possible_2d_sharding
import copy
class NormalPoolStrategyGenerator(StrategyGenerator_V2):
"""
NormalPoolStrategyGenerator is a generic class to generate strategies for pool operation like MaxPoolxd.
The reason we call this normal pool is AvgPoolxd and MaxPoolxd are taking the kernel size element from image,
and reduce them depening on the operation type.
"""
def validate(self) -> bool:
'''
In sanity check, we need make sure the input data having correct dimension size.
For Pool1d, the dim of input data should be 3([N, C, L]).
For Pool2d, the dim of input data should be 4([N, C, H, W]).
For Pool3d, the dim of input data should be 5([N, C, H, W, D]).
'''
input_op_data = self.op_data['input']
assert input_op_data.dim() in (3, 4,
5), f'We suppose the dim of input fed into Pool op should in range of [3, 5].'
def update_compute_cost(self, strategy: ShardingStrategy_V2) -> TrainCycleItem:
'''
Compute the computation cost per device with this specific strategy.
Note: compute_cost need to be devided by TFLOPS, now it just shows the computation size.
'''
# TODO: compute_cost need to be devided by TFLOPS, now it just shows the computation size.
# 1D: (Lout) * N * C * kernel
# 2D: (H * W) * N * Cout * Cin * kernel
# 3D: (H * W * D) * N * Cout * Cin * kernel
sharded_output_shape = strategy.sharding_specs[self.op_data['output']].get_sharded_shape_per_device()
sharded_input_shape = strategy.sharding_specs[self.op_data['input']].get_sharded_shape_per_device()
kernel_size = self.op_data["other"].data
if isinstance(kernel_size, int):
kernel_size = [kernel_size] * (len(sharded_output_shape) - 2)
kernel_size_product = reduce(operator.mul, kernel_size)
output_size_product = reduce(operator.mul, sharded_output_shape)
input_size_product = reduce(operator.mul, sharded_input_shape)
forward_compute_cost = output_size_product * kernel_size_product
backward_compute_cost = input_size_product * kernel_size_product
total_compute_cost = forward_compute_cost + backward_compute_cost
compute_cost = TrainCycleItem(fwd=forward_compute_cost, bwd=backward_compute_cost, total=total_compute_cost)
return compute_cost
def update_memory_cost(self, strategy: ShardingStrategy_V2) -> ShardingStrategy_V2:
forward_size_mapping = {
'input': self._compute_size_in_bytes(strategy, "input"),
'output': self._compute_size_in_bytes(strategy, "output")
}
backward_size_mapping = copy.deepcopy(forward_size_mapping)
backward_size_mapping.pop("output")
# compute fwd cost incurred
# fwd_cost = input + output
fwd_activation_cost = sum([v for k, v in forward_size_mapping.items()])
fwd_mem_cost = MemoryCost(activation=fwd_activation_cost, parameter=0)
# compute bwd cost incurred
# bwd_cost = input_grad
bwd_activation_cost = sum([v for k, v in backward_size_mapping.items()])
bwd_mem_cost = MemoryCost(activation=bwd_activation_cost, parameter=0)
# compute total cost
total_mem_cost = MemoryCost(activation=fwd_activation_cost + bwd_activation_cost, parameter=0)
memory_cost = TrainCycleItem(fwd=fwd_mem_cost, bwd=bwd_mem_cost, total=total_mem_cost)
strategy.memory_cost = memory_cost
def _generate_strategy_with_dim_partition(self, dim_partition):
dim_partition_dict_mapping = {"input": dim_partition, "output": dim_partition}
sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping)
name = f'{sharding_spec_mapping["output"].sharding_sequence} = {sharding_spec_mapping["input"].sharding_sequence}'
communication_action_mapping = {}
strategy = self.get_sharding_strategy(name=name,
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping)
return strategy
def enumerate_all_possible_batch_dimensions_dim_partition(self, mesh_dim_0, mesh_dim_1):
dim_partition_list = []
dim_partition_list.extend(enumerate_all_possible_1d_sharding(mesh_dim_0, 2))
dim_partition_list.extend(enumerate_all_possible_1d_sharding(mesh_dim_1, 2))
dim_partition_list.extend(enumerate_all_possible_2d_sharding(mesh_dim_0, mesh_dim_1, 2))
# append {} for non_split case
dim_partition_list.append({})
return dim_partition_list
def generate(self) -> List[ShardingStrategy_V2]:
strategy_list = []
dim_partition_list = self.enumerate_all_possible_batch_dimensions_dim_partition(0, 1)
for dim_partition in dim_partition_list:
strategy = self._generate_strategy_with_dim_partition(dim_partition)
strategy_list.append(strategy)
for strategy in strategy_list:
self.update_communication_cost(strategy)
self.update_compute_cost(strategy)
self.update_memory_cost(strategy)
return strategy_list

View File

@ -0,0 +1,54 @@
from colossalai.fx.tracer.meta_patch.patched_module import linear
import torch
import torch.nn as nn
from colossalai.fx import ColoTracer, ColoGraphModule
from colossalai.auto_parallel.solver.op_handler.normal_pooling_handler import NormPoolingHandler
from colossalai.auto_parallel.solver.sharding_strategy import OperationData, OperationDataType, StrategiesVector
from colossalai.device.device_mesh import DeviceMesh
def test_norm_pool_handler():
model = nn.Sequential(nn.MaxPool2d(4, padding=1).to('meta'))
tracer = ColoTracer()
# graph():
# %input_1 : torch.Tensor [#users=1] = placeholder[target=input]
# %_0 : [#users=1] = call_module[target=0](args = (%input_1,), kwargs = {})
# return _0
graph = tracer.trace(model, meta_args={"input": torch.rand(4, 4, 64, 64).to('meta')})
gm = ColoGraphModule(model, graph)
physical_mesh_id = torch.arange(0, 4)
mesh_shape = (2, 2)
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape)
conv_mod_node = list(graph.nodes)[1]
strategies_vector = StrategiesVector(conv_mod_node)
# build handler
handler = NormPoolingHandler(node=conv_mod_node, device_mesh=device_mesh, strategies_vector=strategies_vector)
# check operation data mapping
mapping = handler.get_operation_data_mapping()
for name, op_data in mapping.items():
op_data: OperationData
# make sure they have valid values
assert op_data.data is not None
assert mapping['input'].name == "input_1"
assert mapping['input'].data.is_meta
assert mapping['input'].data.shape == torch.Size([4, 4, 64, 64])
assert mapping['input'].type == OperationDataType.ARG
assert mapping['input'].logical_shape == torch.Size([4, 4, 64, 64])
assert mapping['output'].name == "_0"
assert mapping['output'].data.is_meta
assert mapping['output'].data.shape == torch.Size([4, 4, 16, 16])
assert mapping['output'].type == OperationDataType.OUTPUT
strategies_vector = handler.register_strategy()
strategy_name_list = [val.name for val in strategies_vector]
assert len(strategy_name_list) == 9
if __name__ == '__main__':
test_norm_pool_handler()