mirror of https://github.com/hpcaitech/ColossalAI
[autoparallel] add layernorm handler (#1629)
parent
bf77d3ab65
commit
0c703189b9
|
@ -94,7 +94,43 @@ def exception_handler(func):
|
||||||
def wrapper(*args, **kwargs):
|
def wrapper(*args, **kwargs):
|
||||||
try:
|
try:
|
||||||
func(*args, **kwargs)
|
func(*args, **kwargs)
|
||||||
except Exception as e:
|
except AssertionError as e:
|
||||||
warnings.warn(f'{e}')
|
warnings.warn(f'{e}')
|
||||||
|
|
||||||
return wrapper
|
return wrapper
|
||||||
|
|
||||||
|
|
||||||
|
def enumerate_all_possible_2d_sharding(mesh_dim_0, mesh_dim_1, dim_size):
|
||||||
|
dim_partition_list = []
|
||||||
|
# enumerate all the 2D sharding cases
|
||||||
|
for i in range(dim_size):
|
||||||
|
for j in range(i + 1, dim_size):
|
||||||
|
dim_partition_dict_0 = {i: [mesh_dim_0], j: [mesh_dim_1]}
|
||||||
|
dim_partition_dict_1 = {i: [mesh_dim_1], j: [mesh_dim_0]}
|
||||||
|
dim_partition_list.append(dim_partition_dict_0)
|
||||||
|
dim_partition_list.append(dim_partition_dict_1)
|
||||||
|
for i in range(dim_size):
|
||||||
|
dim_partition_dict_flatten = {i: [mesh_dim_0, mesh_dim_1]}
|
||||||
|
dim_partition_list.append(dim_partition_dict_flatten)
|
||||||
|
|
||||||
|
return dim_partition_list
|
||||||
|
|
||||||
|
|
||||||
|
def enumerate_all_possible_1d_sharding(mesh_dim_0, dim_size):
|
||||||
|
dim_partition_list = []
|
||||||
|
# enumerate all the 1D sharding cases
|
||||||
|
for i in range(dim_size):
|
||||||
|
dim_partition_dict_0 = {i: [mesh_dim_0]}
|
||||||
|
dim_partition_list.append(dim_partition_dict_0)
|
||||||
|
|
||||||
|
return dim_partition_list
|
||||||
|
|
||||||
|
|
||||||
|
def generate_sharding_size(dim_partition_dict, device_mesh):
|
||||||
|
total_sharding_size = 1
|
||||||
|
for mesh_dim_list in dim_partition_dict.values():
|
||||||
|
mesh_dim_sharding_size = [device_mesh.shape[mesh_dim] for mesh_dim in mesh_dim_list]
|
||||||
|
sharding_size = reduce(operator.mul, mesh_dim_sharding_size)
|
||||||
|
total_sharding_size *= sharding_size
|
||||||
|
|
||||||
|
return total_sharding_size
|
||||||
|
|
|
@ -3,7 +3,8 @@ import operator
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
'ELEMENTWISE_MODULE_OP', 'ELEMENTWISE_FUNC_OP', 'RESHAPE_FUNC_OP', 'CONV_MODULE_OP', 'CONV_FUNC_OP',
|
'ELEMENTWISE_MODULE_OP', 'ELEMENTWISE_FUNC_OP', 'RESHAPE_FUNC_OP', 'CONV_MODULE_OP', 'CONV_FUNC_OP',
|
||||||
'LINEAR_MODULE_OP', 'LINEAR_FUNC_OP', 'BATCHNORM_MODULE_OP', 'POOL_MODULE_OP', 'NON_PARAM_FUNC_OP', 'BCAST_FUNC_OP'
|
'LINEAR_MODULE_OP', 'LINEAR_FUNC_OP', 'BATCHNORM_MODULE_OP', 'POOL_MODULE_OP', 'NON_PARAM_FUNC_OP', 'BCAST_FUNC_OP',
|
||||||
|
'EMBEDDING_MODULE_OP', 'LAYERNORM_MODULE_OP', 'ELEMENTWISE_METHOD_OP', 'RESHAPE_METHOD_OP'
|
||||||
]
|
]
|
||||||
|
|
||||||
ELEMENTWISE_MODULE_OP = [torch.nn.Dropout, torch.nn.ReLU]
|
ELEMENTWISE_MODULE_OP = [torch.nn.Dropout, torch.nn.ReLU]
|
||||||
|
@ -11,7 +12,18 @@ ELEMENTWISE_FUNC_OP = [
|
||||||
torch.abs, torch.cos, torch.exp, operator.neg, torch.multiply, torch.nn.functional.relu,
|
torch.abs, torch.cos, torch.exp, operator.neg, torch.multiply, torch.nn.functional.relu,
|
||||||
torch.nn.functional.dropout, torch.flatten
|
torch.nn.functional.dropout, torch.flatten
|
||||||
]
|
]
|
||||||
RESHAPE_FUNC_OP = [torch.flatten, torch.Tensor.view, torch.reshape]
|
ELEMENTWISE_METHOD_OP = [
|
||||||
|
torch.Tensor.to,
|
||||||
|
torch.Tensor.type,
|
||||||
|
]
|
||||||
|
RESHAPE_FUNC_OP = [torch.flatten, torch.reshape]
|
||||||
|
RESHAPE_METHOD_OP = [
|
||||||
|
torch.Tensor.view,
|
||||||
|
torch.Tensor.unsqueeze,
|
||||||
|
torch.Tensor.split,
|
||||||
|
torch.Tensor.permute,
|
||||||
|
torch.Tensor.transpose,
|
||||||
|
]
|
||||||
BCAST_FUNC_OP = [
|
BCAST_FUNC_OP = [
|
||||||
torch.add, torch.sub, torch.mul, torch.div, torch.floor_divide, torch.true_divide, operator.add, operator.sub,
|
torch.add, torch.sub, torch.mul, torch.div, torch.floor_divide, torch.true_divide, operator.add, operator.sub,
|
||||||
operator.mul, operator.floordiv, operator.truediv, torch.matmul
|
operator.mul, operator.floordiv, operator.truediv, torch.matmul
|
||||||
|
@ -23,9 +35,11 @@ CONV_MODULE_OP = [
|
||||||
CONV_FUNC_OP = [
|
CONV_FUNC_OP = [
|
||||||
torch.conv1d, torch.conv2d, torch.conv3d, torch.conv_transpose1d, torch.conv_transpose2d, torch.conv_transpose3d
|
torch.conv1d, torch.conv2d, torch.conv3d, torch.conv_transpose1d, torch.conv_transpose2d, torch.conv_transpose3d
|
||||||
]
|
]
|
||||||
|
EMBEDDING_MODULE_OP = [torch.nn.modules.sparse.Embedding]
|
||||||
LINEAR_MODULE_OP = [torch.nn.Linear]
|
LINEAR_MODULE_OP = [torch.nn.Linear]
|
||||||
LINEAR_FUNC_OP = [torch.nn.functional.linear, torch.matmul, torch.bmm]
|
LINEAR_FUNC_OP = [torch.nn.functional.linear, torch.matmul, torch.bmm]
|
||||||
BATCHNORM_MODULE_OP = [torch.nn.BatchNorm1d, torch.nn.BatchNorm2d, torch.nn.BatchNorm3d, torch.nn.SyncBatchNorm]
|
BATCHNORM_MODULE_OP = [torch.nn.BatchNorm1d, torch.nn.BatchNorm2d, torch.nn.BatchNorm3d, torch.nn.SyncBatchNorm]
|
||||||
|
LAYERNORM_MODULE_OP = [torch.nn.LayerNorm]
|
||||||
POOL_MODULE_OP = [torch.nn.MaxPool1d, torch.nn.MaxPool2d, torch.nn.MaxPool3d, torch.nn.AdaptiveAvgPool2d]
|
POOL_MODULE_OP = [torch.nn.MaxPool1d, torch.nn.MaxPool2d, torch.nn.MaxPool3d, torch.nn.AdaptiveAvgPool2d]
|
||||||
NON_PARAM_FUNC_OP = RESHAPE_FUNC_OP + ELEMENTWISE_FUNC_OP
|
NON_PARAM_FUNC_OP = RESHAPE_FUNC_OP + ELEMENTWISE_FUNC_OP
|
||||||
|
|
||||||
|
|
|
@ -8,7 +8,7 @@ from colossalai.tensor.shape_consistency import ShapeConsistencyManager
|
||||||
from colossalai.tensor.sharding_spec import ShardingSpec
|
from colossalai.tensor.sharding_spec import ShardingSpec
|
||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
from typing import Dict, List
|
from typing import Dict, List
|
||||||
from colossalai.auto_parallel.solver._utils import exception_handler
|
from colossalai.auto_parallel.solver._utils import exception_handler, enumerate_all_possible_1d_sharding, enumerate_all_possible_2d_sharding
|
||||||
|
|
||||||
__all__ = ['BcastOpHandler']
|
__all__ = ['BcastOpHandler']
|
||||||
|
|
||||||
|
@ -110,45 +110,19 @@ class BcastOpHandler(OperatorHandler):
|
||||||
|
|
||||||
return sharding_spec_list
|
return sharding_spec_list
|
||||||
|
|
||||||
def _enumerate_all_possible_2d_sharding(self, mesh_dim_0, mesh_dim_1, dim_size):
|
|
||||||
dim_partition_list = []
|
|
||||||
# enumerate all the 2D sharding cases
|
|
||||||
for i in range(dim_size):
|
|
||||||
for j in range(i + 1, dim_size):
|
|
||||||
dim_partition_dict_0 = {i: [mesh_dim_0], j: [mesh_dim_1]}
|
|
||||||
dim_partition_dict_1 = {i: [mesh_dim_1], j: [mesh_dim_0]}
|
|
||||||
dim_partition_list.append(dim_partition_dict_0)
|
|
||||||
dim_partition_list.append(dim_partition_dict_1)
|
|
||||||
for i in range(dim_size):
|
|
||||||
dim_partition_dict_flatten = {i: [mesh_dim_0, mesh_dim_1]}
|
|
||||||
dim_partition_list.append(dim_partition_dict_flatten)
|
|
||||||
|
|
||||||
# sharding_spec_list = self._convert_partition_dict_to_sharding_spec(dim_partition_list)
|
|
||||||
return dim_partition_list
|
|
||||||
|
|
||||||
def _enumerate_all_possible_1d_sharding(self, mesh_dim_0, dim_size):
|
|
||||||
dim_partition_list = []
|
|
||||||
# enumerate all the 1D sharding cases
|
|
||||||
for i in range(dim_size):
|
|
||||||
dim_partition_dict_0 = {i: [mesh_dim_0]}
|
|
||||||
dim_partition_list.append(dim_partition_dict_0)
|
|
||||||
|
|
||||||
# sharding_spec_list = self._convert_partition_dict_to_sharding_spec(dim_partition_list)
|
|
||||||
return dim_partition_list
|
|
||||||
|
|
||||||
def _enumerate_all_possible_output(self, mesh_dim_0, mesh_dim_1):
|
def _enumerate_all_possible_output(self, mesh_dim_0, mesh_dim_1):
|
||||||
# use mesh_dim_0, mesh_dim_1 instead of constant 0, 1 in here for N-D device mesh scaliablity.
|
# use mesh_dim_0, mesh_dim_1 instead of constant 0, 1 in here for N-D device mesh scaliablity.
|
||||||
|
|
||||||
output_dim_partition_list = []
|
output_dim_partition_list = []
|
||||||
dim_size = self.output_data.dim()
|
dim_size = self.output_data.dim()
|
||||||
# enumerate all the 2D sharding cases
|
# enumerate all the 2D sharding cases
|
||||||
sharding_list_2d = self._enumerate_all_possible_2d_sharding(mesh_dim_0, mesh_dim_1, dim_size)
|
sharding_list_2d = enumerate_all_possible_2d_sharding(mesh_dim_0, mesh_dim_1, dim_size)
|
||||||
output_dim_partition_list.extend(sharding_list_2d)
|
output_dim_partition_list.extend(sharding_list_2d)
|
||||||
|
|
||||||
# enumerate all the 1D sharding cases
|
# enumerate all the 1D sharding cases
|
||||||
sharding_list_1d_on_dim_0 = self._enumerate_all_possible_1d_sharding(mesh_dim_0, dim_size)
|
sharding_list_1d_on_dim_0 = enumerate_all_possible_1d_sharding(mesh_dim_0, dim_size)
|
||||||
output_dim_partition_list.extend(sharding_list_1d_on_dim_0)
|
output_dim_partition_list.extend(sharding_list_1d_on_dim_0)
|
||||||
sharding_list_1d_on_dim_1 = self._enumerate_all_possible_1d_sharding(mesh_dim_1, dim_size)
|
sharding_list_1d_on_dim_1 = enumerate_all_possible_1d_sharding(mesh_dim_1, dim_size)
|
||||||
output_dim_partition_list.extend(sharding_list_1d_on_dim_1)
|
output_dim_partition_list.extend(sharding_list_1d_on_dim_1)
|
||||||
|
|
||||||
# add empty dict for fully replicated case
|
# add empty dict for fully replicated case
|
||||||
|
@ -545,15 +519,13 @@ class BcastOpHandler(OperatorHandler):
|
||||||
dim_size = self.output_data.dim() - 2
|
dim_size = self.output_data.dim() - 2
|
||||||
|
|
||||||
# Both device mesh axises are uesd on batch dimensions
|
# Both device mesh axises are uesd on batch dimensions
|
||||||
dim_partition_dicts_2d = self._enumerate_all_possible_2d_sharding(MESH_DIM_LIST[0], MESH_DIM_LIST[1],
|
dim_partition_dicts_2d = enumerate_all_possible_2d_sharding(MESH_DIM_LIST[0], MESH_DIM_LIST[1], dim_size)
|
||||||
dim_size)
|
|
||||||
for dim_partition_dict in dim_partition_dicts_2d:
|
for dim_partition_dict in dim_partition_dicts_2d:
|
||||||
self._registry_no_split_strategies_for_matmul(dim_partition_dict)
|
self._registry_no_split_strategies_for_matmul(dim_partition_dict)
|
||||||
|
|
||||||
# Only one device mesh axis is uesd on batch dimensions
|
# Only one device mesh axis is uesd on batch dimensions
|
||||||
for mesh_dim_index in [0, 1]:
|
for mesh_dim_index in [0, 1]:
|
||||||
dim_partition_dicts_1d = self._enumerate_all_possible_1d_sharding(MESH_DIM_LIST[mesh_dim_index],
|
dim_partition_dicts_1d = enumerate_all_possible_1d_sharding(MESH_DIM_LIST[mesh_dim_index], dim_size)
|
||||||
dim_size)
|
|
||||||
for dim_partition_dict in dim_partition_dicts_1d:
|
for dim_partition_dict in dim_partition_dicts_1d:
|
||||||
self._registry_no_split_strategies_for_matmul(dim_partition_dict)
|
self._registry_no_split_strategies_for_matmul(dim_partition_dict)
|
||||||
self._registry_1d_strategies_for_matmul(dim_partition_dict, [MESH_DIM_LIST[mesh_dim_index - 1]])
|
self._registry_1d_strategies_for_matmul(dim_partition_dict, [MESH_DIM_LIST[mesh_dim_index - 1]])
|
||||||
|
|
|
@ -0,0 +1,233 @@
|
||||||
|
import operator
|
||||||
|
from functools import reduce
|
||||||
|
import torch
|
||||||
|
from colossalai.auto_parallel.solver.sharding_strategy import ShardingStrategy, StrategiesVector
|
||||||
|
from .operator_handler import OperatorHandler
|
||||||
|
from colossalai.auto_parallel.solver._utils import exception_handler, enumerate_all_possible_2d_sharding, enumerate_all_possible_1d_sharding, generate_sharding_size
|
||||||
|
|
||||||
|
__all__ = ['LayerNormHandler']
|
||||||
|
|
||||||
|
|
||||||
|
class LayerNormHandler(OperatorHandler):
|
||||||
|
"""
|
||||||
|
A OperatorHandler which deals with the sharding strategies of normalization.
|
||||||
|
|
||||||
|
Note: To keep the math consistency, LayerNorm do not allow shards on hidden dimension.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
self.input_data = self.predecessor_node[0]._meta_data
|
||||||
|
self.weight = self.module_named_parameters['weight']
|
||||||
|
self.bias = self.module_named_parameters['bias']
|
||||||
|
self.output_data = self.node._meta_data
|
||||||
|
|
||||||
|
def _generate_compute_cost(self, total_sharding_size):
|
||||||
|
'''
|
||||||
|
Compute the computation cost per device with this specific strategy.
|
||||||
|
|
||||||
|
Note: compute_cost need to be devided by TFLOPS, now it just shows the computation size.
|
||||||
|
|
||||||
|
Argument:
|
||||||
|
bs(int): Batch size of the input data.
|
||||||
|
channel_in(int): The channel dimension of input data.
|
||||||
|
|
||||||
|
Return:
|
||||||
|
compute_cost(float): Computation cost per device with this specific strategy
|
||||||
|
'''
|
||||||
|
# TODO: compute_cost need to be devided by TFLOPS, now it just shows the computation size.
|
||||||
|
# TODO: a constant coefficient need to be added.
|
||||||
|
|
||||||
|
norm_kernel_size = self.weight.shape
|
||||||
|
# in LayerNorm context, batch dimensions mean all the dimensions do not join the normalization.
|
||||||
|
input_batch_shape = self.input_data.shape[:-len(norm_kernel_size)]
|
||||||
|
input_batch_product = reduce(operator.mul, input_batch_shape, 1)
|
||||||
|
norm_kernel_product = reduce(operator.mul, norm_kernel_size, 1)
|
||||||
|
forward_compute_cost = input_batch_product * norm_kernel_product / total_sharding_size
|
||||||
|
backward_activation_compute_cost = input_batch_product * norm_kernel_product / total_sharding_size
|
||||||
|
# To compute gradient of on norm kernel element requires input_batch_product times computation, so
|
||||||
|
# the total cost is input_batch_product * norm_kernel_product
|
||||||
|
backward_weight_compute_cost = input_batch_product * norm_kernel_product / total_sharding_size
|
||||||
|
backward_compute_cost = backward_activation_compute_cost + backward_weight_compute_cost
|
||||||
|
compute_cost = forward_compute_cost + backward_compute_cost
|
||||||
|
return compute_cost
|
||||||
|
|
||||||
|
def _generate_memory_cost(self, sharding_size_forward, sharding_size_backward_activation, sharding_size_weight):
|
||||||
|
'''
|
||||||
|
Compute the memory cost per device with this specific strategy.
|
||||||
|
|
||||||
|
Argument:
|
||||||
|
sharding_size_forward(int): The forward activation will be divided
|
||||||
|
into sharding_size_forward number partions.
|
||||||
|
sharding_size_backward_activation(int): The backward activation will
|
||||||
|
be divided into sharding_size_backward_activation number partions.
|
||||||
|
sharding_size_weight(int): The backward weight will be divided
|
||||||
|
into sharding_size_weight number partions.
|
||||||
|
|
||||||
|
Return:
|
||||||
|
memory_cost(Tuple[float]): Memory cost per device with this
|
||||||
|
specific strategy, the first element of this tuple is forward
|
||||||
|
memory cost, and the second element of this tuple is backward
|
||||||
|
memory cost.
|
||||||
|
memory_cost_forward(float): Memory cost of forward activation per
|
||||||
|
device with this specific strategy.
|
||||||
|
memory_cost_backward_activation(float): Memory cost of backward activation
|
||||||
|
per device with this specific strategy.
|
||||||
|
'''
|
||||||
|
# compute the memory cost of this strategy
|
||||||
|
dtype = self.input_data.dtype
|
||||||
|
numel_output = self.output_data.numel()
|
||||||
|
# this operation will not change the shape of input
|
||||||
|
numel_input = numel_output
|
||||||
|
numel_weight = self.weight.numel()
|
||||||
|
size_per_elem_bytes = torch.tensor([], dtype=dtype).element_size()
|
||||||
|
|
||||||
|
# forward memory_cost
|
||||||
|
memory_cost_forward_activation = numel_output * size_per_elem_bytes / sharding_size_forward
|
||||||
|
memory_cost_forward_weight = numel_weight * size_per_elem_bytes / sharding_size_weight
|
||||||
|
memory_cost_forward = memory_cost_forward_activation + memory_cost_forward_weight
|
||||||
|
|
||||||
|
# backward memory_cost
|
||||||
|
memory_cost_backward_activation = numel_input * size_per_elem_bytes / sharding_size_backward_activation
|
||||||
|
memory_cost_backward_weight = numel_weight * size_per_elem_bytes / sharding_size_weight
|
||||||
|
memory_cost_backward = memory_cost_backward_activation + memory_cost_backward_weight
|
||||||
|
|
||||||
|
# memory_cost pair
|
||||||
|
memory_cost = (memory_cost_forward, memory_cost_backward)
|
||||||
|
|
||||||
|
return memory_cost, memory_cost_forward_activation, memory_cost_backward_activation, memory_cost_backward_weight
|
||||||
|
|
||||||
|
def _generate_strategy_with_dim_partition(self, dim_partition):
|
||||||
|
dim_partition_dict_for_input = dim_partition
|
||||||
|
sharding_spec_for_input = self._generate_sharding_spec(self.input_data, dim_partition_dict_for_input)
|
||||||
|
|
||||||
|
dim_partition_dict_for_weight = {}
|
||||||
|
sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight)
|
||||||
|
|
||||||
|
dim_partition_dict_for_output = dim_partition
|
||||||
|
sharding_spec_for_output = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output)
|
||||||
|
|
||||||
|
name = f'{sharding_spec_for_output.sharding_sequence} = {sharding_spec_for_input.sharding_sequence} x {sharding_spec_for_weight.sharding_sequence}'
|
||||||
|
# generate resharding cost for this strategy
|
||||||
|
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input])
|
||||||
|
|
||||||
|
total_sharding_size = generate_sharding_size(dim_partition, self.device_mesh)
|
||||||
|
# compute the computation cost of this strategy
|
||||||
|
compute_cost = self._generate_compute_cost(total_sharding_size)
|
||||||
|
|
||||||
|
# compute the memory cost of this strategy
|
||||||
|
sharding_size_forward = generate_sharding_size(dim_partition_dict_for_input, self.device_mesh)
|
||||||
|
sharding_size_backward_activation = generate_sharding_size(dim_partition_dict_for_output, self.device_mesh)
|
||||||
|
sharding_size_weight = generate_sharding_size(dim_partition_dict_for_weight, self.device_mesh)
|
||||||
|
memory_cost, _, _, memory_cost_backward_weight = self._generate_memory_cost(sharding_size_forward,
|
||||||
|
sharding_size_backward_activation,
|
||||||
|
sharding_size_weight)
|
||||||
|
|
||||||
|
total_mesh_dim_list = []
|
||||||
|
for mesh_dim_list in dim_partition.values():
|
||||||
|
total_mesh_dim_list.extend(mesh_dim_list)
|
||||||
|
|
||||||
|
# This strategy do not need to do all_reduce operation for activation
|
||||||
|
communication_cost_forward_activation = 0
|
||||||
|
communication_cost_backward_activation = 0
|
||||||
|
if len(total_mesh_dim_list) == 1:
|
||||||
|
communication_cost_backward_weight = self.device_mesh.all_reduce_cost(memory_cost_backward_weight,
|
||||||
|
total_mesh_dim_list[0])
|
||||||
|
else:
|
||||||
|
assert len(total_mesh_dim_list) == 2, f'temporally we just support 2d device mesh.'
|
||||||
|
communication_cost_backward_weight = self.device_mesh.flatten_device_mesh.all_reduce_cost(
|
||||||
|
memory_cost_backward_weight, 0)
|
||||||
|
communication_cost = communication_cost_forward_activation + communication_cost_backward_activation + communication_cost_backward_weight
|
||||||
|
|
||||||
|
sharding_strategies = ShardingStrategy(name,
|
||||||
|
output_sharding_spec=sharding_spec_for_output,
|
||||||
|
compute_cost=compute_cost,
|
||||||
|
communication_cost=communication_cost,
|
||||||
|
memory_cost=memory_cost,
|
||||||
|
resharding_costs=resharding_costs,
|
||||||
|
input_shardings=(sharding_spec_for_input, sharding_spec_for_weight))
|
||||||
|
|
||||||
|
self.strategies_vector.append(sharding_strategies)
|
||||||
|
|
||||||
|
@exception_handler
|
||||||
|
def split_input_batch_single_mesh_dim(self, mesh_dim_0):
|
||||||
|
batch_dimension_length = self.input_data.dim() - self.weight.dim()
|
||||||
|
dim_partition_list = enumerate_all_possible_1d_sharding(mesh_dim_0, batch_dimension_length)
|
||||||
|
for dim_partition in dim_partition_list:
|
||||||
|
self._generate_strategy_with_dim_partition(dim_partition)
|
||||||
|
|
||||||
|
@exception_handler
|
||||||
|
def split_input_batch_both_mesh_dim(self, mesh_dim_0, mesh_dim_1):
|
||||||
|
batch_dimension_length = self.input_data.dim() - self.weight.dim()
|
||||||
|
dim_partition_list = enumerate_all_possible_2d_sharding(mesh_dim_0, mesh_dim_1, batch_dimension_length)
|
||||||
|
for dim_partition in dim_partition_list:
|
||||||
|
self._generate_strategy_with_dim_partition(dim_partition)
|
||||||
|
|
||||||
|
@exception_handler
|
||||||
|
def non_split(self):
|
||||||
|
name = f'RR = RR x R'
|
||||||
|
|
||||||
|
dim_partition_dict_for_input = {}
|
||||||
|
sharding_spec_for_input = self._generate_sharding_spec(self.input_data, dim_partition_dict_for_input)
|
||||||
|
|
||||||
|
dim_partition_dict_for_weight = {}
|
||||||
|
sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight)
|
||||||
|
|
||||||
|
dim_partition_dict_for_output = {}
|
||||||
|
sharding_spec_for_output = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output)
|
||||||
|
|
||||||
|
# generate resharding cost for this strategy
|
||||||
|
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input])
|
||||||
|
|
||||||
|
total_sharding_size = 1
|
||||||
|
# compute the computation cost of this strategy
|
||||||
|
compute_cost = self._generate_compute_cost(total_sharding_size)
|
||||||
|
|
||||||
|
# compute the memory cost of this strategy
|
||||||
|
sharding_size_forward = 1
|
||||||
|
sharding_size_backward_activation = 1
|
||||||
|
sharding_size_weight = 1
|
||||||
|
memory_cost, _, _, _ = self._generate_memory_cost(sharding_size_forward, sharding_size_backward_activation,
|
||||||
|
sharding_size_weight)
|
||||||
|
|
||||||
|
# This strategy do not need to do all_reduce operation
|
||||||
|
communication_cost = 0
|
||||||
|
sharding_strategies = ShardingStrategy(name,
|
||||||
|
output_sharding_spec=sharding_spec_for_output,
|
||||||
|
compute_cost=compute_cost,
|
||||||
|
communication_cost=communication_cost,
|
||||||
|
memory_cost=memory_cost,
|
||||||
|
resharding_costs=resharding_costs,
|
||||||
|
input_shardings=(sharding_spec_for_input, sharding_spec_for_weight))
|
||||||
|
|
||||||
|
self.strategies_vector.append(sharding_strategies)
|
||||||
|
|
||||||
|
def register_strategy(self) -> StrategiesVector:
|
||||||
|
'''
|
||||||
|
Generate every possible strategies for a BatchNorm node, and record all strategies into the strategies_vector.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
norm_handler = BatchNormHandler(node, strategies_vector,
|
||||||
|
self.shape_consistency_manager)
|
||||||
|
norm_handler.register_strategy()
|
||||||
|
for strategy in norm_handler.strategies_vector:
|
||||||
|
print(f'{strategy.name}, computation_cost: {strategy.compute_cost}, memory_cost: {strategy.memory_cost}')
|
||||||
|
|
||||||
|
Output:
|
||||||
|
RS0 = RS0 x S0, computation_cost: 131072, memory_cost: 524288.0
|
||||||
|
RS1 = RS1 x S1, computation_cost: 131072, memory_cost: 524288.0
|
||||||
|
RR = RR x R, computation_cost: 262144, memory_cost: 1048576
|
||||||
|
RS01 = RS01 x S01, computation_cost: 65536, memory_cost: 262144.0
|
||||||
|
'''
|
||||||
|
|
||||||
|
# SR = SR x R with single mesh dim on batch dimensions
|
||||||
|
self.split_input_batch_single_mesh_dim(0)
|
||||||
|
self.split_input_batch_single_mesh_dim(1)
|
||||||
|
|
||||||
|
# SR = SR x R with both mesh dims on batch dimensions
|
||||||
|
self.split_input_batch_both_mesh_dim(0, 1)
|
||||||
|
|
||||||
|
# RR = RR x R
|
||||||
|
self.non_split()
|
||||||
|
|
||||||
|
return self.strategies_vector
|
|
@ -1,5 +1,6 @@
|
||||||
from torch.fx import Graph, Node
|
from torch.fx import Graph, Node
|
||||||
from colossalai.auto_parallel.solver.op_handler.bcast_op_handler import BcastOpHandler
|
from colossalai.auto_parallel.solver.op_handler.bcast_op_handler import BcastOpHandler
|
||||||
|
from colossalai.auto_parallel.solver.op_handler.layer_norm_handler import LayerNormHandler
|
||||||
from colossalai.tensor.sharding_spec import ShardingSpec
|
from colossalai.tensor.sharding_spec import ShardingSpec
|
||||||
from colossalai.device.device_mesh import DeviceMesh
|
from colossalai.device.device_mesh import DeviceMesh
|
||||||
from colossalai.tensor.shape_consistency import ShapeConsistencyManager
|
from colossalai.tensor.shape_consistency import ShapeConsistencyManager
|
||||||
|
@ -216,6 +217,15 @@ class StrategiesConstructor:
|
||||||
input_shardings=[input_sharding_spec])
|
input_shardings=[input_sharding_spec])
|
||||||
strategies_vector.append(sharding_strategy)
|
strategies_vector.append(sharding_strategy)
|
||||||
|
|
||||||
|
# embedding module
|
||||||
|
elif submod_type in EMBEDDING_MODULE_OP:
|
||||||
|
embedding_handler = EmbeddingHandler(node, self.device_mesh, strategies_vector)
|
||||||
|
embedding_handler.register_strategy()
|
||||||
|
|
||||||
|
# layernorm module
|
||||||
|
elif submod_type in LAYERNORM_MODULE_OP:
|
||||||
|
layernorm_handler = LayerNormHandler(node, self.device_mesh, strategies_vector)
|
||||||
|
layernorm_handler.register_strategy()
|
||||||
# other module
|
# other module
|
||||||
else:
|
else:
|
||||||
raise RuntimeError(f'{submod_type} module is NOT supported now.')
|
raise RuntimeError(f'{submod_type} module is NOT supported now.')
|
||||||
|
@ -349,8 +359,9 @@ class StrategiesConstructor:
|
||||||
elif target == operator.getitem:
|
elif target == operator.getitem:
|
||||||
index = node.args[1]
|
index = node.args[1]
|
||||||
input_tensor_node = strategies_vector.predecessor_nodes[0]
|
input_tensor_node = strategies_vector.predecessor_nodes[0]
|
||||||
|
if isinstance(input_tensor_node, torch.Tensor):
|
||||||
for strategy in input_tensor_node.strategies_vector:
|
for strategy in input_tensor_node.strategies_vector:
|
||||||
input_sharding_spec = input_tensor_node.output_sharding_spec[index]
|
input_sharding_spec = strategy.output_sharding_spec[index]
|
||||||
assert isinstance(input_sharding_spec, ShardingSpec), f'This assertion is used to debug.'
|
assert isinstance(input_sharding_spec, ShardingSpec), f'This assertion is used to debug.'
|
||||||
dim_partition_dict_for_output = deepcopy(input_sharding_spec.dim_partition_dict)
|
dim_partition_dict_for_output = deepcopy(input_sharding_spec.dim_partition_dict)
|
||||||
entire_shape_output = deepcopy(input_sharding_spec.entire_shape)
|
entire_shape_output = deepcopy(input_sharding_spec.entire_shape)
|
||||||
|
@ -366,7 +377,8 @@ class StrategiesConstructor:
|
||||||
resharding_costs[input_tensor_node] = [
|
resharding_costs[input_tensor_node] = [
|
||||||
cost if cost == 0 else math.inf for cost in resharding_costs[input_tensor_node]
|
cost if cost == 0 else math.inf for cost in resharding_costs[input_tensor_node]
|
||||||
]
|
]
|
||||||
sharding_strategy = ShardingStrategy(name,
|
sharding_strategy = ShardingStrategy(
|
||||||
|
name,
|
||||||
output_sharding_spec,
|
output_sharding_spec,
|
||||||
compute_cost=compute_cost,
|
compute_cost=compute_cost,
|
||||||
memory_cost=memory_cost,
|
memory_cost=memory_cost,
|
||||||
|
@ -374,10 +386,45 @@ class StrategiesConstructor:
|
||||||
input_shardings=[input_tensor_node.output_sharding_spec])
|
input_shardings=[input_tensor_node.output_sharding_spec])
|
||||||
strategies_vector.append(sharding_strategy)
|
strategies_vector.append(sharding_strategy)
|
||||||
|
|
||||||
|
# torch.arange function
|
||||||
|
elif target == torch.arange:
|
||||||
|
name = f'FULLY REPLICATED ARANGE'
|
||||||
|
entire_shape_output = node._meta_data.shape
|
||||||
|
dim_partition_dict_for_output = {}
|
||||||
|
output_sharding_spec = ShardingSpec(self.device_mesh,
|
||||||
|
entire_shape_output,
|
||||||
|
dim_partition_dict=dim_partition_dict_for_output)
|
||||||
|
memory_cost = node._meta_data.numel()
|
||||||
|
sharding_strategy = ShardingStrategy(name,
|
||||||
|
output_sharding_spec,
|
||||||
|
compute_cost=0,
|
||||||
|
memory_cost=memory_cost)
|
||||||
|
strategies_vector.append(sharding_strategy)
|
||||||
|
|
||||||
|
# op list to be processed to support gpt2
|
||||||
|
elif target in (builtins.getattr, operator.le, torch.addmm, operator.pow, torch.where, torch.softmax,
|
||||||
|
torch.nn.functional.softmax, torch.pow, torch.tanh):
|
||||||
|
pass
|
||||||
# other function
|
# other function
|
||||||
else:
|
else:
|
||||||
raise RuntimeError(f'{target} function is NOT supported now.')
|
raise RuntimeError(f'{target} function is NOT supported now.')
|
||||||
|
|
||||||
|
# call_method node
|
||||||
|
if node.op == 'call_method':
|
||||||
|
method = getattr(node.args[0]._meta_data.__class__, node.target)
|
||||||
|
if method in (torch.Tensor.size, torch.Tensor.contiguous):
|
||||||
|
pass
|
||||||
|
elif method in ELEMENTWISE_METHOD_OP:
|
||||||
|
unary_elementwise_handler = UnaryElementwiseHandler(node, self.device_mesh, strategies_vector)
|
||||||
|
unary_elementwise_handler.register_strategy()
|
||||||
|
|
||||||
|
elif method in RESHAPE_METHOD_OP:
|
||||||
|
reshape_handler = ReshapeHandler(node, self.device_mesh, strategies_vector)
|
||||||
|
reshape_handler.register_strategy()
|
||||||
|
|
||||||
|
else:
|
||||||
|
raise RuntimeError(f'{method} function is NOT supported now.')
|
||||||
|
|
||||||
# output node
|
# output node
|
||||||
if node.op == 'output':
|
if node.op == 'output':
|
||||||
if self.solver_options.fast:
|
if self.solver_options.fast:
|
||||||
|
|
|
@ -0,0 +1,70 @@
|
||||||
|
import torch
|
||||||
|
from torch.fx import GraphModule
|
||||||
|
import torch.nn as nn
|
||||||
|
import pytest
|
||||||
|
from colossalai.auto_parallel.solver import sharding_strategy
|
||||||
|
|
||||||
|
from colossalai.fx.proxy import ColoProxy
|
||||||
|
from colossalai.fx.tracer.tracer import ColoTracer
|
||||||
|
from colossalai.tensor.sharding_spec import ShardingSpec, _DimSpec
|
||||||
|
from colossalai.auto_parallel.solver.op_handler.layer_norm_handler import LayerNormHandler
|
||||||
|
from colossalai.auto_parallel.solver.sharding_strategy import ShardingStrategy, StrategiesVector
|
||||||
|
from colossalai.device.device_mesh import DeviceMesh
|
||||||
|
|
||||||
|
|
||||||
|
class LNModel(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self, c):
|
||||||
|
super().__init__()
|
||||||
|
self.ln = nn.LayerNorm(c)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
x = x * 2
|
||||||
|
x = self.ln(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
def test_bn_handler():
|
||||||
|
physical_mesh_id = torch.arange(0, 4)
|
||||||
|
mesh_shape = (2, 2)
|
||||||
|
# [[0, 1]
|
||||||
|
# [2, 3]]
|
||||||
|
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape)
|
||||||
|
entire_shape = torch.Size((4, 4, 128))
|
||||||
|
|
||||||
|
tracer = ColoTracer()
|
||||||
|
model = LNModel(128)
|
||||||
|
input_sample = {'x': torch.rand(4, 4, 128).to('meta')}
|
||||||
|
# graph():
|
||||||
|
# %x : torch.Tensor [#users=1] = placeholder[target=x]
|
||||||
|
# %mul : [#users=1] = call_function[target=operator.mul](args = (%x, 2), kwargs = {})
|
||||||
|
# %ln : [#users=1] = call_module[target=ln](args = (%mul,), kwargs = {})
|
||||||
|
# return ln
|
||||||
|
graph = tracer.trace(root=model, meta_args=input_sample)
|
||||||
|
gm = GraphModule(model, graph, model.__class__.__name__)
|
||||||
|
gm.recompile()
|
||||||
|
# [x, mul, ln, output]
|
||||||
|
nodes = [node for node in gm.graph.nodes]
|
||||||
|
sharding_spec_for_input = ShardingSpec(device_mesh, entire_shape, {})
|
||||||
|
sharding_strategy_for_input = ShardingStrategy('node_1', sharding_spec_for_input)
|
||||||
|
strategies_vector_for_input = StrategiesVector(nodes[1])
|
||||||
|
strategies_vector_for_input.append(sharding_strategy_for_input)
|
||||||
|
setattr(nodes[1], 'strategies_vector', strategies_vector_for_input)
|
||||||
|
|
||||||
|
# generate bn strategy
|
||||||
|
strategies_vector = StrategiesVector(node=nodes[2])
|
||||||
|
ln_handler = LayerNormHandler(
|
||||||
|
node=nodes[2],
|
||||||
|
device_mesh=device_mesh,
|
||||||
|
strategies_vector=strategies_vector,
|
||||||
|
)
|
||||||
|
ln_handler.register_strategy()
|
||||||
|
# ['[S0, R, R] = [S0, R, R] x [R]', '[R, S0, R] = [R, S0, R] x [R]', '[S1, R, R] = [S1, R, R] x [R]', '[R, S1, R] = [R, S1, R] x [R]',
|
||||||
|
# '[S0, S1, R] = [S0, S1, R] x [R]', '[S1, S0, R] = [S1, S0, R] x [R]', '[S01, R, R] = [S01, R, R] x [R]', '[R, S01, R] = [R, S01, R] x [R]', 'RR = RR x R']
|
||||||
|
strategy_name_list = [strategy.name for strategy in ln_handler.strategies_vector]
|
||||||
|
|
||||||
|
assert len(strategy_name_list) == 9
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
test_bn_handler()
|
Loading…
Reference in New Issue