mirror of https://github.com/hpcaitech/ColossalAI
[autoparallel] add batch norm handler v2 (#1666)
parent
9708638ded
commit
746f8f979d
|
@ -0,0 +1,45 @@
|
|||
import torch
|
||||
import torch.nn.functional as F
|
||||
from .node_handler import ModuleHandler, NodeHandler
|
||||
from ..sharding_strategy import ShardingStrategy_V2, OperationDataType, OperationData
|
||||
from ..strategy import BatchNormStrategyGenerator, StrategyGenerator_V2
|
||||
from typing import List, Dict
|
||||
from .registry import operator_registry
|
||||
|
||||
__all__ = ['BatchNormModuleHandler']
|
||||
|
||||
|
||||
@operator_registry.register(torch.nn.BatchNorm1d)
|
||||
@operator_registry.register(torch.nn.BatchNorm2d)
|
||||
@operator_registry.register(torch.nn.BatchNorm3d)
|
||||
class BatchNormModuleHandler(ModuleHandler):
|
||||
"""
|
||||
A BatchNormModuleHandler which deals with the sharding strategies for nn.BatchNormXd module.
|
||||
"""
|
||||
|
||||
def get_strategy_generator(self) -> List[StrategyGenerator_V2]:
|
||||
op_data_mapping = self.get_operation_data_mapping()
|
||||
generators = []
|
||||
generators.append(BatchNormStrategyGenerator(op_data_mapping, self.device_mesh))
|
||||
return generators
|
||||
|
||||
def get_operation_data_mapping(self) -> Dict[str, OperationData]:
|
||||
# use transposed shape for strategies
|
||||
# the strategies will be transformed back to its original shape in self.post_process
|
||||
physical_input_operand = OperationData(name=str(self.node.args[0]),
|
||||
type=OperationDataType.ARG,
|
||||
data=self.node.args[0]._meta_data)
|
||||
physical_other_operand = OperationData(name="weight",
|
||||
type=OperationDataType.PARAM,
|
||||
data=self.named_parameters['weight'],
|
||||
logical_shape=self.named_parameters['weight'].shape)
|
||||
physical_output = OperationData(name=str(self.node), type=OperationDataType.OUTPUT, data=self.node._meta_data)
|
||||
|
||||
mapping = {"input": physical_input_operand, "other": physical_other_operand, "output": physical_output}
|
||||
|
||||
if self.named_parameters['bias'] is not None:
|
||||
physical_bias_operand = OperationData(name="bias",
|
||||
type=OperationDataType.PARAM,
|
||||
data=self.named_parameters['bias'])
|
||||
mapping['bias'] = physical_bias_operand
|
||||
return mapping
|
|
@ -0,0 +1,291 @@
|
|||
import operator
|
||||
from functools import reduce
|
||||
from ..sharding_strategy import ShardingStrategy_V2, TrainCycleItem, MemoryCost
|
||||
from colossalai.tensor.shape_consistency import CollectiveCommPattern
|
||||
from .strategy_generator import StrategyGenerator_V2
|
||||
from typing import List
|
||||
from .._utils import exception_handler
|
||||
import copy
|
||||
|
||||
__all__ = ['BatchNormStrategyGenerator']
|
||||
|
||||
|
||||
class BatchNormStrategyGenerator(StrategyGenerator_V2):
|
||||
"""
|
||||
A StrategyGenerator which deals with the sharding strategies of batch normalization.
|
||||
|
||||
To keep the math consistency, there are two way to do BatchNorm if the input
|
||||
shards on batch dimension:
|
||||
1. We gather the input partitions through batch dimension, then do the normal BatchNorm.
|
||||
2. We do the SyncBatchNorm on the each input partition seperately, the SyncBN op will help
|
||||
us to keep the computing correctness.
|
||||
In this generator, both methods will be considered.
|
||||
"""
|
||||
|
||||
@property
|
||||
def has_bias(self):
|
||||
return 'bias' in self.op_data
|
||||
|
||||
def validate(self) -> bool:
|
||||
'''
|
||||
In sanity check, we need make sure the input data having correct dimension size.
|
||||
For BatchNorm1d, the dim of input data should be 3([N, C, L]).
|
||||
For BatchNorm2d, the dim of input data should be 4([N, C, H, W]).
|
||||
For BatchNorm3d, the dim of input data should be 5([N, C, H, W, D]).
|
||||
'''
|
||||
input_op_data = self.op_data['input']
|
||||
assert input_op_data.dim() in (3, 4,
|
||||
5), f'We suppose the dim of input fed into conv op should in range of [3, 5].'
|
||||
|
||||
def update_compute_cost(self, strategy: ShardingStrategy_V2) -> TrainCycleItem:
|
||||
'''
|
||||
Compute the computation cost per device with this specific strategy.
|
||||
|
||||
Note: compute_cost need to be devided by TFLOPS, now it just shows the computation size.
|
||||
'''
|
||||
# TODO: a constant coefficient need to be added.
|
||||
# 1D: (L) * N * Cin
|
||||
# 2D: (H * W) * N * Cin
|
||||
# 3D: (H * W * D) * N * Cin
|
||||
sharded_input_shape = strategy.sharding_specs[self.op_data['input']].get_sharded_shape_per_device()
|
||||
sharded_output_shape = strategy.sharding_specs[self.op_data['output']].get_sharded_shape_per_device()
|
||||
if self.has_bias:
|
||||
# bias add is an element wise operation, so the cost is equal to product of output shape.
|
||||
bias_compute_cost = reduce(operator.mul, sharded_output_shape)
|
||||
input_product = reduce(operator.mul, sharded_input_shape, 1)
|
||||
forward_compute_cost = input_product
|
||||
backward_activation_compute_cost = input_product
|
||||
backward_weight_compute_cost = input_product
|
||||
backward_compute_cost = backward_weight_compute_cost + backward_activation_compute_cost
|
||||
if self.has_bias:
|
||||
forward_compute_cost += bias_compute_cost
|
||||
backward_compute_cost += bias_compute_cost
|
||||
total_compute_cost = forward_compute_cost + backward_compute_cost
|
||||
compute_cost = TrainCycleItem(fwd=forward_compute_cost, bwd=backward_compute_cost, total=total_compute_cost)
|
||||
return compute_cost
|
||||
|
||||
def update_memory_cost(self, strategy: ShardingStrategy_V2) -> TrainCycleItem:
|
||||
forward_size_mapping = {
|
||||
'input': self._compute_size_in_bytes(strategy, "input"),
|
||||
'other': self._compute_size_in_bytes(strategy, "other"),
|
||||
'output': self._compute_size_in_bytes(strategy, "output")
|
||||
}
|
||||
|
||||
if self.has_bias:
|
||||
bias_size = self._compute_size_in_bytes(strategy, "bias")
|
||||
forward_size_mapping['bias'] = bias_size
|
||||
|
||||
backward_size_mapping = copy.deepcopy(forward_size_mapping)
|
||||
backward_size_mapping.pop("output")
|
||||
# compute fwd cost incurred
|
||||
# fwd_cost = input + other + bias + output
|
||||
fwd_activation_cost = sum([v for k, v in forward_size_mapping.items() if not self.is_param(k)])
|
||||
fwd_parameter_cost = sum([v for k, v in forward_size_mapping.items() if self.is_param(k)])
|
||||
fwd_mem_cost = MemoryCost(activation=fwd_activation_cost, parameter=fwd_parameter_cost)
|
||||
|
||||
# compute bwd cost incurred
|
||||
# bwd_cost = input_grad + other_grad + bias_grad
|
||||
bwd_activation_cost = sum([v for k, v in backward_size_mapping.items() if not self.is_param(k)])
|
||||
bwd_activation_cost = sum([v for k, v in backward_size_mapping.items() if self.is_param(k)])
|
||||
bwd_mem_cost = MemoryCost(activation=bwd_activation_cost, parameter=bwd_activation_cost)
|
||||
|
||||
# compute total cost
|
||||
total_mem_cost = MemoryCost(activation=fwd_activation_cost + bwd_activation_cost,
|
||||
parameter=fwd_parameter_cost + bwd_activation_cost)
|
||||
memory_cost = TrainCycleItem(fwd=fwd_mem_cost, bwd=bwd_mem_cost, total=total_mem_cost)
|
||||
strategy.memory_cost = memory_cost
|
||||
|
||||
def split_input_channel(self, mesh_dim_0):
|
||||
strategy_list = []
|
||||
name = f'RS{mesh_dim_0} = RS{mesh_dim_0} x S{mesh_dim_0}'
|
||||
dim_partition_dict_mapping = {
|
||||
"input": {
|
||||
1: [mesh_dim_0]
|
||||
},
|
||||
"other": {
|
||||
0: [mesh_dim_0]
|
||||
},
|
||||
"output": {
|
||||
1: [mesh_dim_0]
|
||||
},
|
||||
}
|
||||
if self.has_bias:
|
||||
dim_partition_dict_mapping["bias"] = {0: [mesh_dim_0]}
|
||||
|
||||
sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping)
|
||||
|
||||
communication_action_mapping = {}
|
||||
|
||||
return self.get_sharding_strategy(name=name,
|
||||
sharding_spec_mapping=sharding_spec_mapping,
|
||||
communication_action_mapping=communication_action_mapping)
|
||||
|
||||
def split_input_channel_1d(self, mesh_dim_0, mesh_dim_1):
|
||||
name = f'RS{mesh_dim_0}{mesh_dim_1} = RS{mesh_dim_0}{mesh_dim_1} x S{mesh_dim_0}{mesh_dim_1}'
|
||||
dim_partition_dict_mapping = {
|
||||
"input": {
|
||||
1: [mesh_dim_0, mesh_dim_1]
|
||||
},
|
||||
"other": {
|
||||
0: [mesh_dim_0, mesh_dim_1]
|
||||
},
|
||||
"output": {
|
||||
1: [mesh_dim_0, mesh_dim_1]
|
||||
},
|
||||
}
|
||||
if self.has_bias:
|
||||
dim_partition_dict_mapping["bias"] = {0: [mesh_dim_0, mesh_dim_1]}
|
||||
|
||||
sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping)
|
||||
|
||||
communication_action_mapping = {}
|
||||
|
||||
return self.get_sharding_strategy(name=name,
|
||||
sharding_spec_mapping=sharding_spec_mapping,
|
||||
communication_action_mapping=communication_action_mapping)
|
||||
|
||||
def non_split(self):
|
||||
name = f'RR = RR x R'
|
||||
dim_partition_dict_mapping = {
|
||||
"input": {},
|
||||
"other": {},
|
||||
"output": {},
|
||||
}
|
||||
if self.has_bias:
|
||||
dim_partition_dict_mapping["bias"] = {}
|
||||
|
||||
sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping)
|
||||
|
||||
communication_action_mapping = {}
|
||||
|
||||
return self.get_sharding_strategy(name=name,
|
||||
sharding_spec_mapping=sharding_spec_mapping,
|
||||
communication_action_mapping=communication_action_mapping)
|
||||
|
||||
def split_input_batch(self, mesh_dim_0):
|
||||
name = f'S{mesh_dim_0}R = S{mesh_dim_0}R x R WITH SYNC_BN'
|
||||
dim_partition_dict_mapping = {
|
||||
"input": {
|
||||
0: [mesh_dim_0]
|
||||
},
|
||||
"other": {},
|
||||
"output": {
|
||||
0: [mesh_dim_0]
|
||||
},
|
||||
}
|
||||
if self.has_bias:
|
||||
dim_partition_dict_mapping["bias"] = {}
|
||||
|
||||
sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping)
|
||||
|
||||
# set communication action
|
||||
# For SyncBN case, we don't need to do communication for weight and bias.
|
||||
# TODO: the communication happens interally at SyncBN operation. We need to replace the BN operation
|
||||
# to SyncBN operation instead of inserting a communication node.
|
||||
output_comm_spec = self.get_communication_spec(
|
||||
sharding_spec=sharding_spec_mapping["output"],
|
||||
communication_pattern=CollectiveCommPattern.ALLREDUCE_FWD_IDENTITY_BWD,
|
||||
logical_process_axis=mesh_dim_0)
|
||||
|
||||
communication_action_mapping = {"output": output_comm_spec}
|
||||
|
||||
return self.get_sharding_strategy(name=name,
|
||||
sharding_spec_mapping=sharding_spec_mapping,
|
||||
communication_action_mapping=communication_action_mapping)
|
||||
|
||||
def split_input_batch_1d(self, mesh_dim_0, mesh_dim_1):
|
||||
name = f'S{mesh_dim_0}{mesh_dim_1}R = S{mesh_dim_0}{mesh_dim_1}R x R WITH SYNC_BN'
|
||||
dim_partition_dict_mapping = {
|
||||
"input": {
|
||||
0: [mesh_dim_0, mesh_dim_1]
|
||||
},
|
||||
"other": {},
|
||||
"output": {
|
||||
0: [mesh_dim_0, mesh_dim_1]
|
||||
},
|
||||
}
|
||||
if self.has_bias:
|
||||
dim_partition_dict_mapping["bias"] = {}
|
||||
|
||||
sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping)
|
||||
|
||||
# set communication action
|
||||
# For SyncBN case, we don't need to do communication for gradients of weight and bias.
|
||||
# TODO: the communication happens interally at SyncBN operation. We need to replace the BN operation
|
||||
# to SyncBN operation instead of inserting a communication node.
|
||||
output_comm_spec = self.get_communication_spec(
|
||||
sharding_spec=sharding_spec_mapping["output"],
|
||||
communication_pattern=CollectiveCommPattern.ALLREDUCE_FWD_IDENTITY_BWD,
|
||||
logical_process_axis=[mesh_dim_0, mesh_dim_1])
|
||||
|
||||
communication_action_mapping = {"output": output_comm_spec}
|
||||
|
||||
return self.get_sharding_strategy(name=name,
|
||||
sharding_spec_mapping=sharding_spec_mapping,
|
||||
communication_action_mapping=communication_action_mapping)
|
||||
|
||||
def split_input_both_dim(self, mesh_dim_0, mesh_dim_1):
|
||||
name = f'S{mesh_dim_0}S{mesh_dim_1} = S{mesh_dim_0}S{mesh_dim_1} x S{mesh_dim_1} WITH SYNC_BN'
|
||||
dim_partition_dict_mapping = {
|
||||
"input": {
|
||||
0: [mesh_dim_0],
|
||||
1: [mesh_dim_1],
|
||||
},
|
||||
"other": {
|
||||
0: [mesh_dim_1],
|
||||
},
|
||||
"output": {
|
||||
0: [mesh_dim_0],
|
||||
1: [mesh_dim_1],
|
||||
},
|
||||
}
|
||||
if self.has_bias:
|
||||
dim_partition_dict_mapping["bias"] = {
|
||||
0: [mesh_dim_1],
|
||||
}
|
||||
|
||||
sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping)
|
||||
|
||||
# set communication action
|
||||
# For SyncBN case, we don't need to do communication for gradients of weight and bias.
|
||||
# TODO: the communication happens interally at SyncBN operation. We need to replace the BN operation
|
||||
# to SyncBN operation instead of inserting a communication node.
|
||||
output_comm_spec = self.get_communication_spec(
|
||||
sharding_spec=sharding_spec_mapping["output"],
|
||||
communication_pattern=CollectiveCommPattern.ALLREDUCE_FWD_IDENTITY_BWD,
|
||||
logical_process_axis=[mesh_dim_0])
|
||||
|
||||
communication_action_mapping = {"output": output_comm_spec}
|
||||
|
||||
return self.get_sharding_strategy(name=name,
|
||||
sharding_spec_mapping=sharding_spec_mapping,
|
||||
communication_action_mapping=communication_action_mapping)
|
||||
|
||||
def generate(self):
|
||||
'''
|
||||
Generate every possible strategies for a BatchNorm node, and record all strategies into the strategies_vector.
|
||||
'''
|
||||
|
||||
strategy_list = []
|
||||
# RS = RS x S
|
||||
strategy_list.append(self.split_input_channel(0))
|
||||
strategy_list.append(self.split_input_channel(1))
|
||||
|
||||
# RR = RR x R
|
||||
strategy_list.append(self.non_split())
|
||||
|
||||
# RS01 = RS01 x S01
|
||||
strategy_list.append(self.split_input_channel_1d(0, 1))
|
||||
|
||||
# SR = SR x R WITH SYNC_BN
|
||||
strategy_list.append(self.split_input_batch(0))
|
||||
strategy_list.append(self.split_input_batch(1))
|
||||
|
||||
# SS = SS x S WITH SYNC_BN
|
||||
strategy_list.append(self.split_input_both_dim(0, 1))
|
||||
strategy_list.append(self.split_input_both_dim(1, 0))
|
||||
|
||||
# S01R = S01R x R WITH SYNC_BN
|
||||
strategy_list.append(self.split_input_batch_1d(0, 1))
|
||||
|
||||
return strategy_list
|
|
@ -0,0 +1,88 @@
|
|||
from colossalai.fx.tracer.meta_patch.patched_module import linear
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from colossalai.fx import ColoTracer, ColoGraphModule
|
||||
from colossalai.auto_parallel.solver.op_handler.batch_norm_handler_v2 import BatchNormModuleHandler
|
||||
from colossalai.auto_parallel.solver.sharding_strategy import OperationData, OperationDataType, StrategiesVector
|
||||
from colossalai.device.device_mesh import DeviceMesh
|
||||
|
||||
|
||||
def test_bn_module_handler():
|
||||
model = nn.Sequential(nn.BatchNorm2d(16).to('meta'))
|
||||
tracer = ColoTracer()
|
||||
# graph():
|
||||
# %input_1 : torch.Tensor [#users=1] = placeholder[target=input]
|
||||
# %_0 : [#users=1] = call_module[target=0](args = (%input_1,), kwargs = {})
|
||||
# return _0
|
||||
graph = tracer.trace(model, meta_args={"input": torch.rand(4, 16, 64, 64).to('meta')})
|
||||
gm = ColoGraphModule(model, graph)
|
||||
physical_mesh_id = torch.arange(0, 4)
|
||||
|
||||
mesh_shape = (2, 2)
|
||||
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape)
|
||||
bn_mod_node = list(graph.nodes)[1]
|
||||
strategies_vector = StrategiesVector(bn_mod_node)
|
||||
|
||||
# build handler
|
||||
handler = BatchNormModuleHandler(node=bn_mod_node, device_mesh=device_mesh, strategies_vector=strategies_vector)
|
||||
|
||||
# check operation data mapping
|
||||
mapping = handler.get_operation_data_mapping()
|
||||
|
||||
for name, op_data in mapping.items():
|
||||
op_data: OperationData
|
||||
# make sure they have valid values
|
||||
assert op_data.logical_shape is not None
|
||||
assert op_data.data is not None
|
||||
|
||||
assert mapping['input'].name == "input_1"
|
||||
assert mapping['input'].data.is_meta
|
||||
assert mapping['input'].data.shape == torch.Size([4, 16, 64, 64])
|
||||
assert mapping['input'].type == OperationDataType.ARG
|
||||
assert mapping['input'].logical_shape == torch.Size([4, 16, 64, 64])
|
||||
|
||||
assert mapping['other'].name == "weight"
|
||||
assert mapping['other'].data.is_meta
|
||||
assert mapping['other'].data.shape == torch.Size([16])
|
||||
assert mapping['other'].type == OperationDataType.PARAM
|
||||
assert mapping['other'].logical_shape == torch.Size([16])
|
||||
|
||||
assert mapping['bias'].name == "bias"
|
||||
assert mapping['bias'].data.is_meta
|
||||
assert mapping['bias'].data.shape == torch.Size([16])
|
||||
assert mapping['bias'].type == OperationDataType.PARAM
|
||||
assert mapping['bias'].logical_shape == torch.Size([16])
|
||||
|
||||
assert mapping['output'].name == "_0"
|
||||
assert mapping['output'].data.is_meta
|
||||
assert mapping['output'].data.shape == torch.Size([4, 16, 64, 64])
|
||||
assert mapping['output'].type == OperationDataType.OUTPUT
|
||||
|
||||
strategies_vector = handler.register_strategy()
|
||||
#[ 'S01R = S01R x R WITH SYNC_BN']
|
||||
strategy_name_list = [val.name for val in strategies_vector]
|
||||
|
||||
# RS = RS x S
|
||||
assert 'RS0 = RS0 x S0' in strategy_name_list
|
||||
assert 'RS1 = RS1 x S1' in strategy_name_list
|
||||
|
||||
# RR = RR x R
|
||||
assert 'RR = RR x R' in strategy_name_list
|
||||
|
||||
# RS01 = RS01 x S01
|
||||
assert 'RS01 = RS01 x S01' in strategy_name_list
|
||||
|
||||
# SR = SR x R WITH SYNC_BN
|
||||
assert 'S0R = S0R x R WITH SYNC_BN' in strategy_name_list
|
||||
assert 'S1R = S1R x R WITH SYNC_BN' in strategy_name_list
|
||||
|
||||
# SS = SS x S WITH SYNC_BN
|
||||
assert 'S0S1 = S0S1 x S1 WITH SYNC_BN' in strategy_name_list
|
||||
assert 'S1S0 = S1S0 x S0 WITH SYNC_BN' in strategy_name_list
|
||||
|
||||
# S01R = S01R x R WITH SYNC_BN
|
||||
assert 'S01R = S01R x R WITH SYNC_BN' in strategy_name_list
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_bn_module_handler()
|
Loading…
Reference in New Issue