mirror of https://github.com/hpcaitech/ColossalAI
645 lines
38 KiB
Python
645 lines
38 KiB
Python
import operator
|
|
from functools import reduce
|
|
import warnings
|
|
import torch
|
|
from colossalai.auto_parallel.solver.sharding_strategy import ShardingStrategy, StrategiesVector
|
|
from .operator_handler import OperatorHandler
|
|
from .._utils import generate_sharding_spec
|
|
|
|
__all__ = ['ConvHandler']
|
|
|
|
|
|
class ConvHandler(OperatorHandler):
|
|
"""
|
|
A OperatorHandler which deals with the sharding strategies of Convolution.
|
|
"""
|
|
|
|
def __init__(self, *args, **kwargs):
|
|
super().__init__(*args, **kwargs)
|
|
self.input_data = self.predecessor_node[0]._meta_data
|
|
self.weight = self.module_named_parameters['weight']
|
|
self.output_data = self.node._meta_data
|
|
self._sanity_check()
|
|
|
|
def _sanity_check(self):
|
|
'''
|
|
In sanity check, we need make sure the input data having correct dimension size.
|
|
For Conv1d, the dim of input data should be 3([N, C, L]).
|
|
For Conv2d, the dim of input data should be 4([N, C, H, W]).
|
|
For Conv3d, the dim of input data should be 5([N, C, H, W, D]).
|
|
'''
|
|
assert self.input_data.dim() in (3, 4,
|
|
5), f'We suppose the dim of input fed into conv op should in range of [3, 5].'
|
|
|
|
def _generate_compute_cost(self, bs, channel_in, channel_out):
|
|
'''
|
|
Compute the computation cost per device with this specific strategy.
|
|
|
|
Note: compute_cost need to be devided by TFLOPS, now it just shows the computation size.
|
|
|
|
Argument:
|
|
bs(int): Batch size of the input data.
|
|
channel_in(int): The channel dimension of input data.
|
|
channel_out(int): The out channel of the conv weight.
|
|
|
|
Return:
|
|
compute_cost(float): Computation cost per device with this specific strategy
|
|
'''
|
|
# TODO: compute_cost need to be devided by TFLOPS, now it just shows the computation size.
|
|
# 1D: (L) * N * Cout * Cin * kernel
|
|
# 2D: (H * W) * N * Cout * Cin * kernel
|
|
# 3D: (H * W * D) * N * Cout * Cin * kernel
|
|
output_size = self.output_data.shape[2:]
|
|
output_size_product = reduce(operator.mul, output_size, 1)
|
|
input_size = self.input_data.shape[2:]
|
|
input_size_product = reduce(operator.mul, input_size, 1)
|
|
kernel_size = self.weight.shape[2:]
|
|
kernel_size_product = reduce(operator.mul, kernel_size, 1)
|
|
forward_compute_cost = output_size_product * bs * channel_in * channel_out * kernel_size_product
|
|
backward_activation_cost = input_size_product * bs * channel_in * channel_out * kernel_size_product
|
|
backward_weight_cost = output_size_product * bs * channel_in * channel_out * kernel_size_product
|
|
compute_cost = forward_compute_cost + backward_activation_cost + backward_weight_cost
|
|
return compute_cost
|
|
|
|
def _generate_memory_cost(self, sharding_size_forward, sharding_size_backward_activation, sharding_size_weight):
|
|
'''
|
|
Compute the memory cost per device with this specific strategy.
|
|
|
|
Argument:
|
|
sharding_size_forward(int): The forward activation will be divided
|
|
into sharding_size_forward number partions.
|
|
sharding_size_backward_activation(int): The backward activation will
|
|
be divided into sharding_size_backward_activation number partions.
|
|
sharding_size_weight(int): The backward weight will be divided
|
|
into sharding_size_weight number partions.
|
|
|
|
Return:
|
|
memory_cost(Tuple[float]): Memory cost per device with this
|
|
specific strategy, the first element of this tuple is forward
|
|
memory cost, and the second element of this tuple is backward
|
|
memory cost.
|
|
memory_cost_forward(float): Memory cost of forward activation per
|
|
device with this specific strategy.
|
|
memory_cost_backward_activation(float): Memory cost of backward activation
|
|
per device with this specific strategy.
|
|
'''
|
|
# compute the memory cost of this strategy
|
|
dtype = self.input_data.dtype
|
|
numel_output = self.output_data.numel()
|
|
numel_input = self.input_data.numel()
|
|
numel_weight = self.weight.numel()
|
|
size_per_elem_bytes = torch.tensor([], dtype=dtype).element_size()
|
|
|
|
# forward memory_cost
|
|
memory_cost_forward_activation = numel_output * size_per_elem_bytes / sharding_size_forward
|
|
memory_cost_forward_weight = numel_weight * size_per_elem_bytes / sharding_size_weight
|
|
memory_cost_forward = memory_cost_forward_activation + memory_cost_forward_weight
|
|
|
|
# backward memory_cost
|
|
memory_cost_backward_activation = numel_input * size_per_elem_bytes / sharding_size_backward_activation
|
|
memory_cost_backward_weight = numel_weight * size_per_elem_bytes / sharding_size_weight
|
|
memory_cost_backward = memory_cost_backward_activation + memory_cost_backward_weight
|
|
|
|
# memory_cost pair
|
|
memory_cost = (memory_cost_forward, memory_cost_backward)
|
|
|
|
return memory_cost, memory_cost_forward_activation, memory_cost_backward_activation, memory_cost_backward_weight
|
|
|
|
def split_input_batch_weight_out_channel(self, mesh_dim_0, mesh_dim_1):
|
|
name = f'S{mesh_dim_0}S{mesh_dim_1} = S{mesh_dim_0}R x RS{mesh_dim_1}'
|
|
|
|
dim_partition_dict_for_input = {0: [mesh_dim_0]}
|
|
sharding_spec_for_input = generate_sharding_spec(self.input_data, self.device_mesh,
|
|
dim_partition_dict_for_input)
|
|
|
|
dim_partition_dict_for_weight = {1: [mesh_dim_1]}
|
|
sharding_spec_for_weight = generate_sharding_spec(self.weight, self.device_mesh, dim_partition_dict_for_weight)
|
|
|
|
dim_partition_dict_for_output = {0: [mesh_dim_0], 1: [mesh_dim_1]}
|
|
sharding_spec_for_output = generate_sharding_spec(self.output_data, self.device_mesh,
|
|
dim_partition_dict_for_output)
|
|
|
|
# generate resharding cost for this strategy
|
|
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input, sharding_spec_for_weight])
|
|
|
|
# compute the computation cost of this strategy
|
|
bs = self.input_data.shape[0] // self.device_mesh.shape[mesh_dim_0]
|
|
channel_in = self.input_data.shape[1]
|
|
channel_out = self.weight.shape[1] // self.device_mesh.shape[mesh_dim_1]
|
|
compute_cost = self._generate_compute_cost(bs, channel_in, channel_out)
|
|
|
|
# compute the memory cost of this strategy
|
|
sharding_size_forward = self.device_mesh.shape[mesh_dim_0] * self.device_mesh.shape[mesh_dim_1]
|
|
sharding_size_backward_activation = self.device_mesh.shape[mesh_dim_0]
|
|
sharding_size_weight = self.device_mesh.shape[mesh_dim_1]
|
|
memory_cost, _, memory_cost_backward_activation, memory_cost_backward_weight = self._generate_memory_cost(
|
|
sharding_size_forward, sharding_size_backward_activation, sharding_size_weight)
|
|
|
|
# This strategy do not need to do all_reduce operation during forward
|
|
communication_cost_forward = 0
|
|
# compute the backward communication cost to all reduce the input activation grad
|
|
communication_cost_backward_activation = self.device_mesh.all_reduce_cost(memory_cost_backward_activation,
|
|
mesh_dim_1)
|
|
# compute the backward communication cost to all reduce the weight due to data parallel
|
|
communication_cost_backward_weight = self.device_mesh.all_reduce_cost(memory_cost_backward_weight, mesh_dim_0)
|
|
# total communication cost
|
|
communication_cost = communication_cost_forward + communication_cost_backward_activation + communication_cost_backward_weight
|
|
|
|
sharding_strategies = ShardingStrategy(name,
|
|
output_sharding_spec=sharding_spec_for_output,
|
|
compute_cost=compute_cost,
|
|
communication_cost=communication_cost,
|
|
memory_cost=memory_cost,
|
|
resharding_costs=resharding_costs,
|
|
input_shardings=(sharding_spec_for_input, sharding_spec_for_weight))
|
|
self.strategies_vector.append(sharding_strategies)
|
|
|
|
def split_input_batch(self, mesh_dim_0):
|
|
name = f'S{mesh_dim_0}R = S{mesh_dim_0}R x RR'
|
|
|
|
dim_partition_dict_for_input = {0: [mesh_dim_0]}
|
|
sharding_spec_for_input = generate_sharding_spec(self.input_data, self.device_mesh,
|
|
dim_partition_dict_for_input)
|
|
|
|
dim_partition_dict_for_weight = {}
|
|
sharding_spec_for_weight = generate_sharding_spec(self.weight, self.device_mesh, dim_partition_dict_for_weight)
|
|
|
|
dim_partition_dict_for_output = {0: [mesh_dim_0]}
|
|
sharding_spec_for_output = generate_sharding_spec(self.output_data, self.device_mesh,
|
|
dim_partition_dict_for_output)
|
|
|
|
# generate resharding cost for this strategy
|
|
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input, sharding_spec_for_weight])
|
|
|
|
# compute the computation cost of this strategy
|
|
bs = self.input_data.shape[0] // self.device_mesh.shape[mesh_dim_0]
|
|
channel_in = self.input_data.shape[1]
|
|
channel_out = self.weight.shape[1]
|
|
compute_cost = self._generate_compute_cost(bs, channel_in, channel_out)
|
|
|
|
# compute the memory cost of this strategy
|
|
sharding_size_forward = self.device_mesh.shape[mesh_dim_0]
|
|
sharding_size_backward_activation = self.device_mesh.shape[mesh_dim_0]
|
|
sharding_size_weight = 1
|
|
memory_cost, _, _, memory_cost_backward_weight = self._generate_memory_cost(sharding_size_forward,
|
|
sharding_size_backward_activation,
|
|
sharding_size_weight)
|
|
|
|
# This strategy do not need to do all_reduce operation in forward phase.
|
|
communication_cost_forward = 0
|
|
# compute the backward communication cost to all reduce the weight due to data parallel
|
|
communication_cost_backward_weight = self.device_mesh.all_reduce_cost(memory_cost_backward_weight, mesh_dim_0)
|
|
# compute the total cost
|
|
communication_cost = communication_cost_forward + communication_cost_backward_weight
|
|
sharding_strategies = ShardingStrategy(name,
|
|
output_sharding_spec=sharding_spec_for_output,
|
|
compute_cost=compute_cost,
|
|
communication_cost=communication_cost,
|
|
memory_cost=memory_cost,
|
|
resharding_costs=resharding_costs,
|
|
input_shardings=(sharding_spec_for_input, sharding_spec_for_weight))
|
|
|
|
self.strategies_vector.append(sharding_strategies)
|
|
|
|
def split_input_both_dim_weight_in_channel(self, mesh_dim_0, mesh_dim_1):
|
|
name = f'S{mesh_dim_0}R = S{mesh_dim_0}S{mesh_dim_1} x S{mesh_dim_1}R'
|
|
|
|
dim_partition_dict_for_input = {0: [mesh_dim_0], 1: [mesh_dim_1]}
|
|
sharding_spec_for_input = generate_sharding_spec(self.input_data, self.device_mesh,
|
|
dim_partition_dict_for_input)
|
|
|
|
dim_partition_dict_for_weight = {0: [mesh_dim_0]}
|
|
sharding_spec_for_weight = generate_sharding_spec(self.weight, self.device_mesh, dim_partition_dict_for_weight)
|
|
|
|
dim_partition_dict_for_output = {0: [mesh_dim_0]}
|
|
sharding_spec_for_output = generate_sharding_spec(self.output_data, self.device_mesh,
|
|
dim_partition_dict_for_output)
|
|
|
|
# generate resharding cost for this strategy
|
|
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input, sharding_spec_for_weight])
|
|
|
|
# compute the computation cost of this strategy
|
|
bs = self.input_data.shape[0] // self.device_mesh.shape[mesh_dim_0]
|
|
channel_in = self.input_data.shape[1] // self.device_mesh.shape[mesh_dim_1]
|
|
channel_out = self.weight.shape[1]
|
|
compute_cost = self._generate_compute_cost(bs, channel_in, channel_out)
|
|
|
|
# compute the memory cost of this strategy
|
|
sharding_size_forward = self.device_mesh.shape[mesh_dim_0]
|
|
sharding_size_backward_activation = self.device_mesh.shape[mesh_dim_0] * self.device_mesh.shape[mesh_dim_1]
|
|
sharding_size_weight = self.device_mesh.shape[mesh_dim_1]
|
|
memory_cost, memory_cost_forward_activation, _, memory_cost_backward_weight = self._generate_memory_cost(
|
|
sharding_size_forward, sharding_size_backward_activation, sharding_size_weight)
|
|
|
|
# compute the communication cost of this strategy during forward phase
|
|
communication_cost_forward = self.device_mesh.all_reduce_cost(memory_cost_forward_activation, mesh_dim_1)
|
|
# This strategy do not need to do all_reduce operation to compute the input activation grad
|
|
communication_cost_backward_activation = 0
|
|
# compute the backward communication cost to all reduce the weight due to data parallel
|
|
communication_cost_backward_weight = self.device_mesh.all_reduce_cost(memory_cost_backward_weight, mesh_dim_0)
|
|
# compute total cost
|
|
communication_cost = communication_cost_forward + communication_cost_backward_activation + communication_cost_backward_weight
|
|
sharding_strategies = ShardingStrategy(name,
|
|
output_sharding_spec=sharding_spec_for_output,
|
|
compute_cost=compute_cost,
|
|
communication_cost=communication_cost,
|
|
memory_cost=memory_cost,
|
|
resharding_costs=resharding_costs,
|
|
input_shardings=(sharding_spec_for_input, sharding_spec_for_weight))
|
|
self.strategies_vector.append(sharding_strategies)
|
|
|
|
def split_input_in_channel_weight_both_channel(self, mesh_dim_0, mesh_dim_1):
|
|
name = f'RS{mesh_dim_1} = RS{mesh_dim_0} x S{mesh_dim_0}S{mesh_dim_1}'
|
|
|
|
dim_partition_dict_for_input = {1: [mesh_dim_0]}
|
|
sharding_spec_for_input = generate_sharding_spec(self.input_data, self.device_mesh,
|
|
dim_partition_dict_for_input)
|
|
|
|
dim_partition_dict_for_weight = {0: [mesh_dim_0], 1: [mesh_dim_1]}
|
|
sharding_spec_for_weight = generate_sharding_spec(self.weight, self.device_mesh, dim_partition_dict_for_weight)
|
|
|
|
dim_partition_dict_for_output = {1: [mesh_dim_1]}
|
|
sharding_spec_for_output = generate_sharding_spec(self.output_data, self.device_mesh,
|
|
dim_partition_dict_for_output)
|
|
|
|
# generate resharding cost for this strategy
|
|
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input, sharding_spec_for_weight])
|
|
|
|
# compute the computation cost of this strategy
|
|
bs = self.input_data.shape[0]
|
|
channel_in = self.input_data.shape[1] // self.device_mesh.shape[mesh_dim_0]
|
|
channel_out = self.weight.shape[1] // self.device_mesh.shape[mesh_dim_1]
|
|
compute_cost = self._generate_compute_cost(bs, channel_in, channel_out)
|
|
|
|
# compute the memory cost of this strategy
|
|
sharding_size_forward = self.device_mesh.shape[mesh_dim_1]
|
|
sharding_size_backward_activation = self.device_mesh.shape[mesh_dim_0]
|
|
sharding_size_weight = self.device_mesh.shape[mesh_dim_0] * self.device_mesh.shape[mesh_dim_1]
|
|
memory_cost, memory_cost_forward_activation, memory_cost_backward_activation, _ = self._generate_memory_cost(
|
|
sharding_size_forward, sharding_size_backward_activation, sharding_size_weight)
|
|
|
|
# compute the communication cost of this strategy during forward phase
|
|
communication_cost_forward = self.device_mesh.all_reduce_cost(memory_cost_forward_activation, mesh_dim_0)
|
|
# compute the communication cost of this strategy during backward phase
|
|
communication_cost_backward = self.device_mesh.all_reduce_cost(memory_cost_backward_activation, mesh_dim_1)
|
|
communication_cost = communication_cost_forward + communication_cost_backward
|
|
sharding_strategies = ShardingStrategy(name,
|
|
output_sharding_spec=sharding_spec_for_output,
|
|
compute_cost=compute_cost,
|
|
communication_cost=communication_cost,
|
|
memory_cost=memory_cost,
|
|
resharding_costs=resharding_costs,
|
|
input_shardings=(sharding_spec_for_input, sharding_spec_for_weight))
|
|
self.strategies_vector.append(sharding_strategies)
|
|
|
|
def split_input_in_channel_weight_in_channel(self, mesh_dim_0):
|
|
name = f'RR = RS{mesh_dim_0} x S{mesh_dim_0}R'
|
|
|
|
dim_partition_dict_for_input = {1: [mesh_dim_0]}
|
|
sharding_spec_for_input = generate_sharding_spec(self.input_data, self.device_mesh,
|
|
dim_partition_dict_for_input)
|
|
|
|
dim_partition_dict_for_weight = {0: [mesh_dim_0]}
|
|
sharding_spec_for_weight = generate_sharding_spec(self.weight, self.device_mesh, dim_partition_dict_for_weight)
|
|
|
|
dim_partition_dict_for_output = {}
|
|
sharding_spec_for_output = generate_sharding_spec(self.output_data, self.device_mesh,
|
|
dim_partition_dict_for_output)
|
|
|
|
# generate resharding cost for this strategy
|
|
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input, sharding_spec_for_weight])
|
|
|
|
# compute the computation cost of this strategy
|
|
bs = self.input_data.shape[0]
|
|
channel_in = self.input_data.shape[1] // self.device_mesh.shape[mesh_dim_0]
|
|
channel_out = self.weight.shape[1]
|
|
compute_cost = self._generate_compute_cost(bs, channel_in, channel_out)
|
|
|
|
# compute the memory cost of this strategy
|
|
sharding_size_forward = 1
|
|
sharding_size_backward_activation = self.device_mesh.shape[mesh_dim_0]
|
|
sharding_size_weight = self.device_mesh.shape[mesh_dim_0]
|
|
memory_cost, memory_cost_forward_activation, _, _ = self._generate_memory_cost(
|
|
sharding_size_forward, sharding_size_backward_activation, sharding_size_weight)
|
|
|
|
# compute the communication cost of this strategy during forward phase
|
|
communication_cost_forward = self.device_mesh.all_reduce_cost(memory_cost_forward_activation, mesh_dim_0)
|
|
# This strategy do NOT need all_reduce during forward phase
|
|
communication_cost_backward = 0
|
|
communication_cost = communication_cost_forward + communication_cost_backward
|
|
sharding_strategies = ShardingStrategy(name,
|
|
output_sharding_spec=sharding_spec_for_output,
|
|
compute_cost=compute_cost,
|
|
communication_cost=communication_cost,
|
|
memory_cost=memory_cost,
|
|
resharding_costs=resharding_costs,
|
|
input_shardings=(sharding_spec_for_input, sharding_spec_for_weight))
|
|
self.strategies_vector.append(sharding_strategies)
|
|
|
|
def split_weight_out_channel(self, mesh_dim_0):
|
|
name = f'RS{mesh_dim_0} = RR x RS{mesh_dim_0}'
|
|
|
|
dim_partition_dict_for_input = {}
|
|
sharding_spec_for_input = generate_sharding_spec(self.input_data, self.device_mesh,
|
|
dim_partition_dict_for_input)
|
|
|
|
dim_partition_dict_for_weight = {1: [mesh_dim_0]}
|
|
sharding_spec_for_weight = generate_sharding_spec(self.weight, self.device_mesh, dim_partition_dict_for_weight)
|
|
|
|
dim_partition_dict_for_output = {1: [mesh_dim_0]}
|
|
sharding_spec_for_output = generate_sharding_spec(self.output_data, self.device_mesh,
|
|
dim_partition_dict_for_output)
|
|
|
|
# generate resharding cost for this strategy
|
|
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input, sharding_spec_for_weight])
|
|
|
|
# compute the computation cost of this strategy
|
|
bs = self.input_data.shape[0]
|
|
channel_in = self.input_data.shape[1]
|
|
channel_out = self.weight.shape[1] // self.device_mesh.shape[mesh_dim_0]
|
|
compute_cost = self._generate_compute_cost(bs, channel_in, channel_out)
|
|
|
|
# compute the memory cost of this strategy
|
|
sharding_size_forward = self.device_mesh.shape[mesh_dim_0]
|
|
sharding_size_backward_activation = 1
|
|
sharding_size_weight = self.device_mesh.shape[mesh_dim_0]
|
|
memory_cost, _, memory_cost_backward_activation, _ = self._generate_memory_cost(
|
|
sharding_size_forward, sharding_size_backward_activation, sharding_size_weight)
|
|
|
|
# This strategy do not need to do all_reduce during forward phase
|
|
communication_cost_forward = 0
|
|
# compute the communication cost of this strategy during backward phase
|
|
communication_cost_backward = self.device_mesh.all_reduce_cost(memory_cost_backward_activation, mesh_dim_0)
|
|
communication_cost = communication_cost_forward + communication_cost_backward
|
|
sharding_strategies = ShardingStrategy(name,
|
|
output_sharding_spec=sharding_spec_for_output,
|
|
compute_cost=compute_cost,
|
|
communication_cost=communication_cost,
|
|
memory_cost=memory_cost,
|
|
resharding_costs=resharding_costs,
|
|
input_shardings=(sharding_spec_for_input, sharding_spec_for_weight))
|
|
self.strategies_vector.append(sharding_strategies)
|
|
|
|
def non_split(self):
|
|
name = f'RR = RR x RR'
|
|
|
|
dim_partition_dict_for_input = {}
|
|
sharding_spec_for_input = generate_sharding_spec(self.input_data, self.device_mesh,
|
|
dim_partition_dict_for_input)
|
|
|
|
dim_partition_dict_for_weight = {}
|
|
sharding_spec_for_weight = generate_sharding_spec(self.weight, self.device_mesh, dim_partition_dict_for_weight)
|
|
|
|
dim_partition_dict_for_output = {}
|
|
sharding_spec_for_output = generate_sharding_spec(self.output_data, self.device_mesh,
|
|
dim_partition_dict_for_output)
|
|
|
|
# generate resharding cost for this strategy
|
|
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input, sharding_spec_for_weight])
|
|
|
|
# compute the computation cost of this strategy
|
|
bs = self.input_data.shape[0]
|
|
channel_in = self.input_data.shape[1]
|
|
channel_out = self.weight.shape[1]
|
|
compute_cost = self._generate_compute_cost(bs, channel_in, channel_out)
|
|
|
|
# compute the memory cost of this strategy
|
|
sharding_size_forward = 1
|
|
sharding_size_backward_activation = 1
|
|
sharding_size_weight = 1
|
|
memory_cost, _, _, _ = self._generate_memory_cost(sharding_size_forward, sharding_size_backward_activation,
|
|
sharding_size_weight)
|
|
|
|
# This strategy do not need to do all_reduce in both forward and backward phase
|
|
communication_cost = 0
|
|
|
|
sharding_strategies = ShardingStrategy(name,
|
|
output_sharding_spec=sharding_spec_for_output,
|
|
compute_cost=compute_cost,
|
|
communication_cost=communication_cost,
|
|
memory_cost=memory_cost,
|
|
resharding_costs=resharding_costs,
|
|
input_shardings=(sharding_spec_for_input, sharding_spec_for_weight))
|
|
self.strategies_vector.append(sharding_strategies)
|
|
|
|
def split_1d_parallel_on_input_batch(self, mesh_dim_0, mesh_dim_1):
|
|
name = f'S{mesh_dim_0}{mesh_dim_1}R = S{mesh_dim_0}{mesh_dim_1}R x RR'
|
|
|
|
dim_partition_dict_for_input = {0: [mesh_dim_0, mesh_dim_1]}
|
|
sharding_spec_for_input = generate_sharding_spec(self.input_data, self.device_mesh,
|
|
dim_partition_dict_for_input)
|
|
|
|
dim_partition_dict_for_weight = {}
|
|
sharding_spec_for_weight = generate_sharding_spec(self.weight, self.device_mesh, dim_partition_dict_for_weight)
|
|
|
|
dim_partition_dict_for_output = {0: [mesh_dim_0, mesh_dim_1]}
|
|
sharding_spec_for_output = generate_sharding_spec(self.output_data, self.device_mesh,
|
|
dim_partition_dict_for_output)
|
|
|
|
# generate resharding cost for this strategy
|
|
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input, sharding_spec_for_weight])
|
|
|
|
# compute the computation cost of this strategy
|
|
bs = self.input_data.shape[0] // (self.device_mesh.shape[mesh_dim_0] * self.device_mesh.shape[mesh_dim_1])
|
|
channel_in = self.input_data.shape[1]
|
|
channel_out = self.weight.shape[1]
|
|
compute_cost = self._generate_compute_cost(bs, channel_in, channel_out)
|
|
|
|
# compute the memory cost of this strategy
|
|
sharding_size_forward = self.device_mesh.mesh_shape[mesh_dim_0] * self.device_mesh.mesh_shape[mesh_dim_1]
|
|
sharding_size_backward_activation = self.device_mesh.mesh_shape[mesh_dim_0] * self.device_mesh.mesh_shape[
|
|
mesh_dim_1]
|
|
sharding_size_weight = 1
|
|
memory_cost, _, _, memory_cost_backward_weight = self._generate_memory_cost(sharding_size_forward,
|
|
sharding_size_backward_activation,
|
|
sharding_size_weight)
|
|
|
|
# This strategy do not need to do all_reduce in forward phase
|
|
communication_cost_forward = 0
|
|
# compute the backward communication cost to all reduce the weight due to data parallel
|
|
communication_cost_backward_weight = self.device_mesh.flatten_device_mesh.all_reduce_cost(
|
|
memory_cost_backward_weight, 0)
|
|
# compute the total communication cost
|
|
communication_cost = communication_cost_backward_weight + communication_cost_forward
|
|
|
|
sharding_strategies = ShardingStrategy(name,
|
|
output_sharding_spec=sharding_spec_for_output,
|
|
compute_cost=compute_cost,
|
|
communication_cost=communication_cost,
|
|
memory_cost=memory_cost,
|
|
resharding_costs=resharding_costs,
|
|
input_shardings=(sharding_spec_for_input, sharding_spec_for_weight))
|
|
self.strategies_vector.append(sharding_strategies)
|
|
|
|
def split_1d_parallel_on_in_channel(self, mesh_dim_0, mesh_dim_1):
|
|
name = f'RR = RS{mesh_dim_0}{mesh_dim_1} x S{mesh_dim_0}{mesh_dim_1}R'
|
|
|
|
dim_partition_dict_for_input = {1: [mesh_dim_0, mesh_dim_1]}
|
|
sharding_spec_for_input = generate_sharding_spec(self.input_data, self.device_mesh,
|
|
dim_partition_dict_for_input)
|
|
|
|
dim_partition_dict_for_weight = {0: [mesh_dim_0, mesh_dim_1]}
|
|
sharding_spec_for_weight = generate_sharding_spec(self.weight, self.device_mesh, dim_partition_dict_for_weight)
|
|
|
|
dim_partition_dict_for_output = {}
|
|
sharding_spec_for_output = generate_sharding_spec(self.output_data, self.device_mesh,
|
|
dim_partition_dict_for_output)
|
|
|
|
# generate resharding cost for this strategy
|
|
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input, sharding_spec_for_weight])
|
|
|
|
# compute the computation cost of this strategy
|
|
bs = self.input_data.shape[0]
|
|
channel_in = self.input_data.shape[1] // (self.device_mesh.shape[mesh_dim_0] *
|
|
self.device_mesh.shape[mesh_dim_1])
|
|
channel_out = self.weight.shape[1]
|
|
compute_cost = self._generate_compute_cost(bs, channel_in, channel_out)
|
|
|
|
# compute the memory cost of this strategy
|
|
sharding_size_forward = 1
|
|
sharding_size_backward_activation = self.device_mesh.mesh_shape[mesh_dim_0] * self.device_mesh.mesh_shape[
|
|
mesh_dim_1]
|
|
sharding_size_weight = self.device_mesh.mesh_shape[mesh_dim_0] * self.device_mesh.mesh_shape[mesh_dim_1]
|
|
memory_cost, memory_cost_forward_activation, _, _ = self._generate_memory_cost(
|
|
sharding_size_forward, sharding_size_backward_activation, sharding_size_weight)
|
|
|
|
# compute communication cost during forward phase
|
|
communication_cost_forward = self.device_mesh.flatten_device_mesh.all_reduce_cost(
|
|
memory_cost_forward_activation, 0)
|
|
# This strategy do NOT need do all_reduce during backward phase
|
|
communication_cost_backward = 0
|
|
communication_cost = communication_cost_forward + communication_cost_backward
|
|
|
|
sharding_strategies = ShardingStrategy(name,
|
|
output_sharding_spec=sharding_spec_for_output,
|
|
compute_cost=compute_cost,
|
|
communication_cost=communication_cost,
|
|
memory_cost=memory_cost,
|
|
resharding_costs=resharding_costs,
|
|
input_shardings=(sharding_spec_for_input, sharding_spec_for_weight))
|
|
self.strategies_vector.append(sharding_strategies)
|
|
|
|
def register_strategy(self) -> StrategiesVector:
|
|
'''
|
|
Generate every possible strategies for a Conv node, and record all strategies into the strategies_vector.
|
|
|
|
Example:
|
|
physical_mesh_id = torch.arange(0, 4)
|
|
mesh_shape = (2, 2)
|
|
# [[0, 1]
|
|
# [2, 3]]
|
|
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape)
|
|
shape_consistency_manager = ShapeConsistencyManager()
|
|
|
|
model = ConvModel(16, 32)
|
|
input_sample = {'x': torch.rand(4, 16, 64, 64).to('meta')}
|
|
# graph():
|
|
# %x : torch.Tensor [#users=1] = placeholder[target=x]
|
|
# %mul : [#users=1] = call_function[target=operator.mul](args = (%x, 2), kwargs = {})
|
|
# %conv : [#users=1] = call_module[target=conv](args = (%mul,), kwargs = {})
|
|
# return conv
|
|
graph = tracer.trace(root=model, meta_args=input_sample)
|
|
gm = GraphModule(model, graph, model.__class__.__name__)
|
|
gm.recompile()
|
|
# [x, mul, conv, output]
|
|
nodes = [node for node in gm.graph.nodes]
|
|
|
|
# strategies_for_input = [[R, R, R, R], [R, S0, R, R], [R, S1, R, R], [S0, R, R, R], [S0, S1, R, R], [S1, R, R, R], [S1, S0, R, R]]
|
|
strategies_vector_for_input = StrategiesVector(node=nodes[0], in_nodes=[nodes[1], 2], strategies=strategies_for_input)
|
|
setattr(nodes[1], 'strategies_vector', strategies_vector_for_input)
|
|
|
|
strategies_vector = StrategiesVector(node=nodes[2], in_nodes=[nodes[1], ])
|
|
conv_handler = ConvHandler(input_node=nodes[1], input_index=0, weight=dict(gm.named_modules())[nodes[2].name].weight, output_node=nodes[2],
|
|
device_mesh=device_mesh, strategies_vector=strategies_vector, shape_consistency_manager=shape_consistency_manager)
|
|
conv_handler.register_strategy_into_strategies_vector()
|
|
for strategy in conv_handler.strategies_vector:
|
|
print(f'{strategy.name}: compute_cost is {strategy.compute_cost}, communication_cost is {strategy.communication_cost}, memory_cost is {strategy.memory_cost}, resharding_costs is {strategy.resharding_costs}')
|
|
|
|
Output:
|
|
S0S1 = S0R x RS1: compute_cost is 8856576, communication_cost is 0, memory_cost is 492032.0, resharding_costs is {mul: [0, 32769.001, 131074.2, 0, 32769.1, 131074.2, 98307.201]}
|
|
S1S0 = S1R x RS0: compute_cost is 8856576, communication_cost is 0, memory_cost is 492032.0, resharding_costs is {mul: [0, 131074.2, 32769.001, 131074.2, 98307.201, 0, 32769.1]}
|
|
S0R = S0R x RR: compute_cost is 17713152, communication_cost is 0, memory_cost is 984064.0, resharding_costs is {mul: [0, 32769.001, 131074.2, 0, 32769.1, 131074.2, 98307.201]}
|
|
S1R = S1R x RR: compute_cost is 17713152, communication_cost is 0, memory_cost is 984064.0, resharding_costs is {mul: [0, 131074.2, 32769.001, 131074.2, 98307.201, 0, 32769.1]}
|
|
S0R = S0S1 x S1R: compute_cost is 8856576, communication_cost is 984065.01, memory_cost is 984064.0, resharding_costs is {mul: [0, 65538.002, 0, 0, 0, 65538.002, 196614.402]}
|
|
S1R = S1S0 x S0R: compute_cost is 8856576, communication_cost is 984065.01, memory_cost is 984064.0, resharding_costs is {mul: [0, 0, 65538.002, 65538.002, 196614.402, 0, 0]}
|
|
RS1 = RS0 x S0S1: compute_cost is 8856576, communication_cost is 984065.01, memory_cost is 984064.0, resharding_costs is {mul: [0, 0, 131074.2, 32769.001, 98307.201, 131074.2, 32769.1]}
|
|
RS0 = RS1 x S1S0: compute_cost is 8856576, communication_cost is 984065.01, memory_cost is 984064.0, resharding_costs is {mul: [0, 131074.2, 0, 131074.2, 32769.1, 32769.001, 98307.201]}
|
|
RR = RS0 x S0R: compute_cost is 17713152, communication_cost is 1968129.01, memory_cost is 1968128, resharding_costs is {mul: [0, 0, 131074.2, 32769.001, 98307.201, 131074.2, 32769.1]}
|
|
RR = RS1 x S1R: compute_cost is 17713152, communication_cost is 1968129.01, memory_cost is 1968128, resharding_costs is {mul: [0, 131074.2, 0, 131074.2, 32769.1, 32769.001, 98307.201]}
|
|
RS0 = RR x RS0: compute_cost is 17713152, communication_cost is 0, memory_cost is 984064.0, resharding_costs is {mul: [0, 65537.1, 65537.1, 65537.1, 131075.30000000002, 65537.1, 131075.30000000002]}
|
|
RS1 = RR x RS1: compute_cost is 17713152, communication_cost is 0, memory_cost is 984064.0, resharding_costs is {mul: [0, 65537.1, 65537.1, 65537.1, 131075.30000000002, 65537.1, 131075.30000000002]}
|
|
RR = RR x RR: compute_cost is 35426304, communication_cost is 0, memory_cost is 1968128, resharding_costs is {mul: [0, 65537.1, 65537.1, 65537.1, 131075.30000000002, 65537.1, 131075.30000000002]}
|
|
S01R = S01R x RR: compute_cost is 8856576, communication_cost is 0, memory_cost is 492032.0, resharding_costs is {mul: [0, 65538.002, 262148.4, 0, 16385.001, 262148.4, 196614.402]}
|
|
RR = RS01 x S01R: compute_cost is 8856576, communication_cost is 0, memory_cost is 1968128, resharding_costs is {mul: [0, 0, 262148.4, 65538.002, 196614.402, 262148.4, 65538.2]}
|
|
'''
|
|
# SS = SR x RS
|
|
try:
|
|
self.split_input_batch_weight_out_channel(0, 1)
|
|
except Exception as e:
|
|
warnings.warn(f'{e}')
|
|
try:
|
|
self.split_input_batch_weight_out_channel(1, 0)
|
|
except Exception as e:
|
|
warnings.warn(f'{e}')
|
|
|
|
# SR = SR x RR
|
|
self.split_input_batch(0)
|
|
self.split_input_batch(1)
|
|
|
|
# SR = SS x SR
|
|
try:
|
|
self.split_input_both_dim_weight_in_channel(0, 1)
|
|
except Exception as e:
|
|
warnings.warn(f'{e}')
|
|
try:
|
|
self.split_input_both_dim_weight_in_channel(1, 0)
|
|
except Exception as e:
|
|
warnings.warn(f'{e}')
|
|
|
|
# RS = RS x SS
|
|
try:
|
|
self.split_input_in_channel_weight_both_channel(0, 1)
|
|
except Exception as e:
|
|
warnings.warn(f'{e}')
|
|
try:
|
|
self.split_input_in_channel_weight_both_channel(1, 0)
|
|
except Exception as e:
|
|
warnings.warn(f'{e}')
|
|
|
|
# RR = RS x SR
|
|
try:
|
|
self.split_input_in_channel_weight_in_channel(0)
|
|
except Exception as e:
|
|
warnings.warn(f'{e}')
|
|
try:
|
|
self.split_input_in_channel_weight_in_channel(1)
|
|
except Exception as e:
|
|
warnings.warn(f'{e}')
|
|
|
|
# RS = RR x RS
|
|
self.split_weight_out_channel(0)
|
|
self.split_weight_out_channel(1)
|
|
|
|
# RR= RR x RR
|
|
self.non_split()
|
|
|
|
# S01R = S01R x RR
|
|
self.split_1d_parallel_on_input_batch(0, 1)
|
|
|
|
# RR = RS01 x S01R
|
|
try:
|
|
self.split_1d_parallel_on_in_channel(0, 1)
|
|
except Exception as e:
|
|
warnings.warn(f'{e}')
|
|
|
|
# print(f'strategies num is :{len(self.strategies_vector)}')
|
|
|
|
return self.strategies_vector
|
|
|
|
|
|
CONV_STRATEGIES_LIST = [
|
|
'S0S1 = S0R x RS1', 'S1S0 = S1R x RS0', 'S0R = S0R x RR', 'S1R = S1R x RR', 'S0R = S0S1 x S1R', 'S1R = S1S0 x S0R',
|
|
'RS1 = RS0 x S0S1', 'RS0 = RS1 x S1S0', 'RR = RS0 x S0R', 'RR = RS1 x S1R', 'RS0 = RR x RS0', 'RS1 = RR x RS1',
|
|
'RR = RR x RR', 'S01R = S01R x RR', 'RR = RS01 x S01R'
|
|
]
|