[autoparallel] add layer norm handler v2 (#1671)

* [autoparallel] add layer norm handler v2

* polish code

* polish code
pull/1673/head
YuliangLiu0306 2022-10-09 14:23:22 +08:00 committed by GitHub
parent 87c5ad352a
commit 52fda88796
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 311 additions and 3 deletions

View File

@ -7,8 +7,11 @@ from .bcast_op_handler import BcastOpHandler
from .embedding_handler import EmbeddingHandler
from .unary_elementwise_handler import UnaryElementwiseHandler
from .dot_handler_v2 import LinearFunctionHandler, LinearModuleHandler
from .layer_norm_handler_v2 import LayerNormModuleHandler
from .batch_norm_handler_v2 import BatchNormModuleHandler
__all__ = [
'OperatorHandler', 'DotHandler', 'ConvHandler', 'BatchNormHandler', 'ReshapeHandler', 'BcastOpHandler',
'UnaryElementwiseHandler', 'EmbeddingHandler', 'LinearFunctionHandler', 'LinearModuleHandler'
'UnaryElementwiseHandler', 'EmbeddingHandler', 'LinearFunctionHandler', 'LinearModuleHandler',
'LayerNormModuleHandler', 'BatchNormModuleHandler'
]

View File

@ -0,0 +1,42 @@
import torch
from .node_handler import ModuleHandler
from ..sharding_strategy import ShardingStrategy_V2, OperationDataType, OperationData
from ..strategy import LayerNormGenerator, StrategyGenerator_V2
from typing import List, Dict
from .registry import operator_registry
__all__ = ['LayerNormModuleHandler']
@operator_registry.register(torch.nn.LayerNorm)
class LayerNormModuleHandler(ModuleHandler):
"""
A LayerNormModuleHandler which deals with the sharding strategies for nn.LayerNorm module.
"""
def get_strategy_generator(self) -> List[StrategyGenerator_V2]:
op_data_mapping = self.get_operation_data_mapping()
generators = []
generators.append(LayerNormGenerator(op_data_mapping, self.device_mesh))
return generators
def get_operation_data_mapping(self) -> Dict[str, OperationData]:
# use transposed shape for strategies
# the strategies will be transformed back to its original shape in self.post_process
physical_input_operand = OperationData(name=str(self.node.args[0]),
type=OperationDataType.ARG,
data=self.node.args[0]._meta_data)
physical_other_operand = OperationData(name="weight",
type=OperationDataType.PARAM,
data=self.named_parameters['weight'],
logical_shape=self.named_parameters['weight'].shape)
physical_output = OperationData(name=str(self.node), type=OperationDataType.OUTPUT, data=self.node._meta_data)
mapping = {"input": physical_input_operand, "other": physical_other_operand, "output": physical_output}
if self.named_parameters['bias'] is not None:
physical_bias_operand = OperationData(name="bias",
type=OperationDataType.PARAM,
data=self.named_parameters['bias'])
mapping['bias'] = physical_bias_operand
return mapping

View File

@ -2,9 +2,10 @@ from .strategy_generator import StrategyGenerator_V2
from .matmul_strategy_generator import DotProductStrategyGenerator, MatVecStrategyGenerator, LinearProjectionStrategyGenerator, BatchedMatMulStrategyGenerator
from .conv_strategy_generator import ConvStrategyGenerator
from .batch_norm_generator import BatchNormStrategyGenerator
from .layer_norm_generator import LayerNormGenerator
__all__ = [
'StrategyGenerator_V2', 'DotProductStrategyGenerator', 'MatVecStrategyGenerator',
'LinearProjectionStrategyGenerator', 'BatchedMatMulStrategyGenerator', 'ConvStrategyGenerator',
'BatchNormStrategyGenerator'
'BatchNormStrategyGenerator', 'LayerNormGenerator'
]

View File

@ -0,0 +1,187 @@
import operator
from functools import reduce
from ..sharding_strategy import ShardingStrategy_V2, TrainCycleItem, MemoryCost
from colossalai.tensor.shape_consistency import CollectiveCommPattern
from .strategy_generator import StrategyGenerator_V2
from typing import List
from .._utils import exception_handler, enumerate_all_possible_1d_sharding, enumerate_all_possible_2d_sharding
import copy
__all__ = ['LayerNormGenerator']
class LayerNormGenerator(StrategyGenerator_V2):
"""
LayerNormGenerator is a generic class to generate strategies for LayerNorm operation.
The operation data is defined as `output = input x other + bias`.
"""
@property
def has_bias(self):
return 'bias' in self.op_data
def validate(self) -> bool:
return super().validate()
def update_compute_cost(self, strategy: ShardingStrategy_V2) -> TrainCycleItem:
'''
Compute the computation cost per device with this specific strategy.
Note: compute_cost need to be devided by TFLOPS, now it just shows the computation size.
'''
# TODO: compute_cost need to be devided by TFLOPS, now it just shows the computation size.
# TODO: a constant coefficient need to be added.
sharded_input_shape = strategy.sharding_specs[self.op_data['input']].get_sharded_shape_per_device()
sharded_weight_shape = strategy.sharding_specs[self.op_data['other']].get_sharded_shape_per_device()
if self.has_bias:
# bias add is an element wise operation, so the cost is equal to product of output shape.
bias_compute_cost = reduce(operator.mul, sharded_weight_shape)
# in LayerNorm context, batch dimensions mean all the dimensions do not join the normalization.
input_batch_shape = sharded_input_shape[:-len(sharded_weight_shape)]
input_batch_product = reduce(operator.mul, input_batch_shape, 1)
norm_kernel_product = reduce(operator.mul, sharded_weight_shape, 1)
forward_compute_cost = input_batch_product * norm_kernel_product
backward_activation_compute_cost = input_batch_product * norm_kernel_product
# To compute gradient of on norm kernel element requires input_batch_product times computation, so
# the total cost is input_batch_product * norm_kernel_product
backward_weight_compute_cost = input_batch_product * norm_kernel_product
backward_compute_cost = backward_activation_compute_cost + backward_weight_compute_cost
if self.has_bias:
forward_compute_cost += bias_compute_cost
backward_compute_cost += bias_compute_cost
total_compute_cost = forward_compute_cost + backward_compute_cost
compute_cost = TrainCycleItem(fwd=forward_compute_cost, bwd=backward_compute_cost, total=total_compute_cost)
return compute_cost
def update_memory_cost(self, strategy: ShardingStrategy_V2) -> TrainCycleItem:
'''
Compute the memory cost per device with this specific strategy.
'''
forward_size_mapping = {
'input': self._compute_size_in_bytes(strategy, "input"),
'other': self._compute_size_in_bytes(strategy, "other"),
'output': self._compute_size_in_bytes(strategy, "output")
}
if self.has_bias:
bias_size = self._compute_size_in_bytes(strategy, "bias")
forward_size_mapping['bias'] = bias_size
backward_size_mapping = copy.deepcopy(forward_size_mapping)
backward_size_mapping.pop("output")
# compute fwd cost incurred
# fwd_cost = input + other + bias + output
fwd_activation_cost = sum([v for k, v in forward_size_mapping.items() if not self.is_param(k)])
fwd_parameter_cost = sum([v for k, v in forward_size_mapping.items() if self.is_param(k)])
fwd_mem_cost = MemoryCost(activation=fwd_activation_cost, parameter=fwd_parameter_cost)
# compute bwd cost incurred
# bwd_cost = input_grad + other_grad + bias_grad
bwd_activation_cost = sum([v for k, v in backward_size_mapping.items() if not self.is_param(k)])
bwd_parameter_cost = sum([v for k, v in backward_size_mapping.items() if self.is_param(k)])
bwd_mem_cost = MemoryCost(activation=bwd_activation_cost, parameter=bwd_parameter_cost)
# compute total cost
total_mem_cost = MemoryCost(activation=fwd_activation_cost + bwd_activation_cost,
parameter=fwd_parameter_cost + bwd_parameter_cost)
memory_cost = TrainCycleItem(fwd=fwd_mem_cost, bwd=bwd_mem_cost, total=total_mem_cost)
strategy.memory_cost = memory_cost
def _generate_strategy_with_dim_partition(self, dim_partition):
dim_partition_dict_mapping = {
"input": dim_partition,
"other": {},
"output": dim_partition,
}
if self.has_bias:
dim_partition_dict_mapping["bias"] = {}
sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping)
name = f'{sharding_spec_mapping["output"].sharding_sequence} = {sharding_spec_mapping["input"].sharding_sequence} x {sharding_spec_mapping["other"].sharding_sequence}'
total_mesh_dim_list = []
for mesh_dim_list in dim_partition.values():
total_mesh_dim_list.extend(mesh_dim_list)
communication_action_mapping = {}
other_comm_spec = self.get_communication_spec(
sharding_spec=sharding_spec_mapping["other"],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=total_mesh_dim_list)
communication_action_mapping["other"] = other_comm_spec
if self.has_bias:
bias_comm_spec = self.get_communication_spec(
sharding_spec=sharding_spec_mapping["bias"],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=total_mesh_dim_list)
communication_action_mapping["bias"] = bias_comm_spec
strategy = self.get_sharding_strategy(name=name,
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping)
return strategy
def split_input_batch_single_mesh_dim(self, mesh_dim_0, batch_dimension_length):
strategy_list = []
dim_partition_list = enumerate_all_possible_1d_sharding(mesh_dim_0, batch_dimension_length)
for dim_partition in dim_partition_list:
strategy = self._generate_strategy_with_dim_partition(dim_partition)
strategy_list.append(strategy)
return strategy_list
def split_input_batch_both_mesh_dim(self, mesh_dim_0, mesh_dim_1, batch_dimension_length):
strategy_list = []
dim_partition_list = enumerate_all_possible_2d_sharding(mesh_dim_0, mesh_dim_1, batch_dimension_length)
for dim_partition in dim_partition_list:
strategy = self._generate_strategy_with_dim_partition(dim_partition)
strategy_list.append(strategy)
return strategy_list
def non_split(self):
name = f'RR = RR x R'
dim_partition_dict_mapping = {
"input": {},
"other": {},
"output": {},
}
if self.has_bias:
dim_partition_dict_mapping["bias"] = {}
sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping)
communication_action_mapping = {}
return self.get_sharding_strategy(name=name,
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping)
def generate(self):
'''
Generate every possible strategies for a BatchNorm node, and record all strategies into the strategies_vector.
'''
strategy_list = []
input_data_dim = len(self.op_data["input"].logical_shape)
weight_data_dim = len(self.op_data["other"].logical_shape)
# in LayerNorm context, batch dimensions mean all the dimensions do not join the normalization.
batch_dimension_length = input_data_dim - weight_data_dim
# SR = SR x R with single mesh dim on batch dimensions
strategy_list.extend(self.split_input_batch_single_mesh_dim(0, batch_dimension_length))
strategy_list.extend(self.split_input_batch_single_mesh_dim(1, batch_dimension_length))
# SR = SR x R with both mesh dims on batch dimensions
strategy_list.extend(self.split_input_batch_both_mesh_dim(0, 1, batch_dimension_length))
# RR = RR x R
strategy_list.append(self.non_split())
# update mete info on cost
for strategy in strategy_list:
self.update_communication_cost(strategy)
self.update_compute_cost(strategy)
self.update_memory_cost(strategy)
return strategy_list

