mirror of https://github.com/hpcaitech/ColossalAI
[autoparallel] add embedding handler (#1620)
parent
69448f64c4
commit
3a46215135
|
@ -0,0 +1,176 @@
|
|||
import operator
|
||||
from functools import reduce
|
||||
import warnings
|
||||
import torch
|
||||
from colossalai.auto_parallel.solver.sharding_strategy import ShardingStrategy, StrategiesVector
|
||||
from .operator_handler import OperatorHandler
|
||||
from colossalai.tensor.shape_consistency import ShapeConsistencyManager
|
||||
from colossalai.tensor.sharding_spec import ShardingSpec
|
||||
from copy import deepcopy
|
||||
from typing import Dict, List
|
||||
from colossalai.auto_parallel.solver._utils import exception_handler
|
||||
|
||||
__all__ = ['EmbeddingHandler']
|
||||
|
||||
|
||||
class EmbeddingHandler(OperatorHandler):
|
||||
"""
|
||||
An OperatorHandler which deals with the sharding strategies of Embedding operators(such as nn.embedding).
|
||||
"""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.input_data = self.predecessor_node[0]._meta_data
|
||||
self.weight = self.module_named_parameters['weight']
|
||||
self.output_data = self.node._meta_data
|
||||
|
||||
def _generate_compute_cost(self, total_sharding_size):
|
||||
input_shape = self.input_data.shape
|
||||
weight_shape = self.weight.shape
|
||||
input_shape_product = reduce(operator.mul, input_shape, 1)
|
||||
weight_shape_product = reduce(operator.mul, weight_shape, 1)
|
||||
compute_cost = input_shape_product * weight_shape_product * 2 / total_sharding_size
|
||||
return compute_cost
|
||||
|
||||
def _generate_memory_cost(self, sharding_size_forward, sharding_size_backward_activation, sharding_size_weight):
|
||||
'''
|
||||
Compute the memory cost per device with this specific strategy.
|
||||
|
||||
Argument:
|
||||
sharding_size_forward(int): The forward activation will be divided
|
||||
into sharding_size_forward number partions.
|
||||
sharding_size_backward_activation(int): The backward activation will
|
||||
be divided into sharding_size_backward_activation number partions.
|
||||
sharding_size_weight(int): The backward weight will be divided
|
||||
into sharding_size_weight number partions.
|
||||
|
||||
Return:
|
||||
memory_cost(Tuple[float]): Memory cost per device with this
|
||||
specific strategy, the first element of this tuple is forward
|
||||
memory cost, and the second element of this tuple is backward
|
||||
memory cost.
|
||||
memory_cost_forward(float): Memory cost of forward activation per
|
||||
device with this specific strategy.
|
||||
memory_cost_backward_activation(float): Memory cost of backward activation
|
||||
per device with this specific strategy.
|
||||
'''
|
||||
# compute the memory cost of this strategy
|
||||
dtype = self.input_data.dtype
|
||||
numel_output = self.output_data.numel()
|
||||
numel_input = self.input_data.numel()
|
||||
numel_weight = self.weight.numel()
|
||||
size_per_elem_bytes = torch.tensor([], dtype=dtype).element_size()
|
||||
|
||||
# forward memory_cost
|
||||
memory_cost_forward_activation = numel_output * size_per_elem_bytes / sharding_size_forward
|
||||
memory_cost_forward_weight = numel_weight * size_per_elem_bytes / sharding_size_weight
|
||||
memory_cost_forward = memory_cost_forward_activation + memory_cost_forward_weight
|
||||
|
||||
# backward memory_cost
|
||||
memory_cost_backward_activation = numel_input * size_per_elem_bytes / sharding_size_backward_activation
|
||||
memory_cost_backward_weight = numel_weight * size_per_elem_bytes / sharding_size_weight
|
||||
memory_cost_backward = memory_cost_backward_activation + memory_cost_backward_weight
|
||||
|
||||
# memory_cost pair
|
||||
memory_cost = (memory_cost_forward, memory_cost_backward)
|
||||
|
||||
return memory_cost, memory_cost_forward_activation, memory_cost_backward_activation, memory_cost_backward_weight
|
||||
|
||||
@exception_handler
|
||||
def split_weight_both_dim(self, mesh_dim_0, mesh_dim_1):
|
||||
name = f'RRS{mesh_dim_1} = RR x S{mesh_dim_0}S{mesh_dim_1}'
|
||||
|
||||
dim_partition_dict_for_input = {}
|
||||
sharding_spec_for_input = self._generate_sharding_spec(self.input_data, dim_partition_dict_for_input)
|
||||
|
||||
dim_partition_dict_for_weight = {0: [mesh_dim_0], 1: [mesh_dim_1]}
|
||||
sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight)
|
||||
|
||||
dim_partition_dict_for_output = {2: [mesh_dim_1]}
|
||||
sharding_spec_for_output = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output)
|
||||
|
||||
# generate resharding cost for this strategy
|
||||
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input, sharding_spec_for_weight])
|
||||
|
||||
# compute the computation cost of this strategy
|
||||
total_sharding_size = self.device_mesh.shape[0] * self.device_mesh.shape[1]
|
||||
compute_cost = self._generate_compute_cost(total_sharding_size)
|
||||
|
||||
# compute the memory cost of this strategy
|
||||
sharding_size_forward = self.device_mesh.shape[mesh_dim_1]
|
||||
sharding_size_backward_activation = 1
|
||||
sharding_size_weight = self.device_mesh.shape[mesh_dim_0] * self.device_mesh.shape[mesh_dim_1]
|
||||
memory_cost, memory_cost_forward_activation, memory_cost_backward_activation, _ = self._generate_memory_cost(
|
||||
sharding_size_forward, sharding_size_backward_activation, sharding_size_weight)
|
||||
|
||||
# compute the communication cost of this strategy during forward phase
|
||||
communication_cost_forward = self.device_mesh.all_reduce_cost(memory_cost_forward_activation, mesh_dim_0)
|
||||
# compute the communication cost of this strategy during backward phase
|
||||
communication_cost_backward = self.device_mesh.all_reduce_cost(memory_cost_backward_activation, mesh_dim_1)
|
||||
communication_cost = communication_cost_forward + communication_cost_backward
|
||||
sharding_strategies = ShardingStrategy(name,
|
||||
output_sharding_spec=sharding_spec_for_output,
|
||||
compute_cost=compute_cost,
|
||||
communication_cost=communication_cost,
|
||||
memory_cost=memory_cost,
|
||||
resharding_costs=resharding_costs,
|
||||
input_shardings=(sharding_spec_for_input, sharding_spec_for_weight))
|
||||
self.strategies_vector.append(sharding_strategies)
|
||||
|
||||
@exception_handler
|
||||
def split_input_both_dim(self, mesh_dim_0, mesh_dim_1):
|
||||
name = f'S{mesh_dim_0}S{mesh_dim_1}R = S{mesh_dim_0}S{mesh_dim_1} x RR'
|
||||
|
||||
dim_partition_dict_for_input = {0: [mesh_dim_0], 1: [mesh_dim_1]}
|
||||
sharding_spec_for_input = self._generate_sharding_spec(self.input_data, dim_partition_dict_for_input)
|
||||
|
||||
dim_partition_dict_for_weight = {}
|
||||
sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight)
|
||||
|
||||
dim_partition_dict_for_output = {0: [mesh_dim_0], 1: [mesh_dim_1]}
|
||||
sharding_spec_for_output = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output)
|
||||
|
||||
# generate resharding cost for this strategy
|
||||
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input, sharding_spec_for_weight])
|
||||
|
||||
# compute the computation cost of this strategy
|
||||
total_sharding_size = self.device_mesh.shape[0] * self.device_mesh.shape[1]
|
||||
compute_cost = self._generate_compute_cost(total_sharding_size)
|
||||
|
||||
# compute the memory cost of this strategy
|
||||
sharding_size_forward = self.device_mesh.shape[mesh_dim_0] * self.device_mesh.shape[mesh_dim_1]
|
||||
sharding_size_backward_activation = self.device_mesh.shape[mesh_dim_0] * self.device_mesh.shape[mesh_dim_1]
|
||||
sharding_size_weight = 1
|
||||
memory_cost, memory_cost_forward_activation, memory_cost_backward_activation, memory_cost_backward_weight = self._generate_memory_cost(
|
||||
sharding_size_forward, sharding_size_backward_activation, sharding_size_weight)
|
||||
|
||||
# This strategy do not need to do all_reduce during forward phase
|
||||
communication_cost_forward = 0
|
||||
# compute the communication cost of this strategy during backward phase
|
||||
communication_cost_backward_activation = 0
|
||||
communication_cost_backward_weight = self.device_mesh.flatten_device_mesh.all_reduce_cost(
|
||||
memory_cost_backward_weight, 0)
|
||||
communication_cost_backward = communication_cost_backward_activation + communication_cost_backward_weight
|
||||
communication_cost = communication_cost_forward + communication_cost_backward
|
||||
sharding_strategies = ShardingStrategy(name,
|
||||
output_sharding_spec=sharding_spec_for_output,
|
||||
compute_cost=compute_cost,
|
||||
communication_cost=communication_cost,
|
||||
memory_cost=memory_cost,
|
||||
resharding_costs=resharding_costs,
|
||||
input_shardings=(sharding_spec_for_input, sharding_spec_for_weight))
|
||||
self.strategies_vector.append(sharding_strategies)
|
||||
|
||||
def register_strategy(self) -> StrategiesVector:
|
||||
'''
|
||||
Generate every possible strategies for a Conv node, and record all strategies into the strategies_vector.
|
||||
'''
|
||||
# RRS = RR x SS
|
||||
self.split_weight_both_dim(0, 1)
|
||||
self.split_weight_both_dim(1, 0)
|
||||
|
||||
# SSR = SS x RR
|
||||
self.split_input_both_dim(0, 1)
|
||||
self.split_input_both_dim(1, 0)
|
||||
|
||||
return self.strategies_vector
|
Loading…
Reference in New Issue