mirror of https://github.com/hpcaitech/ColossalAI
[autoparallel] add layernorm handler (#1629)
parent
bf77d3ab65
commit
0c703189b9
|
@ -94,7 +94,43 @@ def exception_handler(func):
|
|||
def wrapper(*args, **kwargs):
|
||||
try:
|
||||
func(*args, **kwargs)
|
||||
except Exception as e:
|
||||
except AssertionError as e:
|
||||
warnings.warn(f'{e}')
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
def enumerate_all_possible_2d_sharding(mesh_dim_0, mesh_dim_1, dim_size):
|
||||
dim_partition_list = []
|
||||
# enumerate all the 2D sharding cases
|
||||
for i in range(dim_size):
|
||||
for j in range(i + 1, dim_size):
|
||||
dim_partition_dict_0 = {i: [mesh_dim_0], j: [mesh_dim_1]}
|
||||
dim_partition_dict_1 = {i: [mesh_dim_1], j: [mesh_dim_0]}
|
||||
dim_partition_list.append(dim_partition_dict_0)
|
||||
dim_partition_list.append(dim_partition_dict_1)
|
||||
for i in range(dim_size):
|
||||
dim_partition_dict_flatten = {i: [mesh_dim_0, mesh_dim_1]}
|
||||
dim_partition_list.append(dim_partition_dict_flatten)
|
||||
|
||||
return dim_partition_list
|
||||
|
||||
|
||||
def enumerate_all_possible_1d_sharding(mesh_dim_0, dim_size):
|
||||
dim_partition_list = []
|
||||
# enumerate all the 1D sharding cases
|
||||
for i in range(dim_size):
|
||||
dim_partition_dict_0 = {i: [mesh_dim_0]}
|
||||
dim_partition_list.append(dim_partition_dict_0)
|
||||
|
||||
return dim_partition_list
|
||||
|
||||
|
||||
def generate_sharding_size(dim_partition_dict, device_mesh):
|
||||
total_sharding_size = 1
|
||||
for mesh_dim_list in dim_partition_dict.values():
|
||||
mesh_dim_sharding_size = [device_mesh.shape[mesh_dim] for mesh_dim in mesh_dim_list]
|
||||
sharding_size = reduce(operator.mul, mesh_dim_sharding_size)
|
||||
total_sharding_size *= sharding_size
|
||||
|
||||
return total_sharding_size
|
||||
|
|
|
@ -3,7 +3,8 @@ import operator
|
|||
|
||||
__all__ = [
|
||||
'ELEMENTWISE_MODULE_OP', 'ELEMENTWISE_FUNC_OP', 'RESHAPE_FUNC_OP', 'CONV_MODULE_OP', 'CONV_FUNC_OP',
|
||||
'LINEAR_MODULE_OP', 'LINEAR_FUNC_OP', 'BATCHNORM_MODULE_OP', 'POOL_MODULE_OP', 'NON_PARAM_FUNC_OP', 'BCAST_FUNC_OP'
|
||||
'LINEAR_MODULE_OP', 'LINEAR_FUNC_OP', 'BATCHNORM_MODULE_OP', 'POOL_MODULE_OP', 'NON_PARAM_FUNC_OP', 'BCAST_FUNC_OP',
|
||||
'EMBEDDING_MODULE_OP', 'LAYERNORM_MODULE_OP', 'ELEMENTWISE_METHOD_OP', 'RESHAPE_METHOD_OP'
|
||||
]
|
||||
|
||||
ELEMENTWISE_MODULE_OP = [torch.nn.Dropout, torch.nn.ReLU]
|
||||
|
@ -11,7 +12,18 @@ ELEMENTWISE_FUNC_OP = [
|
|||
torch.abs, torch.cos, torch.exp, operator.neg, torch.multiply, torch.nn.functional.relu,
|
||||
torch.nn.functional.dropout, torch.flatten
|
||||
]
|
||||
RESHAPE_FUNC_OP = [torch.flatten, torch.Tensor.view, torch.reshape]
|
||||
ELEMENTWISE_METHOD_OP = [
|
||||
torch.Tensor.to,
|
||||
torch.Tensor.type,
|
||||
]
|
||||
RESHAPE_FUNC_OP = [torch.flatten, torch.reshape]
|
||||
RESHAPE_METHOD_OP = [
|
||||
torch.Tensor.view,
|
||||
torch.Tensor.unsqueeze,
|
||||
torch.Tensor.split,
|
||||
torch.Tensor.permute,
|
||||
torch.Tensor.transpose,
|
||||
]
|
||||
BCAST_FUNC_OP = [
|
||||
torch.add, torch.sub, torch.mul, torch.div, torch.floor_divide, torch.true_divide, operator.add, operator.sub,
|
||||
operator.mul, operator.floordiv, operator.truediv, torch.matmul
|
||||
|
@ -23,9 +35,11 @@ CONV_MODULE_OP = [
|
|||
CONV_FUNC_OP = [
|
||||
torch.conv1d, torch.conv2d, torch.conv3d, torch.conv_transpose1d, torch.conv_transpose2d, torch.conv_transpose3d
|
||||
]
|
||||
EMBEDDING_MODULE_OP = [torch.nn.modules.sparse.Embedding]
|
||||
LINEAR_MODULE_OP = [torch.nn.Linear]
|
||||
LINEAR_FUNC_OP = [torch.nn.functional.linear, torch.matmul, torch.bmm]
|
||||
BATCHNORM_MODULE_OP = [torch.nn.BatchNorm1d, torch.nn.BatchNorm2d, torch.nn.BatchNorm3d, torch.nn.SyncBatchNorm]
|
||||
LAYERNORM_MODULE_OP = [torch.nn.LayerNorm]
|
||||
POOL_MODULE_OP = [torch.nn.MaxPool1d, torch.nn.MaxPool2d, torch.nn.MaxPool3d, torch.nn.AdaptiveAvgPool2d]
|
||||
NON_PARAM_FUNC_OP = RESHAPE_FUNC_OP + ELEMENTWISE_FUNC_OP
|
||||
|
||||
|
|
|
@ -8,7 +8,7 @@ from colossalai.tensor.shape_consistency import ShapeConsistencyManager
|
|||
from colossalai.tensor.sharding_spec import ShardingSpec
|
||||
from copy import deepcopy
|
||||
from typing import Dict, List
|
||||
from colossalai.auto_parallel.solver._utils import exception_handler
|
||||
from colossalai.auto_parallel.solver._utils import exception_handler, enumerate_all_possible_1d_sharding, enumerate_all_possible_2d_sharding
|
||||
|
||||
__all__ = ['BcastOpHandler']
|
||||
|
||||
|
@ -110,45 +110,19 @@ class BcastOpHandler(OperatorHandler):
|
|||
|
||||
return sharding_spec_list
|
||||
|
||||
def _enumerate_all_possible_2d_sharding(self, mesh_dim_0, mesh_dim_1, dim_size):
|
||||
dim_partition_list = []
|
||||
# enumerate all the 2D sharding cases
|
||||
for i in range(dim_size):
|
||||
for j in range(i + 1, dim_size):
|
||||
dim_partition_dict_0 = {i: [mesh_dim_0], j: [mesh_dim_1]}
|
||||
dim_partition_dict_1 = {i: [mesh_dim_1], j: [mesh_dim_0]}
|
||||
dim_partition_list.append(dim_partition_dict_0)
|
||||
dim_partition_list.append(dim_partition_dict_1)
|
||||
for i in range(dim_size):
|
||||
dim_partition_dict_flatten = {i: [mesh_dim_0, mesh_dim_1]}
|
||||
dim_partition_list.append(dim_partition_dict_flatten)
|
||||
|
||||
# sharding_spec_list = self._convert_partition_dict_to_sharding_spec(dim_partition_list)
|
||||
return dim_partition_list
|
||||
|
||||
def _enumerate_all_possible_1d_sharding(self, mesh_dim_0, dim_size):
|
||||
dim_partition_list = []
|
||||
# enumerate all the 1D sharding cases
|
||||
for i in range(dim_size):
|
||||
dim_partition_dict_0 = {i: [mesh_dim_0]}
|
||||
dim_partition_list.append(dim_partition_dict_0)
|
||||
|
||||
# sharding_spec_list = self._convert_partition_dict_to_sharding_spec(dim_partition_list)
|
||||
return dim_partition_list
|
||||
|
||||
def _enumerate_all_possible_output(self, mesh_dim_0, mesh_dim_1):
|
||||
# use mesh_dim_0, mesh_dim_1 instead of constant 0, 1 in here for N-D device mesh scaliablity.
|
||||
|
||||
output_dim_partition_list = []
|
||||
dim_size = self.output_data.dim()
|
||||
# enumerate all the 2D sharding cases
|
||||
sharding_list_2d = self._enumerate_all_possible_2d_sharding(mesh_dim_0, mesh_dim_1, dim_size)
|
||||
sharding_list_2d = enumerate_all_possible_2d_sharding(mesh_dim_0, mesh_dim_1, dim_size)
|
||||
output_dim_partition_list.extend(sharding_list_2d)
|
||||
|
||||
# enumerate all the 1D sharding cases
|
||||
sharding_list_1d_on_dim_0 = self._enumerate_all_possible_1d_sharding(mesh_dim_0, dim_size)
|
||||
sharding_list_1d_on_dim_0 = enumerate_all_possible_1d_sharding(mesh_dim_0, dim_size)
|
||||
output_dim_partition_list.extend(sharding_list_1d_on_dim_0)
|
||||
sharding_list_1d_on_dim_1 = self._enumerate_all_possible_1d_sharding(mesh_dim_1, dim_size)
|
||||
sharding_list_1d_on_dim_1 = enumerate_all_possible_1d_sharding(mesh_dim_1, dim_size)
|
||||
output_dim_partition_list.extend(sharding_list_1d_on_dim_1)
|
||||
|
||||
# add empty dict for fully replicated case
|
||||
|
@ -545,15 +519,13 @@ class BcastOpHandler(OperatorHandler):
|
|||
dim_size = self.output_data.dim() - 2
|
||||
|
||||
# Both device mesh axises are uesd on batch dimensions
|
||||
dim_partition_dicts_2d = self._enumerate_all_possible_2d_sharding(MESH_DIM_LIST[0], MESH_DIM_LIST[1],
|
||||
dim_size)
|
||||
dim_partition_dicts_2d = enumerate_all_possible_2d_sharding(MESH_DIM_LIST[0], MESH_DIM_LIST[1], dim_size)
|
||||
for dim_partition_dict in dim_partition_dicts_2d:
|
||||
self._registry_no_split_strategies_for_matmul(dim_partition_dict)
|
||||
|
||||
# Only one device mesh axis is uesd on batch dimensions
|
||||
for mesh_dim_index in [0, 1]:
|
||||
dim_partition_dicts_1d = self._enumerate_all_possible_1d_sharding(MESH_DIM_LIST[mesh_dim_index],
|
||||
dim_size)
|
||||
dim_partition_dicts_1d = enumerate_all_possible_1d_sharding(MESH_DIM_LIST[mesh_dim_index], dim_size)
|
||||
for dim_partition_dict in dim_partition_dicts_1d:
|
||||
self._registry_no_split_strategies_for_matmul(dim_partition_dict)
|
||||
self._registry_1d_strategies_for_matmul(dim_partition_dict, [MESH_DIM_LIST[mesh_dim_index - 1]])
|
||||
|
|
|
@ -0,0 +1,233 @@
|
|||
import operator
|
||||
from functools import reduce
|
||||
import torch
|
||||
from colossalai.auto_parallel.solver.sharding_strategy import ShardingStrategy, StrategiesVector
|
||||
from .operator_handler import OperatorHandler
|
||||
from colossalai.auto_parallel.solver._utils import exception_handler, enumerate_all_possible_2d_sharding, enumerate_all_possible_1d_sharding, generate_sharding_size
|
||||
|
||||
__all__ = ['LayerNormHandler']
|
||||
|
||||
|
||||
class LayerNormHandler(OperatorHandler):
|
||||
"""
|
||||
A OperatorHandler which deals with the sharding strategies of normalization.
|
||||
|
||||
Note: To keep the math consistency, LayerNorm do not allow shards on hidden dimension.
|
||||
"""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.input_data = self.predecessor_node[0]._meta_data
|
||||
self.weight = self.module_named_parameters['weight']
|
||||
self.bias = self.module_named_parameters['bias']
|
||||
self.output_data = self.node._meta_data
|
||||
|
||||
def _generate_compute_cost(self, total_sharding_size):
|
||||
'''
|
||||
Compute the computation cost per device with this specific strategy.
|
||||
|
||||
Note: compute_cost need to be devided by TFLOPS, now it just shows the computation size.
|
||||
|
||||
Argument:
|
||||
bs(int): Batch size of the input data.
|
||||
channel_in(int): The channel dimension of input data.
|
||||
|
||||
Return:
|
||||
compute_cost(float): Computation cost per device with this specific strategy
|
||||
'''
|
||||
# TODO: compute_cost need to be devided by TFLOPS, now it just shows the computation size.
|
||||
# TODO: a constant coefficient need to be added.
|
||||
|
||||
norm_kernel_size = self.weight.shape
|
||||
# in LayerNorm context, batch dimensions mean all the dimensions do not join the normalization.
|
||||
input_batch_shape = self.input_data.shape[:-len(norm_kernel_size)]
|
||||
input_batch_product = reduce(operator.mul, input_batch_shape, 1)
|
||||
norm_kernel_product = reduce(operator.mul, norm_kernel_size, 1)
|
||||
forward_compute_cost = input_batch_product * norm_kernel_product / total_sharding_size
|
||||
backward_activation_compute_cost = input_batch_product * norm_kernel_product / total_sharding_size
|
||||
# To compute gradient of on norm kernel element requires input_batch_product times computation, so
|
||||
# the total cost is input_batch_product * norm_kernel_product
|
||||
backward_weight_compute_cost = input_batch_product * norm_kernel_product / total_sharding_size
|
||||
backward_compute_cost = backward_activation_compute_cost + backward_weight_compute_cost
|
||||
compute_cost = forward_compute_cost + backward_compute_cost
|
||||
return compute_cost
|
||||
|
||||
def _generate_memory_cost(self, sharding_size_forward, sharding_size_backward_activation, sharding_size_weight):
|
||||
'''
|
||||
Compute the memory cost per device with this specific strategy.
|
||||
|
||||
Argument:
|
||||
sharding_size_forward(int): The forward activation will be divided
|
||||
into sharding_size_forward number partions.
|
||||
sharding_size_backward_activation(int): The backward activation will
|
||||
be divided into sharding_size_backward_activation number partions.
|
||||
sharding_size_weight(int): The backward weight will be divided
|
||||
into sharding_size_weight number partions.
|
||||
|
||||
Return:
|
||||
memory_cost(Tuple[float]): Memory cost per device with this
|
||||
specific strategy, the first element of this tuple is forward
|
||||
memory cost, and the second element of this tuple is backward
|
||||
memory cost.
|
||||
memory_cost_forward(float): Memory cost of forward activation per
|
||||
device with this specific strategy.
|
||||
memory_cost_backward_activation(float): Memory cost of backward activation
|
||||
per device with this specific strategy.
|
||||
'''
|
||||
# compute the memory cost of this strategy
|
||||
dtype = self.input_data.dtype
|
||||
numel_output = self.output_data.numel()
|
||||
# this operation will not change the shape of input
|
||||
numel_input = numel_output
|
||||
numel_weight = self.weight.numel()
|
||||
size_per_elem_bytes = torch.tensor([], dtype=dtype).element_size()
|
||||
|
||||
# forward memory_cost
|
||||
memory_cost_forward_activation = numel_output * size_per_elem_bytes / sharding_size_forward
|
||||
memory_cost_forward_weight = numel_weight * size_per_elem_bytes / sharding_size_weight
|
||||
memory_cost_forward = memory_cost_forward_activation + memory_cost_forward_weight
|
||||
|
||||
# backward memory_cost
|
||||
memory_cost_backward_activation = numel_input * size_per_elem_bytes / sharding_size_backward_activation
|
||||
memory_cost_backward_weight = numel_weight * size_per_elem_bytes / sharding_size_weight
|
||||
memory_cost_backward = memory_cost_backward_activation + memory_cost_backward_weight
|
||||
|
||||
# memory_cost pair
|
||||
memory_cost = (memory_cost_forward, memory_cost_backward)
|
||||
|
||||
return memory_cost, memory_cost_forward_activation, memory_cost_backward_activation, memory_cost_backward_weight
|
||||
|
||||
def _generate_strategy_with_dim_partition(self, dim_partition):
|
||||
dim_partition_dict_for_input = dim_partition
|
||||
sharding_spec_for_input = self._generate_sharding_spec(self.input_data, dim_partition_dict_for_input)
|
||||
|
||||
dim_partition_dict_for_weight = {}
|
||||
sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight)
|
||||
|
||||
dim_partition_dict_for_output = dim_partition
|
||||
sharding_spec_for_output = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output)
|
||||
|
||||
name = f'{sharding_spec_for_output.sharding_sequence} = {sharding_spec_for_input.sharding_sequence} x {sharding_spec_for_weight.sharding_sequence}'
|
||||
# generate resharding cost for this strategy
|
||||
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input])
|
||||
|
||||
total_sharding_size = generate_sharding_size(dim_partition, self.device_mesh)
|
||||
# compute the computation cost of this strategy
|
||||
compute_cost = self._generate_compute_cost(total_sharding_size)
|
||||
|
||||
# compute the memory cost of this strategy
|
||||
sharding_size_forward = generate_sharding_size(dim_partition_dict_for_input, self.device_mesh)
|
||||
sharding_size_backward_activation = generate_sharding_size(dim_partition_dict_for_output, self.device_mesh)
|
||||
sharding_size_weight = generate_sharding_size(dim_partition_dict_for_weight, self.device_mesh)
|
||||
memory_cost, _, _, memory_cost_backward_weight = self._generate_memory_cost(sharding_size_forward,
|
||||
sharding_size_backward_activation,
|
||||
sharding_size_weight)
|
||||
|
||||
total_mesh_dim_list = []
|
||||
for mesh_dim_list in dim_partition.values():
|
||||
total_mesh_dim_list.extend(mesh_dim_list)
|
||||
|
||||
# This strategy do not need to do all_reduce operation for activation
|
||||
communication_cost_forward_activation = 0
|
||||
communication_cost_backward_activation = 0
|
||||
if len(total_mesh_dim_list) == 1:
|
||||
communication_cost_backward_weight = self.device_mesh.all_reduce_cost(memory_cost_backward_weight,
|
||||
total_mesh_dim_list[0])
|
||||
else:
|
||||
assert len(total_mesh_dim_list) == 2, f'temporally we just support 2d device mesh.'
|
||||
communication_cost_backward_weight = self.device_mesh.flatten_device_mesh.all_reduce_cost(
|
||||
memory_cost_backward_weight, 0)
|
||||
communication_cost = communication_cost_forward_activation + communication_cost_backward_activation + communication_cost_backward_weight
|
||||
|
||||
sharding_strategies = ShardingStrategy(name,
|
||||
output_sharding_spec=sharding_spec_for_output,
|
||||
compute_cost=compute_cost,
|
||||
communication_cost=communication_cost,
|
||||
memory_cost=memory_cost,
|
||||
resharding_costs=resharding_costs,
|
||||
input_shardings=(sharding_spec_for_input, sharding_spec_for_weight))
|
||||
|
||||
self.strategies_vector.append(sharding_strategies)
|
||||
|
||||
@exception_handler
|
||||
def split_input_batch_single_mesh_dim(self, mesh_dim_0):
|
||||
batch_dimension_length = self.input_data.dim() - self.weight.dim()
|
||||
dim_partition_list = enumerate_all_possible_1d_sharding(mesh_dim_0, batch_dimension_length)
|
||||
for dim_partition in dim_partition_list:
|
||||
self._generate_strategy_with_dim_partition(dim_partition)
|
||||
|
||||
@exception_handler
|
||||
def split_input_batch_both_mesh_dim(self, mesh_dim_0, mesh_dim_1):
|
||||
batch_dimension_length = self.input_data.dim() - self.weight.dim()
|
||||
dim_partition_list = enumerate_all_possible_2d_sharding(mesh_dim_0, mesh_dim_1, batch_dimension_length)
|
||||
for dim_partition in dim_partition_list:
|
||||
self._generate_strategy_with_dim_partition(dim_partition)
|
||||
|
||||
@exception_handler
|
||||
def non_split(self):
|
||||
name = f'RR = RR x R'
|
||||
|
||||
dim_partition_dict_for_input = {}
|
||||
sharding_spec_for_input = self._generate_sharding_spec(self.input_data, dim_partition_dict_for_input)
|
||||
|
||||
dim_partition_dict_for_weight = {}
|
||||
sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight)
|
||||
|
||||
dim_partition_dict_for_output = {}
|
||||
sharding_spec_for_output = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output)
|
||||
|
||||
# generate resharding cost for this strategy
|
||||
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input])
|
||||
|
||||
total_sharding_size = 1
|
||||
# compute the computation cost of this strategy
|
||||
compute_cost = self._generate_compute_cost(total_sharding_size)
|
||||
|
||||
# compute the memory cost of this strategy
|
||||
sharding_size_forward = 1
|
||||
sharding_size_backward_activation = 1
|
||||
sharding_size_weight = 1
|
||||
memory_cost, _, _, _ = self._generate_memory_cost(sharding_size_forward, sharding_size_backward_activation,
|
||||
sharding_size_weight)
|
||||
|
||||
# This strategy do not need to do all_reduce operation
|
||||
communication_cost = 0
|
||||
sharding_strategies = ShardingStrategy(name,
|
||||
output_sharding_spec=sharding_spec_for_output,
|
||||
compute_cost=compute_cost,
|
||||
communication_cost=communication_cost,
|
||||
memory_cost=memory_cost,
|
||||
resharding_costs=resharding_costs,
|
||||
input_shardings=(sharding_spec_for_input, sharding_spec_for_weight))
|
||||
|
||||
self.strategies_vector.append(sharding_strategies)
|
||||
|
||||
def register_strategy(self) -> StrategiesVector:
|
||||
'''
|
||||
Generate every possible strategies for a BatchNorm node, and record all strategies into the strategies_vector.
|
||||
|
||||
Example:
|
||||
norm_handler = BatchNormHandler(node, strategies_vector,
|
||||
self.shape_consistency_manager)
|
||||
norm_handler.register_strategy()
|
||||
for strategy in norm_handler.strategies_vector:
|
||||
print(f'{strategy.name}, computation_cost: {strategy.compute_cost}, memory_cost: {strategy.memory_cost}')
|
||||
|
||||
Output:
|
||||
RS0 = RS0 x S0, computation_cost: 131072, memory_cost: 524288.0
|
||||
RS1 = RS1 x S1, computation_cost: 131072, memory_cost: 524288.0
|
||||
RR = RR x R, computation_cost: 262144, memory_cost: 1048576
|
||||
RS01 = RS01 x S01, computation_cost: 65536, memory_cost: 262144.0
|
||||
'''
|
||||
|
||||
# SR = SR x R with single mesh dim on batch dimensions
|
||||
self.split_input_batch_single_mesh_dim(0)
|
||||
self.split_input_batch_single_mesh_dim(1)
|
||||
|
||||
# SR = SR x R with both mesh dims on batch dimensions
|
||||
self.split_input_batch_both_mesh_dim(0, 1)
|
||||
|
||||
# RR = RR x R
|
||||
self.non_split()
|
||||
|
||||
return self.strategies_vector
|
|
@ -1,5 +1,6 @@
|
|||
from torch.fx import Graph, Node
|
||||
from colossalai.auto_parallel.solver.op_handler.bcast_op_handler import BcastOpHandler
|
||||
from colossalai.auto_parallel.solver.op_handler.layer_norm_handler import LayerNormHandler
|
||||
from colossalai.tensor.sharding_spec import ShardingSpec
|
||||
from colossalai.device.device_mesh import DeviceMesh
|
||||
from colossalai.tensor.shape_consistency import ShapeConsistencyManager
|
||||
|
@ -216,6 +217,15 @@ class StrategiesConstructor:
|
|||
input_shardings=[input_sharding_spec])
|
||||
strategies_vector.append(sharding_strategy)
|
||||
|
||||
# embedding module
|
||||
elif submod_type in EMBEDDING_MODULE_OP:
|
||||
embedding_handler = EmbeddingHandler(node, self.device_mesh, strategies_vector)
|
||||
embedding_handler.register_strategy()
|
||||
|
||||
# layernorm module
|
||||
elif submod_type in LAYERNORM_MODULE_OP:
|
||||
layernorm_handler = LayerNormHandler(node, self.device_mesh, strategies_vector)
|
||||
layernorm_handler.register_strategy()
|
||||
# other module
|
||||
else:
|
||||
raise RuntimeError(f'{submod_type} module is NOT supported now.')
|
||||
|
@ -349,35 +359,72 @@ class StrategiesConstructor:
|
|||
elif target == operator.getitem:
|
||||
index = node.args[1]
|
||||
input_tensor_node = strategies_vector.predecessor_nodes[0]
|
||||
for strategy in input_tensor_node.strategies_vector:
|
||||
input_sharding_spec = input_tensor_node.output_sharding_spec[index]
|
||||
assert isinstance(input_sharding_spec, ShardingSpec), f'This assertion is used to debug.'
|
||||
dim_partition_dict_for_output = deepcopy(input_sharding_spec.dim_partition_dict)
|
||||
entire_shape_output = deepcopy(input_sharding_spec.entire_shape)
|
||||
output_sharding_spec = ShardingSpec(self.device_mesh,
|
||||
entire_shape_output,
|
||||
dim_partition_dict=dim_partition_dict_for_output)
|
||||
# TODO: use meta_info_prop to profile origin memory cost and compute cost, then divide them depending on sharding spec.
|
||||
compute_cost = 0
|
||||
memory_cost = 0
|
||||
resharding_costs = generate_resharding_costs(strategies_vector.predecessor_nodes,
|
||||
[input_sharding_spec])
|
||||
# to prevent the resharding happening, set their resharding cost to inf.
|
||||
resharding_costs[input_tensor_node] = [
|
||||
cost if cost == 0 else math.inf for cost in resharding_costs[input_tensor_node]
|
||||
]
|
||||
sharding_strategy = ShardingStrategy(name,
|
||||
output_sharding_spec,
|
||||
compute_cost=compute_cost,
|
||||
memory_cost=memory_cost,
|
||||
resharding_costs=resharding_costs,
|
||||
input_shardings=[input_tensor_node.output_sharding_spec])
|
||||
strategies_vector.append(sharding_strategy)
|
||||
if isinstance(input_tensor_node, torch.Tensor):
|
||||
for strategy in input_tensor_node.strategies_vector:
|
||||
input_sharding_spec = strategy.output_sharding_spec[index]
|
||||
assert isinstance(input_sharding_spec, ShardingSpec), f'This assertion is used to debug.'
|
||||
dim_partition_dict_for_output = deepcopy(input_sharding_spec.dim_partition_dict)
|
||||
entire_shape_output = deepcopy(input_sharding_spec.entire_shape)
|
||||
output_sharding_spec = ShardingSpec(self.device_mesh,
|
||||
entire_shape_output,
|
||||
dim_partition_dict=dim_partition_dict_for_output)
|
||||
# TODO: use meta_info_prop to profile origin memory cost and compute cost, then divide them depending on sharding spec.
|
||||
compute_cost = 0
|
||||
memory_cost = 0
|
||||
resharding_costs = generate_resharding_costs(strategies_vector.predecessor_nodes,
|
||||
[input_sharding_spec])
|
||||
# to prevent the resharding happening, set their resharding cost to inf.
|
||||
resharding_costs[input_tensor_node] = [
|
||||
cost if cost == 0 else math.inf for cost in resharding_costs[input_tensor_node]
|
||||
]
|
||||
sharding_strategy = ShardingStrategy(
|
||||
name,
|
||||
output_sharding_spec,
|
||||
compute_cost=compute_cost,
|
||||
memory_cost=memory_cost,
|
||||
resharding_costs=resharding_costs,
|
||||
input_shardings=[input_tensor_node.output_sharding_spec])
|
||||
strategies_vector.append(sharding_strategy)
|
||||
|
||||
# torch.arange function
|
||||
elif target == torch.arange:
|
||||
name = f'FULLY REPLICATED ARANGE'
|
||||
entire_shape_output = node._meta_data.shape
|
||||
dim_partition_dict_for_output = {}
|
||||
output_sharding_spec = ShardingSpec(self.device_mesh,
|
||||
entire_shape_output,
|
||||
dim_partition_dict=dim_partition_dict_for_output)
|
||||
memory_cost = node._meta_data.numel()
|
||||
sharding_strategy = ShardingStrategy(name,
|
||||
output_sharding_spec,
|
||||
compute_cost=0,
|
||||
memory_cost=memory_cost)
|
||||
strategies_vector.append(sharding_strategy)
|
||||
|
||||
# op list to be processed to support gpt2
|
||||
elif target in (builtins.getattr, operator.le, torch.addmm, operator.pow, torch.where, torch.softmax,
|
||||
torch.nn.functional.softmax, torch.pow, torch.tanh):
|
||||
pass
|
||||
# other function
|
||||
else:
|
||||
raise RuntimeError(f'{target} function is NOT supported now.')
|
||||
|
||||
# call_method node
|
||||
if node.op == 'call_method':
|
||||
method = getattr(node.args[0]._meta_data.__class__, node.target)
|
||||
if method in (torch.Tensor.size, torch.Tensor.contiguous):
|
||||
pass
|
||||
elif method in ELEMENTWISE_METHOD_OP:
|
||||
unary_elementwise_handler = UnaryElementwiseHandler(node, self.device_mesh, strategies_vector)
|
||||
unary_elementwise_handler.register_strategy()
|
||||
|
||||
elif method in RESHAPE_METHOD_OP:
|
||||
reshape_handler = ReshapeHandler(node, self.device_mesh, strategies_vector)
|
||||
reshape_handler.register_strategy()
|
||||
|
||||
else:
|
||||
raise RuntimeError(f'{method} function is NOT supported now.')
|
||||
|
||||
# output node
|
||||
if node.op == 'output':
|
||||
if self.solver_options.fast:
|
||||
|
|
|
@ -0,0 +1,70 @@
|
|||
import torch
|
||||
from torch.fx import GraphModule
|
||||
import torch.nn as nn
|
||||
import pytest
|
||||
from colossalai.auto_parallel.solver import sharding_strategy
|
||||
|
||||
from colossalai.fx.proxy import ColoProxy
|
||||
from colossalai.fx.tracer.tracer import ColoTracer
|
||||
from colossalai.tensor.sharding_spec import ShardingSpec, _DimSpec
|
||||
from colossalai.auto_parallel.solver.op_handler.layer_norm_handler import LayerNormHandler
|
||||
from colossalai.auto_parallel.solver.sharding_strategy import ShardingStrategy, StrategiesVector
|
||||
from colossalai.device.device_mesh import DeviceMesh
|
||||
|
||||
|
||||
class LNModel(nn.Module):
|
||||
|
||||
def __init__(self, c):
|
||||
super().__init__()
|
||||
self.ln = nn.LayerNorm(c)
|
||||
|
||||
def forward(self, x):
|
||||
x = x * 2
|
||||
x = self.ln(x)
|
||||
return x
|
||||
|
||||
|
||||
def test_bn_handler():
|
||||
physical_mesh_id = torch.arange(0, 4)
|
||||
mesh_shape = (2, 2)
|
||||
# [[0, 1]
|
||||
# [2, 3]]
|
||||
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape)
|
||||
entire_shape = torch.Size((4, 4, 128))
|
||||
|
||||
tracer = ColoTracer()
|
||||
model = LNModel(128)
|
||||
input_sample = {'x': torch.rand(4, 4, 128).to('meta')}
|
||||
# graph():
|
||||
# %x : torch.Tensor [#users=1] = placeholder[target=x]
|
||||
# %mul : [#users=1] = call_function[target=operator.mul](args = (%x, 2), kwargs = {})
|
||||
# %ln : [#users=1] = call_module[target=ln](args = (%mul,), kwargs = {})
|
||||
# return ln
|
||||
graph = tracer.trace(root=model, meta_args=input_sample)
|
||||
gm = GraphModule(model, graph, model.__class__.__name__)
|
||||
gm.recompile()
|
||||
# [x, mul, ln, output]
|
||||
nodes = [node for node in gm.graph.nodes]
|
||||
sharding_spec_for_input = ShardingSpec(device_mesh, entire_shape, {})
|
||||
sharding_strategy_for_input = ShardingStrategy('node_1', sharding_spec_for_input)
|
||||
strategies_vector_for_input = StrategiesVector(nodes[1])
|
||||
strategies_vector_for_input.append(sharding_strategy_for_input)
|
||||
setattr(nodes[1], 'strategies_vector', strategies_vector_for_input)
|
||||
|
||||
# generate bn strategy
|
||||
strategies_vector = StrategiesVector(node=nodes[2])
|
||||
ln_handler = LayerNormHandler(
|
||||
node=nodes[2],
|
||||
device_mesh=device_mesh,
|
||||
strategies_vector=strategies_vector,
|
||||
)
|
||||
ln_handler.register_strategy()
|
||||
# ['[S0, R, R] = [S0, R, R] x [R]', '[R, S0, R] = [R, S0, R] x [R]', '[S1, R, R] = [S1, R, R] x [R]', '[R, S1, R] = [R, S1, R] x [R]',
|
||||
# '[S0, S1, R] = [S0, S1, R] x [R]', '[S1, S0, R] = [S1, S0, R] x [R]', '[S01, R, R] = [S01, R, R] x [R]', '[R, S01, R] = [R, S01, R] x [R]', 'RR = RR x R']
|
||||
strategy_name_list = [strategy.name for strategy in ln_handler.strategies_vector]
|
||||
|
||||
assert len(strategy_name_list) == 9
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_bn_handler()
|
Loading…
Reference in New Issue