mirror of https://github.com/hpcaitech/ColossalAI
[autoparallel] adapt solver with resnet (#1583)
* [autoparallel]adapt solver with resnet * polish code * polish codepull/1588/head
parent
f3403ff98e
commit
82d4376c23
|
@ -1,7 +1,8 @@
|
|||
from .operator_handler import OperatorHandler
|
||||
from .dot_handler import DotHandler
|
||||
from .conv_handler import ConvHandler
|
||||
from .sharding_strategy import ShardingStrategy, StrategiesVector
|
||||
from .graph_analysis import GraphAnalyser
|
||||
from .solver import Solver
|
||||
from .cost_graph import CostGraph
|
||||
from .strategies_constructor import StrategiesConstructor
|
||||
from .constants import *
|
||||
|
||||
__all__ = ['OperatorHandler', 'DotHandler', 'ConvHandler', 'StrategiesVector', 'ShardingStrategy', 'GraphAnalyser']
|
||||
__all__ = ['StrategiesVector', 'ShardingStrategy', 'GraphAnalyser', 'Solver', 'StrategiesConstructor', 'CostGraph']
|
||||
|
|
|
@ -3,13 +3,14 @@ import operator
|
|||
|
||||
__all__ = [
|
||||
'ELEMENTWISE_MODULE_OP', 'ELEMENTWISE_FUNC_OP', 'CONV_MODULE_OP', 'CONV_FUNC_OP', 'LINEAR_MODULE_OP',
|
||||
'LINEAR_FUNC_OP'
|
||||
'LINEAR_FUNC_OP', 'BATCHNORM_MODULE_OP', 'POOL_MODULE_OP'
|
||||
]
|
||||
|
||||
ELEMENTWISE_MODULE_OP = [torch.nn.Dropout, torch.nn.ReLU]
|
||||
# TODO: flatten should not be added into this group
|
||||
ELEMENTWISE_FUNC_OP = [
|
||||
torch.add, operator.add, torch.abs, torch.cos, torch.exp, torch.mul, operator.mul, operator.floordiv,
|
||||
operator.truediv, operator.neg, torch.multiply, torch.nn.functional.relu, torch.nn.functional.dropout
|
||||
operator.truediv, operator.neg, torch.multiply, torch.nn.functional.relu, torch.nn.functional.dropout, torch.flatten
|
||||
]
|
||||
CONV_MODULE_OP = [
|
||||
torch.nn.Conv1d, torch.nn.Conv2d, torch.nn.Conv3d, torch.nn.ConvTranspose1d, torch.nn.ConvTranspose2d,
|
||||
|
@ -20,3 +21,7 @@ CONV_FUNC_OP = [
|
|||
]
|
||||
LINEAR_MODULE_OP = [torch.nn.Linear]
|
||||
LINEAR_FUNC_OP = [torch.nn.functional.linear, torch.matmul, torch.bmm]
|
||||
BATCHNORM_MODULE_OP = [torch.nn.BatchNorm1d, torch.nn.BatchNorm2d, torch.nn.BatchNorm3d, torch.nn.SyncBatchNorm]
|
||||
POOL_MODULE_OP = [torch.nn.MaxPool1d, torch.nn.MaxPool2d, torch.nn.MaxPool3d, torch.nn.AdaptiveAvgPool2d]
|
||||
|
||||
INFINITY_COST = 1e13
|
||||
|
|
|
@ -1,4 +1,3 @@
|
|||
from colossalai.auto_parallel.solver.sharding_strategy import ShardingStrategy, StrategiesVector
|
||||
from typing import List
|
||||
import math
|
||||
from torch.fx.node import Node
|
||||
|
|
|
@ -15,7 +15,7 @@ class LiveVariable:
|
|||
LiveVariable is a data structure to store the meta information of a variable for liveness analysis.
|
||||
"""
|
||||
name: str
|
||||
meta: Union[Any, List[Any]]
|
||||
node: Node
|
||||
is_inplace: bool
|
||||
|
||||
|
||||
|
@ -80,13 +80,13 @@ class GraphAnalyser:
|
|||
"""
|
||||
return self._graph
|
||||
|
||||
def liveness_analysis(self) -> OrderedDict[int, LiveStage]:
|
||||
def liveness_analysis(self) -> List[LiveStage]:
|
||||
"""
|
||||
Analyse the graph to obtain the variable liveness information. This function returns
|
||||
an ordered dictionary where the key is the compute stage ID and the value is a LivenessStage object.
|
||||
"""
|
||||
compute_nodes = self.graph.nodes
|
||||
liveness_dict = ODict()
|
||||
liveness_list = []
|
||||
|
||||
# checked: record all variables created since the first stage
|
||||
# all: record the live variables only exist until the current stage.
|
||||
|
@ -97,25 +97,6 @@ class GraphAnalyser:
|
|||
all_live_variables = LiveVariableVector()
|
||||
unique_live_vars = LiveVariableVector()
|
||||
|
||||
def _add_param_or_buf(node, tensor_type):
|
||||
module = get_node_module(node)
|
||||
|
||||
if tensor_type == 'param':
|
||||
iterator = module.named_parameters()
|
||||
elif tensor_type == 'buffer':
|
||||
iterator = module.named_buffers()
|
||||
else:
|
||||
raise ValueError(f"Expected tensor_type to be param or buffer, but got {tensor_type}")
|
||||
|
||||
for name, tensor in iterator:
|
||||
tensor_name = f'{node.name}.{name}'
|
||||
|
||||
if not checked_variables.exists(tensor_name):
|
||||
live_tensor = LiveVariable(name=tensor_name, meta=tensor.to('meta'), is_inplace=False)
|
||||
unique_live_vars.append(live_tensor)
|
||||
checked_variables.append(live_tensor)
|
||||
all_live_variables.append(live_tensor)
|
||||
|
||||
for idx, node in enumerate(compute_nodes):
|
||||
#############################
|
||||
# find new living variables #
|
||||
|
@ -135,26 +116,19 @@ class GraphAnalyser:
|
|||
|
||||
# add the output var
|
||||
meta = getattr(node, '_meta_data', None)
|
||||
live_var = LiveVariable(name=node.name, meta=meta, is_inplace=is_inplace)
|
||||
live_var = LiveVariable(name=node.name, node=node, is_inplace=is_inplace)
|
||||
if not is_inplace:
|
||||
unique_live_vars.append(live_var)
|
||||
checked_variables.append(live_var)
|
||||
all_live_variables.append(live_var)
|
||||
|
||||
# add the model parameters
|
||||
if node.op == 'call_module':
|
||||
_add_param_or_buf(node, tensor_type='param')
|
||||
_add_param_or_buf(node, tensor_type='buffer')
|
||||
|
||||
# add this output variable to the checked list
|
||||
checked_variables.append(live_var)
|
||||
|
||||
# check if any input is not checked yet
|
||||
for arg in node.args:
|
||||
arg_name = str(arg)
|
||||
if not isinstance(arg, Node):
|
||||
continue
|
||||
arg_name = arg.name
|
||||
if not checked_variables.exists(arg_name):
|
||||
meta = getattr(node, '_meta_data', None)
|
||||
live_var_from_arg = LiveVariable(name=arg_name, meta=meta, is_inplace=False)
|
||||
live_var_from_arg = LiveVariable(name=arg_name, node=node, is_inplace=False)
|
||||
all_live_variables.append(live_var_from_arg)
|
||||
checked_variables.append(live_var_from_arg)
|
||||
unique_live_vars.append(live_var_from_arg)
|
||||
|
@ -167,8 +141,23 @@ class GraphAnalyser:
|
|||
node=node,
|
||||
all_live_vars=all_live_variables.copy(),
|
||||
unique_live_vars=unique_live_vars.copy())
|
||||
liveness_dict[idx] = stage
|
||||
return liveness_dict
|
||||
# if a LiveStage is covered by another LiveStage, we just keep the larger one.
|
||||
replace = False
|
||||
for index, prev_stage in enumerate(liveness_list):
|
||||
all_covered = True
|
||||
for ele in prev_stage.unique_live_vars:
|
||||
if ele not in stage.unique_live_vars:
|
||||
all_covered = False
|
||||
break
|
||||
if all_covered:
|
||||
replace = True
|
||||
break
|
||||
if replace:
|
||||
liveness_list[index] = stage
|
||||
else:
|
||||
liveness_list.append(stage)
|
||||
|
||||
return liveness_list
|
||||
|
||||
def get_alias_set(self):
|
||||
pass
|
||||
|
|
|
@ -0,0 +1,6 @@
|
|||
from .operator_handler import OperatorHandler
|
||||
from .dot_handler import DotHandler
|
||||
from .conv_handler import ConvHandler
|
||||
from .batch_norm_handler import BatchNormHandler
|
||||
|
||||
__all__ = ['OperatorHandler', 'DotHandler', 'ConvHandler', 'BatchNormHandler']
|
|
@ -0,0 +1,483 @@
|
|||
import operator
|
||||
from functools import reduce
|
||||
import warnings
|
||||
import torch
|
||||
from colossalai.auto_parallel.solver.sharding_strategy import ShardingStrategy, StrategiesVector
|
||||
from .operator_handler import OperatorHandler
|
||||
|
||||
__all__ = ['BatchNormHandler']
|
||||
|
||||
|
||||
class BatchNormHandler(OperatorHandler):
|
||||
"""
|
||||
A OperatorHandler which deals with the sharding strategies of normalization.
|
||||
|
||||
To keep the math consistency, there are two way to do BatchNorm if the input
|
||||
shards on batch dimension:
|
||||
1. We gather the input partitions through batch dimension, then do the normal BatchNorm.
|
||||
2. We do the SyncBatchNorm on the each input partition seperately, the SyncBN op will help
|
||||
us to keep the computing correctness.
|
||||
In this handler, both methods will be considered.
|
||||
"""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.input_data = self.predecessor_node[0]._meta_data
|
||||
self.weight = self.module_named_parameters['weight']
|
||||
self.bias = self.module_named_parameters['bias']
|
||||
self.output_data = self.node._meta_data
|
||||
self._sanity_check()
|
||||
|
||||
def _sanity_check(self):
|
||||
'''
|
||||
In sanity check, we need make sure the input data having correct dimension size.
|
||||
For BatchNorm1d, the dim of input data should be 3([N, C, L]).
|
||||
For BatchNorm2d, the dim of input data should be 4([N, C, H, W]).
|
||||
For BatchNorm3d, the dim of input data should be 5([N, C, H, W, D]).
|
||||
'''
|
||||
assert self.input_data.dim() in (3, 4,
|
||||
5), f'We suppose the dim of input fed into conv op should in range of [3, 5].'
|
||||
|
||||
def _generate_compute_cost(self, bs, channel_in):
|
||||
'''
|
||||
Compute the computation cost per device with this specific strategy.
|
||||
|
||||
Note: compute_cost need to be devided by TFLOPS, now it just shows the computation size.
|
||||
|
||||
Argument:
|
||||
bs(int): Batch size of the input data.
|
||||
channel_in(int): The channel dimension of input data.
|
||||
|
||||
Return:
|
||||
compute_cost(float): Computation cost per device with this specific strategy
|
||||
'''
|
||||
# TODO: compute_cost need to be devided by TFLOPS, now it just shows the computation size.
|
||||
# TODO: a constant coefficient need to be added.
|
||||
# 1D: (L) * N * Cin
|
||||
# 2D: (H * W) * N * Cin
|
||||
# 3D: (H * W * D) * N * Cin
|
||||
|
||||
input_size = self.input_data.shape[2:]
|
||||
input_size_product = reduce(operator.mul, input_size, 1)
|
||||
forward_compute_cost = input_size_product * bs * channel_in
|
||||
backward_activation_compute_cost = input_size_product * bs * channel_in
|
||||
backward_weight_compute_cost = input_size_product * bs * channel_in
|
||||
backward_compute_cost = backward_activation_compute_cost + backward_weight_compute_cost
|
||||
compute_cost = forward_compute_cost + backward_compute_cost
|
||||
return compute_cost
|
||||
|
||||
def _generate_memory_cost(self, sharding_size_forward, sharding_size_backward_activation, sharding_size_weight):
|
||||
'''
|
||||
Compute the memory cost per device with this specific strategy.
|
||||
|
||||
Argument:
|
||||
sharding_size_forward(int): The forward activation will be divided
|
||||
into sharding_size_forward number partions.
|
||||
sharding_size_backward_activation(int): The backward activation will
|
||||
be divided into sharding_size_backward_activation number partions.
|
||||
sharding_size_weight(int): The backward weight will be divided
|
||||
into sharding_size_weight number partions.
|
||||
|
||||
Return:
|
||||
memory_cost(Tuple[float]): Memory cost per device with this
|
||||
specific strategy, the first element of this tuple is forward
|
||||
memory cost, and the second element of this tuple is backward
|
||||
memory cost.
|
||||
memory_cost_forward(float): Memory cost of forward activation per
|
||||
device with this specific strategy.
|
||||
memory_cost_backward_activation(float): Memory cost of backward activation
|
||||
per device with this specific strategy.
|
||||
'''
|
||||
# compute the memory cost of this strategy
|
||||
dtype = self.input_data.dtype
|
||||
numel_output = self.output_data.numel()
|
||||
numel_input = numel_output
|
||||
numel_weight = self.weight.numel()
|
||||
size_per_elem_bytes = torch.tensor([], dtype=dtype).element_size()
|
||||
|
||||
# forward memory_cost
|
||||
memory_cost_forward_activation = numel_output * size_per_elem_bytes / sharding_size_forward
|
||||
memory_cost_forward_weight = numel_weight * size_per_elem_bytes / sharding_size_weight
|
||||
memory_cost_forward = memory_cost_forward_activation + memory_cost_forward_weight
|
||||
|
||||
# backward memory_cost
|
||||
memory_cost_backward_activation = numel_input * size_per_elem_bytes / sharding_size_backward_activation
|
||||
memory_cost_backward_weight = numel_weight * size_per_elem_bytes / sharding_size_weight
|
||||
memory_cost_backward = memory_cost_backward_activation + memory_cost_backward_weight
|
||||
|
||||
# memory_cost pair
|
||||
memory_cost = (memory_cost_forward, memory_cost_backward)
|
||||
|
||||
return memory_cost, memory_cost_forward_activation, memory_cost_backward_activation
|
||||
|
||||
def split_input_channel(self, mesh_dim_0, mesh_dim_1):
|
||||
name = f'RS{mesh_dim_0} = RS{mesh_dim_0} x S{mesh_dim_0}'
|
||||
|
||||
dim_partition_dict_for_input = {1: [mesh_dim_0]}
|
||||
sharding_spec_for_input = self._generate_sharding_spec(self.input_data, dim_partition_dict_for_input)
|
||||
|
||||
dim_partition_dict_for_weight = {0: [mesh_dim_0]}
|
||||
sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight)
|
||||
|
||||
dim_partition_dict_for_output = {1: [mesh_dim_0]}
|
||||
sharding_spec_for_output = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output)
|
||||
|
||||
# generate resharding cost for this strategy
|
||||
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input])
|
||||
|
||||
# compute the computation cost of this strategy
|
||||
bs = self.input_data.shape[0]
|
||||
channel_in = self.input_data.shape[1] // self.device_mesh.shape[mesh_dim_0]
|
||||
compute_cost = self._generate_compute_cost(bs, channel_in)
|
||||
|
||||
# compute the memory cost of this strategy
|
||||
sharding_size_forward = self.device_mesh.shape[mesh_dim_0]
|
||||
sharding_size_backward_activation = self.device_mesh.shape[mesh_dim_0]
|
||||
sharding_size_weight = self.device_mesh.shape[mesh_dim_0]
|
||||
memory_cost, _, _ = self._generate_memory_cost(sharding_size_forward, sharding_size_backward_activation,
|
||||
sharding_size_weight)
|
||||
|
||||
# This strategy do not need to do all_reduce operation
|
||||
communication_cost = 0
|
||||
sharding_strategies = ShardingStrategy(name,
|
||||
output_sharding_spec=sharding_spec_for_output,
|
||||
compute_cost=compute_cost,
|
||||
communication_cost=communication_cost,
|
||||
memory_cost=memory_cost,
|
||||
resharding_costs=resharding_costs,
|
||||
input_shardings=(sharding_spec_for_input, sharding_spec_for_weight))
|
||||
|
||||
self.strategies_vector.append(sharding_strategies)
|
||||
|
||||
# shard the output batch dimension to get all possible sharding strategy from this basic strategy
|
||||
new_name = f'S{mesh_dim_1}S{mesh_dim_0} = RS{mesh_dim_0} x S{mesh_dim_0}'
|
||||
|
||||
dim_partition_dict_for_output = {0: [mesh_dim_1], 1: [mesh_dim_0]}
|
||||
new_sharding_spec_for_output = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output)
|
||||
# the computation cost is all the same
|
||||
new_compute_cost = compute_cost
|
||||
|
||||
# the memory cost need to be recomputed
|
||||
# compute the memroy cost of new strategy
|
||||
new_sharding_size_forward = self.device_mesh.shape[mesh_dim_0] * self.device_mesh.shape[mesh_dim_1]
|
||||
sharding_size_backward_activation = self.device_mesh.shape[mesh_dim_0]
|
||||
sharding_size_weight = self.device_mesh.shape[mesh_dim_0]
|
||||
new_memory_cost, _, memory_cost_backward_activation = self._generate_memory_cost(
|
||||
new_sharding_size_forward, sharding_size_backward_activation, sharding_size_weight)
|
||||
|
||||
# the communication cost need to count the sharding cost into this strategy
|
||||
# compute the communication cost of new strategy
|
||||
origin_communication_cost = communication_cost
|
||||
tiny_shard_cost = 10
|
||||
new_forward_communication_cost = tiny_shard_cost
|
||||
# we need to all gather the batch dimension for the basic strategy
|
||||
new_backward_communication_cost = self.device_mesh.all_gather_cost(memory_cost_backward_activation, mesh_dim_1)
|
||||
new_communication_cost = origin_communication_cost + new_forward_communication_cost + new_backward_communication_cost
|
||||
|
||||
sharding_strategies = ShardingStrategy(new_name,
|
||||
output_sharding_spec=new_sharding_spec_for_output,
|
||||
compute_cost=new_compute_cost,
|
||||
communication_cost=new_communication_cost,
|
||||
memory_cost=new_memory_cost,
|
||||
resharding_costs=resharding_costs,
|
||||
input_shardings=(sharding_spec_for_input, sharding_spec_for_weight))
|
||||
|
||||
self.strategies_vector.append(sharding_strategies)
|
||||
|
||||
def split_input_channel_1d(self, mesh_dim_0, mesh_dim_1):
|
||||
name = f'RS{mesh_dim_0}{mesh_dim_1} = RS{mesh_dim_0}{mesh_dim_1} x S{mesh_dim_0}{mesh_dim_1}'
|
||||
|
||||
dim_partition_dict_for_input = {1: [mesh_dim_0, mesh_dim_1]}
|
||||
sharding_spec_for_input = self._generate_sharding_spec(self.input_data, dim_partition_dict_for_input)
|
||||
|
||||
dim_partition_dict_for_weight = {0: [mesh_dim_0, mesh_dim_1]}
|
||||
sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight)
|
||||
|
||||
dim_partition_dict_for_output = {1: [mesh_dim_0, mesh_dim_1]}
|
||||
sharding_spec_for_output = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output)
|
||||
|
||||
# generate resharding cost for this strategy
|
||||
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input])
|
||||
|
||||
# compute the computation cost of this strategy
|
||||
bs = self.input_data.shape[0]
|
||||
channel_in = self.input_data.shape[1] // (self.device_mesh.shape[mesh_dim_0] *
|
||||
self.device_mesh.shape[mesh_dim_1])
|
||||
compute_cost = self._generate_compute_cost(bs, channel_in)
|
||||
|
||||
# compute the memory cost of this strategy
|
||||
sharding_size_forward = self.device_mesh.shape[mesh_dim_0] * self.device_mesh.shape[mesh_dim_1]
|
||||
sharding_size_backward_activation = self.device_mesh.shape[mesh_dim_0] * self.device_mesh.shape[mesh_dim_1]
|
||||
sharding_size_weight = self.device_mesh.shape[mesh_dim_0] * self.device_mesh.shape[mesh_dim_1]
|
||||
memory_cost, _, _ = self._generate_memory_cost(sharding_size_forward, sharding_size_backward_activation,
|
||||
sharding_size_weight)
|
||||
|
||||
# This strategy do not need to do all_reduce operation
|
||||
communication_cost = 0
|
||||
sharding_strategies = ShardingStrategy(name,
|
||||
output_sharding_spec=sharding_spec_for_output,
|
||||
compute_cost=compute_cost,
|
||||
communication_cost=communication_cost,
|
||||
memory_cost=memory_cost,
|
||||
resharding_costs=resharding_costs,
|
||||
input_shardings=(sharding_spec_for_input, sharding_spec_for_weight))
|
||||
|
||||
self.strategies_vector.append(sharding_strategies)
|
||||
|
||||
def non_split(self, mesh_dim_0, mesh_dim_1):
|
||||
name = f'RR = RR x R'
|
||||
|
||||
dim_partition_dict_for_input = {}
|
||||
sharding_spec_for_input = self._generate_sharding_spec(self.input_data, dim_partition_dict_for_input)
|
||||
|
||||
dim_partition_dict_for_weight = {}
|
||||
sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight)
|
||||
|
||||
dim_partition_dict_for_output = {}
|
||||
sharding_spec_for_output = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output)
|
||||
|
||||
# generate resharding cost for this strategy
|
||||
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input])
|
||||
|
||||
# compute the computation cost of this strategy
|
||||
bs = self.input_data.shape[0]
|
||||
channel_in = self.input_data.shape[1]
|
||||
compute_cost = self._generate_compute_cost(bs, channel_in)
|
||||
|
||||
# compute the memory cost of this strategy
|
||||
sharding_size_forward = 1
|
||||
sharding_size_backward_activation = 1
|
||||
sharding_size_weight = 1
|
||||
memory_cost, _, _ = self._generate_memory_cost(sharding_size_forward, sharding_size_backward_activation,
|
||||
sharding_size_weight)
|
||||
|
||||
# This strategy do not need to do all_reduce operation
|
||||
communication_cost = 0
|
||||
sharding_strategies = ShardingStrategy(name,
|
||||
output_sharding_spec=sharding_spec_for_output,
|
||||
compute_cost=compute_cost,
|
||||
communication_cost=communication_cost,
|
||||
memory_cost=memory_cost,
|
||||
resharding_costs=resharding_costs,
|
||||
input_shardings=(sharding_spec_for_input, sharding_spec_for_weight))
|
||||
|
||||
self.strategies_vector.append(sharding_strategies)
|
||||
|
||||
def _construct_batch_sharding_strategies(mesh_dim_list, new_name):
|
||||
dim_partition_dict_for_output = {0: mesh_dim_list}
|
||||
new_sharding_spec_for_output = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output)
|
||||
|
||||
# the computation cost is all the same
|
||||
new_compute_cost = compute_cost
|
||||
|
||||
# the memory cost need to be recomputed
|
||||
new_sharding_size_input = 1
|
||||
for mesh_dim in mesh_dim_list:
|
||||
new_sharding_size_input = new_sharding_size_input * self.device_mesh.shape[mesh_dim]
|
||||
new_memory_cost, _, memory_cost_backward_activation = self._generate_memory_cost(
|
||||
new_sharding_size_input, sharding_size_backward_activation, sharding_size_weight)
|
||||
|
||||
# the communication cost need to count the sharding cost into this strategy
|
||||
origin_communication_cost = communication_cost
|
||||
tiny_shard_cost = 10
|
||||
new_forward_communication_cost = tiny_shard_cost
|
||||
if len(mesh_dim_list) == 1:
|
||||
new_backward_communication_cost = self.device_mesh.all_gather_cost(memory_cost_backward_activation,
|
||||
mesh_dim_list[0])
|
||||
else:
|
||||
new_backward_communication_cost = self.device_mesh.flatten_device_mesh.all_gather_cost(
|
||||
memory_cost_backward_activation, 0)
|
||||
new_communication_cost = origin_communication_cost + new_forward_communication_cost + new_backward_communication_cost
|
||||
|
||||
new_sharding_strategy = ShardingStrategy(new_name,
|
||||
output_sharding_spec=new_sharding_spec_for_output,
|
||||
compute_cost=new_compute_cost,
|
||||
communication_cost=new_communication_cost,
|
||||
memory_cost=new_memory_cost,
|
||||
resharding_costs=resharding_costs,
|
||||
input_shardings=(sharding_spec_for_input,
|
||||
sharding_spec_for_weight))
|
||||
|
||||
return new_sharding_strategy
|
||||
|
||||
# shard the output batch dimension to get all possible sharding strategy from this basic strategy
|
||||
# shard on mesh_dim_0
|
||||
new_name = f'S{mesh_dim_0}R = RR x R'
|
||||
mesh_dim_list = [mesh_dim_0]
|
||||
new_sharding_strategy = _construct_batch_sharding_strategies(mesh_dim_list, new_name)
|
||||
self.strategies_vector.append(new_sharding_strategy)
|
||||
|
||||
# shard on mesh_dim_1
|
||||
new_name = f'S{mesh_dim_1}R = RR x R'
|
||||
mesh_dim_list = [mesh_dim_1]
|
||||
new_sharding_strategy = _construct_batch_sharding_strategies(mesh_dim_list, new_name)
|
||||
self.strategies_vector.append(new_sharding_strategy)
|
||||
|
||||
# shard on mesh_dim_0, mesh_dim_1
|
||||
new_name = f'S{mesh_dim_0}{mesh_dim_1}R = RR x R'
|
||||
mesh_dim_list = [mesh_dim_0, mesh_dim_1]
|
||||
new_sharding_strategy = _construct_batch_sharding_strategies(mesh_dim_list, new_name)
|
||||
self.strategies_vector.append(new_sharding_strategy)
|
||||
|
||||
def split_input_batch(self, mesh_dim_0):
|
||||
name = f'S{mesh_dim_0}R = S{mesh_dim_0}R x R WITH SYNC_BN'
|
||||
|
||||
dim_partition_dict_for_input = {0: [mesh_dim_0]}
|
||||
sharding_spec_for_input = self._generate_sharding_spec(self.input_data, dim_partition_dict_for_input)
|
||||
|
||||
dim_partition_dict_for_weight = {}
|
||||
sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight)
|
||||
|
||||
dim_partition_dict_for_output = {0: [mesh_dim_0]}
|
||||
sharding_spec_for_output = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output)
|
||||
|
||||
# generate resharding cost for this strategy
|
||||
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input])
|
||||
|
||||
# compute the computation cost of this strategy
|
||||
bs = self.input_data.shape[0] // self.device_mesh.shape[mesh_dim_0]
|
||||
channel_in = self.input_data.shape[1]
|
||||
compute_cost = self._generate_compute_cost(bs, channel_in)
|
||||
|
||||
# compute the memory cost of this strategy
|
||||
sharding_size_forward = self.device_mesh.shape[mesh_dim_0]
|
||||
sharding_size_backward_activation = self.device_mesh.shape[mesh_dim_0]
|
||||
sharding_size_weight = 1
|
||||
memory_cost, memory_cost_forward_activation, _ = self._generate_memory_cost(sharding_size_forward,
|
||||
sharding_size_backward_activation,
|
||||
sharding_size_weight)
|
||||
|
||||
# the all reduce communication will happen during the sync bn computing.
|
||||
communication_cost = self.device_mesh.all_reduce_cost(memory_cost_forward_activation, mesh_dim_0)
|
||||
sharding_strategies = ShardingStrategy(name,
|
||||
output_sharding_spec=sharding_spec_for_output,
|
||||
compute_cost=compute_cost,
|
||||
communication_cost=communication_cost,
|
||||
memory_cost=memory_cost,
|
||||
resharding_costs=resharding_costs,
|
||||
input_shardings=(sharding_spec_for_input, sharding_spec_for_weight))
|
||||
|
||||
self.strategies_vector.append(sharding_strategies)
|
||||
|
||||
def split_input_batch_1d(self, mesh_dim_0, mesh_dim_1):
|
||||
name = f'S{mesh_dim_0}{mesh_dim_1}R = S{mesh_dim_0}{mesh_dim_1}R x R WITH SYNC_BN'
|
||||
|
||||
dim_partition_dict_for_input = {0: [mesh_dim_0, mesh_dim_1]}
|
||||
sharding_spec_for_input = self._generate_sharding_spec(self.input_data, dim_partition_dict_for_input)
|
||||
|
||||
dim_partition_dict_for_weight = {}
|
||||
sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight)
|
||||
|
||||
dim_partition_dict_for_output = {0: [mesh_dim_0, mesh_dim_1]}
|
||||
sharding_spec_for_output = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output)
|
||||
|
||||
# generate resharding cost for this strategy
|
||||
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input])
|
||||
|
||||
# compute the computation cost of this strategy
|
||||
bs = self.input_data.shape[0] // (self.device_mesh.shape[mesh_dim_0] * self.device_mesh.shape[mesh_dim_1])
|
||||
channel_in = self.input_data.shape[1]
|
||||
compute_cost = self._generate_compute_cost(bs, channel_in)
|
||||
|
||||
# compute the memory cost of this strategy
|
||||
sharding_size_forward = self.device_mesh.shape[mesh_dim_0] * self.device_mesh.shape[mesh_dim_1]
|
||||
sharding_size_backward_activation = self.device_mesh.shape[mesh_dim_0] * self.device_mesh.shape[mesh_dim_1]
|
||||
sharding_size_weight = 1
|
||||
memory_cost, memory_cost_forward_activation, _ = self._generate_memory_cost(sharding_size_forward,
|
||||
sharding_size_backward_activation,
|
||||
sharding_size_weight)
|
||||
|
||||
# the all reduce communication will happen during the sync bn computing.
|
||||
communication_cost = self.device_mesh.flatten_device_mesh.all_reduce_cost(memory_cost_forward_activation, 0)
|
||||
sharding_strategies = ShardingStrategy(name,
|
||||
output_sharding_spec=sharding_spec_for_output,
|
||||
compute_cost=compute_cost,
|
||||
communication_cost=communication_cost,
|
||||
memory_cost=memory_cost,
|
||||
resharding_costs=resharding_costs,
|
||||
input_shardings=(sharding_spec_for_input, sharding_spec_for_weight))
|
||||
|
||||
self.strategies_vector.append(sharding_strategies)
|
||||
|
||||
def split_input_both_dim(self, mesh_dim_0, mesh_dim_1):
|
||||
name = f'S{mesh_dim_0}S{mesh_dim_1} = S{mesh_dim_0}S{mesh_dim_1} x S{mesh_dim_1} WITH SYNC_BN'
|
||||
|
||||
dim_partition_dict_for_input = {0: [mesh_dim_0], 1: [mesh_dim_1]}
|
||||
sharding_spec_for_input = self._generate_sharding_spec(self.input_data, dim_partition_dict_for_input)
|
||||
|
||||
dim_partition_dict_for_weight = {0: [mesh_dim_1]}
|
||||
sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight)
|
||||
|
||||
dim_partition_dict_for_output = {0: [mesh_dim_0], 1: [mesh_dim_1]}
|
||||
sharding_spec_for_output = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output)
|
||||
|
||||
# generate resharding cost for this strategy
|
||||
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input])
|
||||
|
||||
# compute the computation cost of this strategy
|
||||
bs = self.input_data.shape[0] // self.device_mesh.shape[mesh_dim_0]
|
||||
channel_in = self.input_data.shape[1] // self.device_mesh.shape[mesh_dim_1]
|
||||
compute_cost = self._generate_compute_cost(bs, channel_in)
|
||||
|
||||
# compute the memory cost of this strategy
|
||||
sharding_size_forward = self.device_mesh.shape[mesh_dim_0] * self.device_mesh.shape[mesh_dim_1]
|
||||
sharding_size_backward_activation = self.device_mesh.shape[mesh_dim_0] * self.device_mesh.shape[mesh_dim_1]
|
||||
sharding_size_weight = self.device_mesh.shape[mesh_dim_1]
|
||||
memory_cost, memory_cost_forward_activation, _ = self._generate_memory_cost(sharding_size_forward,
|
||||
sharding_size_backward_activation,
|
||||
sharding_size_weight)
|
||||
|
||||
# the all reduce communication will happen during the sync bn computing.
|
||||
communication_cost = self.device_mesh.all_reduce_cost(memory_cost_forward_activation, mesh_dim_0)
|
||||
sharding_strategies = ShardingStrategy(name,
|
||||
output_sharding_spec=sharding_spec_for_output,
|
||||
compute_cost=compute_cost,
|
||||
communication_cost=communication_cost,
|
||||
memory_cost=memory_cost,
|
||||
resharding_costs=resharding_costs,
|
||||
input_shardings=(sharding_spec_for_input, sharding_spec_for_weight))
|
||||
|
||||
self.strategies_vector.append(sharding_strategies)
|
||||
|
||||
def register_strategy(self) -> StrategiesVector:
|
||||
'''
|
||||
Generate every possible strategies for a BatchNorm node, and record all strategies into the strategies_vector.
|
||||
|
||||
Example:
|
||||
norm_handler = BatchNormHandler(node, self.device_mesh, strategies_vector,
|
||||
self.shape_consistency_manager)
|
||||
norm_handler.register_strategy()
|
||||
for strategy in norm_handler.strategies_vector:
|
||||
print(f'{strategy.name}, computation_cost: {strategy.compute_cost}, memory_cost: {strategy.memory_cost}')
|
||||
|
||||
Output:
|
||||
RS0 = RS0 x S0, computation_cost: 131072, memory_cost: 524288.0
|
||||
RS1 = RS1 x S1, computation_cost: 131072, memory_cost: 524288.0
|
||||
RR = RR x R, computation_cost: 262144, memory_cost: 1048576
|
||||
RS01 = RS01 x S01, computation_cost: 65536, memory_cost: 262144.0
|
||||
'''
|
||||
|
||||
# RS = RS x S and strategies based on it, such as
|
||||
# SS = RS x S
|
||||
self.split_input_channel(0, 1)
|
||||
self.split_input_channel(1, 0)
|
||||
|
||||
# RR = RR x R and strategies based on it, such as
|
||||
# SR = SR x R
|
||||
self.non_split(0, 1)
|
||||
|
||||
# RS01 = RS01 x S01
|
||||
self.split_input_channel_1d(0, 1)
|
||||
|
||||
# SR = SR x R WITH SYNC_BN
|
||||
self.split_input_batch(0)
|
||||
self.split_input_batch(1)
|
||||
|
||||
# SS = SS x S WITH SYNC_BN
|
||||
self.split_input_both_dim(0, 1)
|
||||
self.split_input_both_dim(1, 0)
|
||||
|
||||
# S01R = S01R x R WITH SYNC_BN
|
||||
self.split_input_batch_1d(0, 1)
|
||||
|
||||
return self.strategies_vector
|
|
@ -1,5 +1,6 @@
|
|||
import operator
|
||||
from functools import reduce
|
||||
import warnings
|
||||
import torch
|
||||
from colossalai.auto_parallel.solver.sharding_strategy import ShardingStrategy, StrategiesVector
|
||||
from .operator_handler import OperatorHandler
|
||||
|
@ -9,7 +10,7 @@ __all__ = ['ConvHandler']
|
|||
|
||||
class ConvHandler(OperatorHandler):
|
||||
"""
|
||||
An OperatorHandler which deals with the sharding strategies of linear matrix multiplication.
|
||||
A OperatorHandler which deals with the sharding strategies of Convolution.
|
||||
"""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
|
@ -59,8 +60,7 @@ class ConvHandler(OperatorHandler):
|
|||
compute_cost = forward_compute_cost + backward_activation_cost + backward_weight_cost
|
||||
return compute_cost
|
||||
|
||||
def _generate_memory_cost(self, sharding_size_forward, sharding_size_backward_activation,
|
||||
sharding_size_backward_weight):
|
||||
def _generate_memory_cost(self, sharding_size_forward, sharding_size_backward_activation, sharding_size_weight):
|
||||
'''
|
||||
Compute the memory cost per device with this specific strategy.
|
||||
|
||||
|
@ -69,8 +69,8 @@ class ConvHandler(OperatorHandler):
|
|||
into sharding_size_forward number partions.
|
||||
sharding_size_backward_activation(int): The backward activation will
|
||||
be divided into sharding_size_backward_activation number partions.
|
||||
sharding_size_backward_weight(int): The backward weight will be divided
|
||||
into sharding_size_backward_weight number partions.
|
||||
sharding_size_weight(int): The backward weight will be divided
|
||||
into sharding_size_weight number partions.
|
||||
|
||||
Return:
|
||||
memory_cost(Tuple[float]): Memory cost per device with this
|
||||
|
@ -90,17 +90,19 @@ class ConvHandler(OperatorHandler):
|
|||
size_per_elem_bytes = torch.tensor([], dtype=dtype).element_size()
|
||||
|
||||
# forward memory_cost
|
||||
memory_cost_forward = numel_output * size_per_elem_bytes / sharding_size_forward
|
||||
memory_cost_forward_activation = numel_output * size_per_elem_bytes / sharding_size_forward
|
||||
memory_cost_forward_weight = numel_weight * size_per_elem_bytes / sharding_size_weight
|
||||
memory_cost_forward = memory_cost_forward_activation + memory_cost_forward_weight
|
||||
|
||||
# backward memory_cost
|
||||
memory_cost_backward_activation = numel_input * size_per_elem_bytes / sharding_size_backward_activation
|
||||
memory_cost_backward_weight = numel_weight * size_per_elem_bytes / sharding_size_backward_weight
|
||||
memory_cost_backward_weight = numel_weight * size_per_elem_bytes / sharding_size_weight
|
||||
memory_cost_backward = memory_cost_backward_activation + memory_cost_backward_weight
|
||||
|
||||
# memory_cost pair
|
||||
memory_cost = (memory_cost_forward, memory_cost_backward)
|
||||
|
||||
return memory_cost, memory_cost_forward, memory_cost_backward_activation
|
||||
return memory_cost, memory_cost_forward_activation, memory_cost_backward_activation
|
||||
|
||||
def split_input_batch_weight_out_channel(self, mesh_dim_0, mesh_dim_1):
|
||||
name = f'S{mesh_dim_0}S{mesh_dim_1} = S{mesh_dim_0}R x RS{mesh_dim_1}'
|
||||
|
@ -112,7 +114,7 @@ class ConvHandler(OperatorHandler):
|
|||
sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight)
|
||||
|
||||
dim_partition_dict_for_output = {0: [mesh_dim_0], 1: [mesh_dim_1]}
|
||||
sharding_spec_for_ouput = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output)
|
||||
sharding_spec_for_output = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output)
|
||||
|
||||
# generate resharding cost for this strategy
|
||||
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input, sharding_spec_for_weight])
|
||||
|
@ -126,9 +128,9 @@ class ConvHandler(OperatorHandler):
|
|||
# compute the memory cost of this strategy
|
||||
sharding_size_forward = self.device_mesh.shape[mesh_dim_0] * self.device_mesh.shape[mesh_dim_1]
|
||||
sharding_size_backward_activation = self.device_mesh.shape[mesh_dim_0]
|
||||
sharding_size_backward_weight = self.device_mesh.shape[mesh_dim_1]
|
||||
sharding_size_weight = self.device_mesh.shape[mesh_dim_1]
|
||||
memory_cost, _, memory_cost_backward_activation = self._generate_memory_cost(
|
||||
sharding_size_forward, sharding_size_backward_activation, sharding_size_backward_weight)
|
||||
sharding_size_forward, sharding_size_backward_activation, sharding_size_weight)
|
||||
|
||||
# This strategy do not need to do all_reduce operation during forward
|
||||
communication_cost_forward = 0
|
||||
|
@ -138,7 +140,7 @@ class ConvHandler(OperatorHandler):
|
|||
communication_cost = communication_cost_forward + communication_cost_backward
|
||||
|
||||
sharding_strategies = ShardingStrategy(name,
|
||||
output_sharding_spec=sharding_spec_for_ouput,
|
||||
output_sharding_spec=sharding_spec_for_output,
|
||||
compute_cost=compute_cost,
|
||||
communication_cost=communication_cost,
|
||||
memory_cost=memory_cost,
|
||||
|
@ -156,7 +158,7 @@ class ConvHandler(OperatorHandler):
|
|||
sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight)
|
||||
|
||||
dim_partition_dict_for_output = {0: [mesh_dim_0]}
|
||||
sharding_spec_for_ouput = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output)
|
||||
sharding_spec_for_output = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output)
|
||||
|
||||
# generate resharding cost for this strategy
|
||||
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input, sharding_spec_for_weight])
|
||||
|
@ -170,19 +172,20 @@ class ConvHandler(OperatorHandler):
|
|||
# compute the memory cost of this strategy
|
||||
sharding_size_forward = self.device_mesh.shape[mesh_dim_0]
|
||||
sharding_size_backward_activation = self.device_mesh.shape[mesh_dim_0]
|
||||
sharding_size_backward_weight = 1
|
||||
sharding_size_weight = 1
|
||||
memory_cost, _, _ = self._generate_memory_cost(sharding_size_forward, sharding_size_backward_activation,
|
||||
sharding_size_backward_weight)
|
||||
sharding_size_weight)
|
||||
|
||||
# This strategy do not need to do all_reduce operation in both forward and backward phase.
|
||||
communication_cost = 0
|
||||
sharding_strategies = ShardingStrategy(name,
|
||||
output_sharding_spec=sharding_spec_for_ouput,
|
||||
output_sharding_spec=sharding_spec_for_output,
|
||||
compute_cost=compute_cost,
|
||||
communication_cost=communication_cost,
|
||||
memory_cost=memory_cost,
|
||||
resharding_costs=resharding_costs,
|
||||
input_shardings=(sharding_spec_for_input, sharding_spec_for_weight))
|
||||
|
||||
self.strategies_vector.append(sharding_strategies)
|
||||
|
||||
def split_input_both_dim_weight_in_channel(self, mesh_dim_0, mesh_dim_1):
|
||||
|
@ -195,7 +198,7 @@ class ConvHandler(OperatorHandler):
|
|||
sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight)
|
||||
|
||||
dim_partition_dict_for_output = {0: [mesh_dim_0]}
|
||||
sharding_spec_for_ouput = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_input)
|
||||
sharding_spec_for_output = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output)
|
||||
|
||||
# generate resharding cost for this strategy
|
||||
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input, sharding_spec_for_weight])
|
||||
|
@ -209,18 +212,18 @@ class ConvHandler(OperatorHandler):
|
|||
# compute the memory cost of this strategy
|
||||
sharding_size_forward = self.device_mesh.shape[mesh_dim_0]
|
||||
sharding_size_backward_activation = self.device_mesh.shape[mesh_dim_0] * self.device_mesh.shape[mesh_dim_1]
|
||||
sharding_size_backward_weight = self.device_mesh.shape[mesh_dim_1]
|
||||
memory_cost, memory_cost_forward, _ = self._generate_memory_cost(sharding_size_forward,
|
||||
sharding_size_backward_activation,
|
||||
sharding_size_backward_weight)
|
||||
sharding_size_weight = self.device_mesh.shape[mesh_dim_1]
|
||||
memory_cost, memory_cost_forward_activation, _ = self._generate_memory_cost(sharding_size_forward,
|
||||
sharding_size_backward_activation,
|
||||
sharding_size_weight)
|
||||
|
||||
# compute the communication cost of this strategy during forward phase
|
||||
communication_cost_forward = self.device_mesh.all_reduce_cost(memory_cost_forward, mesh_dim_1)
|
||||
communication_cost_forward = self.device_mesh.all_reduce_cost(memory_cost_forward_activation, mesh_dim_1)
|
||||
# This strategy do not need to do all_reduce operation during backward phase
|
||||
communication_cost_backward = 0
|
||||
communication_cost = communication_cost_forward + communication_cost_backward
|
||||
sharding_strategies = ShardingStrategy(name,
|
||||
output_sharding_spec=sharding_spec_for_ouput,
|
||||
output_sharding_spec=sharding_spec_for_output,
|
||||
compute_cost=compute_cost,
|
||||
communication_cost=communication_cost,
|
||||
memory_cost=memory_cost,
|
||||
|
@ -238,7 +241,7 @@ class ConvHandler(OperatorHandler):
|
|||
sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight)
|
||||
|
||||
dim_partition_dict_for_output = {1: [mesh_dim_1]}
|
||||
sharding_spec_for_ouput = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_input)
|
||||
sharding_spec_for_output = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output)
|
||||
|
||||
# generate resharding cost for this strategy
|
||||
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input, sharding_spec_for_weight])
|
||||
|
@ -252,17 +255,17 @@ class ConvHandler(OperatorHandler):
|
|||
# compute the memory cost of this strategy
|
||||
sharding_size_forward = self.device_mesh.shape[mesh_dim_1]
|
||||
sharding_size_backward_activation = self.device_mesh.shape[mesh_dim_0]
|
||||
sharding_size_backward_weight = self.device_mesh.shape[mesh_dim_0] * self.device_mesh.shape[mesh_dim_1]
|
||||
memory_cost, memory_cost_forward, memory_cost_backward_activation = self._generate_memory_cost(
|
||||
sharding_size_forward, sharding_size_backward_activation, sharding_size_backward_weight)
|
||||
sharding_size_weight = self.device_mesh.shape[mesh_dim_0] * self.device_mesh.shape[mesh_dim_1]
|
||||
memory_cost, memory_cost_forward_activation, memory_cost_backward_activation = self._generate_memory_cost(
|
||||
sharding_size_forward, sharding_size_backward_activation, sharding_size_weight)
|
||||
|
||||
# compute the communication cost of this strategy during forward phase
|
||||
communication_cost_forward = self.device_mesh.all_reduce_cost(memory_cost_forward, mesh_dim_0)
|
||||
communication_cost_forward = self.device_mesh.all_reduce_cost(memory_cost_forward_activation, mesh_dim_0)
|
||||
# compute the communication cost of this strategy during backward phase
|
||||
communication_cost_backward = self.device_mesh.all_reduce_cost(memory_cost_backward_activation, mesh_dim_1)
|
||||
communication_cost = communication_cost_forward + communication_cost_backward
|
||||
sharding_strategies = ShardingStrategy(name,
|
||||
output_sharding_spec=sharding_spec_for_ouput,
|
||||
output_sharding_spec=sharding_spec_for_output,
|
||||
compute_cost=compute_cost,
|
||||
communication_cost=communication_cost,
|
||||
memory_cost=memory_cost,
|
||||
|
@ -280,7 +283,7 @@ class ConvHandler(OperatorHandler):
|
|||
sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight)
|
||||
|
||||
dim_partition_dict_for_output = {}
|
||||
sharding_spec_for_ouput = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_input)
|
||||
sharding_spec_for_output = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output)
|
||||
|
||||
# generate resharding cost for this strategy
|
||||
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input, sharding_spec_for_weight])
|
||||
|
@ -294,18 +297,18 @@ class ConvHandler(OperatorHandler):
|
|||
# compute the memory cost of this strategy
|
||||
sharding_size_forward = 1
|
||||
sharding_size_backward_activation = self.device_mesh.shape[mesh_dim_0]
|
||||
sharding_size_backward_weight = self.device_mesh.shape[mesh_dim_0]
|
||||
memory_cost, memory_cost_forward, _ = self._generate_memory_cost(sharding_size_forward,
|
||||
sharding_size_backward_activation,
|
||||
sharding_size_backward_weight)
|
||||
sharding_size_weight = self.device_mesh.shape[mesh_dim_0]
|
||||
memory_cost, memory_cost_forward_activation, _ = self._generate_memory_cost(sharding_size_forward,
|
||||
sharding_size_backward_activation,
|
||||
sharding_size_weight)
|
||||
|
||||
# compute the communication cost of this strategy during forward phase
|
||||
communication_cost_forward = self.device_mesh.all_reduce_cost(memory_cost_forward, mesh_dim_0)
|
||||
communication_cost_forward = self.device_mesh.all_reduce_cost(memory_cost_forward_activation, mesh_dim_0)
|
||||
# This strategy do NOT need all_reduce during forward phase
|
||||
communication_cost_backward = 0
|
||||
communication_cost = communication_cost_forward + communication_cost_backward
|
||||
sharding_strategies = ShardingStrategy(name,
|
||||
output_sharding_spec=sharding_spec_for_ouput,
|
||||
output_sharding_spec=sharding_spec_for_output,
|
||||
compute_cost=compute_cost,
|
||||
communication_cost=communication_cost,
|
||||
memory_cost=memory_cost,
|
||||
|
@ -323,7 +326,7 @@ class ConvHandler(OperatorHandler):
|
|||
sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight)
|
||||
|
||||
dim_partition_dict_for_output = {1: [mesh_dim_0]}
|
||||
sharding_spec_for_ouput = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_input)
|
||||
sharding_spec_for_output = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output)
|
||||
|
||||
# generate resharding cost for this strategy
|
||||
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input, sharding_spec_for_weight])
|
||||
|
@ -337,9 +340,9 @@ class ConvHandler(OperatorHandler):
|
|||
# compute the memory cost of this strategy
|
||||
sharding_size_forward = self.device_mesh.shape[mesh_dim_0]
|
||||
sharding_size_backward_activation = 1
|
||||
sharding_size_backward_weight = self.device_mesh.shape[mesh_dim_0]
|
||||
sharding_size_weight = self.device_mesh.shape[mesh_dim_0]
|
||||
memory_cost, _, memory_cost_backward_activation = self._generate_memory_cost(
|
||||
sharding_size_forward, sharding_size_backward_activation, sharding_size_backward_weight)
|
||||
sharding_size_forward, sharding_size_backward_activation, sharding_size_weight)
|
||||
|
||||
# This strategy do not need to do all_reduce during forward phase
|
||||
communication_cost_forward = 0
|
||||
|
@ -347,7 +350,7 @@ class ConvHandler(OperatorHandler):
|
|||
communication_cost_backward = self.device_mesh.all_reduce_cost(memory_cost_backward_activation, mesh_dim_0)
|
||||
communication_cost = communication_cost_forward + communication_cost_backward
|
||||
sharding_strategies = ShardingStrategy(name,
|
||||
output_sharding_spec=sharding_spec_for_ouput,
|
||||
output_sharding_spec=sharding_spec_for_output,
|
||||
compute_cost=compute_cost,
|
||||
communication_cost=communication_cost,
|
||||
memory_cost=memory_cost,
|
||||
|
@ -365,7 +368,7 @@ class ConvHandler(OperatorHandler):
|
|||
sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight)
|
||||
|
||||
dim_partition_dict_for_output = {}
|
||||
sharding_spec_for_ouput = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_input)
|
||||
sharding_spec_for_output = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output)
|
||||
|
||||
# generate resharding cost for this strategy
|
||||
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input, sharding_spec_for_weight])
|
||||
|
@ -379,15 +382,15 @@ class ConvHandler(OperatorHandler):
|
|||
# compute the memory cost of this strategy
|
||||
sharding_size_forward = 1
|
||||
sharding_size_backward_activation = 1
|
||||
sharding_size_backward_weight = 1
|
||||
sharding_size_weight = 1
|
||||
memory_cost, _, _ = self._generate_memory_cost(sharding_size_forward, sharding_size_backward_activation,
|
||||
sharding_size_backward_weight)
|
||||
sharding_size_weight)
|
||||
|
||||
# This strategy do not need to do all_reduce in both forward and backward phase
|
||||
communication_cost = 0
|
||||
|
||||
sharding_strategies = ShardingStrategy(name,
|
||||
output_sharding_spec=sharding_spec_for_ouput,
|
||||
output_sharding_spec=sharding_spec_for_output,
|
||||
compute_cost=compute_cost,
|
||||
communication_cost=communication_cost,
|
||||
memory_cost=memory_cost,
|
||||
|
@ -405,7 +408,7 @@ class ConvHandler(OperatorHandler):
|
|||
sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight)
|
||||
|
||||
dim_partition_dict_for_output = {0: [mesh_dim_0, mesh_dim_1]}
|
||||
sharding_spec_for_ouput = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_input)
|
||||
sharding_spec_for_output = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output)
|
||||
|
||||
# generate resharding cost for this strategy
|
||||
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input, sharding_spec_for_weight])
|
||||
|
@ -420,15 +423,15 @@ class ConvHandler(OperatorHandler):
|
|||
sharding_size_forward = self.device_mesh.mesh_shape[mesh_dim_0] * self.device_mesh.mesh_shape[mesh_dim_1]
|
||||
sharding_size_backward_activation = self.device_mesh.mesh_shape[mesh_dim_0] * self.device_mesh.mesh_shape[
|
||||
mesh_dim_1]
|
||||
sharding_size_backward_weight = 1
|
||||
sharding_size_weight = 1
|
||||
memory_cost, _, _ = self._generate_memory_cost(sharding_size_forward, sharding_size_backward_activation,
|
||||
sharding_size_backward_weight)
|
||||
sharding_size_weight)
|
||||
|
||||
# This strategy do not need to do all_reduce in both forward and backward phase
|
||||
communication_cost = 0
|
||||
|
||||
sharding_strategies = ShardingStrategy(name,
|
||||
output_sharding_spec=sharding_spec_for_ouput,
|
||||
output_sharding_spec=sharding_spec_for_output,
|
||||
compute_cost=compute_cost,
|
||||
communication_cost=communication_cost,
|
||||
memory_cost=memory_cost,
|
||||
|
@ -446,7 +449,7 @@ class ConvHandler(OperatorHandler):
|
|||
sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight)
|
||||
|
||||
dim_partition_dict_for_output = {}
|
||||
sharding_spec_for_ouput = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_input)
|
||||
sharding_spec_for_output = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output)
|
||||
|
||||
# generate resharding cost for this strategy
|
||||
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input, sharding_spec_for_weight])
|
||||
|
@ -462,19 +465,20 @@ class ConvHandler(OperatorHandler):
|
|||
sharding_size_forward = 1
|
||||
sharding_size_backward_activation = self.device_mesh.mesh_shape[mesh_dim_0] * self.device_mesh.mesh_shape[
|
||||
mesh_dim_1]
|
||||
sharding_size_backward_weight = self.device_mesh.mesh_shape[mesh_dim_0] * self.device_mesh.mesh_shape[mesh_dim_1]
|
||||
memory_cost, memory_cost_forward, _ = self._generate_memory_cost(sharding_size_forward,
|
||||
sharding_size_backward_activation,
|
||||
sharding_size_backward_weight)
|
||||
sharding_size_weight = self.device_mesh.mesh_shape[mesh_dim_0] * self.device_mesh.mesh_shape[mesh_dim_1]
|
||||
memory_cost, memory_cost_forward_activation, _ = self._generate_memory_cost(sharding_size_forward,
|
||||
sharding_size_backward_activation,
|
||||
sharding_size_weight)
|
||||
|
||||
# compute communication cost during forward phase
|
||||
communication_cost_forward = self.device_mesh.flatten_device_mesh.all_reduce_cost(memory_cost_forward, 0)
|
||||
communication_cost_forward = self.device_mesh.flatten_device_mesh.all_reduce_cost(
|
||||
memory_cost_forward_activation, 0)
|
||||
# This strategy do NOT need do all_reduce during backward phase
|
||||
communication_cost_backward = 0
|
||||
communication_cost = communication_cost_forward + communication_cost_backward
|
||||
|
||||
sharding_strategies = ShardingStrategy(name,
|
||||
output_sharding_spec=sharding_spec_for_ouput,
|
||||
output_sharding_spec=sharding_spec_for_output,
|
||||
compute_cost=compute_cost,
|
||||
communication_cost=communication_cost,
|
||||
memory_cost=memory_cost,
|
||||
|
@ -536,24 +540,48 @@ class ConvHandler(OperatorHandler):
|
|||
RR = RS01 x S01R: compute_cost is 8856576, communication_cost is 0, memory_cost is 1968128, resharding_costs is {mul: [0, 0, 262148.4, 65538.002, 196614.402, 262148.4, 65538.2]}
|
||||
'''
|
||||
# SS = SR x RS
|
||||
self.split_input_batch_weight_out_channel(0, 1)
|
||||
self.split_input_batch_weight_out_channel(1, 0)
|
||||
try:
|
||||
self.split_input_batch_weight_out_channel(0, 1)
|
||||
except Exception as e:
|
||||
warnings.warn(f'{e}')
|
||||
try:
|
||||
self.split_input_batch_weight_out_channel(1, 0)
|
||||
except Exception as e:
|
||||
warnings.warn(f'{e}')
|
||||
|
||||
# SR = SR x RR
|
||||
self.split_input_batch(0)
|
||||
self.split_input_batch(1)
|
||||
|
||||
# SR = SS x SR
|
||||
self.split_input_both_dim_weight_in_channel(0, 1)
|
||||
self.split_input_both_dim_weight_in_channel(1, 0)
|
||||
try:
|
||||
self.split_input_both_dim_weight_in_channel(0, 1)
|
||||
except Exception as e:
|
||||
warnings.warn(f'{e}')
|
||||
try:
|
||||
self.split_input_both_dim_weight_in_channel(1, 0)
|
||||
except Exception as e:
|
||||
warnings.warn(f'{e}')
|
||||
|
||||
# RS = RS x SS
|
||||
self.split_input_in_channel_weight_both_channel(0, 1)
|
||||
self.split_input_in_channel_weight_both_channel(1, 0)
|
||||
try:
|
||||
self.split_input_in_channel_weight_both_channel(0, 1)
|
||||
except Exception as e:
|
||||
warnings.warn(f'{e}')
|
||||
try:
|
||||
self.split_input_in_channel_weight_both_channel(1, 0)
|
||||
except Exception as e:
|
||||
warnings.warn(f'{e}')
|
||||
|
||||
# RR = RS x SR
|
||||
self.split_input_in_channel_weight_in_channel(0)
|
||||
self.split_input_in_channel_weight_in_channel(1)
|
||||
try:
|
||||
self.split_input_in_channel_weight_in_channel(0)
|
||||
except Exception as e:
|
||||
warnings.warn(f'{e}')
|
||||
try:
|
||||
self.split_input_in_channel_weight_in_channel(1)
|
||||
except Exception as e:
|
||||
warnings.warn(f'{e}')
|
||||
|
||||
# RS = RR x RS
|
||||
self.split_weight_out_channel(0)
|
||||
|
@ -566,7 +594,12 @@ class ConvHandler(OperatorHandler):
|
|||
self.split_1d_parallel_on_input_batch(0, 1)
|
||||
|
||||
# RR = RS01 x S01R
|
||||
self.split_1d_parallel_on_in_channel(0, 1)
|
||||
try:
|
||||
self.split_1d_parallel_on_in_channel(0, 1)
|
||||
except Exception as e:
|
||||
warnings.warn(f'{e}')
|
||||
|
||||
# print(f'strategies num is :{len(self.strategies_vector)}')
|
||||
|
||||
return self.strategies_vector
|
||||
|
|
@ -44,11 +44,8 @@ class DotHandler(OperatorHandler):
|
|||
compute_cost = self._generate_compute_cost(self.input_data.shape, self.weight.shape)
|
||||
|
||||
# compute the memory cost of this strategy
|
||||
dtype = self.input_data.dtype
|
||||
numel = self.output_data.numel()
|
||||
size_per_elem_bytes = torch.tensor([], dtype=dtype).element_size()
|
||||
sharding_size = self.device_mesh.shape[mesh_dim_0] * self.device_mesh.shape[mesh_dim_1]
|
||||
memory_cost = numel * size_per_elem_bytes / sharding_size
|
||||
toatl_memory_cost, activation_memory_cost, weight_memory_cost = self._generate_memory_cost(
|
||||
dim_partition_dict_for_output, dim_partition_dict_for_weight)
|
||||
|
||||
# compute the communication cost
|
||||
# no all-reduce required for this case
|
||||
|
@ -59,7 +56,7 @@ class DotHandler(OperatorHandler):
|
|||
output_sharding_spec=sharding_spec_for_ouput,
|
||||
compute_cost=compute_cost,
|
||||
communication_cost=communication_cost,
|
||||
memory_cost=memory_cost,
|
||||
memory_cost=toatl_memory_cost,
|
||||
resharding_costs=resharding_costs,
|
||||
input_shardings=(sharding_spec_for_input, sharding_spec_for_weight))
|
||||
self.strategies_vector.append(sharding_strategies)
|
||||
|
@ -86,19 +83,16 @@ class DotHandler(OperatorHandler):
|
|||
compute_cost = self._generate_compute_cost(self.input_data.shape, self.weight.shape)
|
||||
|
||||
# compute the memory cost of this strategy
|
||||
dtype = self.input_data.dtype
|
||||
numel = self.output_data.numel()
|
||||
size_per_elem_bytes = torch.tensor([], dtype=dtype).element_size()
|
||||
sharding_size = self.device_mesh.shape[mesh_dim_0]
|
||||
memory_cost = numel * size_per_elem_bytes / sharding_size
|
||||
toatl_memory_cost, activation_memory_cost, weight_memory_cost = self._generate_memory_cost(
|
||||
dim_partition_dict_for_output, dim_partition_dict_for_weight)
|
||||
|
||||
# compute the communication cost of this strategy
|
||||
communication_cost = self.device_mesh.all_reduce_cost(memory_cost, mesh_dim_1)
|
||||
communication_cost = self.device_mesh.all_reduce_cost(activation_memory_cost, mesh_dim_1)
|
||||
sharding_strategies = ShardingStrategy(name,
|
||||
output_sharding_spec=sharding_spec_for_ouput,
|
||||
compute_cost=compute_cost,
|
||||
communication_cost=communication_cost,
|
||||
memory_cost=memory_cost,
|
||||
memory_cost=toatl_memory_cost,
|
||||
resharding_costs=resharding_costs,
|
||||
input_shardings=(sharding_spec_for_input, sharding_spec_for_weight))
|
||||
self.strategies_vector.append(sharding_strategies)
|
||||
|
@ -122,19 +116,16 @@ class DotHandler(OperatorHandler):
|
|||
compute_cost = self._generate_compute_cost(self.input_data.shape, self.weight.shape)
|
||||
|
||||
# compute the memory cost of this strategy
|
||||
dtype = self.input_data.dtype
|
||||
numel = self.output_data.numel()
|
||||
size_per_elem_bytes = torch.tensor([], dtype=dtype).element_size()
|
||||
sharding_size = self.device_mesh.shape[mesh_dim_0]
|
||||
memory_cost = numel * size_per_elem_bytes / sharding_size
|
||||
toatl_memory_cost, activation_memory_cost, weight_memory_cost = self._generate_memory_cost(
|
||||
dim_partition_dict_for_output, dim_partition_dict_for_weight)
|
||||
|
||||
# compute the communication cost of this strategy
|
||||
communication_cost = self.device_mesh.all_reduce_cost(memory_cost, mesh_dim_1)
|
||||
communication_cost = self.device_mesh.all_reduce_cost(activation_memory_cost, mesh_dim_1)
|
||||
sharding_strategies = ShardingStrategy(name,
|
||||
output_sharding_spec=sharding_spec_for_ouput,
|
||||
compute_cost=compute_cost,
|
||||
communication_cost=communication_cost,
|
||||
memory_cost=memory_cost,
|
||||
memory_cost=toatl_memory_cost,
|
||||
resharding_costs=resharding_costs,
|
||||
input_shardings=(sharding_spec_for_input, sharding_spec_for_weight))
|
||||
self.strategies_vector.append(sharding_strategies)
|
||||
|
@ -158,18 +149,16 @@ class DotHandler(OperatorHandler):
|
|||
compute_cost = self._generate_compute_cost(self.input_data.shape, self.weight.shape)
|
||||
|
||||
# compute the memory cost of this strategy
|
||||
dtype = self.input_data.dtype
|
||||
numel = self.output_data.numel()
|
||||
size_per_elem_bytes = torch.tensor([], dtype=dtype).element_size()
|
||||
memory_cost = numel * size_per_elem_bytes
|
||||
toatl_memory_cost, activation_memory_cost, weight_memory_cost = self._generate_memory_cost(
|
||||
dim_partition_dict_for_output, dim_partition_dict_for_weight)
|
||||
|
||||
# compute the communication cost of this strategy
|
||||
communication_cost = self.device_mesh.all_reduce_cost(memory_cost, mesh_dim)
|
||||
communication_cost = self.device_mesh.all_reduce_cost(activation_memory_cost, mesh_dim)
|
||||
sharding_strategies = ShardingStrategy(name,
|
||||
output_sharding_spec=sharding_spec_for_ouput,
|
||||
compute_cost=compute_cost,
|
||||
communication_cost=communication_cost,
|
||||
memory_cost=memory_cost,
|
||||
memory_cost=toatl_memory_cost,
|
||||
resharding_costs=resharding_costs,
|
||||
input_shardings=(sharding_spec_for_input, sharding_spec_for_weight))
|
||||
self.strategies_vector.append(sharding_strategies)
|
||||
|
@ -193,19 +182,115 @@ class DotHandler(OperatorHandler):
|
|||
compute_cost = self._generate_compute_cost(self.input_data.shape, self.weight.shape)
|
||||
|
||||
# compute the memory cost of this strategy
|
||||
dtype = self.input_data.dtype
|
||||
numel = self.output_data.numel()
|
||||
size_per_elem_bytes = torch.tensor([], dtype=dtype).element_size()
|
||||
sharding_size = self.device_mesh.shape[mesh_dim]
|
||||
memory_cost = numel * size_per_elem_bytes / sharding_size
|
||||
toatl_memory_cost, activation_memory_cost, weight_memory_cost = self._generate_memory_cost(
|
||||
dim_partition_dict_for_output, dim_partition_dict_for_weight)
|
||||
|
||||
# compute the communication cost of this strategy
|
||||
communication_cost = self.device_mesh.all_reduce_cost(memory_cost, mesh_dim)
|
||||
communication_cost = self.device_mesh.all_reduce_cost(activation_memory_cost, mesh_dim)
|
||||
sharding_strategies = ShardingStrategy(name,
|
||||
output_sharding_spec=sharding_spec_for_ouput,
|
||||
compute_cost=compute_cost,
|
||||
communication_cost=communication_cost,
|
||||
memory_cost=memory_cost,
|
||||
memory_cost=toatl_memory_cost,
|
||||
resharding_costs=resharding_costs,
|
||||
input_shardings=(sharding_spec_for_input, sharding_spec_for_weight))
|
||||
self.strategies_vector.append(sharding_strategies)
|
||||
|
||||
def split_lhs_1st_dim_1d(self, mesh_dim_0, mesh_dim_1):
|
||||
name = f'S{mesh_dim_0}{mesh_dim_1}R = S{mesh_dim_0}{mesh_dim_1}R x RR'
|
||||
|
||||
dim_partition_dict_for_input = {0: [mesh_dim_0, mesh_dim_1]}
|
||||
sharding_spec_for_input = self._generate_sharding_spec(self.input_data, dim_partition_dict_for_input)
|
||||
|
||||
dim_partition_dict_for_weight = {}
|
||||
sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight)
|
||||
|
||||
dim_partition_dict_for_output = {0: [mesh_dim_0, mesh_dim_1]}
|
||||
sharding_spec_for_ouput = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output)
|
||||
|
||||
# generate resharding cost for this strategy
|
||||
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input])
|
||||
|
||||
# compute the computation cost of this strategy
|
||||
compute_cost = self._generate_compute_cost(self.input_data.shape, self.weight.shape)
|
||||
|
||||
# compute the memory cost of this strategy
|
||||
toatl_memory_cost, activation_memory_cost, weight_memory_cost = self._generate_memory_cost(
|
||||
dim_partition_dict_for_output, dim_partition_dict_for_weight)
|
||||
|
||||
# compute the communication cost of this strategy
|
||||
communication_cost = 0
|
||||
sharding_strategies = ShardingStrategy(name,
|
||||
output_sharding_spec=sharding_spec_for_ouput,
|
||||
compute_cost=compute_cost,
|
||||
communication_cost=communication_cost,
|
||||
memory_cost=toatl_memory_cost,
|
||||
resharding_costs=resharding_costs,
|
||||
input_shardings=(sharding_spec_for_input, sharding_spec_for_weight))
|
||||
self.strategies_vector.append(sharding_strategies)
|
||||
|
||||
def split_lhs_2nd_dim_1d(self, mesh_dim_0, mesh_dim_1):
|
||||
name = f'RR = RS{mesh_dim_0}{mesh_dim_1} x S{mesh_dim_0}{mesh_dim_1}R'
|
||||
|
||||
dim_partition_dict_for_input = {1: [mesh_dim_0, mesh_dim_1]}
|
||||
sharding_spec_for_input = self._generate_sharding_spec(self.input_data, dim_partition_dict_for_input)
|
||||
|
||||
dim_partition_dict_for_weight = {0: [mesh_dim_0, mesh_dim_1]}
|
||||
sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight)
|
||||
|
||||
dim_partition_dict_for_output = {}
|
||||
sharding_spec_for_ouput = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output)
|
||||
|
||||
# generate resharding cost for this strategy
|
||||
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input])
|
||||
|
||||
# compute the computation cost of this strategy
|
||||
compute_cost = self._generate_compute_cost(self.input_data.shape, self.weight.shape)
|
||||
|
||||
# compute the memory cost of this strategy
|
||||
toatl_memory_cost, activation_memory_cost, weight_memory_cost = self._generate_memory_cost(
|
||||
dim_partition_dict_for_output, dim_partition_dict_for_weight)
|
||||
|
||||
# compute the communication cost of this strategy
|
||||
communication_cost = self.device_mesh.flatten_device_mesh.all_reduce_cost(activation_memory_cost, 0)
|
||||
sharding_strategies = ShardingStrategy(name,
|
||||
output_sharding_spec=sharding_spec_for_ouput,
|
||||
compute_cost=compute_cost,
|
||||
communication_cost=communication_cost,
|
||||
memory_cost=toatl_memory_cost,
|
||||
resharding_costs=resharding_costs,
|
||||
input_shardings=(sharding_spec_for_input, sharding_spec_for_weight))
|
||||
self.strategies_vector.append(sharding_strategies)
|
||||
|
||||
def split_rhs_2nd_dim_1d(self, mesh_dim_0, mesh_dim_1):
|
||||
name = f'RS{mesh_dim_0}{mesh_dim_1} = RR x RS{mesh_dim_0}{mesh_dim_1}'
|
||||
|
||||
dim_partition_dict_for_input = {}
|
||||
sharding_spec_for_input = self._generate_sharding_spec(self.input_data, dim_partition_dict_for_input)
|
||||
|
||||
dim_partition_dict_for_weight = {1: [mesh_dim_0, mesh_dim_1]}
|
||||
sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight)
|
||||
|
||||
dim_partition_dict_for_output = {1: [mesh_dim_0, mesh_dim_1]}
|
||||
sharding_spec_for_ouput = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output)
|
||||
|
||||
# generate resharding cost for this strategy
|
||||
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input])
|
||||
|
||||
# compute the computation cost of this strategy
|
||||
compute_cost = self._generate_compute_cost(self.input_data.shape, self.weight.shape)
|
||||
|
||||
# compute the memory cost of this strategy
|
||||
toatl_memory_cost, activation_memory_cost, weight_memory_cost = self._generate_memory_cost(
|
||||
dim_partition_dict_for_output, dim_partition_dict_for_weight)
|
||||
|
||||
# compute the communication cost of this strategy
|
||||
communication_cost = 0
|
||||
sharding_strategies = ShardingStrategy(name,
|
||||
output_sharding_spec=sharding_spec_for_ouput,
|
||||
compute_cost=compute_cost,
|
||||
communication_cost=communication_cost,
|
||||
memory_cost=toatl_memory_cost,
|
||||
resharding_costs=resharding_costs,
|
||||
input_shardings=(sharding_spec_for_input, sharding_spec_for_weight))
|
||||
self.strategies_vector.append(sharding_strategies)
|
||||
|
@ -236,4 +321,14 @@ class DotHandler(OperatorHandler):
|
|||
# RS = RR x RS
|
||||
self.split_rhs_space_only(0)
|
||||
self.split_rhs_space_only(1)
|
||||
|
||||
# S01R = S01R x RR
|
||||
self.split_lhs_1st_dim_1d(0, 1)
|
||||
|
||||
# RR = RS01 x S01R
|
||||
self.split_lhs_2nd_dim_1d(0, 1)
|
||||
|
||||
# RS01 = RR x RS01
|
||||
self.split_rhs_2nd_dim_1d(0, 1)
|
||||
|
||||
return self.strategies_vector
|
|
@ -8,7 +8,7 @@ from colossalai.device.device_mesh import DeviceMesh
|
|||
from colossalai.tensor.shape_consistency import ShapeConsistencyManager
|
||||
from colossalai.tensor.sharding_spec import ShardingSpec
|
||||
|
||||
from .sharding_strategy import StrategiesVector
|
||||
from ..sharding_strategy import StrategiesVector
|
||||
|
||||
__all__ = ['OperatorHandler']
|
||||
|
||||
|
@ -70,6 +70,48 @@ class OperatorHandler(ABC):
|
|||
dim_partition_dict=dim_partition_dict)
|
||||
return sharding_spec
|
||||
|
||||
def _generate_memory_cost(self, dim_partition_dict_for_output, dim_partition_dict_for_weight):
|
||||
'''
|
||||
Compute the memory cost per device with this specific strategy.
|
||||
|
||||
Argument:
|
||||
dim_partition_dict_for_output(List[int]): The key is the dimension of output to be sharded,
|
||||
and the value of the key decribe which logical axis will be sharded in that dimension.
|
||||
dim_partition_dict_for_weight(List[int]): The key is the dimension of weight to be sharded,
|
||||
and the value of the key decribe which logical axis will be sharded in that dimension.
|
||||
Return:
|
||||
total_memory_cost(float): total memory cost per device with this specific strategy
|
||||
activation_cost(float): the memory cost of activation per device with this specific strategy
|
||||
weight_memory_cost(float): the memory cost of weight per device with this specific strategy
|
||||
'''
|
||||
# compute the size of one element with specific dtype
|
||||
dtype = self.input_data.dtype
|
||||
size_per_elem_bytes = torch.tensor([], dtype=dtype).element_size()
|
||||
|
||||
# compute the memory cost of activation
|
||||
activation_numel = self.output_data.numel()
|
||||
output_mesh_dims = []
|
||||
for sharding_dim, mesh_dims in dim_partition_dict_for_output.items():
|
||||
output_mesh_dims.extend(mesh_dims)
|
||||
activation_sharding_size = 1
|
||||
for mesh_dim in output_mesh_dims:
|
||||
activation_sharding_size *= self.device_mesh.shape[mesh_dim]
|
||||
activation_memory_cost = activation_numel / activation_sharding_size * size_per_elem_bytes
|
||||
|
||||
# compute the memory cost of weight
|
||||
weight_numel = self.weight.numel()
|
||||
weight_sharding_size = 1
|
||||
weight_mesh_dims = []
|
||||
for sharding_dim, mesh_dims in dim_partition_dict_for_weight.items():
|
||||
weight_mesh_dims.extend(mesh_dims)
|
||||
for mesh_dim in weight_mesh_dims:
|
||||
weight_sharding_size *= self.device_mesh.shape[mesh_dim]
|
||||
weight_memory_cost = weight_numel / weight_sharding_size * size_per_elem_bytes
|
||||
|
||||
total_memory_cost = activation_memory_cost + weight_memory_cost
|
||||
|
||||
return total_memory_cost, activation_memory_cost, weight_memory_cost
|
||||
|
||||
def _generate_resharding_costs(self, sharding_spec_for_input):
|
||||
'''
|
||||
Compute the resharding costs with this specific strategy.
|
||||
|
@ -85,17 +127,20 @@ class OperatorHandler(ABC):
|
|||
'''
|
||||
# The resharding_cost of weight is counted due to sharing weight cases.
|
||||
resharding_costs = {}
|
||||
for input_node, target_spec in zip(self.predecessor_node, sharding_spec_for_input):
|
||||
dtype = self.node._meta_data.dtype
|
||||
size_per_elem_bytes = torch.tensor([], dtype=dtype).element_size()
|
||||
for input_node, input_spec in zip(self.predecessor_node, sharding_spec_for_input):
|
||||
resharding_costs[input_node] = []
|
||||
for strategy in input_node.strategies_vector:
|
||||
input_sharding_spec = strategy.output_sharding_spec
|
||||
assert isinstance(input_sharding_spec, ShardingSpec), f'The input node should NOT be a tuple of tensor.'
|
||||
# compute the resharding cost during forward phase
|
||||
_, _, resharding_cost_forward = self.shape_consistency_manager.shape_consistency(
|
||||
input_sharding_spec, target_spec)
|
||||
input_sharding_spec, input_spec)
|
||||
# In backward phase, we should convert grad with target_spec into input_sharding_spec
|
||||
_, _, resharding_cost_backward = self.shape_consistency_manager.shape_consistency(
|
||||
target_spec, input_sharding_spec)
|
||||
resharding_cost = resharding_cost_forward + resharding_cost_backward
|
||||
input_spec, input_sharding_spec)
|
||||
# we need multiply the size of elem dtype to get correct communication cost
|
||||
resharding_cost = (resharding_cost_forward + resharding_cost_backward) * size_per_elem_bytes
|
||||
resharding_costs[input_node].append(resharding_cost)
|
||||
return resharding_costs
|
|
@ -0,0 +1,444 @@
|
|||
import warnings
|
||||
|
||||
import time
|
||||
import numpy as np
|
||||
import multiprocessing
|
||||
from torch.fx.node import Node
|
||||
from torch.fx.graph import Graph
|
||||
from . import GraphAnalyser
|
||||
from colossalai.auto_parallel.solver.cost_graph import CostGraph
|
||||
from colossalai.auto_parallel.solver.strategies_constructor import StrategiesConstructor
|
||||
from typing import Dict
|
||||
from .constants import INFINITY_COST
|
||||
try:
|
||||
import pulp
|
||||
from pulp import LpVariable, LpProblem, LpMinimize, lpSum, lpDot, LpStatus
|
||||
except:
|
||||
warnings.warn(f'please install the pulp')
|
||||
|
||||
__all___ = ['Solver']
|
||||
|
||||
|
||||
class Solver:
|
||||
|
||||
def __init__(self,
|
||||
graph: Graph,
|
||||
strategies_constructor: StrategiesConstructor,
|
||||
cost_graph: CostGraph,
|
||||
graph_analyser: GraphAnalyser,
|
||||
memory_budget: float = -1.0,
|
||||
solution_numbers: int = 1,
|
||||
memory_increasing_coefficient: float = 1.3):
|
||||
'''
|
||||
Solver class will integrate information provided by the components and use ILP solver to find a possible optimal strategies combination for target computing graph.
|
||||
|
||||
Argument:
|
||||
graph: The computing graph to be optimized.
|
||||
strategies_constructor: It will provide all the possible strategies for each node in the computing graph.
|
||||
cost_graph: A graph data structure to simplify the edge cost graph.
|
||||
graph_analyser: graph_analyser will analyse the graph to obtain the variable liveness information, which will be used to generate memory constraints.
|
||||
memory_budget: Memory constraint for the solution.
|
||||
solution_numbers: If solution_numbers is larger than one, solver will us a serious of solutions based on different memory budget.
|
||||
memory_increasing_coefficient: If solution_numbers is larger than one, we will use this coefficient to generate new memory budget.
|
||||
'''
|
||||
self.graph = graph
|
||||
self.strategies_constructor = strategies_constructor
|
||||
self.cost_graph = cost_graph
|
||||
self.graph_analyser = graph_analyser
|
||||
self.nodes = list(self.graph.nodes)
|
||||
self.leaf_strategies = self.strategies_constructor.leaf_strategies
|
||||
self.strategy_map = self.strategies_constructor.strategy_map
|
||||
self.memory_budget = memory_budget
|
||||
self.solution_numbers = solution_numbers
|
||||
if self.solution_numbers > 1:
|
||||
self.memory_increasing_coefficient = memory_increasing_coefficient
|
||||
else:
|
||||
self.memory_increasing_coefficient = 1
|
||||
self.liveness_list = self.graph_analyser.liveness_analysis()
|
||||
self.node_index_dict = self._generate_node_index_dict()
|
||||
# The last solution vector of auto sharding.
|
||||
self.last_s_val = None
|
||||
# The last objective value of the best ILP solution.
|
||||
self.last_objective = None
|
||||
|
||||
def _generate_node_index_dict(self) -> Dict[Node, int]:
|
||||
node_index_dict = {}
|
||||
for index, strategies_vector in enumerate(self.leaf_strategies):
|
||||
node_index_dict[strategies_vector.node] = index
|
||||
return node_index_dict
|
||||
|
||||
def _prepare_data_for_solver(self):
|
||||
'''
|
||||
Extract information from components for solver.
|
||||
'''
|
||||
node_nums = len(self.leaf_strategies)
|
||||
memory_budget = self.memory_budget
|
||||
|
||||
# prepare strategies_len
|
||||
strategies_len = []
|
||||
for node in self.nodes:
|
||||
strategies_len.append(self.cost_graph.node_lens[node])
|
||||
strategies_len = np.array(strategies_len)
|
||||
|
||||
# prepare following_nodes
|
||||
following_nodes = self.cost_graph.following_dict
|
||||
index_following_nodes = {}
|
||||
for src, target in following_nodes.items():
|
||||
src_index = self.node_index_dict[src]
|
||||
target_index = self.node_index_dict[target]
|
||||
index_following_nodes[src_index] = target_index
|
||||
following_nodes = index_following_nodes
|
||||
for index in range(node_nums):
|
||||
if index not in following_nodes:
|
||||
following_nodes[index] = -1
|
||||
|
||||
# prepare edge_pairs and resharding costs
|
||||
edge_pairs = []
|
||||
resharding_costs = []
|
||||
for pairs, edge_cost in self.cost_graph.edge_costs.items():
|
||||
src_node = pairs[0]
|
||||
dst_node = pairs[1]
|
||||
src_node_index = self.node_index_dict[src_node]
|
||||
dst_node_index = self.node_index_dict[dst_node]
|
||||
edge_pairs.append(src_node_index)
|
||||
edge_pairs.append(dst_node_index)
|
||||
|
||||
for i in range(strategies_len[src_node_index]):
|
||||
for j in range(strategies_len[dst_node_index]):
|
||||
resharding_costs.append(edge_cost[(i, j)])
|
||||
edge_pairs = np.array(edge_pairs)
|
||||
resharding_costs = np.array(resharding_costs)
|
||||
|
||||
# prepare liveness_set
|
||||
liveness_set = self.liveness_list
|
||||
|
||||
# omit alias_set now
|
||||
alias_set = None
|
||||
alias_convert_costs = None
|
||||
|
||||
# prepare compute_costs, communication_costs and memory_costs
|
||||
compute_costs = []
|
||||
communication_costs = []
|
||||
memory_costs = []
|
||||
extra_node_costs = self.cost_graph.extra_node_costs
|
||||
for strategies_vector in self.leaf_strategies:
|
||||
node = strategies_vector.node
|
||||
for index, strategy in enumerate(strategies_vector):
|
||||
compute_costs.append(strategy.compute_cost)
|
||||
# node in extra_node_costs means it has some extra communication
|
||||
# cost from node merging, so we need to add those extra communication
|
||||
# cost into
|
||||
if node in extra_node_costs:
|
||||
origin_communication_cost = strategy.communication_cost
|
||||
extra_node_cost = extra_node_costs[node][index]
|
||||
communication_cost = origin_communication_cost + extra_node_cost
|
||||
communication_costs.append(communication_cost)
|
||||
else:
|
||||
communication_costs.append(strategy.communication_cost)
|
||||
# temporarily we just consider the forward memory cost
|
||||
memory_cost = strategy.memory_cost
|
||||
if isinstance(memory_cost, tuple):
|
||||
memory_costs.append(memory_cost[0])
|
||||
else:
|
||||
memory_costs.append(memory_cost)
|
||||
compute_costs = np.array(compute_costs)
|
||||
communication_costs = np.array(communication_costs)
|
||||
memory_costs = np.array(memory_costs)
|
||||
|
||||
# omit initial value for nodes
|
||||
s_init_np = None
|
||||
|
||||
return node_nums, memory_budget, strategies_len, following_nodes, edge_pairs, alias_set, liveness_set, compute_costs, communication_costs, memory_costs, resharding_costs, alias_convert_costs, s_init_np
|
||||
|
||||
def _call_solver_serialized_args(self,
|
||||
node_nums,
|
||||
memory_budget,
|
||||
strategies_len,
|
||||
following_nodes,
|
||||
edge_pairs,
|
||||
alias_set,
|
||||
liveness_set,
|
||||
compute_costs,
|
||||
communication_costs,
|
||||
memory_costs,
|
||||
resharding_costs,
|
||||
alias_convert_costs,
|
||||
s_init_np=None):
|
||||
"""
|
||||
Call the solver with serialized arguments.
|
||||
"""
|
||||
|
||||
tic = time.time()
|
||||
|
||||
for x in [strategies_len, edge_pairs, compute_costs, communication_costs, memory_costs, resharding_costs]:
|
||||
assert isinstance(x, np.ndarray)
|
||||
assert len(strategies_len) == node_nums, "strategies_len"
|
||||
|
||||
def get_non_zero_index(binary_vector):
|
||||
"""
|
||||
Get the index of non-zero item in a vector.
|
||||
"""
|
||||
ct = 0
|
||||
ret = None
|
||||
for i, elem in enumerate(binary_vector):
|
||||
if pulp.value(elem):
|
||||
ret = i
|
||||
ct += 1
|
||||
|
||||
assert ct == 1
|
||||
return ret
|
||||
|
||||
# 0. Unpack flatten numpy arrays
|
||||
s_follow = following_nodes
|
||||
|
||||
E = edge_pairs.reshape((-1, 2)) # noqa
|
||||
r = []
|
||||
pt = 0
|
||||
edge_set = set()
|
||||
for (i, j) in E:
|
||||
prod_length = strategies_len[i] * strategies_len[j]
|
||||
|
||||
if (i, j) in edge_set:
|
||||
raise ValueError(f"Duplicated edges: {(i, j)}")
|
||||
|
||||
edge_set.add((i, j))
|
||||
r.append(resharding_costs[pt:pt + prod_length])
|
||||
pt += prod_length
|
||||
assert pt == len(resharding_costs)
|
||||
|
||||
######################
|
||||
# omit alias set now #
|
||||
######################
|
||||
|
||||
# A = alias_set.reshape((-1, 2)) # noqa
|
||||
# for (i, j) in A:
|
||||
# prod_length = strategies_len[i] * strategies_len[j]
|
||||
# v.append(alias_convert_costs[pt:pt + prod_length])
|
||||
# pt += prod_length
|
||||
# assert pt == len(alias_convert_costs)
|
||||
|
||||
# L = [] # noqa
|
||||
# pt = node_nums
|
||||
# for i in range(node_nums):
|
||||
# length = liveness_set[i]
|
||||
# L.append(liveness_set[pt:pt + length])
|
||||
# pt += length
|
||||
# assert pt == len(liveness_set)
|
||||
v = []
|
||||
pt = 0
|
||||
|
||||
c = []
|
||||
d = []
|
||||
m = []
|
||||
pt = 0
|
||||
for i in range(node_nums):
|
||||
length = strategies_len[i]
|
||||
c.append(compute_costs[pt:pt + length])
|
||||
d.append(communication_costs[pt:pt + length])
|
||||
m.append(memory_costs[pt:pt + length])
|
||||
pt += length
|
||||
assert pt == len(compute_costs), f"{pt} == {len(compute_costs)}"
|
||||
assert pt == len(communication_costs), f"{pt} == {len(communication_costs)}"
|
||||
assert pt == len(memory_costs), f"{pt} == {len(memory_costs)}"
|
||||
|
||||
# 1. Create variables
|
||||
|
||||
#############################
|
||||
# create variables for node #
|
||||
#############################
|
||||
s = []
|
||||
num_nodes = 0
|
||||
reverse_follow_backpatch = []
|
||||
for i in range(node_nums):
|
||||
if s_follow[i] < 0:
|
||||
if strategies_len[i] == 1:
|
||||
s.append([1])
|
||||
else:
|
||||
num_nodes += 1
|
||||
s.append(LpVariable.matrix(f"s[{i}]", (range(strategies_len[i]),), cat="Binary"))
|
||||
else:
|
||||
if s_follow[i] < len(s):
|
||||
s.append(s[s_follow[i]])
|
||||
else:
|
||||
s.append(None)
|
||||
reverse_follow_backpatch.append(i)
|
||||
|
||||
for i in reverse_follow_backpatch:
|
||||
s[i] = s[s_follow[i]]
|
||||
|
||||
#############################
|
||||
# create variables for edge #
|
||||
#############################
|
||||
e = []
|
||||
num_edges = 0
|
||||
for (idx, (i, j)) in enumerate(E):
|
||||
if len(s[i]) == 1:
|
||||
e.append(s[j])
|
||||
elif len(s[j]) == 1:
|
||||
e.append(s[i])
|
||||
else:
|
||||
num_edges += 1
|
||||
e.append(LpVariable.matrix(f"e[{i},{j}]", (range(len(s[i]) * len(s[j])),), cat="Binary"))
|
||||
assert len(e[idx]) == len(r[idx])
|
||||
|
||||
# 2. Set initial value
|
||||
######################################
|
||||
# set a initial value for warm start #
|
||||
######################################
|
||||
if s_init_np is not None:
|
||||
s_init = s_init_np.reshape((-1, 3))
|
||||
for (idx, value, fix) in s_init:
|
||||
for i in range(len(s[idx])):
|
||||
s[idx][i].setInitialValue(i == value)
|
||||
if fix:
|
||||
s[idx][i].fixValue()
|
||||
|
||||
# 3. Objective
|
||||
prob = LpProblem("myProblem", LpMinimize)
|
||||
###################################################################
|
||||
# computing the node cost(computing cost and communication cost) #
|
||||
###################################################################
|
||||
obj = 0
|
||||
for i in range(node_nums):
|
||||
obj += lpDot(s[i], c[i]) + lpDot(s[i], d[i])
|
||||
|
||||
#############################################
|
||||
# computing the edge cost(resharding cost) #
|
||||
#############################################
|
||||
for i in range(len(E)):
|
||||
obj += lpDot(e[i], r[i])
|
||||
|
||||
prob += obj
|
||||
|
||||
# 4. Constraints
|
||||
# (a). specified by `cat="Binary"`
|
||||
|
||||
# (b)
|
||||
#################################################
|
||||
# make sure each node only choose one strategy #
|
||||
#################################################
|
||||
for i in range(node_nums):
|
||||
if s_follow[i] < 0:
|
||||
prob += lpSum(s[i]) == 1
|
||||
|
||||
# (c)
|
||||
#################################################
|
||||
# compute memory consumption with liveness set #
|
||||
#################################################
|
||||
if memory_budget > 0:
|
||||
for liveness_stage in liveness_set:
|
||||
mem = 0
|
||||
for live_variable in liveness_stage.unique_live_vars:
|
||||
node_index = self.node_index_dict[live_variable.node]
|
||||
mem += lpSum(s[node_index][j] * m[node_index][j] for j in range(len(s[node_index])))
|
||||
prob += mem <= memory_budget
|
||||
|
||||
# (d). specified by `cat="Binary"`
|
||||
|
||||
for (idx, (i, j)) in enumerate(E):
|
||||
if strategies_len[i] == 1 or strategies_len[j] == 1:
|
||||
continue
|
||||
|
||||
# (e)
|
||||
prob += lpSum(e[idx]) == 1
|
||||
|
||||
# (f)
|
||||
for row in range(len(s[i])):
|
||||
C = len(s[j]) # noqa
|
||||
prob += lpSum(e[idx][row * C + col] for col in range(0, C)) <= s[i][row]
|
||||
|
||||
# (g)
|
||||
for col in range(len(s[j])):
|
||||
R = len(s[i]) # noqa
|
||||
C = len(s[j]) # noqa
|
||||
prob += lpSum(e[idx][row * C + col] for row in range(0, R)) <= s[j][col]
|
||||
|
||||
# (h)
|
||||
######################
|
||||
# omit alias set now #
|
||||
######################
|
||||
|
||||
# alias_set = set()
|
||||
# for (idx, (i, j)) in enumerate(A):
|
||||
# R = len(s[i]) # noqa
|
||||
# C = len(s[j]) # noqa
|
||||
# if (i, j) in alias_set:
|
||||
# raise ValueError(f"Duplicated edges: {(i, j)}")
|
||||
|
||||
# alias_set.add((i, j))
|
||||
# alias_set.add((j, i))
|
||||
|
||||
# for row in range(len(s[i])):
|
||||
# for col in range(len(s[j])):
|
||||
# if v[idx][row * C + col] > 0.5:
|
||||
# prob += s[i][row] + s[j][col] <= 1
|
||||
|
||||
verbose = True
|
||||
|
||||
msg = verbose
|
||||
time_limit = 600
|
||||
assert "COIN_CMD" in pulp.listSolvers(
|
||||
onlyAvailable=True), ("Please install ILP solvers by 'sudo apt install coinor-cbc'")
|
||||
|
||||
solver = pulp.COIN_CMD(mip=True, msg=msg, timeLimit=time_limit, threads=multiprocessing.cpu_count())
|
||||
# solver = pulp.GLPK_CMD(mip=True, msg=msg, timeLimit=time_limit)
|
||||
prob.solve(solver)
|
||||
|
||||
status = prob.status
|
||||
objective = pulp.value(prob.objective)
|
||||
objective = float(objective) if objective is not None else -1.0
|
||||
if verbose:
|
||||
print(f"ILP Status: {LpStatus[status]}\tObjective: {objective}\t"
|
||||
f"Time: {time.time() - tic}")
|
||||
print(f"#nodes: {num_nodes}, #edges: {num_edges}")
|
||||
|
||||
if prob.status in [pulp.LpStatusInfeasible]:
|
||||
raise RuntimeError("Cannot run the function under the given memory budget. "
|
||||
"Please increase the memory budget.")
|
||||
|
||||
# Get and check results
|
||||
s_val = np.full((node_nums,), -1, dtype=np.int32)
|
||||
for i in range(node_nums):
|
||||
s_val[i] = get_non_zero_index(s[i])
|
||||
|
||||
e_val = np.full((len(E),), -1, dtype=np.int32)
|
||||
for (idx, (i, j)) in enumerate(E):
|
||||
e_val[idx] = get_non_zero_index(e[idx])
|
||||
i_spec_index = e_val[idx] // len(s[j])
|
||||
j_spec_index = e_val[idx] % len(s[j])
|
||||
assert i_spec_index == s_val[i], f"e_val[{i}][{j}]"
|
||||
assert j_spec_index == s_val[j], f"e_val[{i}][{j}]"
|
||||
if verbose and r[idx][e_val[idx]] > 0:
|
||||
print(f"Edge cost {(i, j)} : {r[idx][e_val[idx]]}")
|
||||
|
||||
self.last_s_val = s_val
|
||||
self.last_objective = objective
|
||||
|
||||
if objective > INFINITY_COST:
|
||||
warnings.warn("Detect unexpected behaviors in the auto-sharding pass.")
|
||||
|
||||
return s_val, e_val, objective, status
|
||||
|
||||
def call_solver_serialized_args(self):
|
||||
"""
|
||||
Call the solver with serialized arguments and handle python errors. Additionally,
|
||||
we could give a serious of solutions with different memory budget.
|
||||
"""
|
||||
if self.solution_numbers == 1:
|
||||
args = self._prepare_data_for_solver()
|
||||
ret = self._call_solver_serialized_args(*args)
|
||||
|
||||
return ret
|
||||
|
||||
origin_memory_budget = self.memory_budget
|
||||
memory_budget_list = [
|
||||
origin_memory_budget * self.memory_increasing_coefficient**i for i in range(self.solution_numbers)
|
||||
]
|
||||
ret_list = []
|
||||
for memory_budget in memory_budget_list:
|
||||
self.memory_budget = memory_budget
|
||||
args = self._prepare_data_for_solver()
|
||||
ret = self._call_solver_serialized_args(*args)
|
||||
ret_list.append(ret)
|
||||
|
||||
return ret_list
|
|
@ -1,7 +1,7 @@
|
|||
from torch.fx import Graph, Node
|
||||
from colossalai.tensor.sharding_spec import ShardingSpec
|
||||
from .sharding_strategy import ShardingStrategy, StrategiesVector
|
||||
from .conv_handler import ConvHandler
|
||||
from . import ShardingStrategy, StrategiesVector
|
||||
from .op_handler import *
|
||||
from .constants import *
|
||||
from copy import deepcopy
|
||||
import math
|
||||
|
@ -175,6 +175,58 @@ class StrategiesConstructor:
|
|||
input_shardings=[input_sharding_spec])
|
||||
strategies_vector.append(sharding_strategy)
|
||||
|
||||
# BatchNormNd module
|
||||
elif submod_type in BATCHNORM_MODULE_OP:
|
||||
# bn1 call_module bn1 (conv1,)
|
||||
# print(node, node.op, node.target, node.args)
|
||||
# create sharding strategy for element-wise module
|
||||
# input_node = strategies_vector.predecessor_nodes[0]
|
||||
norm_handler = BatchNormHandler(node, self.device_mesh, strategies_vector,
|
||||
self.shape_consistency_manager)
|
||||
norm_handler.register_strategy()
|
||||
# for strategy in norm_handler.strategies_vector:
|
||||
# print(f'{strategy.name}, computation_cost: {strategy.compute_cost}, memory_cost: {strategy.memory_cost}')
|
||||
# assert False
|
||||
|
||||
# MaxPool module
|
||||
elif submod_type in POOL_MODULE_OP:
|
||||
# create sharding strategy for element-wise module
|
||||
assert len(strategies_vector.predecessor_nodes
|
||||
) == 1, f'Temporally, we just support single input element-wise op.'
|
||||
input_node = strategies_vector.predecessor_nodes[0]
|
||||
# For element-wise module, we keep the sharding spec of output node same as
|
||||
# the input. Therefore, the different strategies of input node with same
|
||||
# output sharding spec will generate same strategy for element-wise module.
|
||||
sharding_spec_checklist = []
|
||||
for strategy in input_node.strategies_vector:
|
||||
# It looks a little bit confusing, the input of the processing node
|
||||
# is the output of the input_node.
|
||||
input_sharding_spec = strategy.output_sharding_spec
|
||||
assert isinstance(input_sharding_spec,
|
||||
ShardingSpec), f'The input node should NOT be a tuple of tensor.'
|
||||
if input_sharding_spec in sharding_spec_checklist:
|
||||
continue
|
||||
|
||||
sharding_spec_checklist.append(input_sharding_spec)
|
||||
dim_partition_dict = deepcopy(input_sharding_spec.dim_partition_dict)
|
||||
output_sharding_spec = self._generate_sharding_spec(node, dim_partition_dict)
|
||||
|
||||
name = f'{input_sharding_spec.sharding_sequence} -> {output_sharding_spec.sharding_sequence}'
|
||||
|
||||
# TODO: use meta_info_prop to profile memory cost and compute cost
|
||||
compute_cost = node._meta_data.numel()
|
||||
memory_cost = 0
|
||||
resharding_costs = self._generate_resharding_costs(strategies_vector.predecessor_nodes,
|
||||
[input_sharding_spec])
|
||||
|
||||
sharding_strategy = ShardingStrategy(name,
|
||||
output_sharding_spec,
|
||||
compute_cost=compute_cost,
|
||||
memory_cost=memory_cost,
|
||||
resharding_costs=resharding_costs,
|
||||
input_shardings=[input_sharding_spec])
|
||||
strategies_vector.append(sharding_strategy)
|
||||
|
||||
# other module
|
||||
else:
|
||||
raise RuntimeError(f'{submod_type} module is NOT supported now.')
|
||||
|
@ -203,7 +255,7 @@ class StrategiesConstructor:
|
|||
# TODO: integrate element-wise func and module together
|
||||
# create sharding strategy for element-wise function
|
||||
assert len(strategies_vector.predecessor_nodes
|
||||
) == 1, f'Temporally, we just support single input element-wise op.'
|
||||
) == 1, f'Temporally, we just support single input element-wise op, node name is {node}.'
|
||||
input_node = strategies_vector.predecessor_nodes[0]
|
||||
# For element-wise function, we keep the sharding spec of output node same as
|
||||
# the input. Therefore, the different strategies of input node with same
|
||||
|
@ -349,6 +401,13 @@ class StrategiesConstructor:
|
|||
memory_cost = 0
|
||||
resharding_costs = self._generate_resharding_costs(strategies_vector.predecessor_nodes,
|
||||
input_sharding_specs)
|
||||
|
||||
# clear the resharding cost for the output node
|
||||
# TODO: we may remove this in final version
|
||||
if True:
|
||||
for prev_node, resharding_cost_list in resharding_costs.items():
|
||||
resharding_costs[prev_node] = [0] * len(resharding_cost_list)
|
||||
|
||||
sharding_strategy_attribute = ShardingStrategy(name,
|
||||
output_sharding_spec,
|
||||
memory_cost=memory_cost,
|
||||
|
|
|
@ -0,0 +1,119 @@
|
|||
import torch
|
||||
from torch.fx import GraphModule
|
||||
import torch.nn as nn
|
||||
import pytest
|
||||
|
||||
from colossalai.fx.proxy import ColoProxy
|
||||
from colossalai.fx.tracer.tracer import ColoTracer
|
||||
from colossalai.tensor.sharding_spec import ShardingSpec, _DimSpec
|
||||
from colossalai.auto_parallel.solver.op_handler.batch_norm_handler import BatchNormHandler
|
||||
from colossalai.auto_parallel.solver.sharding_strategy import ShardingStrategy, StrategiesVector
|
||||
from colossalai.tensor.shape_consistency import ShapeConsistencyManager
|
||||
from colossalai.device.device_mesh import DeviceMesh
|
||||
|
||||
|
||||
class BNModel(nn.Module):
|
||||
|
||||
def __init__(self, c):
|
||||
super().__init__()
|
||||
self.bn = nn.BatchNorm2d(c)
|
||||
|
||||
def forward(self, x):
|
||||
x = x * 2
|
||||
x = self.bn(x)
|
||||
return x
|
||||
|
||||
|
||||
def test_bn_handler():
|
||||
physical_mesh_id = torch.arange(0, 4)
|
||||
mesh_shape = (2, 2)
|
||||
# [[0, 1]
|
||||
# [2, 3]]
|
||||
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape)
|
||||
entire_shape = torch.Size((4, 16, 64, 64))
|
||||
shape_consistency_manager = ShapeConsistencyManager()
|
||||
|
||||
tracer = ColoTracer()
|
||||
model = BNModel(16)
|
||||
input_sample = {'x': torch.rand(4, 16, 64, 64).to('meta')}
|
||||
# graph():
|
||||
# %x : torch.Tensor [#users=1] = placeholder[target=x]
|
||||
# %mul : [#users=1] = call_function[target=operator.mul](args = (%x, 2), kwargs = {})
|
||||
# %bn : [#users=1] = call_module[target=bn](args = (%mul,), kwargs = {})
|
||||
# return bn
|
||||
graph = tracer.trace(root=model, meta_args=input_sample)
|
||||
gm = GraphModule(model, graph, model.__class__.__name__)
|
||||
gm.recompile()
|
||||
# [x, mul, bn, output]
|
||||
nodes = [node for node in gm.graph.nodes]
|
||||
|
||||
# find the sharding strategies for the input node of the bn node
|
||||
# strategies_for_input = [[R, R, R, R], [R, S0, R, R], [R, S1, R, R], [S0, R, R, R], [S0, S1, R, R], [S1, R, R, R], [S1, S0, R, R]]
|
||||
strategies_vector_for_input = StrategiesVector(nodes[1])
|
||||
sharding_option = (None, 0, 1)
|
||||
for first_sharding_index in sharding_option:
|
||||
for second_sharding_index in sharding_option:
|
||||
if first_sharding_index is not None and second_sharding_index == first_sharding_index:
|
||||
continue
|
||||
if first_sharding_index is None:
|
||||
first_dim_spec = _DimSpec([])
|
||||
else:
|
||||
first_dim_spec = _DimSpec([first_sharding_index])
|
||||
|
||||
if second_sharding_index is None:
|
||||
second_dim_spec = _DimSpec([])
|
||||
else:
|
||||
second_dim_spec = _DimSpec([second_sharding_index])
|
||||
|
||||
replica_dim_spec = _DimSpec([])
|
||||
sharding_sequence = [first_dim_spec, second_dim_spec, replica_dim_spec, replica_dim_spec]
|
||||
sharding_spec = ShardingSpec(device_mesh=device_mesh,
|
||||
entire_shape=entire_shape,
|
||||
sharding_sequence=sharding_sequence)
|
||||
strategy_name = str(sharding_spec.sharding_sequence)
|
||||
sharding_strategy = ShardingStrategy(name=strategy_name, output_sharding_spec=sharding_spec)
|
||||
strategies_vector_for_input.append(sharding_strategy)
|
||||
setattr(nodes[1], 'strategies_vector', strategies_vector_for_input)
|
||||
|
||||
# generate bn strategy
|
||||
strategies_vector = StrategiesVector(node=nodes[2])
|
||||
bn_handler = BatchNormHandler(node=nodes[2],
|
||||
device_mesh=device_mesh,
|
||||
strategies_vector=strategies_vector,
|
||||
shape_consistency_manager=shape_consistency_manager)
|
||||
bn_handler.register_strategy()
|
||||
# ['RS0 = RS0 x S0', 'S1S0 = RS0 x S0', 'RS1 = RS1 x S1', 'S0S1 = RS1 x S1', 'RR = RR x R', 'S0R = RR x R', 'S1R = RR x R', 'S01R = RR x R', 'RS01 = RS01 x S01',
|
||||
# 'S0R = S0R x R WITH SYNC_BN', 'S1R = S1R x R WITH SYNC_BN', 'S0S1 = S0S1 x S1 WITH SYNC_BN', 'S1S0 = S1S0 x S0 WITH SYNC_BN', 'S01R = S01R x R WITH SYNC_BN']
|
||||
strategy_name_list = [strategy.name for strategy in bn_handler.strategies_vector]
|
||||
|
||||
# RS = RS x S and strategies based on it, such as
|
||||
# SS = RS x S
|
||||
assert 'RS0 = RS0 x S0' in strategy_name_list
|
||||
assert 'S1S0 = RS0 x S0' in strategy_name_list
|
||||
assert 'RS1 = RS1 x S1' in strategy_name_list
|
||||
assert 'S0S1 = RS1 x S1' in strategy_name_list
|
||||
|
||||
# RR = RR x R and strategies based on it, such as
|
||||
# SR = SR x R
|
||||
assert 'RR = RR x R' in strategy_name_list
|
||||
assert 'S0R = RR x R' in strategy_name_list
|
||||
assert 'S1R = RR x R' in strategy_name_list
|
||||
assert 'S01R = RR x R' in strategy_name_list
|
||||
|
||||
# RS01 = RS01 x S01
|
||||
assert 'RS01 = RS01 x S01' in strategy_name_list
|
||||
|
||||
# SR = SR x R WITH SYNC_BN
|
||||
assert 'S0R = S0R x R WITH SYNC_BN' in strategy_name_list
|
||||
assert 'S1R = S1R x R WITH SYNC_BN' in strategy_name_list
|
||||
|
||||
# SS = SS x S WITH SYNC_BN
|
||||
assert 'S0S1 = S0S1 x S1 WITH SYNC_BN' in strategy_name_list
|
||||
assert 'S1S0 = S1S0 x S0 WITH SYNC_BN' in strategy_name_list
|
||||
|
||||
# S01R = S01R x R WITH SYNC_BN
|
||||
assert 'S01R = S01R x R WITH SYNC_BN' in strategy_name_list
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_bn_handler()
|
|
@ -6,7 +6,7 @@ import pytest
|
|||
from colossalai.fx.proxy import ColoProxy
|
||||
from colossalai.fx.tracer.tracer import ColoTracer
|
||||
from colossalai.tensor.sharding_spec import ShardingSpec, _DimSpec
|
||||
from colossalai.auto_parallel.solver.conv_handler import ConvHandler
|
||||
from colossalai.auto_parallel.solver.op_handler.conv_handler import ConvHandler
|
||||
from colossalai.auto_parallel.solver.sharding_strategy import ShardingStrategy, StrategiesVector
|
||||
from colossalai.tensor.shape_consistency import ShapeConsistencyManager
|
||||
from colossalai.device.device_mesh import DeviceMesh
|
||||
|
|
|
@ -6,8 +6,6 @@ import pytest
|
|||
from colossalai.fx.proxy import ColoProxy
|
||||
from colossalai.fx.tracer.tracer import ColoTracer
|
||||
from colossalai.tensor.sharding_spec import ShardingSpec, _DimSpec
|
||||
from colossalai.auto_parallel.solver.conv_handler import ConvHandler, CONV_STRATEGIES_LIST
|
||||
from colossalai.auto_parallel.solver.sharding_strategy import ShardingStrategy, StrategiesVector
|
||||
from colossalai.tensor.shape_consistency import ShapeConsistencyManager
|
||||
from colossalai.device.device_mesh import DeviceMesh
|
||||
from colossalai.auto_parallel.solver.strategies_constructor import StrategiesConstructor
|
||||
|
|
|
@ -6,7 +6,7 @@ import pytest
|
|||
from colossalai.fx.proxy import ColoProxy
|
||||
from colossalai.fx.tracer.tracer import ColoTracer
|
||||
from colossalai.tensor.sharding_spec import ShardingSpec, _DimSpec
|
||||
from colossalai.auto_parallel.solver.dot_handler import DotHandler
|
||||
from colossalai.auto_parallel.solver.op_handler.dot_handler import DotHandler
|
||||
from colossalai.auto_parallel.solver.sharding_strategy import ShardingStrategy, StrategiesVector
|
||||
from colossalai.tensor.shape_consistency import ShapeConsistencyManager
|
||||
from colossalai.device.device_mesh import DeviceMesh
|
||||
|
|
|
@ -32,21 +32,21 @@ def test_liveness_analysis():
|
|||
gm = ColoGraphModule(root=model, graph=graph, class_name=model.__class__.__name__)
|
||||
|
||||
graph_analyser = GraphAnalyser(gm)
|
||||
liveness_dict = graph_analyser.liveness_analysis()
|
||||
stage_count = len(liveness_dict)
|
||||
liveness_list = graph_analyser.liveness_analysis()
|
||||
stage_count = len(liveness_list)
|
||||
|
||||
# 8 stages including input and output
|
||||
assert stage_count == 8
|
||||
# if a LiveStage is covered by another LiveStage, we just keep the larger one.
|
||||
assert stage_count == 1
|
||||
|
||||
# a variable named `relu` must exist
|
||||
# and this live var must have inplace = True
|
||||
assert liveness_dict[5].all_live_vars.exists('relu')
|
||||
relu_var = liveness_dict[5].all_live_vars.get('relu')
|
||||
assert liveness_list[0].all_live_vars.exists('relu')
|
||||
relu_var = liveness_list[0].all_live_vars.get('relu')
|
||||
assert relu_var.is_inplace
|
||||
|
||||
# the unique vars must be fewer than the all vars since in-place ops exist
|
||||
all_live_vars = liveness_dict[7].all_live_vars
|
||||
unique_live_vars = liveness_dict[7].unique_live_vars
|
||||
all_live_vars = liveness_list[0].all_live_vars
|
||||
unique_live_vars = liveness_list[0].unique_live_vars
|
||||
assert len(unique_live_vars) + 1 == len(all_live_vars)
|
||||
|
||||
|
||||
|
|
|
@ -0,0 +1,78 @@
|
|||
import torch
|
||||
from torch.fx import GraphModule
|
||||
import torch.nn as nn
|
||||
import pytest
|
||||
|
||||
from colossalai.fx.tracer.tracer import ColoTracer
|
||||
from colossalai.tensor.shape_consistency import ShapeConsistencyManager
|
||||
from colossalai.device.device_mesh import DeviceMesh
|
||||
from colossalai.auto_parallel.solver.strategies_constructor import StrategiesConstructor
|
||||
from colossalai.auto_parallel.solver.cost_graph import CostGraph
|
||||
from colossalai.auto_parallel.solver.graph_analysis import GraphAnalyser
|
||||
from copy import deepcopy
|
||||
from colossalai.auto_parallel.solver import Solver
|
||||
|
||||
|
||||
class ConvModel(nn.Module):
|
||||
|
||||
def __init__(self, c_in, c_out):
|
||||
super().__init__()
|
||||
self.conv1 = nn.Conv2d(c_in, c_out, kernel_size=3)
|
||||
self.conv2 = nn.Conv2d(c_out, c_out, kernel_size=3)
|
||||
self.conv3 = nn.Conv2d(c_out, c_out, kernel_size=3)
|
||||
self.relu = nn.ReLU()
|
||||
|
||||
def forward(self, x):
|
||||
x = x * 2
|
||||
x = self.conv1(x)
|
||||
x = self.conv2(x)
|
||||
x = x / 2
|
||||
x = self.conv3(x)
|
||||
x = self.relu(x)
|
||||
return x
|
||||
|
||||
|
||||
@pytest.mark.skip("for higher testing speed")
|
||||
def test_solver():
|
||||
physical_mesh_id = torch.arange(0, 4)
|
||||
mesh_shape = (2, 2)
|
||||
# [[0, 1]
|
||||
# [2, 3]]
|
||||
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape)
|
||||
entire_shape = torch.Size((4, 16, 64, 64))
|
||||
shape_consistency_manager = ShapeConsistencyManager()
|
||||
|
||||
tracer = ColoTracer()
|
||||
model = ConvModel(16, 32)
|
||||
input_sample = {'x': torch.rand(4, 16, 64, 64).to('meta')}
|
||||
|
||||
# graph():
|
||||
# %x : torch.Tensor [#users=1] = placeholder[target=x]
|
||||
# %mul : [#users=1] = call_function[target=operator.mul](args = (%x, 2), kwargs = {})
|
||||
# %conv1 : [#users=1] = call_module[target=conv1](args = (%mul,), kwargs = {})
|
||||
# %conv2 : [#users=1] = call_module[target=conv2](args = (%conv1,), kwargs = {})
|
||||
# %truediv : [#users=1] = call_function[target=operator.truediv](args = (%conv2, 2), kwargs = {})
|
||||
# %conv3 : [#users=1] = call_module[target=conv3](args = (%truediv,), kwargs = {})
|
||||
# %relu : [#users=1] = call_module[target=relu](args = (%conv3,), kwargs = {})
|
||||
# return relu
|
||||
graph = tracer.trace(root=model, meta_args=input_sample)
|
||||
gm = GraphModule(model, graph, model.__class__.__name__)
|
||||
gm.recompile()
|
||||
|
||||
solver_options = {'fast_mode': True}
|
||||
strategies_constructor = StrategiesConstructor(graph, device_mesh, shape_consistency_manager, solver_options)
|
||||
strategies_constructor.build_strategies_and_cost()
|
||||
|
||||
cost_graph = CostGraph(strategies_constructor.leaf_strategies)
|
||||
cost_graph.simplify_graph()
|
||||
graph_analyser = GraphAnalyser(gm)
|
||||
solver = Solver(gm.graph, strategies_constructor, cost_graph, graph_analyser)
|
||||
ret = solver.call_solver_serialized_args()
|
||||
|
||||
# [ 0 0 13 13 13 13 13 0]
|
||||
strategies_combination_list = ret[0]
|
||||
assert solver.leaf_strategies[2][13].name == 'S01R = S01R x RR'
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_solver()
|
|
@ -6,7 +6,7 @@ import pytest
|
|||
from colossalai.fx.proxy import ColoProxy
|
||||
from colossalai.fx.tracer.tracer import ColoTracer
|
||||
from colossalai.tensor.sharding_spec import ShardingSpec, _DimSpec
|
||||
from colossalai.auto_parallel.solver.conv_handler import ConvHandler, CONV_STRATEGIES_LIST
|
||||
from colossalai.auto_parallel.solver.op_handler.conv_handler import CONV_STRATEGIES_LIST
|
||||
from colossalai.auto_parallel.solver.sharding_strategy import ShardingStrategy, StrategiesVector
|
||||
from colossalai.tensor.shape_consistency import ShapeConsistencyManager
|
||||
from colossalai.device.device_mesh import DeviceMesh
|
||||
|
|
Loading…
Reference in New Issue