[autoparallel] add batch norm handler v2 (#1666)

pull/1669/head
YuliangLiu0306 2022-09-29 11:02:49 +08:00 committed by GitHub
parent 9708638ded
commit 746f8f979d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 424 additions and 0 deletions

View File

@ -0,0 +1,45 @@
import torch
import torch.nn.functional as F
from .node_handler import ModuleHandler, NodeHandler
from ..sharding_strategy import ShardingStrategy_V2, OperationDataType, OperationData
from ..strategy import BatchNormStrategyGenerator, StrategyGenerator_V2
from typing import List, Dict
from .registry import operator_registry
__all__ = ['BatchNormModuleHandler']
@operator_registry.register(torch.nn.BatchNorm1d)
@operator_registry.register(torch.nn.BatchNorm2d)
@operator_registry.register(torch.nn.BatchNorm3d)
class BatchNormModuleHandler(ModuleHandler):
"""
A BatchNormModuleHandler which deals with the sharding strategies for nn.BatchNormXd module.
"""
def get_strategy_generator(self) -> List[StrategyGenerator_V2]:
op_data_mapping = self.get_operation_data_mapping()
generators = []
generators.append(BatchNormStrategyGenerator(op_data_mapping, self.device_mesh))
return generators
def get_operation_data_mapping(self) -> Dict[str, OperationData]:
# use transposed shape for strategies
# the strategies will be transformed back to its original shape in self.post_process
physical_input_operand = OperationData(name=str(self.node.args[0]),
type=OperationDataType.ARG,
data=self.node.args[0]._meta_data)
physical_other_operand = OperationData(name="weight",
type=OperationDataType.PARAM,
data=self.named_parameters['weight'],
logical_shape=self.named_parameters['weight'].shape)
physical_output = OperationData(name=str(self.node), type=OperationDataType.OUTPUT, data=self.node._meta_data)
mapping = {"input": physical_input_operand, "other": physical_other_operand, "output": physical_output}
if self.named_parameters['bias'] is not None:
physical_bias_operand = OperationData(name="bias",
type=OperationDataType.PARAM,
data=self.named_parameters['bias'])
mapping['bias'] = physical_bias_operand
return mapping

View File

@ -0,0 +1,291 @@
import operator
from functools import reduce
from ..sharding_strategy import ShardingStrategy_V2, TrainCycleItem, MemoryCost
from colossalai.tensor.shape_consistency import CollectiveCommPattern
from .strategy_generator import StrategyGenerator_V2
from typing import List
from .._utils import exception_handler
import copy
__all__ = ['BatchNormStrategyGenerator']
class BatchNormStrategyGenerator(StrategyGenerator_V2):
"""
A StrategyGenerator which deals with the sharding strategies of batch normalization.
To keep the math consistency, there are two way to do BatchNorm if the input
shards on batch dimension:
1. We gather the input partitions through batch dimension, then do the normal BatchNorm.
2. We do the SyncBatchNorm on the each input partition seperately, the SyncBN op will help
us to keep the computing correctness.
In this generator, both methods will be considered.
"""
@property
def has_bias(self):
return 'bias' in self.op_data
def validate(self) -> bool:
'''
In sanity check, we need make sure the input data having correct dimension size.
For BatchNorm1d, the dim of input data should be 3([N, C, L]).
For BatchNorm2d, the dim of input data should be 4([N, C, H, W]).
For BatchNorm3d, the dim of input data should be 5([N, C, H, W, D]).
'''
input_op_data = self.op_data['input']
assert input_op_data.dim() in (3, 4,
5), f'We suppose the dim of input fed into conv op should in range of [3, 5].'
def update_compute_cost(self, strategy: ShardingStrategy_V2) -> TrainCycleItem:
'''
Compute the computation cost per device with this specific strategy.
Note: compute_cost need to be devided by TFLOPS, now it just shows the computation size.
'''
# TODO: a constant coefficient need to be added.
# 1D: (L) * N * Cin
# 2D: (H * W) * N * Cin
# 3D: (H * W * D) * N * Cin
sharded_input_shape = strategy.sharding_specs[self.op_data['input']].get_sharded_shape_per_device()
sharded_output_shape = strategy.sharding_specs[self.op_data['output']].get_sharded_shape_per_device()
if self.has_bias:
# bias add is an element wise operation, so the cost is equal to product of output shape.
bias_compute_cost = reduce(operator.mul, sharded_output_shape)
input_product = reduce(operator.mul, sharded_input_shape, 1)
forward_compute_cost = input_product
backward_activation_compute_cost = input_product
backward_weight_compute_cost = input_product
backward_compute_cost = backward_weight_compute_cost + backward_activation_compute_cost
if self.has_bias:
forward_compute_cost += bias_compute_cost
backward_compute_cost += bias_compute_cost
total_compute_cost = forward_compute_cost + backward_compute_cost
compute_cost = TrainCycleItem(fwd=forward_compute_cost, bwd=backward_compute_cost, total=total_compute_cost)
return compute_cost
def update_memory_cost(self, strategy: ShardingStrategy_V2) -> TrainCycleItem:
forward_size_mapping = {
'input': self._compute_size_in_bytes(strategy, "input"),
'other': self._compute_size_in_bytes(strategy, "other"),
'output': self._compute_size_in_bytes(strategy, "output")
}
if self.has_bias:
bias_size = self._compute_size_in_bytes(strategy, "bias")
forward_size_mapping['bias'] = bias_size
backward_size_mapping = copy.deepcopy(forward_size_mapping)
backward_size_mapping.pop("output")
# compute fwd cost incurred
# fwd_cost = input + other + bias + output
fwd_activation_cost = sum([v for k, v in forward_size_mapping.items() if not self.is_param(k)])
fwd_parameter_cost = sum([v for k, v in forward_size_mapping.items() if self.is_param(k)])
fwd_mem_cost = MemoryCost(activation=fwd_activation_cost, parameter=fwd_parameter_cost)
# compute bwd cost incurred
# bwd_cost = input_grad + other_grad + bias_grad
bwd_activation_cost = sum([v for k, v in backward_size_mapping.items() if not self.is_param(k)])
bwd_activation_cost = sum([v for k, v in backward_size_mapping.items() if self.is_param(k)])
bwd_mem_cost = MemoryCost(activation=bwd_activation_cost, parameter=bwd_activation_cost)
# compute total cost
total_mem_cost = MemoryCost(activation=fwd_activation_cost + bwd_activation_cost,
parameter=fwd_parameter_cost + bwd_activation_cost)
memory_cost = TrainCycleItem(fwd=fwd_mem_cost, bwd=bwd_mem_cost, total=total_mem_cost)
strategy.memory_cost = memory_cost
def split_input_channel(self, mesh_dim_0):
strategy_list = []
name = f'RS{mesh_dim_0} = RS{mesh_dim_0} x S{mesh_dim_0}'
dim_partition_dict_mapping = {
"input": {
1: [mesh_dim_0]
},
"other": {
0: [mesh_dim_0]
},
"output": {
1: [mesh_dim_0]
},
}
if self.has_bias:
dim_partition_dict_mapping["bias"] = {0: [mesh_dim_0]}
sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping)
communication_action_mapping = {}
return self.get_sharding_strategy(name=name,
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping)
def split_input_channel_1d(self, mesh_dim_0, mesh_dim_1):
name = f'RS{mesh_dim_0}{mesh_dim_1} = RS{mesh_dim_0}{mesh_dim_1} x S{mesh_dim_0}{mesh_dim_1}'
dim_partition_dict_mapping = {
"input": {
1: [mesh_dim_0, mesh_dim_1]
},
"other": {
0: [mesh_dim_0, mesh_dim_1]
},
"output": {
1: [mesh_dim_0, mesh_dim_1]
},
}
if self.has_bias:
dim_partition_dict_mapping["bias"] = {0: [mesh_dim_0, mesh_dim_1]}
sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping)
communication_action_mapping = {}
return self.get_sharding_strategy(name=name,
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping)
def non_split(self):
name = f'RR = RR x R'
dim_partition_dict_mapping = {
"input": {},
"other": {},
"output": {},
}
if self.has_bias:
dim_partition_dict_mapping["bias"] = {}
sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping)
communication_action_mapping = {}
return self.get_sharding_strategy(name=name,
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping)
def split_input_batch(self, mesh_dim_0):
name = f'S{mesh_dim_0}R = S{mesh_dim_0}R x R WITH SYNC_BN'
dim_partition_dict_mapping = {
"input": {
0: [mesh_dim_0]
},
"other": {},
"output": {
0: [mesh_dim_0]
},
}
if self.has_bias:
dim_partition_dict_mapping["bias"] = {}
sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping)
# set communication action
# For SyncBN case, we don't need to do communication for weight and bias.
# TODO: the communication happens interally at SyncBN operation. We need to replace the BN operation
# to SyncBN operation instead of inserting a communication node.
output_comm_spec = self.get_communication_spec(
sharding_spec=sharding_spec_mapping["output"],
communication_pattern=CollectiveCommPattern.ALLREDUCE_FWD_IDENTITY_BWD,
logical_process_axis=mesh_dim_0)
communication_action_mapping = {"output": output_comm_spec}
return self.get_sharding_strategy(name=name,
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping)
def split_input_batch_1d(self, mesh_dim_0, mesh_dim_1):
name = f'S{mesh_dim_0}{mesh_dim_1}R = S{mesh_dim_0}{mesh_dim_1}R x R WITH SYNC_BN'
dim_partition_dict_mapping = {
"input": {
0: [mesh_dim_0, mesh_dim_1]
},
"other": {},
"output": {
0: [mesh_dim_0, mesh_dim_1]
},
}
if self.has_bias:
dim_partition_dict_mapping["bias"] = {}
sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping)
# set communication action
# For SyncBN case, we don't need to do communication for gradients of weight and bias.
# TODO: the communication happens interally at SyncBN operation. We need to replace the BN operation
# to SyncBN operation instead of inserting a communication node.
output_comm_spec = self.get_communication_spec(
sharding_spec=sharding_spec_mapping["output"],
communication_pattern=CollectiveCommPattern.ALLREDUCE_FWD_IDENTITY_BWD,
logical_process_axis=[mesh_dim_0, mesh_dim_1])
communication_action_mapping = {"output": output_comm_spec}
return self.get_sharding_strategy(name=name,
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping)
def split_input_both_dim(self, mesh_dim_0, mesh_dim_1):
name = f'S{mesh_dim_0}S{mesh_dim_1} = S{mesh_dim_0}S{mesh_dim_1} x S{mesh_dim_1} WITH SYNC_BN'
dim_partition_dict_mapping = {
"input": {
0: [mesh_dim_0],
1: [mesh_dim_1],
},
"other": {
0: [mesh_dim_1],
},
"output": {
0: [mesh_dim_0],
1: [mesh_dim_1],
},
}
if self.has_bias:
dim_partition_dict_mapping["bias"] = {
0: [mesh_dim_1],
}
sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping)
# set communication action
# For SyncBN case, we don't need to do communication for gradients of weight and bias.
# TODO: the communication happens interally at SyncBN operation. We need to replace the BN operation
# to SyncBN operation instead of inserting a communication node.
output_comm_spec = self.get_communication_spec(
sharding_spec=sharding_spec_mapping["output"],
communication_pattern=CollectiveCommPattern.ALLREDUCE_FWD_IDENTITY_BWD,
logical_process_axis=[mesh_dim_0])
communication_action_mapping = {"output": output_comm_spec}
return self.get_sharding_strategy(name=name,
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping)
def generate(self):
'''
Generate every possible strategies for a BatchNorm node, and record all strategies into the strategies_vector.
'''
strategy_list = []
# RS = RS x S
strategy_list.append(self.split_input_channel(0))
strategy_list.append(self.split_input_channel(1))
# RR = RR x R
strategy_list.append(self.non_split())
# RS01 = RS01 x S01
strategy_list.append(self.split_input_channel_1d(0, 1))
# SR = SR x R WITH SYNC_BN
strategy_list.append(self.split_input_batch(0))
strategy_list.append(self.split_input_batch(1))
# SS = SS x S WITH SYNC_BN
strategy_list.append(self.split_input_both_dim(0, 1))
strategy_list.append(self.split_input_both_dim(1, 0))
# S01R = S01R x R WITH SYNC_BN
strategy_list.append(self.split_input_batch_1d(0, 1))
return strategy_list