View File

@ -59,7 +59,6 @@ def test_bn_module_handler():
assert mapping['output'].type == OperationDataType.OUTPUT
strategies_vector = handler.register_strategy()
#[ 'S01R = S01R x R WITH SYNC_BN']
strategy_name_list = [val.name for val in strategies_vector]
# RS = RS x S

View File

@ -0,0 +1,76 @@
from colossalai.fx.tracer.meta_patch.patched_module import linear
import torch
import torch.nn as nn
from colossalai.fx import ColoTracer, ColoGraphModule
from colossalai.auto_parallel.solver.op_handler.layer_norm_handler_v2 import LayerNormModuleHandler
from colossalai.auto_parallel.solver.sharding_strategy import OperationData, OperationDataType, StrategiesVector
from colossalai.device.device_mesh import DeviceMesh
def test_ln_module_handler():
model = nn.Sequential(nn.LayerNorm(16).to('meta'))
tracer = ColoTracer()
# graph():
# %input_1 : torch.Tensor [#users=1] = placeholder[target=input]
# %_0 : [#users=1] = call_module[target=0](args = (%input_1,), kwargs = {})
# return _0
graph = tracer.trace(model, meta_args={"input": torch.rand(4, 16).to('meta')})
gm = ColoGraphModule(model, graph)
physical_mesh_id = torch.arange(0, 4)
mesh_shape = (2, 2)
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape)
ln_mod_node = list(graph.nodes)[1]
strategies_vector = StrategiesVector(ln_mod_node)
# build handler
handler = LayerNormModuleHandler(node=ln_mod_node, device_mesh=device_mesh, strategies_vector=strategies_vector)
# check operation data mapping
mapping = handler.get_operation_data_mapping()
for name, op_data in mapping.items():
op_data: OperationData
# make sure they have valid values
assert op_data.logical_shape is not None
assert op_data.data is not None
assert mapping['input'].name == "input_1"
assert mapping['input'].data.is_meta
assert mapping['input'].data.shape == torch.Size([4, 16])
assert mapping['input'].type == OperationDataType.ARG
assert mapping['input'].logical_shape == torch.Size([4, 16])
assert mapping['other'].name == "weight"
assert mapping['other'].data.is_meta
assert mapping['other'].data.shape == torch.Size([16])
assert mapping['other'].type == OperationDataType.PARAM
assert mapping['other'].logical_shape == torch.Size([16])
assert mapping['bias'].name == "bias"
assert mapping['bias'].data.is_meta
assert mapping['bias'].data.shape == torch.Size([16])
assert mapping['bias'].type == OperationDataType.PARAM
assert mapping['bias'].logical_shape == torch.Size([16])
assert mapping['output'].name == "_0"
assert mapping['output'].data.is_meta
assert mapping['output'].data.shape == torch.Size([4, 16])
assert mapping['output'].type == OperationDataType.OUTPUT
strategies_vector = handler.register_strategy()
strategy_name_list = [val.name for val in strategies_vector]
# SR = SR x R
assert '[S0, R] = [S0, R] x [R]' in strategy_name_list
assert '[S1, R] = [S1, R] x [R]' in strategy_name_list
# RR = RR x R
assert 'RR = RR x R' in strategy_name_list
# S01R = S01R x R
assert '[S01, R] = [S01, R] x [R]' in strategy_name_list
if __name__ == '__main__':
test_ln_module_handler()