View File

@ -0,0 +1,88 @@
from colossalai.fx.tracer.meta_patch.patched_module import linear
import torch
import torch.nn as nn
from colossalai.fx import ColoTracer, ColoGraphModule
from colossalai.auto_parallel.solver.op_handler.batch_norm_handler_v2 import BatchNormModuleHandler
from colossalai.auto_parallel.solver.sharding_strategy import OperationData, OperationDataType, StrategiesVector
from colossalai.device.device_mesh import DeviceMesh
def test_bn_module_handler():
model = nn.Sequential(nn.BatchNorm2d(16).to('meta'))
tracer = ColoTracer()
# graph():
# %input_1 : torch.Tensor [#users=1] = placeholder[target=input]
# %_0 : [#users=1] = call_module[target=0](args = (%input_1,), kwargs = {})
# return _0
graph = tracer.trace(model, meta_args={"input": torch.rand(4, 16, 64, 64).to('meta')})
gm = ColoGraphModule(model, graph)
physical_mesh_id = torch.arange(0, 4)
mesh_shape = (2, 2)
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape)
bn_mod_node = list(graph.nodes)[1]
strategies_vector = StrategiesVector(bn_mod_node)
# build handler
handler = BatchNormModuleHandler(node=bn_mod_node, device_mesh=device_mesh, strategies_vector=strategies_vector)
# check operation data mapping
mapping = handler.get_operation_data_mapping()
for name, op_data in mapping.items():
op_data: OperationData
# make sure they have valid values
assert op_data.logical_shape is not None
assert op_data.data is not None
assert mapping['input'].name == "input_1"
assert mapping['input'].data.is_meta
assert mapping['input'].data.shape == torch.Size([4, 16, 64, 64])
assert mapping['input'].type == OperationDataType.ARG
assert mapping['input'].logical_shape == torch.Size([4, 16, 64, 64])
assert mapping['other'].name == "weight"
assert mapping['other'].data.is_meta
assert mapping['other'].data.shape == torch.Size([16])
assert mapping['other'].type == OperationDataType.PARAM
assert mapping['other'].logical_shape == torch.Size([16])
assert mapping['bias'].name == "bias"
assert mapping['bias'].data.is_meta
assert mapping['bias'].data.shape == torch.Size([16])
assert mapping['bias'].type == OperationDataType.PARAM
assert mapping['bias'].logical_shape == torch.Size([16])
assert mapping['output'].name == "_0"
assert mapping['output'].data.is_meta
assert mapping['output'].data.shape == torch.Size([4, 16, 64, 64])
assert mapping['output'].type == OperationDataType.OUTPUT
strategies_vector = handler.register_strategy()
#[ 'S01R = S01R x R WITH SYNC_BN']
strategy_name_list = [val.name for val in strategies_vector]
# RS = RS x S
assert 'RS0 = RS0 x S0' in strategy_name_list
assert 'RS1 = RS1 x S1' in strategy_name_list
# RR = RR x R
assert 'RR = RR x R' in strategy_name_list
# RS01 = RS01 x S01
assert 'RS01 = RS01 x S01' in strategy_name_list
# SR = SR x R WITH SYNC_BN
assert 'S0R = S0R x R WITH SYNC_BN' in strategy_name_list
assert 'S1R = S1R x R WITH SYNC_BN' in strategy_name_list
# SS = SS x S WITH SYNC_BN
assert 'S0S1 = S0S1 x S1 WITH SYNC_BN' in strategy_name_list
assert 'S1S0 = S1S0 x S0 WITH SYNC_BN' in strategy_name_list
# S01R = S01R x R WITH SYNC_BN
assert 'S01R = S01R x R WITH SYNC_BN' in strategy_name_list
if __name__ == '__main__':
test_bn_module_handler()