[autoparallel] adapt solver with resnet (#1583)

* [autoparallel]adapt solver with resnet

* polish code

* polish code
pull/1588/head
YuliangLiu0306 2022-09-13 12:07:09 +08:00 committed by GitHub
parent f3403ff98e
commit 82d4376c23
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
18 changed files with 1515 additions and 161 deletions

View File

@ -1,7 +1,8 @@
from .operator_handler import OperatorHandler
from .dot_handler import DotHandler
from .conv_handler import ConvHandler
from .sharding_strategy import ShardingStrategy, StrategiesVector
from .graph_analysis import GraphAnalyser
from .solver import Solver
from .cost_graph import CostGraph
from .strategies_constructor import StrategiesConstructor
from .constants import *
__all__ = ['OperatorHandler', 'DotHandler', 'ConvHandler', 'StrategiesVector', 'ShardingStrategy', 'GraphAnalyser']
__all__ = ['StrategiesVector', 'ShardingStrategy', 'GraphAnalyser', 'Solver', 'StrategiesConstructor', 'CostGraph']

View File

@ -3,13 +3,14 @@ import operator
__all__ = [
'ELEMENTWISE_MODULE_OP', 'ELEMENTWISE_FUNC_OP', 'CONV_MODULE_OP', 'CONV_FUNC_OP', 'LINEAR_MODULE_OP',
'LINEAR_FUNC_OP'
'LINEAR_FUNC_OP', 'BATCHNORM_MODULE_OP', 'POOL_MODULE_OP'
]
ELEMENTWISE_MODULE_OP = [torch.nn.Dropout, torch.nn.ReLU]
# TODO: flatten should not be added into this group
ELEMENTWISE_FUNC_OP = [
torch.add, operator.add, torch.abs, torch.cos, torch.exp, torch.mul, operator.mul, operator.floordiv,
operator.truediv, operator.neg, torch.multiply, torch.nn.functional.relu, torch.nn.functional.dropout
operator.truediv, operator.neg, torch.multiply, torch.nn.functional.relu, torch.nn.functional.dropout, torch.flatten
]
CONV_MODULE_OP = [
torch.nn.Conv1d, torch.nn.Conv2d, torch.nn.Conv3d, torch.nn.ConvTranspose1d, torch.nn.ConvTranspose2d,
@ -20,3 +21,7 @@ CONV_FUNC_OP = [
]
LINEAR_MODULE_OP = [torch.nn.Linear]
LINEAR_FUNC_OP = [torch.nn.functional.linear, torch.matmul, torch.bmm]
BATCHNORM_MODULE_OP = [torch.nn.BatchNorm1d, torch.nn.BatchNorm2d, torch.nn.BatchNorm3d, torch.nn.SyncBatchNorm]
POOL_MODULE_OP = [torch.nn.MaxPool1d, torch.nn.MaxPool2d, torch.nn.MaxPool3d, torch.nn.AdaptiveAvgPool2d]
INFINITY_COST = 1e13

View File

@ -1,4 +1,3 @@
from colossalai.auto_parallel.solver.sharding_strategy import ShardingStrategy, StrategiesVector
from typing import List
import math
from torch.fx.node import Node

View File

@ -15,7 +15,7 @@ class LiveVariable:
LiveVariable is a data structure to store the meta information of a variable for liveness analysis.
"""
name: str
meta: Union[Any, List[Any]]
node: Node
is_inplace: bool
@ -80,13 +80,13 @@ class GraphAnalyser:
"""
return self._graph
def liveness_analysis(self) -> OrderedDict[int, LiveStage]:
def liveness_analysis(self) -> List[LiveStage]:
"""
Analyse the graph to obtain the variable liveness information. This function returns
an ordered dictionary where the key is the compute stage ID and the value is a LivenessStage object.
"""
compute_nodes = self.graph.nodes
liveness_dict = ODict()
liveness_list = []
# checked: record all variables created since the first stage
# all: record the live variables only exist until the current stage.
@ -97,25 +97,6 @@ class GraphAnalyser:
all_live_variables = LiveVariableVector()
unique_live_vars = LiveVariableVector()
def _add_param_or_buf(node, tensor_type):
module = get_node_module(node)
if tensor_type == 'param':
iterator = module.named_parameters()
elif tensor_type == 'buffer':
iterator = module.named_buffers()
else:
raise ValueError(f"Expected tensor_type to be param or buffer, but got {tensor_type}")
for name, tensor in iterator:
tensor_name = f'{node.name}.{name}'
if not checked_variables.exists(tensor_name):
live_tensor = LiveVariable(name=tensor_name, meta=tensor.to('meta'), is_inplace=False)
unique_live_vars.append(live_tensor)
checked_variables.append(live_tensor)
all_live_variables.append(live_tensor)
for idx, node in enumerate(compute_nodes):
#############################
# find new living variables #
@ -135,26 +116,19 @@ class GraphAnalyser:
# add the output var
meta = getattr(node, '_meta_data', None)
live_var = LiveVariable(name=node.name, meta=meta, is_inplace=is_inplace)
live_var = LiveVariable(name=node.name, node=node, is_inplace=is_inplace)
if not is_inplace:
unique_live_vars.append(live_var)
checked_variables.append(live_var)
all_live_variables.append(live_var)
# add the model parameters
if node.op == 'call_module':
_add_param_or_buf(node, tensor_type='param')
_add_param_or_buf(node, tensor_type='buffer')
# add this output variable to the checked list
checked_variables.append(live_var)
# check if any input is not checked yet
for arg in node.args:
arg_name = str(arg)
if not isinstance(arg, Node):
continue
arg_name = arg.name
if not checked_variables.exists(arg_name):
meta = getattr(node, '_meta_data', None)
live_var_from_arg = LiveVariable(name=arg_name, meta=meta, is_inplace=False)
live_var_from_arg = LiveVariable(name=arg_name, node=node, is_inplace=False)
all_live_variables.append(live_var_from_arg)
checked_variables.append(live_var_from_arg)
unique_live_vars.append(live_var_from_arg)
@ -167,8 +141,23 @@ class GraphAnalyser:
node=node,
all_live_vars=all_live_variables.copy(),
unique_live_vars=unique_live_vars.copy())
liveness_dict[idx] = stage
return liveness_dict
# if a LiveStage is covered by another LiveStage, we just keep the larger one.
replace = False
for index, prev_stage in enumerate(liveness_list):
all_covered = True
for ele in prev_stage.unique_live_vars:
if ele not in stage.unique_live_vars:
all_covered = False
break
if all_covered:
replace = True
break
if replace:
liveness_list[index] = stage
else:
liveness_list.append(stage)
return liveness_list
def get_alias_set(self):
pass

View File

@ -0,0 +1,6 @@
from .operator_handler import OperatorHandler
from .dot_handler import DotHandler
from .conv_handler import ConvHandler
from .batch_norm_handler import BatchNormHandler
__all__ = ['OperatorHandler', 'DotHandler', 'ConvHandler', 'BatchNormHandler']

View File

@ -0,0 +1,483 @@
import operator
from functools import reduce
import warnings
import torch
from colossalai.auto_parallel.solver.sharding_strategy import ShardingStrategy, StrategiesVector
from .operator_handler import OperatorHandler
__all__ = ['BatchNormHandler']
class BatchNormHandler(OperatorHandler):
"""
A OperatorHandler which deals with the sharding strategies of normalization.
To keep the math consistency, there are two way to do BatchNorm if the input
shards on batch dimension:
1. We gather the input partitions through batch dimension, then do the normal BatchNorm.
2. We do the SyncBatchNorm on the each input partition seperately, the SyncBN op will help
us to keep the computing correctness.
In this handler, both methods will be considered.
"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.input_data = self.predecessor_node[0]._meta_data
self.weight = self.module_named_parameters['weight']
self.bias = self.module_named_parameters['bias']
self.output_data = self.node._meta_data
self._sanity_check()
def _sanity_check(self):
'''
In sanity check, we need make sure the input data having correct dimension size.
For BatchNorm1d, the dim of input data should be 3([N, C, L]).
For BatchNorm2d, the dim of input data should be 4([N, C, H, W]).
For BatchNorm3d, the dim of input data should be 5([N, C, H, W, D]).
'''
assert self.input_data.dim() in (3, 4,
5), f'We suppose the dim of input fed into conv op should in range of [3, 5].'
def _generate_compute_cost(self, bs, channel_in):
'''
Compute the computation cost per device with this specific strategy.
Note: compute_cost need to be devided by TFLOPS, now it just shows the computation size.
Argument:
bs(int): Batch size of the input data.
channel_in(int): The channel dimension of input data.
Return:
compute_cost(float): Computation cost per device with this specific strategy
'''
# TODO: compute_cost need to be devided by TFLOPS, now it just shows the computation size.
# TODO: a constant coefficient need to be added.
# 1D: (L) * N * Cin
# 2D: (H * W) * N * Cin
# 3D: (H * W * D) * N * Cin
input_size = self.input_data.shape[2:]
input_size_product = reduce(operator.mul, input_size, 1)
forward_compute_cost = input_size_product * bs * channel_in
backward_activation_compute_cost = input_size_product * bs * channel_in
backward_weight_compute_cost = input_size_product * bs * channel_in
backward_compute_cost = backward_activation_compute_cost + backward_weight_compute_cost
compute_cost = forward_compute_cost + backward_compute_cost
return compute_cost
def _generate_memory_cost(self, sharding_size_forward, sharding_size_backward_activation, sharding_size_weight):
'''
Compute the memory cost per device with this specific strategy.
Argument:
sharding_size_forward(int): The forward activation will be divided
into sharding_size_forward number partions.
sharding_size_backward_activation(int): The backward activation will
be divided into sharding_size_backward_activation number partions.
sharding_size_weight(int): The backward weight will be divided
into sharding_size_weight number partions.
Return:
memory_cost(Tuple[float]): Memory cost per device with this
specific strategy, the first element of this tuple is forward
memory cost, and the second element of this tuple is backward
memory cost.
memory_cost_forward(float): Memory cost of forward activation per
device with this specific strategy.
memory_cost_backward_activation(float): Memory cost of backward activation
per device with this specific strategy.
'''
# compute the memory cost of this strategy
dtype = self.input_data.dtype
numel_output = self.output_data.numel()
numel_input = numel_output
numel_weight = self.weight.numel()
size_per_elem_bytes = torch.tensor([], dtype=dtype).element_size()
# forward memory_cost
memory_cost_forward_activation = numel_output * size_per_elem_bytes / sharding_size_forward
memory_cost_forward_weight = numel_weight * size_per_elem_bytes / sharding_size_weight
memory_cost_forward = memory_cost_forward_activation + memory_cost_forward_weight
# backward memory_cost
memory_cost_backward_activation = numel_input * size_per_elem_bytes / sharding_size_backward_activation
memory_cost_backward_weight = numel_weight * size_per_elem_bytes / sharding_size_weight
memory_cost_backward = memory_cost_backward_activation + memory_cost_backward_weight
# memory_cost pair
memory_cost = (memory_cost_forward, memory_cost_backward)
return memory_cost, memory_cost_forward_activation, memory_cost_backward_activation
def split_input_channel(self, mesh_dim_0, mesh_dim_1):
name = f'RS{mesh_dim_0} = RS{mesh_dim_0} x S{mesh_dim_0}'
dim_partition_dict_for_input = {1: [mesh_dim_0]}
sharding_spec_for_input = self._generate_sharding_spec(self.input_data, dim_partition_dict_for_input)
dim_partition_dict_for_weight = {0: [mesh_dim_0]}
sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight)
dim_partition_dict_for_output = {1: [mesh_dim_0]}
sharding_spec_for_output = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output)
# generate resharding cost for this strategy
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input])
# compute the computation cost of this strategy
bs = self.input_data.shape[0]
channel_in = self.input_data.shape[1] // self.device_mesh.shape[mesh_dim_0]
compute_cost = self._generate_compute_cost(bs, channel_in)
# compute the memory cost of this strategy
sharding_size_forward = self.device_mesh.shape[mesh_dim_0]
sharding_size_backward_activation = self.device_mesh.shape[mesh_dim_0]
sharding_size_weight = self.device_mesh.shape[mesh_dim_0]
memory_cost, _, _ = self._generate_memory_cost(sharding_size_forward, sharding_size_backward_activation,
sharding_size_weight)
# This strategy do not need to do all_reduce operation
communication_cost = 0
sharding_strategies = ShardingStrategy(name,
output_sharding_spec=sharding_spec_for_output,
compute_cost=compute_cost,
communication_cost=communication_cost,
memory_cost=memory_cost,
resharding_costs=resharding_costs,
input_shardings=(sharding_spec_for_input, sharding_spec_for_weight))
self.strategies_vector.append(sharding_strategies)
# shard the output batch dimension to get all possible sharding strategy from this basic strategy
new_name = f'S{mesh_dim_1}S{mesh_dim_0} = RS{mesh_dim_0} x S{mesh_dim_0}'
dim_partition_dict_for_output = {0: [mesh_dim_1], 1: [mesh_dim_0]}
new_sharding_spec_for_output = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output)
# the computation cost is all the same
new_compute_cost = compute_cost
# the memory cost need to be recomputed
# compute the memroy cost of new strategy
new_sharding_size_forward = self.device_mesh.shape[mesh_dim_0] * self.device_mesh.shape[mesh_dim_1]
sharding_size_backward_activation = self.device_mesh.shape[mesh_dim_0]
sharding_size_weight = self.device_mesh.shape[mesh_dim_0]
new_memory_cost, _, memory_cost_backward_activation = self._generate_memory_cost(
new_sharding_size_forward, sharding_size_backward_activation, sharding_size_weight)
# the communication cost need to count the sharding cost into this strategy
# compute the communication cost of new strategy
origin_communication_cost = communication_cost
tiny_shard_cost = 10
new_forward_communication_cost = tiny_shard_cost
# we need to all gather the batch dimension for the basic strategy
new_backward_communication_cost = self.device_mesh.all_gather_cost(memory_cost_backward_activation, mesh_dim_1)
new_communication_cost = origin_communication_cost + new_forward_communication_cost + new_backward_communication_cost
sharding_strategies = ShardingStrategy(new_name,
output_sharding_spec=new_sharding_spec_for_output,
compute_cost=new_compute_cost,
communication_cost=new_communication_cost,
memory_cost=new_memory_cost,
resharding_costs=resharding_costs,
input_shardings=(sharding_spec_for_input, sharding_spec_for_weight))
self.strategies_vector.append(sharding_strategies)
def split_input_channel_1d(self, mesh_dim_0, mesh_dim_1):
name = f'RS{mesh_dim_0}{mesh_dim_1} = RS{mesh_dim_0}{mesh_dim_1} x S{mesh_dim_0}{mesh_dim_1}'
dim_partition_dict_for_input = {1: [mesh_dim_0, mesh_dim_1]}
sharding_spec_for_input = self._generate_sharding_spec(self.input_data, dim_partition_dict_for_input)
dim_partition_dict_for_weight = {0: [mesh_dim_0, mesh_dim_1]}
sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight)
dim_partition_dict_for_output = {1: [mesh_dim_0, mesh_dim_1]}
sharding_spec_for_output = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output)
# generate resharding cost for this strategy
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input])
# compute the computation cost of this strategy
bs = self.input_data.shape[0]
channel_in = self.input_data.shape[1] // (self.device_mesh.shape[mesh_dim_0] *
self.device_mesh.shape[mesh_dim_1])
compute_cost = self._generate_compute_cost(bs, channel_in)
# compute the memory cost of this strategy
sharding_size_forward = self.device_mesh.shape[mesh_dim_0] * self.device_mesh.shape[mesh_dim_1]
sharding_size_backward_activation = self.device_mesh.shape[mesh_dim_0] * self.device_mesh.shape[mesh_dim_1]
sharding_size_weight = self.device_mesh.shape[mesh_dim_0] * self.device_mesh.shape[mesh_dim_1]
memory_cost, _, _ = self._generate_memory_cost(sharding_size_forward, sharding_size_backward_activation,
sharding_size_weight)
# This strategy do not need to do all_reduce operation
communication_cost = 0
sharding_strategies = ShardingStrategy(name,
output_sharding_spec=sharding_spec_for_output,
compute_cost=compute_cost,
communication_cost=communication_cost,
memory_cost=memory_cost,
resharding_costs=resharding_costs,
input_shardings=(sharding_spec_for_input, sharding_spec_for_weight))
self.strategies_vector.append(sharding_strategies)
def non_split(self, mesh_dim_0, mesh_dim_1):
name = f'RR = RR x R'
dim_partition_dict_for_input = {}
sharding_spec_for_input = self._generate_sharding_spec(self.input_data, dim_partition_dict_for_input)
dim_partition_dict_for_weight = {}
sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight)
dim_partition_dict_for_output = {}
sharding_spec_for_output = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output)
# generate resharding cost for this strategy
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input])
# compute the computation cost of this strategy
bs = self.input_data.shape[0]
channel_in = self.input_data.shape[1]
compute_cost = self._generate_compute_cost(bs, channel_in)
# compute the memory cost of this strategy
sharding_size_forward = 1
sharding_size_backward_activation = 1
sharding_size_weight = 1
memory_cost, _, _ = self._generate_memory_cost(sharding_size_forward, sharding_size_backward_activation,
sharding_size_weight)
# This strategy do not need to do all_reduce operation
communication_cost = 0
sharding_strategies = ShardingStrategy(name,
output_sharding_spec=sharding_spec_for_output,
compute_cost=compute_cost,
communication_cost=communication_cost,
memory_cost=memory_cost,
resharding_costs=resharding_costs,
input_shardings=(sharding_spec_for_input, sharding_spec_for_weight))
self.strategies_vector.append(sharding_strategies)
def _construct_batch_sharding_strategies(mesh_dim_list, new_name):
dim_partition_dict_for_output = {0: mesh_dim_list}
new_sharding_spec_for_output = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output)
# the computation cost is all the same
new_compute_cost = compute_cost
# the memory cost need to be recomputed
new_sharding_size_input = 1
for mesh_dim in mesh_dim_list:
new_sharding_size_input = new_sharding_size_input * self.device_mesh.shape[mesh_dim]
new_memory_cost, _, memory_cost_backward_activation = self._generate_memory_cost(
new_sharding_size_input, sharding_size_backward_activation, sharding_size_weight)
# the communication cost need to count the sharding cost into this strategy
origin_communication_cost = communication_cost
tiny_shard_cost = 10
new_forward_communication_cost = tiny_shard_cost
if len(mesh_dim_list) == 1:
new_backward_communication_cost = self.device_mesh.all_gather_cost(memory_cost_backward_activation,
mesh_dim_list[0])
else:
new_backward_communication_cost = self.device_mesh.flatten_device_mesh.all_gather_cost(
memory_cost_backward_activation, 0)
new_communication_cost = origin_communication_cost + new_forward_communication_cost + new_backward_communication_cost
new_sharding_strategy = ShardingStrategy(new_name,
output_sharding_spec=new_sharding_spec_for_output,
compute_cost=new_compute_cost,
communication_cost=new_communication_cost,
memory_cost=new_memory_cost,
resharding_costs=resharding_costs,
input_shardings=(sharding_spec_for_input,
sharding_spec_for_weight))
return new_sharding_strategy
# shard the output batch dimension to get all possible sharding strategy from this basic strategy
# shard on mesh_dim_0
new_name = f'S{mesh_dim_0}R = RR x R'
mesh_dim_list = [mesh_dim_0]
new_sharding_strategy = _construct_batch_sharding_strategies(mesh_dim_list, new_name)
self.strategies_vector.append(new_sharding_strategy)
# shard on mesh_dim_1
new_name = f'S{mesh_dim_1}R = RR x R'
mesh_dim_list = [mesh_dim_1]
new_sharding_strategy = _construct_batch_sharding_strategies(mesh_dim_list, new_name)
self.strategies_vector.append(new_sharding_strategy)
# shard on mesh_dim_0, mesh_dim_1
new_name = f'S{mesh_dim_0}{mesh_dim_1}R = RR x R'
mesh_dim_list = [mesh_dim_0, mesh_dim_1]
new_sharding_strategy = _construct_batch_sharding_strategies(mesh_dim_list, new_name)
self.strategies_vector.append(new_sharding_strategy)
def split_input_batch(self, mesh_dim_0):
name = f'S{mesh_dim_0}R = S{mesh_dim_0}R x R WITH SYNC_BN'
dim_partition_dict_for_input = {0: [mesh_dim_0]}
sharding_spec_for_input = self._generate_sharding_spec(self.input_data, dim_partition_dict_for_input)
dim_partition_dict_for_weight = {}
sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight)
dim_partition_dict_for_output = {0: [mesh_dim_0]}
sharding_spec_for_output = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output)
# generate resharding cost for this strategy
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input])
# compute the computation cost of this strategy
bs = self.input_data.shape[0] // self.device_mesh.shape[mesh_dim_0]
channel_in = self.input_data.shape[1]
compute_cost = self._generate_compute_cost(bs, channel_in)
# compute the memory cost of this strategy
sharding_size_forward = self.device_mesh.shape[mesh_dim_0]
sharding_size_backward_activation = self.device_mesh.shape[mesh_dim_0]
sharding_size_weight = 1
memory_cost, memory_cost_forward_activation, _ = self._generate_memory_cost(sharding_size_forward,
sharding_size_backward_activation,
sharding_size_weight)
# the all reduce communication will happen during the sync bn computing.
communication_cost = self.device_mesh.all_reduce_cost(memory_cost_forward_activation, mesh_dim_0)
sharding_strategies = ShardingStrategy(name,
output_sharding_spec=sharding_spec_for_output,
compute_cost=compute_cost,
communication_cost=communication_cost,
memory_cost=memory_cost,
resharding_costs=resharding_costs,
input_shardings=(sharding_spec_for_input, sharding_spec_for_weight))
self.strategies_vector.append(sharding_strategies)
def split_input_batch_1d(self, mesh_dim_0, mesh_dim_1):
name = f'S{mesh_dim_0}{mesh_dim_1}R = S{mesh_dim_0}{mesh_dim_1}R x R WITH SYNC_BN'
dim_partition_dict_for_input = {0: [mesh_dim_0, mesh_dim_1]}
sharding_spec_for_input = self._generate_sharding_spec(self.input_data, dim_partition_dict_for_input)
dim_partition_dict_for_weight = {}
sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight)
dim_partition_dict_for_output = {0: [mesh_dim_0, mesh_dim_1]}
sharding_spec_for_output = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output)
# generate resharding cost for this strategy
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input])
# compute the computation cost of this strategy
bs = self.input_data.shape[0] // (self.device_mesh.shape[mesh_dim_0] * self.device_mesh.shape[mesh_dim_1])
channel_in = self.input_data.shape[1]
compute_cost = self._generate_compute_cost(bs, channel_in)
# compute the memory cost of this strategy
sharding_size_forward = self.device_mesh.shape[mesh_dim_0] * self.device_mesh.shape[mesh_dim_1]
sharding_size_backward_activation = self.device_mesh.shape[mesh_dim_0] * self.device_mesh.shape[mesh_dim_1]
sharding_size_weight = 1
memory_cost, memory_cost_forward_activation, _ = self._generate_memory_cost(sharding_size_forward,
sharding_size_backward_activation,
sharding_size_weight)
# the all reduce communication will happen during the sync bn computing.
communication_cost = self.device_mesh.flatten_device_mesh.all_reduce_cost(memory_cost_forward_activation, 0)
sharding_strategies = ShardingStrategy(name,
output_sharding_spec=sharding_spec_for_output,
compute_cost=compute_cost,
communication_cost=communication_cost,
memory_cost=memory_cost,
resharding_costs=resharding_costs,
input_shardings=(sharding_spec_for_input, sharding_spec_for_weight))
self.strategies_vector.append(sharding_strategies)
def split_input_both_dim(self, mesh_dim_0, mesh_dim_1):
name = f'S{mesh_dim_0}S{mesh_dim_1} = S{mesh_dim_0}S{mesh_dim_1} x S{mesh_dim_1} WITH SYNC_BN'
dim_partition_dict_for_input = {0: [mesh_dim_0], 1: [mesh_dim_1]}
sharding_spec_for_input = self._generate_sharding_spec(self.input_data, dim_partition_dict_for_input)
dim_partition_dict_for_weight = {0: [mesh_dim_1]}
sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight)
dim_partition_dict_for_output = {0: [mesh_dim_0], 1: [mesh_dim_1]}
sharding_spec_for_output = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output)
# generate resharding cost for this strategy
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input])
# compute the computation cost of this strategy
bs = self.input_data.shape[0] // self.device_mesh.shape[mesh_dim_0]
channel_in = self.input_data.shape[1] // self.device_mesh.shape[mesh_dim_1]
compute_cost = self._generate_compute_cost(bs, channel_in)
# compute the memory cost of this strategy
sharding_size_forward = self.device_mesh.shape[mesh_dim_0] * self.device_mesh.shape[mesh_dim_1]
sharding_size_backward_activation = self.device_mesh.shape[mesh_dim_0] * self.device_mesh.shape[mesh_dim_1]
sharding_size_weight = self.device_mesh.shape[mesh_dim_1]
memory_cost, memory_cost_forward_activation, _ = self._generate_memory_cost(sharding_size_forward,
sharding_size_backward_activation,
sharding_size_weight)
# the all reduce communication will happen during the sync bn computing.
communication_cost = self.device_mesh.all_reduce_cost(memory_cost_forward_activation, mesh_dim_0)
sharding_strategies = ShardingStrategy(name,
output_sharding_spec=sharding_spec_for_output,
compute_cost=compute_cost,
communication_cost=communication_cost,
memory_cost=memory_cost,
resharding_costs=resharding_costs,
input_shardings=(sharding_spec_for_input, sharding_spec_for_weight))
self.strategies_vector.append(sharding_strategies)
def register_strategy(self) -> StrategiesVector:
'''
Generate every possible strategies for a BatchNorm node, and record all strategies into the strategies_vector.
Example:
norm_handler = BatchNormHandler(node, self.device_mesh, strategies_vector,
self.shape_consistency_manager)
norm_handler.register_strategy()
for strategy in norm_handler.strategies_vector:
print(f'{strategy.name}, computation_cost: {strategy.compute_cost}, memory_cost: {strategy.memory_cost}')
Output:
RS0 = RS0 x S0, computation_cost: 131072, memory_cost: 524288.0
RS1 = RS1 x S1, computation_cost: 131072, memory_cost: 524288.0
RR = RR x R, computation_cost: 262144, memory_cost: 1048576
RS01 = RS01 x S01, computation_cost: 65536, memory_cost: 262144.0
'''
# RS = RS x S and strategies based on it, such as
# SS = RS x S
self.split_input_channel(0, 1)
self.split_input_channel(1, 0)
# RR = RR x R and strategies based on it, such as
# SR = SR x R
self.non_split(0, 1)
# RS01 = RS01 x S01
self.split_input_channel_1d(0, 1)
# SR = SR x R WITH SYNC_BN
self.split_input_batch(0)
self.split_input_batch(1)
# SS = SS x S WITH SYNC_BN
self.split_input_both_dim(0, 1)
self.split_input_both_dim(1, 0)
# S01R = S01R x R WITH SYNC_BN
self.split_input_batch_1d(0, 1)
return self.strategies_vector

View File

@ -1,5 +1,6 @@
import operator
from functools import reduce
import warnings
import torch
from colossalai.auto_parallel.solver.sharding_strategy import ShardingStrategy, StrategiesVector
from .operator_handler import OperatorHandler
@ -9,7 +10,7 @@ __all__ = ['ConvHandler']
class ConvHandler(OperatorHandler):
"""
An OperatorHandler which deals with the sharding strategies of linear matrix multiplication.
A OperatorHandler which deals with the sharding strategies of Convolution.
"""
def __init__(self, *args, **kwargs):
@ -59,8 +60,7 @@ class ConvHandler(OperatorHandler):
compute_cost = forward_compute_cost + backward_activation_cost + backward_weight_cost
return compute_cost
def _generate_memory_cost(self, sharding_size_forward, sharding_size_backward_activation,
sharding_size_backward_weight):
def _generate_memory_cost(self, sharding_size_forward, sharding_size_backward_activation, sharding_size_weight):
'''
Compute the memory cost per device with this specific strategy.
@ -69,8 +69,8 @@ class ConvHandler(OperatorHandler):
into sharding_size_forward number partions.
sharding_size_backward_activation(int): The backward activation will
be divided into sharding_size_backward_activation number partions.
sharding_size_backward_weight(int): The backward weight will be divided
into sharding_size_backward_weight number partions.
sharding_size_weight(int): The backward weight will be divided
into sharding_size_weight number partions.
Return:
memory_cost(Tuple[float]): Memory cost per device with this
@ -90,17 +90,19 @@ class ConvHandler(OperatorHandler):
size_per_elem_bytes = torch.tensor([], dtype=dtype).element_size()
# forward memory_cost
memory_cost_forward = numel_output * size_per_elem_bytes / sharding_size_forward
memory_cost_forward_activation = numel_output * size_per_elem_bytes / sharding_size_forward
memory_cost_forward_weight = numel_weight * size_per_elem_bytes / sharding_size_weight
memory_cost_forward = memory_cost_forward_activation + memory_cost_forward_weight
# backward memory_cost
memory_cost_backward_activation = numel_input * size_per_elem_bytes / sharding_size_backward_activation
memory_cost_backward_weight = numel_weight * size_per_elem_bytes / sharding_size_backward_weight
memory_cost_backward_weight = numel_weight * size_per_elem_bytes / sharding_size_weight
memory_cost_backward = memory_cost_backward_activation + memory_cost_backward_weight
# memory_cost pair
memory_cost = (memory_cost_forward, memory_cost_backward)
return memory_cost, memory_cost_forward, memory_cost_backward_activation
return memory_cost, memory_cost_forward_activation, memory_cost_backward_activation
def split_input_batch_weight_out_channel(self, mesh_dim_0, mesh_dim_1):
name = f'S{mesh_dim_0}S{mesh_dim_1} = S{mesh_dim_0}R x RS{mesh_dim_1}'
@ -112,7 +114,7 @@ class ConvHandler(OperatorHandler):
sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight)
dim_partition_dict_for_output = {0: [mesh_dim_0], 1: [mesh_dim_1]}
sharding_spec_for_ouput = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output)
sharding_spec_for_output = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output)
# generate resharding cost for this strategy
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input, sharding_spec_for_weight])
@ -126,9 +128,9 @@ class ConvHandler(OperatorHandler):
# compute the memory cost of this strategy
sharding_size_forward = self.device_mesh.shape[mesh_dim_0] * self.device_mesh.shape[mesh_dim_1]
sharding_size_backward_activation = self.device_mesh.shape[mesh_dim_0]
sharding_size_backward_weight = self.device_mesh.shape[mesh_dim_1]
sharding_size_weight = self.device_mesh.shape[mesh_dim_1]
memory_cost, _, memory_cost_backward_activation = self._generate_memory_cost(
sharding_size_forward, sharding_size_backward_activation, sharding_size_backward_weight)
sharding_size_forward, sharding_size_backward_activation, sharding_size_weight)
# This strategy do not need to do all_reduce operation during forward
communication_cost_forward = 0
@ -138,7 +140,7 @@ class ConvHandler(OperatorHandler):
communication_cost = communication_cost_forward + communication_cost_backward
sharding_strategies = ShardingStrategy(name,
output_sharding_spec=sharding_spec_for_ouput,
output_sharding_spec=sharding_spec_for_output,
compute_cost=compute_cost,
communication_cost=communication_cost,
memory_cost=memory_cost,
@ -156,7 +158,7 @@ class ConvHandler(OperatorHandler):
sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight)
dim_partition_dict_for_output = {0: [mesh_dim_0]}
sharding_spec_for_ouput = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output)
sharding_spec_for_output = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output)
# generate resharding cost for this strategy
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input, sharding_spec_for_weight])
@ -170,19 +172,20 @@ class ConvHandler(OperatorHandler):
# compute the memory cost of this strategy
sharding_size_forward = self.device_mesh.shape[mesh_dim_0]
sharding_size_backward_activation = self.device_mesh.shape[mesh_dim_0]
sharding_size_backward_weight = 1
sharding_size_weight = 1
memory_cost, _, _ = self._generate_memory_cost(sharding_size_forward, sharding_size_backward_activation,
sharding_size_backward_weight)
sharding_size_weight)
# This strategy do not need to do all_reduce operation in both forward and backward phase.
communication_cost = 0
sharding_strategies = ShardingStrategy(name,
output_sharding_spec=sharding_spec_for_ouput,
output_sharding_spec=sharding_spec_for_output,
compute_cost=compute_cost,
communication_cost=communication_cost,
memory_cost=memory_cost,
resharding_costs=resharding_costs,
input_shardings=(sharding_spec_for_input, sharding_spec_for_weight))
self.strategies_vector.append(sharding_strategies)
def split_input_both_dim_weight_in_channel(self, mesh_dim_0, mesh_dim_1):
@ -195,7 +198,7 @@ class ConvHandler(OperatorHandler):
sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight)
dim_partition_dict_for_output = {0: [mesh_dim_0]}
sharding_spec_for_ouput = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_input)
sharding_spec_for_output = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output)
# generate resharding cost for this strategy
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input, sharding_spec_for_weight])
@ -209,18 +212,18 @@ class ConvHandler(OperatorHandler):
# compute the memory cost of this strategy
sharding_size_forward = self.device_mesh.shape[mesh_dim_0]
sharding_size_backward_activation = self.device_mesh.shape[mesh_dim_0] * self.device_mesh.shape[mesh_dim_1]
sharding_size_backward_weight = self.device_mesh.shape[mesh_dim_1]
memory_cost, memory_cost_forward, _ = self._generate_memory_cost(sharding_size_forward,
sharding_size_backward_activation,
sharding_size_backward_weight)
sharding_size_weight = self.device_mesh.shape[mesh_dim_1]
memory_cost, memory_cost_forward_activation, _ = self._generate_memory_cost(sharding_size_forward,
sharding_size_backward_activation,
sharding_size_weight)
# compute the communication cost of this strategy during forward phase
communication_cost_forward = self.device_mesh.all_reduce_cost(memory_cost_forward, mesh_dim_1)
communication_cost_forward = self.device_mesh.all_reduce_cost(memory_cost_forward_activation, mesh_dim_1)
# This strategy do not need to do all_reduce operation during backward phase
communication_cost_backward = 0
communication_cost = communication_cost_forward + communication_cost_backward
sharding_strategies = ShardingStrategy(name,
output_sharding_spec=sharding_spec_for_ouput,
output_sharding_spec=sharding_spec_for_output,
compute_cost=compute_cost,
communication_cost=communication_cost,
memory_cost=memory_cost,
@ -238,7 +241,7 @@ class ConvHandler(OperatorHandler):
sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight)
dim_partition_dict_for_output = {1: [mesh_dim_1]}
sharding_spec_for_ouput = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_input)
sharding_spec_for_output = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output)
# generate resharding cost for this strategy
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input, sharding_spec_for_weight])
@ -252,17 +255,17 @@ class ConvHandler(OperatorHandler):
# compute the memory cost of this strategy
sharding_size_forward = self.device_mesh.shape[mesh_dim_1]
sharding_size_backward_activation = self.device_mesh.shape[mesh_dim_0]
sharding_size_backward_weight = self.device_mesh.shape[mesh_dim_0] * self.device_mesh.shape[mesh_dim_1]
memory_cost, memory_cost_forward, memory_cost_backward_activation = self._generate_memory_cost(
sharding_size_forward, sharding_size_backward_activation, sharding_size_backward_weight)
sharding_size_weight = self.device_mesh.shape[mesh_dim_0] * self.device_mesh.shape[mesh_dim_1]
memory_cost, memory_cost_forward_activation, memory_cost_backward_activation = self._generate_memory_cost(
sharding_size_forward, sharding_size_backward_activation, sharding_size_weight)
# compute the communication cost of this strategy during forward phase
communication_cost_forward = self.device_mesh.all_reduce_cost(memory_cost_forward, mesh_dim_0)
communication_cost_forward = self.device_mesh.all_reduce_cost(memory_cost_forward_activation, mesh_dim_0)
# compute the communication cost of this strategy during backward phase
communication_cost_backward = self.device_mesh.all_reduce_cost(memory_cost_backward_activation, mesh_dim_1)
communication_cost = communication_cost_forward + communication_cost_backward
sharding_strategies = ShardingStrategy(name,
output_sharding_spec=sharding_spec_for_ouput,
output_sharding_spec=sharding_spec_for_output,
compute_cost=compute_cost,
communication_cost=communication_cost,
memory_cost=memory_cost,
@ -280,7 +283,7 @@ class ConvHandler(OperatorHandler):
sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight)
dim_partition_dict_for_output = {}
sharding_spec_for_ouput = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_input)
sharding_spec_for_output = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output)
# generate resharding cost for this strategy
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input, sharding_spec_for_weight])
@ -294,18 +297,18 @@ class ConvHandler(OperatorHandler):
# compute the memory cost of this strategy
sharding_size_forward = 1
sharding_size_backward_activation = self.device_mesh.shape[mesh_dim_0]
sharding_size_backward_weight = self.device_mesh.shape[mesh_dim_0]
memory_cost, memory_cost_forward, _ = self._generate_memory_cost(sharding_size_forward,
sharding_size_backward_activation,
sharding_size_backward_weight)
sharding_size_weight = self.device_mesh.shape[mesh_dim_0]
memory_cost, memory_cost_forward_activation, _ = self._generate_memory_cost(sharding_size_forward,
sharding_size_backward_activation,
sharding_size_weight)
# compute the communication cost of this strategy during forward phase
communication_cost_forward = self.device_mesh.all_reduce_cost(memory_cost_forward, mesh_dim_0)
communication_cost_forward = self.device_mesh.all_reduce_cost(memory_cost_forward_activation, mesh_dim_0)
# This strategy do NOT need all_reduce during forward phase
communication_cost_backward = 0
communication_cost = communication_cost_forward + communication_cost_backward
sharding_strategies = ShardingStrategy(name,
output_sharding_spec=sharding_spec_for_ouput,
output_sharding_spec=sharding_spec_for_output,
compute_cost=compute_cost,
communication_cost=communication_cost,
memory_cost=memory_cost,
@ -323,7 +326,7 @@ class ConvHandler(OperatorHandler):
sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight)
dim_partition_dict_for_output = {1: [mesh_dim_0]}
sharding_spec_for_ouput = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_input)
sharding_spec_for_output = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output)
# generate resharding cost for this strategy
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input, sharding_spec_for_weight])
@ -337,9 +340,9 @@ class ConvHandler(OperatorHandler):
# compute the memory cost of this strategy
sharding_size_forward = self.device_mesh.shape[mesh_dim_0]
sharding_size_backward_activation = 1
sharding_size_backward_weight = self.device_mesh.shape[mesh_dim_0]
sharding_size_weight = self.device_mesh.shape[mesh_dim_0]
memory_cost, _, memory_cost_backward_activation = self._generate_memory_cost(
sharding_size_forward, sharding_size_backward_activation, sharding_size_backward_weight)
sharding_size_forward, sharding_size_backward_activation, sharding_size_weight)
# This strategy do not need to do all_reduce during forward phase
communication_cost_forward = 0
@ -347,7 +350,7 @@ class ConvHandler(OperatorHandler):
communication_cost_backward = self.device_mesh.all_reduce_cost(memory_cost_backward_activation, mesh_dim_0)
communication_cost = communication_cost_forward + communication_cost_backward
sharding_strategies = ShardingStrategy(name,
output_sharding_spec=sharding_spec_for_ouput,
output_sharding_spec=sharding_spec_for_output,
compute_cost=compute_cost,
communication_cost=communication_cost,
memory_cost=memory_cost,
@ -365,7 +368,7 @@ class ConvHandler(OperatorHandler):
sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight)
dim_partition_dict_for_output = {}
sharding_spec_for_ouput = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_input)
sharding_spec_for_output = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output)
# generate resharding cost for this strategy
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input, sharding_spec_for_weight])
@ -379,15 +382,15 @@ class ConvHandler(OperatorHandler):
# compute the memory cost of this strategy
sharding_size_forward = 1
sharding_size_backward_activation = 1
sharding_size_backward_weight = 1
sharding_size_weight = 1
memory_cost, _, _ = self._generate_memory_cost(sharding_size_forward, sharding_size_backward_activation,
sharding_size_backward_weight)
sharding_size_weight)
# This strategy do not need to do all_reduce in both forward and backward phase
communication_cost = 0
sharding_strategies = ShardingStrategy(name,
output_sharding_spec=sharding_spec_for_ouput,
output_sharding_spec=sharding_spec_for_output,
compute_cost=compute_cost,
communication_cost=communication_cost,
memory_cost=memory_cost,
@ -405,7 +408,7 @@ class ConvHandler(OperatorHandler):
sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight)
dim_partition_dict_for_output = {0: [mesh_dim_0, mesh_dim_1]}
sharding_spec_for_ouput = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_input)
sharding_spec_for_output = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output)
# generate resharding cost for this strategy
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input, sharding_spec_for_weight])
@ -420,15 +423,15 @@ class ConvHandler(OperatorHandler):
sharding_size_forward = self.device_mesh.mesh_shape[mesh_dim_0] * self.device_mesh.mesh_shape[mesh_dim_1]
sharding_size_backward_activation = self.device_mesh.mesh_shape[mesh_dim_0] * self.device_mesh.mesh_shape[
mesh_dim_1]
sharding_size_backward_weight = 1
sharding_size_weight = 1
memory_cost, _, _ = self._generate_memory_cost(sharding_size_forward, sharding_size_backward_activation,
sharding_size_backward_weight)
sharding_size_weight)
# This strategy do not need to do all_reduce in both forward and backward phase
communication_cost = 0
sharding_strategies = ShardingStrategy(name,
output_sharding_spec=sharding_spec_for_ouput,
output_sharding_spec=sharding_spec_for_output,
compute_cost=compute_cost,
communication_cost=communication_cost,
memory_cost=memory_cost,
@ -446,7 +449,7 @@ class ConvHandler(OperatorHandler):
sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight)
dim_partition_dict_for_output = {}
sharding_spec_for_ouput = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_input)
sharding_spec_for_output = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output)
# generate resharding cost for this strategy
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input, sharding_spec_for_weight])
@ -462,19 +465,20 @@ class ConvHandler(OperatorHandler):
sharding_size_forward = 1
sharding_size_backward_activation = self.device_mesh.mesh_shape[mesh_dim_0] * self.device_mesh.mesh_shape[
mesh_dim_1]
sharding_size_backward_weight = self.device_mesh.mesh_shape[mesh_dim_0] * self.device_mesh.mesh_shape[mesh_dim_1]
memory_cost, memory_cost_forward, _ = self._generate_memory_cost(sharding_size_forward,
sharding_size_backward_activation,
sharding_size_backward_weight)
sharding_size_weight = self.device_mesh.mesh_shape[mesh_dim_0] * self.device_mesh.mesh_shape[mesh_dim_1]
memory_cost, memory_cost_forward_activation, _ = self._generate_memory_cost(sharding_size_forward,
sharding_size_backward_activation,
sharding_size_weight)
# compute communication cost during forward phase
communication_cost_forward = self.device_mesh.flatten_device_mesh.all_reduce_cost(memory_cost_forward, 0)
communication_cost_forward = self.device_mesh.flatten_device_mesh.all_reduce_cost(
memory_cost_forward_activation, 0)
# This strategy do NOT need do all_reduce during backward phase
communication_cost_backward = 0
communication_cost = communication_cost_forward + communication_cost_backward
sharding_strategies = ShardingStrategy(name,
output_sharding_spec=sharding_spec_for_ouput,
output_sharding_spec=sharding_spec_for_output,
compute_cost=compute_cost,
communication_cost=communication_cost,
memory_cost=memory_cost,
@ -536,24 +540,48 @@ class ConvHandler(OperatorHandler):
RR = RS01 x S01R: compute_cost is 8856576, communication_cost is 0, memory_cost is 1968128, resharding_costs is {mul: [0, 0, 262148.4, 65538.002, 196614.402, 262148.4, 65538.2]}
'''
# SS = SR x RS
self.split_input_batch_weight_out_channel(0, 1)
self.split_input_batch_weight_out_channel(1, 0)
try:
self.split_input_batch_weight_out_channel(0, 1)
except Exception as e:
warnings.warn(f'{e}')
try:
self.split_input_batch_weight_out_channel(1, 0)
except Exception as e:
warnings.warn(f'{e}')
# SR = SR x RR
self.split_input_batch(0)
self.split_input_batch(1)
# SR = SS x SR
self.split_input_both_dim_weight_in_channel(0, 1)
self.split_input_both_dim_weight_in_channel(1, 0)
try:
self.split_input_both_dim_weight_in_channel(0, 1)
except Exception as e:
warnings.warn(f'{e}')
try:
self.split_input_both_dim_weight_in_channel(1, 0)
except Exception as e:
warnings.warn(f'{e}')
# RS = RS x SS
self.split_input_in_channel_weight_both_channel(0, 1)
self.split_input_in_channel_weight_both_channel(1, 0)
try:
self.split_input_in_channel_weight_both_channel(0, 1)
except Exception as e:
warnings.warn(f'{e}')
try:
self.split_input_in_channel_weight_both_channel(1, 0)
except Exception as e:
warnings.warn(f'{e}')
# RR = RS x SR
self.split_input_in_channel_weight_in_channel(0)
self.split_input_in_channel_weight_in_channel(1)
try:
self.split_input_in_channel_weight_in_channel(0)
except Exception as e:
warnings.warn(f'{e}')
try:
self.split_input_in_channel_weight_in_channel(1)
except Exception as e:
warnings.warn(f'{e}')
# RS = RR x RS
self.split_weight_out_channel(0)
@ -566,7 +594,12 @@ class ConvHandler(OperatorHandler):
self.split_1d_parallel_on_input_batch(0, 1)
# RR = RS01 x S01R
self.split_1d_parallel_on_in_channel(0, 1)
try:
self.split_1d_parallel_on_in_channel(0, 1)
except Exception as e:
warnings.warn(f'{e}')
# print(f'strategies num is :{len(self.strategies_vector)}')
return self.strategies_vector

View File

@ -44,11 +44,8 @@ class DotHandler(OperatorHandler):
compute_cost = self._generate_compute_cost(self.input_data.shape, self.weight.shape)
# compute the memory cost of this strategy
dtype = self.input_data.dtype
numel = self.output_data.numel()
size_per_elem_bytes = torch.tensor([], dtype=dtype).element_size()
sharding_size = self.device_mesh.shape[mesh_dim_0] * self.device_mesh.shape[mesh_dim_1]
memory_cost = numel * size_per_elem_bytes / sharding_size
toatl_memory_cost, activation_memory_cost, weight_memory_cost = self._generate_memory_cost(
dim_partition_dict_for_output, dim_partition_dict_for_weight)
# compute the communication cost
# no all-reduce required for this case
@ -59,7 +56,7 @@ class DotHandler(OperatorHandler):
output_sharding_spec=sharding_spec_for_ouput,
compute_cost=compute_cost,
communication_cost=communication_cost,
memory_cost=memory_cost,
memory_cost=toatl_memory_cost,
resharding_costs=resharding_costs,
input_shardings=(sharding_spec_for_input, sharding_spec_for_weight))
self.strategies_vector.append(sharding_strategies)
@ -86,19 +83,16 @@ class DotHandler(OperatorHandler):
compute_cost = self._generate_compute_cost(self.input_data.shape, self.weight.shape)
# compute the memory cost of this strategy
dtype = self.input_data.dtype
numel = self.output_data.numel()
size_per_elem_bytes = torch.tensor([], dtype=dtype).element_size()
sharding_size = self.device_mesh.shape[mesh_dim_0]
memory_cost = numel * size_per_elem_bytes / sharding_size
toatl_memory_cost, activation_memory_cost, weight_memory_cost = self._generate_memory_cost(
dim_partition_dict_for_output, dim_partition_dict_for_weight)
# compute the communication cost of this strategy
communication_cost = self.device_mesh.all_reduce_cost(memory_cost, mesh_dim_1)
communication_cost = self.device_mesh.all_reduce_cost(activation_memory_cost, mesh_dim_1)
sharding_strategies = ShardingStrategy(name,
output_sharding_spec=sharding_spec_for_ouput,
compute_cost=compute_cost,
communication_cost=communication_cost,
memory_cost=memory_cost,
memory_cost=toatl_memory_cost,
resharding_costs=resharding_costs,
input_shardings=(sharding_spec_for_input, sharding_spec_for_weight))
self.strategies_vector.append(sharding_strategies)
@ -122,19 +116,16 @@ class DotHandler(OperatorHandler):
compute_cost = self._generate_compute_cost(self.input_data.shape, self.weight.shape)
# compute the memory cost of this strategy
dtype = self.input_data.dtype
numel = self.output_data.numel()
size_per_elem_bytes = torch.tensor([], dtype=dtype).element_size()
sharding_size = self.device_mesh.shape[mesh_dim_0]
memory_cost = numel * size_per_elem_bytes / sharding_size
toatl_memory_cost, activation_memory_cost, weight_memory_cost = self._generate_memory_cost(
dim_partition_dict_for_output, dim_partition_dict_for_weight)
# compute the communication cost of this strategy
communication_cost = self.device_mesh.all_reduce_cost(memory_cost, mesh_dim_1)
communication_cost = self.device_mesh.all_reduce_cost(activation_memory_cost, mesh_dim_1)
sharding_strategies = ShardingStrategy(name,
output_sharding_spec=sharding_spec_for_ouput,
compute_cost=compute_cost,
communication_cost=communication_cost,
memory_cost=memory_cost,
memory_cost=toatl_memory_cost,
resharding_costs=resharding_costs,
input_shardings=(sharding_spec_for_input, sharding_spec_for_weight))
self.strategies_vector.append(sharding_strategies)
@ -158,18 +149,16 @@ class DotHandler(OperatorHandler):
compute_cost = self._generate_compute_cost(self.input_data.shape, self.weight.shape)
# compute the memory cost of this strategy
dtype = self.input_data.dtype
numel = self.output_data.numel()
size_per_elem_bytes = torch.tensor([], dtype=dtype).element_size()
memory_cost = numel * size_per_elem_bytes
toatl_memory_cost, activation_memory_cost, weight_memory_cost = self._generate_memory_cost(
dim_partition_dict_for_output, dim_partition_dict_for_weight)
# compute the communication cost of this strategy
communication_cost = self.device_mesh.all_reduce_cost(memory_cost, mesh_dim)
communication_cost = self.device_mesh.all_reduce_cost(activation_memory_cost, mesh_dim)
sharding_strategies = ShardingStrategy(name,
output_sharding_spec=sharding_spec_for_ouput,
compute_cost=compute_cost,
communication_cost=communication_cost,
memory_cost=memory_cost,
memory_cost=toatl_memory_cost,
resharding_costs=resharding_costs,
input_shardings=(sharding_spec_for_input, sharding_spec_for_weight))
self.strategies_vector.append(sharding_strategies)
@ -193,19 +182,115 @@ class DotHandler(OperatorHandler):
compute_cost = self._generate_compute_cost(self.input_data.shape, self.weight.shape)
# compute the memory cost of this strategy
dtype = self.input_data.dtype
numel = self.output_data.numel()
size_per_elem_bytes = torch.tensor([], dtype=dtype).element_size()
sharding_size = self.device_mesh.shape[mesh_dim]
memory_cost = numel * size_per_elem_bytes / sharding_size
toatl_memory_cost, activation_memory_cost, weight_memory_cost = self._generate_memory_cost(
dim_partition_dict_for_output, dim_partition_dict_for_weight)
# compute the communication cost of this strategy
communication_cost = self.device_mesh.all_reduce_cost(memory_cost, mesh_dim)
communication_cost = self.device_mesh.all_reduce_cost(activation_memory_cost, mesh_dim)
sharding_strategies = ShardingStrategy(name,
output_sharding_spec=sharding_spec_for_ouput,
compute_cost=compute_cost,
communication_cost=communication_cost,
memory_cost=memory_cost,
memory_cost=toatl_memory_cost,
resharding_costs=resharding_costs,
input_shardings=(sharding_spec_for_input, sharding_spec_for_weight))
self.strategies_vector.append(sharding_strategies)
def split_lhs_1st_dim_1d(self, mesh_dim_0, mesh_dim_1):
name = f'S{mesh_dim_0}{mesh_dim_1}R = S{mesh_dim_0}{mesh_dim_1}R x RR'
dim_partition_dict_for_input = {0: [mesh_dim_0, mesh_dim_1]}
sharding_spec_for_input = self._generate_sharding_spec(self.input_data, dim_partition_dict_for_input)
dim_partition_dict_for_weight = {}
sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight)
dim_partition_dict_for_output = {0: [mesh_dim_0, mesh_dim_1]}
sharding_spec_for_ouput = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output)
# generate resharding cost for this strategy
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input])
# compute the computation cost of this strategy
compute_cost = self._generate_compute_cost(self.input_data.shape, self.weight.shape)
# compute the memory cost of this strategy
toatl_memory_cost, activation_memory_cost, weight_memory_cost = self._generate_memory_cost(
dim_partition_dict_for_output, dim_partition_dict_for_weight)
# compute the communication cost of this strategy
communication_cost = 0
sharding_strategies = ShardingStrategy(name,
output_sharding_spec=sharding_spec_for_ouput,
compute_cost=compute_cost,
communication_cost=communication_cost,
memory_cost=toatl_memory_cost,
resharding_costs=resharding_costs,
input_shardings=(sharding_spec_for_input, sharding_spec_for_weight))
self.strategies_vector.append(sharding_strategies)
def split_lhs_2nd_dim_1d(self, mesh_dim_0, mesh_dim_1):
name = f'RR = RS{mesh_dim_0}{mesh_dim_1} x S{mesh_dim_0}{mesh_dim_1}R'
dim_partition_dict_for_input = {1: [mesh_dim_0, mesh_dim_1]}
sharding_spec_for_input = self._generate_sharding_spec(self.input_data, dim_partition_dict_for_input)
dim_partition_dict_for_weight = {0: [mesh_dim_0, mesh_dim_1]}
sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight)
dim_partition_dict_for_output = {}
sharding_spec_for_ouput = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output)
# generate resharding cost for this strategy
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input])
# compute the computation cost of this strategy
compute_cost = self._generate_compute_cost(self.input_data.shape, self.weight.shape)
# compute the memory cost of this strategy
toatl_memory_cost, activation_memory_cost, weight_memory_cost = self._generate_memory_cost(
dim_partition_dict_for_output, dim_partition_dict_for_weight)
# compute the communication cost of this strategy
communication_cost = self.device_mesh.flatten_device_mesh.all_reduce_cost(activation_memory_cost, 0)
sharding_strategies = ShardingStrategy(name,
output_sharding_spec=sharding_spec_for_ouput,
compute_cost=compute_cost,
communication_cost=communication_cost,
memory_cost=toatl_memory_cost,
resharding_costs=resharding_costs,
input_shardings=(sharding_spec_for_input, sharding_spec_for_weight))
self.strategies_vector.append(sharding_strategies)
def split_rhs_2nd_dim_1d(self, mesh_dim_0, mesh_dim_1):
name = f'RS{mesh_dim_0}{mesh_dim_1} = RR x RS{mesh_dim_0}{mesh_dim_1}'
dim_partition_dict_for_input = {}
sharding_spec_for_input = self._generate_sharding_spec(self.input_data, dim_partition_dict_for_input)
dim_partition_dict_for_weight = {1: [mesh_dim_0, mesh_dim_1]}
sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight)
dim_partition_dict_for_output = {1: [mesh_dim_0, mesh_dim_1]}
sharding_spec_for_ouput = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output)
# generate resharding cost for this strategy
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input])
# compute the computation cost of this strategy
compute_cost = self._generate_compute_cost(self.input_data.shape, self.weight.shape)
# compute the memory cost of this strategy
toatl_memory_cost, activation_memory_cost, weight_memory_cost = self._generate_memory_cost(
dim_partition_dict_for_output, dim_partition_dict_for_weight)
# compute the communication cost of this strategy
communication_cost = 0
sharding_strategies = ShardingStrategy(name,
output_sharding_spec=sharding_spec_for_ouput,
compute_cost=compute_cost,
communication_cost=communication_cost,
memory_cost=toatl_memory_cost,
resharding_costs=resharding_costs,
input_shardings=(sharding_spec_for_input, sharding_spec_for_weight))
self.strategies_vector.append(sharding_strategies)
@ -236,4 +321,14 @@ class DotHandler(OperatorHandler):
# RS = RR x RS
self.split_rhs_space_only(0)
self.split_rhs_space_only(1)
# S01R = S01R x RR
self.split_lhs_1st_dim_1d(0, 1)
# RR = RS01 x S01R
self.split_lhs_2nd_dim_1d(0, 1)
# RS01 = RR x RS01
self.split_rhs_2nd_dim_1d(0, 1)
return self.strategies_vector

View File

@ -8,7 +8,7 @@ from colossalai.device.device_mesh import DeviceMesh
from colossalai.tensor.shape_consistency import ShapeConsistencyManager
from colossalai.tensor.sharding_spec import ShardingSpec
from .sharding_strategy import StrategiesVector
from ..sharding_strategy import StrategiesVector
__all__ = ['OperatorHandler']
@ -70,6 +70,48 @@ class OperatorHandler(ABC):
dim_partition_dict=dim_partition_dict)
return sharding_spec
def _generate_memory_cost(self, dim_partition_dict_for_output, dim_partition_dict_for_weight):
'''
Compute the memory cost per device with this specific strategy.
Argument:
dim_partition_dict_for_output(List[int]): The key is the dimension of output to be sharded,
and the value of the key decribe which logical axis will be sharded in that dimension.
dim_partition_dict_for_weight(List[int]): The key is the dimension of weight to be sharded,
and the value of the key decribe which logical axis will be sharded in that dimension.
Return:
total_memory_cost(float): total memory cost per device with this specific strategy
activation_cost(float): the memory cost of activation per device with this specific strategy
weight_memory_cost(float): the memory cost of weight per device with this specific strategy
'''
# compute the size of one element with specific dtype
dtype = self.input_data.dtype
size_per_elem_bytes = torch.tensor([], dtype=dtype).element_size()
# compute the memory cost of activation
activation_numel = self.output_data.numel()
output_mesh_dims = []
for sharding_dim, mesh_dims in dim_partition_dict_for_output.items():
output_mesh_dims.extend(mesh_dims)
activation_sharding_size = 1
for mesh_dim in output_mesh_dims:
activation_sharding_size *= self.device_mesh.shape[mesh_dim]
activation_memory_cost = activation_numel / activation_sharding_size * size_per_elem_bytes
# compute the memory cost of weight
weight_numel = self.weight.numel()
weight_sharding_size = 1
weight_mesh_dims = []
for sharding_dim, mesh_dims in dim_partition_dict_for_weight.items():
weight_mesh_dims.extend(mesh_dims)
for mesh_dim in weight_mesh_dims:
weight_sharding_size *= self.device_mesh.shape[mesh_dim]
weight_memory_cost = weight_numel / weight_sharding_size * size_per_elem_bytes
total_memory_cost = activation_memory_cost + weight_memory_cost
return total_memory_cost, activation_memory_cost, weight_memory_cost
def _generate_resharding_costs(self, sharding_spec_for_input):
'''
Compute the resharding costs with this specific strategy.
@ -85,17 +127,20 @@ class OperatorHandler(ABC):
'''
# The resharding_cost of weight is counted due to sharing weight cases.
resharding_costs = {}
for input_node, target_spec in zip(self.predecessor_node, sharding_spec_for_input):
dtype = self.node._meta_data.dtype
size_per_elem_bytes = torch.tensor([], dtype=dtype).element_size()
for input_node, input_spec in zip(self.predecessor_node, sharding_spec_for_input):
resharding_costs[input_node] = []
for strategy in input_node.strategies_vector:
input_sharding_spec = strategy.output_sharding_spec
assert isinstance(input_sharding_spec, ShardingSpec), f'The input node should NOT be a tuple of tensor.'
# compute the resharding cost during forward phase
_, _, resharding_cost_forward = self.shape_consistency_manager.shape_consistency(
input_sharding_spec, target_spec)
input_sharding_spec, input_spec)
# In backward phase, we should convert grad with target_spec into input_sharding_spec
_, _, resharding_cost_backward = self.shape_consistency_manager.shape_consistency(
target_spec, input_sharding_spec)
resharding_cost = resharding_cost_forward + resharding_cost_backward
input_spec, input_sharding_spec)
# we need multiply the size of elem dtype to get correct communication cost
resharding_cost = (resharding_cost_forward + resharding_cost_backward) * size_per_elem_bytes
resharding_costs[input_node].append(resharding_cost)
return resharding_costs

View File

@ -0,0 +1,444 @@
import warnings
import time
import numpy as np
import multiprocessing
from torch.fx.node import Node
from torch.fx.graph import Graph
from . import GraphAnalyser
from colossalai.auto_parallel.solver.cost_graph import CostGraph
from colossalai.auto_parallel.solver.strategies_constructor import StrategiesConstructor
from typing import Dict
from .constants import INFINITY_COST
try:
import pulp
from pulp import LpVariable, LpProblem, LpMinimize, lpSum, lpDot, LpStatus
except:
warnings.warn(f'please install the pulp')
__all___ = ['Solver']
class Solver:
def __init__(self,
graph: Graph,
strategies_constructor: StrategiesConstructor,
cost_graph: CostGraph,
graph_analyser: GraphAnalyser,
memory_budget: float = -1.0,
solution_numbers: int = 1,
memory_increasing_coefficient: float = 1.3):
'''
Solver class will integrate information provided by the components and use ILP solver to find a possible optimal strategies combination for target computing graph.
Argument:
graph: The computing graph to be optimized.
strategies_constructor: It will provide all the possible strategies for each node in the computing graph.
cost_graph: A graph data structure to simplify the edge cost graph.
graph_analyser: graph_analyser will analyse the graph to obtain the variable liveness information, which will be used to generate memory constraints.
memory_budget: Memory constraint for the solution.
solution_numbers: If solution_numbers is larger than one, solver will us a serious of solutions based on different memory budget.
memory_increasing_coefficient: If solution_numbers is larger than one, we will use this coefficient to generate new memory budget.
'''
self.graph = graph
self.strategies_constructor = strategies_constructor
self.cost_graph = cost_graph
self.graph_analyser = graph_analyser
self.nodes = list(self.graph.nodes)
self.leaf_strategies = self.strategies_constructor.leaf_strategies
self.strategy_map = self.strategies_constructor.strategy_map
self.memory_budget = memory_budget
self.solution_numbers = solution_numbers
if self.solution_numbers > 1:
self.memory_increasing_coefficient = memory_increasing_coefficient
else:
self.memory_increasing_coefficient = 1
self.liveness_list = self.graph_analyser.liveness_analysis()
self.node_index_dict = self._generate_node_index_dict()
# The last solution vector of auto sharding.
self.last_s_val = None
# The last objective value of the best ILP solution.
self.last_objective = None
def _generate_node_index_dict(self) -> Dict[Node, int]:
node_index_dict = {}
for index, strategies_vector in enumerate(self.leaf_strategies):
node_index_dict[strategies_vector.node] = index
return node_index_dict
def _prepare_data_for_solver(self):
'''
Extract information from components for solver.
'''
node_nums = len(self.leaf_strategies)
memory_budget = self.memory_budget
# prepare strategies_len
strategies_len = []
for node in self.nodes:
strategies_len.append(self.cost_graph.node_lens[node])
strategies_len = np.array(strategies_len)
# prepare following_nodes
following_nodes = self.cost_graph.following_dict
index_following_nodes = {}
for src, target in following_nodes.items():
src_index = self.node_index_dict[src]
target_index = self.node_index_dict[target]
index_following_nodes[src_index] = target_index
following_nodes = index_following_nodes
for index in range(node_nums):
if index not in following_nodes:
following_nodes[index] = -1
# prepare edge_pairs and resharding costs
edge_pairs = []
resharding_costs = []
for pairs, edge_cost in self.cost_graph.edge_costs.items():
src_node = pairs[0]
dst_node = pairs[1]
src_node_index = self.node_index_dict[src_node]
dst_node_index = self.node_index_dict[dst_node]
edge_pairs.append(src_node_index)
edge_pairs.append(dst_node_index)
for i in range(strategies_len[src_node_index]):
for j in range(strategies_len[dst_node_index]):
resharding_costs.append(edge_cost[(i, j)])
edge_pairs = np.array(edge_pairs)
resharding_costs = np.array(resharding_costs)
# prepare liveness_set
liveness_set = self.liveness_list
# omit alias_set now
alias_set = None
alias_convert_costs = None
# prepare compute_costs, communication_costs and memory_costs
compute_costs = []
communication_costs = []
memory_costs = []
extra_node_costs = self.cost_graph.extra_node_costs
for strategies_vector in self.leaf_strategies:
node = strategies_vector.node
for index, strategy in enumerate(strategies_vector):
compute_costs.append(strategy.compute_cost)
# node in extra_node_costs means it has some extra communication
# cost from node merging, so we need to add those extra communication
# cost into
if node in extra_node_costs:
origin_communication_cost = strategy.communication_cost
extra_node_cost = extra_node_costs[node][index]
communication_cost = origin_communication_cost + extra_node_cost
communication_costs.append(communication_cost)
else:
communication_costs.append(strategy.communication_cost)
# temporarily we just consider the forward memory cost
memory_cost = strategy.memory_cost
if isinstance(memory_cost, tuple):
memory_costs.append(memory_cost[0])
else:
memory_costs.append(memory_cost)
compute_costs = np.array(compute_costs)
communication_costs = np.array(communication_costs)
memory_costs = np.array(memory_costs)
# omit initial value for nodes
s_init_np = None
return node_nums, memory_budget, strategies_len, following_nodes, edge_pairs, alias_set, liveness_set, compute_costs, communication_costs, memory_costs, resharding_costs, alias_convert_costs, s_init_np
def _call_solver_serialized_args(self,
node_nums,
memory_budget,
strategies_len,
following_nodes,
edge_pairs,
alias_set,
liveness_set,
compute_costs,
communication_costs,
memory_costs,
resharding_costs,
alias_convert_costs,
s_init_np=None):
"""
Call the solver with serialized arguments.
"""
tic = time.time()
for x in [strategies_len, edge_pairs, compute_costs, communication_costs, memory_costs, resharding_costs]:
assert isinstance(x, np.ndarray)
assert len(strategies_len) == node_nums, "strategies_len"
def get_non_zero_index(binary_vector):
"""
Get the index of non-zero item in a vector.
"""
ct = 0
ret = None
for i, elem in enumerate(binary_vector):
if pulp.value(elem):
ret = i
ct += 1
assert ct == 1
return ret
# 0. Unpack flatten numpy arrays
s_follow = following_nodes
E = edge_pairs.reshape((-1, 2)) # noqa
r = []
pt = 0
edge_set = set()
for (i, j) in E:
prod_length = strategies_len[i] * strategies_len[j]
if (i, j) in edge_set:
raise ValueError(f"Duplicated edges: {(i, j)}")
edge_set.add((i, j))
r.append(resharding_costs[pt:pt + prod_length])
pt += prod_length
assert pt == len(resharding_costs)
######################
# omit alias set now #
######################
# A = alias_set.reshape((-1, 2)) # noqa
# for (i, j) in A:
# prod_length = strategies_len[i] * strategies_len[j]
# v.append(alias_convert_costs[pt:pt + prod_length])
# pt += prod_length
# assert pt == len(alias_convert_costs)
# L = [] # noqa
# pt = node_nums
# for i in range(node_nums):
# length = liveness_set[i]
# L.append(liveness_set[pt:pt + length])
# pt += length
# assert pt == len(liveness_set)
v = []
pt = 0
c = []
d = []
m = []
pt = 0
for i in range(node_nums):
length = strategies_len[i]
c.append(compute_costs[pt:pt + length])
d.append(communication_costs[pt:pt + length])
m.append(memory_costs[pt:pt + length])
pt += length
assert pt == len(compute_costs), f"{pt} == {len(compute_costs)}"
assert pt == len(communication_costs), f"{pt} == {len(communication_costs)}"
assert pt == len(memory_costs), f"{pt} == {len(memory_costs)}"
# 1. Create variables
#############################
# create variables for node #
#############################
s = []
num_nodes = 0
reverse_follow_backpatch = []
for i in range(node_nums):
if s_follow[i] < 0:
if strategies_len[i] == 1:
s.append([1])
else:
num_nodes += 1
s.append(LpVariable.matrix(f"s[{i}]", (range(strategies_len[i]),), cat="Binary"))
else:
if s_follow[i] < len(s):
s.append(s[s_follow[i]])
else:
s.append(None)
reverse_follow_backpatch.append(i)
for i in reverse_follow_backpatch:
s[i] = s[s_follow[i]]
#############################
# create variables for edge #
#############################
e = []
num_edges = 0
for (idx, (i, j)) in enumerate(E):
if len(s[i]) == 1:
e.append(s[j])
elif len(s[j]) == 1:
e.append(s[i])
else:
num_edges += 1
e.append(LpVariable.matrix(f"e[{i},{j}]", (range(len(s[i]) * len(s[j])),), cat="Binary"))
assert len(e[idx]) == len(r[idx])
# 2. Set initial value
######################################
# set a initial value for warm start #
######################################
if s_init_np is not None:
s_init = s_init_np.reshape((-1, 3))
for (idx, value, fix) in s_init:
for i in range(len(s[idx])):
s[idx][i].setInitialValue(i == value)
if fix:
s[idx][i].fixValue()
# 3. Objective
prob = LpProblem("myProblem", LpMinimize)
###################################################################
# computing the node cost(computing cost and communication cost) #
###################################################################
obj = 0
for i in range(node_nums):
obj += lpDot(s[i], c[i]) + lpDot(s[i], d[i])
#############################################
# computing the edge cost(resharding cost) #
#############################################
for i in range(len(E)):
obj += lpDot(e[i], r[i])
prob += obj
# 4. Constraints
# (a). specified by `cat="Binary"`
# (b)
#################################################
# make sure each node only choose one strategy #
#################################################
for i in range(node_nums):
if s_follow[i] < 0:
prob += lpSum(s[i]) == 1
# (c)
#################################################
# compute memory consumption with liveness set #
#################################################
if memory_budget > 0:
for liveness_stage in liveness_set:
mem = 0
for live_variable in liveness_stage.unique_live_vars:
node_index = self.node_index_dict[live_variable.node]
mem += lpSum(s[node_index][j] * m[node_index][j] for j in range(len(s[node_index])))
prob += mem <= memory_budget
# (d). specified by `cat="Binary"`
for (idx, (i, j)) in enumerate(E):
if strategies_len[i] == 1 or strategies_len[j] == 1:
continue
# (e)
prob += lpSum(e[idx]) == 1
# (f)
for row in range(len(s[i])):
C = len(s[j]) # noqa
prob += lpSum(e[idx][row * C + col] for col in range(0, C)) <= s[i][row]
# (g)
for col in range(len(s[j])):
R = len(s[i]) # noqa
C = len(s[j]) # noqa
prob += lpSum(e[idx][row * C + col] for row in range(0, R)) <= s[j][col]
# (h)
######################
# omit alias set now #
######################
# alias_set = set()
# for (idx, (i, j)) in enumerate(A):
# R = len(s[i]) # noqa
# C = len(s[j]) # noqa
# if (i, j) in alias_set:
# raise ValueError(f"Duplicated edges: {(i, j)}")
# alias_set.add((i, j))
# alias_set.add((j, i))
# for row in range(len(s[i])):
# for col in range(len(s[j])):
# if v[idx][row * C + col] > 0.5:
# prob += s[i][row] + s[j][col] <= 1
verbose = True
msg = verbose
time_limit = 600
assert "COIN_CMD" in pulp.listSolvers(
onlyAvailable=True), ("Please install ILP solvers by 'sudo apt install coinor-cbc'")
solver = pulp.COIN_CMD(mip=True, msg=msg, timeLimit=time_limit, threads=multiprocessing.cpu_count())
# solver = pulp.GLPK_CMD(mip=True, msg=msg, timeLimit=time_limit)
prob.solve(solver)
status = prob.status
objective = pulp.value(prob.objective)
objective = float(objective) if objective is not None else -1.0
if verbose:
print(f"ILP Status: {LpStatus[status]}\tObjective: {objective}\t"
f"Time: {time.time() - tic}")
print(f"#nodes: {num_nodes}, #edges: {num_edges}")
if prob.status in [pulp.LpStatusInfeasible]:
raise RuntimeError("Cannot run the function under the given memory budget. "
"Please increase the memory budget.")
# Get and check results
s_val = np.full((node_nums,), -1, dtype=np.int32)
for i in range(node_nums):
s_val[i] = get_non_zero_index(s[i])
e_val = np.full((len(E),), -1, dtype=np.int32)
for (idx, (i, j)) in enumerate(E):
e_val[idx] = get_non_zero_index(e[idx])
i_spec_index = e_val[idx] // len(s[j])
j_spec_index = e_val[idx] % len(s[j])
assert i_spec_index == s_val[i], f"e_val[{i}][{j}]"
assert j_spec_index == s_val[j], f"e_val[{i}][{j}]"
if verbose and r[idx][e_val[idx]] > 0:
print(f"Edge cost {(i, j)} : {r[idx][e_val[idx]]}")
self.last_s_val = s_val
self.last_objective = objective
if objective > INFINITY_COST:
warnings.warn("Detect unexpected behaviors in the auto-sharding pass.")
return s_val, e_val, objective, status
def call_solver_serialized_args(self):
"""
Call the solver with serialized arguments and handle python errors. Additionally,
we could give a serious of solutions with different memory budget.
"""
if self.solution_numbers == 1:
args = self._prepare_data_for_solver()
ret = self._call_solver_serialized_args(*args)
return ret
origin_memory_budget = self.memory_budget
memory_budget_list = [
origin_memory_budget * self.memory_increasing_coefficient**i for i in range(self.solution_numbers)
]
ret_list = []
for memory_budget in memory_budget_list:
self.memory_budget = memory_budget
args = self._prepare_data_for_solver()
ret = self._call_solver_serialized_args(*args)
ret_list.append(ret)
return ret_list

View File

@ -1,7 +1,7 @@
from torch.fx import Graph, Node
from colossalai.tensor.sharding_spec import ShardingSpec
from .sharding_strategy import ShardingStrategy, StrategiesVector
from .conv_handler import ConvHandler
from . import ShardingStrategy, StrategiesVector
from .op_handler import *
from .constants import *
from copy import deepcopy
import math
@ -175,6 +175,58 @@ class StrategiesConstructor:
input_shardings=[input_sharding_spec])
strategies_vector.append(sharding_strategy)
# BatchNormNd module
elif submod_type in BATCHNORM_MODULE_OP:
# bn1 call_module bn1 (conv1,)
# print(node, node.op, node.target, node.args)
# create sharding strategy for element-wise module
# input_node = strategies_vector.predecessor_nodes[0]
norm_handler = BatchNormHandler(node, self.device_mesh, strategies_vector,
self.shape_consistency_manager)
norm_handler.register_strategy()
# for strategy in norm_handler.strategies_vector:
# print(f'{strategy.name}, computation_cost: {strategy.compute_cost}, memory_cost: {strategy.memory_cost}')
# assert False
# MaxPool module
elif submod_type in POOL_MODULE_OP:
# create sharding strategy for element-wise module
assert len(strategies_vector.predecessor_nodes
) == 1, f'Temporally, we just support single input element-wise op.'
input_node = strategies_vector.predecessor_nodes[0]
# For element-wise module, we keep the sharding spec of output node same as
# the input. Therefore, the different strategies of input node with same
# output sharding spec will generate same strategy for element-wise module.
sharding_spec_checklist = []
for strategy in input_node.strategies_vector:
# It looks a little bit confusing, the input of the processing node
# is the output of the input_node.
input_sharding_spec = strategy.output_sharding_spec
assert isinstance(input_sharding_spec,
ShardingSpec), f'The input node should NOT be a tuple of tensor.'
if input_sharding_spec in sharding_spec_checklist:
continue
sharding_spec_checklist.append(input_sharding_spec)
dim_partition_dict = deepcopy(input_sharding_spec.dim_partition_dict)
output_sharding_spec = self._generate_sharding_spec(node, dim_partition_dict)
name = f'{input_sharding_spec.sharding_sequence} -> {output_sharding_spec.sharding_sequence}'
# TODO: use meta_info_prop to profile memory cost and compute cost
compute_cost = node._meta_data.numel()
memory_cost = 0
resharding_costs = self._generate_resharding_costs(strategies_vector.predecessor_nodes,
[input_sharding_spec])
sharding_strategy = ShardingStrategy(name,
output_sharding_spec,
compute_cost=compute_cost,
memory_cost=memory_cost,
resharding_costs=resharding_costs,
input_shardings=[input_sharding_spec])
strategies_vector.append(sharding_strategy)
# other module
else:
raise RuntimeError(f'{submod_type} module is NOT supported now.')
@ -203,7 +255,7 @@ class StrategiesConstructor:
# TODO: integrate element-wise func and module together
# create sharding strategy for element-wise function
assert len(strategies_vector.predecessor_nodes
) == 1, f'Temporally, we just support single input element-wise op.'
) == 1, f'Temporally, we just support single input element-wise op, node name is {node}.'
input_node = strategies_vector.predecessor_nodes[0]
# For element-wise function, we keep the sharding spec of output node same as
# the input. Therefore, the different strategies of input node with same
@ -349,6 +401,13 @@ class StrategiesConstructor:
memory_cost = 0
resharding_costs = self._generate_resharding_costs(strategies_vector.predecessor_nodes,
input_sharding_specs)
# clear the resharding cost for the output node
# TODO: we may remove this in final version
if True:
for prev_node, resharding_cost_list in resharding_costs.items():
resharding_costs[prev_node] = [0] * len(resharding_cost_list)
sharding_strategy_attribute = ShardingStrategy(name,
output_sharding_spec,
memory_cost=memory_cost,

View File

@ -0,0 +1,119 @@
import torch
from torch.fx import GraphModule
import torch.nn as nn
import pytest
from colossalai.fx.proxy import ColoProxy
from colossalai.fx.tracer.tracer import ColoTracer
from colossalai.tensor.sharding_spec import ShardingSpec, _DimSpec
from colossalai.auto_parallel.solver.op_handler.batch_norm_handler import BatchNormHandler
from colossalai.auto_parallel.solver.sharding_strategy import ShardingStrategy, StrategiesVector
from colossalai.tensor.shape_consistency import ShapeConsistencyManager
from colossalai.device.device_mesh import DeviceMesh
class BNModel(nn.Module):
def __init__(self, c):
super().__init__()
self.bn = nn.BatchNorm2d(c)
def forward(self, x):
x = x * 2
x = self.bn(x)
return x
def test_bn_handler():
physical_mesh_id = torch.arange(0, 4)
mesh_shape = (2, 2)
# [[0, 1]
# [2, 3]]
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape)
entire_shape = torch.Size((4, 16, 64, 64))
shape_consistency_manager = ShapeConsistencyManager()
tracer = ColoTracer()
model = BNModel(16)
input_sample = {'x': torch.rand(4, 16, 64, 64).to('meta')}
# graph():
# %x : torch.Tensor [#users=1] = placeholder[target=x]
# %mul : [#users=1] = call_function[target=operator.mul](args = (%x, 2), kwargs = {})
# %bn : [#users=1] = call_module[target=bn](args = (%mul,), kwargs = {})
# return bn
graph = tracer.trace(root=model, meta_args=input_sample)
gm = GraphModule(model, graph, model.__class__.__name__)
gm.recompile()
# [x, mul, bn, output]
nodes = [node for node in gm.graph.nodes]
# find the sharding strategies for the input node of the bn node
# strategies_for_input = [[R, R, R, R], [R, S0, R, R], [R, S1, R, R], [S0, R, R, R], [S0, S1, R, R], [S1, R, R, R], [S1, S0, R, R]]
strategies_vector_for_input = StrategiesVector(nodes[1])
sharding_option = (None, 0, 1)
for first_sharding_index in sharding_option:
for second_sharding_index in sharding_option:
if first_sharding_index is not None and second_sharding_index == first_sharding_index:
continue
if first_sharding_index is None:
first_dim_spec = _DimSpec([])
else:
first_dim_spec = _DimSpec([first_sharding_index])
if second_sharding_index is None:
second_dim_spec = _DimSpec([])
else:
second_dim_spec = _DimSpec([second_sharding_index])
replica_dim_spec = _DimSpec([])
sharding_sequence = [first_dim_spec, second_dim_spec, replica_dim_spec, replica_dim_spec]
sharding_spec = ShardingSpec(device_mesh=device_mesh,
entire_shape=entire_shape,
sharding_sequence=sharding_sequence)
strategy_name = str(sharding_spec.sharding_sequence)
sharding_strategy = ShardingStrategy(name=strategy_name, output_sharding_spec=sharding_spec)
strategies_vector_for_input.append(sharding_strategy)
setattr(nodes[1], 'strategies_vector', strategies_vector_for_input)
# generate bn strategy
strategies_vector = StrategiesVector(node=nodes[2])
bn_handler = BatchNormHandler(node=nodes[2],
device_mesh=device_mesh,
strategies_vector=strategies_vector,
shape_consistency_manager=shape_consistency_manager)
bn_handler.register_strategy()
# ['RS0 = RS0 x S0', 'S1S0 = RS0 x S0', 'RS1 = RS1 x S1', 'S0S1 = RS1 x S1', 'RR = RR x R', 'S0R = RR x R', 'S1R = RR x R', 'S01R = RR x R', 'RS01 = RS01 x S01',
# 'S0R = S0R x R WITH SYNC_BN', 'S1R = S1R x R WITH SYNC_BN', 'S0S1 = S0S1 x S1 WITH SYNC_BN', 'S1S0 = S1S0 x S0 WITH SYNC_BN', 'S01R = S01R x R WITH SYNC_BN']
strategy_name_list = [strategy.name for strategy in bn_handler.strategies_vector]
# RS = RS x S and strategies based on it, such as
# SS = RS x S
assert 'RS0 = RS0 x S0' in strategy_name_list
assert 'S1S0 = RS0 x S0' in strategy_name_list
assert 'RS1 = RS1 x S1' in strategy_name_list
assert 'S0S1 = RS1 x S1' in strategy_name_list
# RR = RR x R and strategies based on it, such as
# SR = SR x R
assert 'RR = RR x R' in strategy_name_list
assert 'S0R = RR x R' in strategy_name_list
assert 'S1R = RR x R' in strategy_name_list
assert 'S01R = RR x R' in strategy_name_list
# RS01 = RS01 x S01
assert 'RS01 = RS01 x S01' in strategy_name_list
# SR = SR x R WITH SYNC_BN
assert 'S0R = S0R x R WITH SYNC_BN' in strategy_name_list
assert 'S1R = S1R x R WITH SYNC_BN' in strategy_name_list
# SS = SS x S WITH SYNC_BN
assert 'S0S1 = S0S1 x S1 WITH SYNC_BN' in strategy_name_list
assert 'S1S0 = S1S0 x S0 WITH SYNC_BN' in strategy_name_list
# S01R = S01R x R WITH SYNC_BN
assert 'S01R = S01R x R WITH SYNC_BN' in strategy_name_list
if __name__ == '__main__':
test_bn_handler()

View File

@ -6,7 +6,7 @@ import pytest
from colossalai.fx.proxy import ColoProxy
from colossalai.fx.tracer.tracer import ColoTracer
from colossalai.tensor.sharding_spec import ShardingSpec, _DimSpec
from colossalai.auto_parallel.solver.conv_handler import ConvHandler
from colossalai.auto_parallel.solver.op_handler.conv_handler import ConvHandler
from colossalai.auto_parallel.solver.sharding_strategy import ShardingStrategy, StrategiesVector
from colossalai.tensor.shape_consistency import ShapeConsistencyManager
from colossalai.device.device_mesh import DeviceMesh

View File

@ -6,8 +6,6 @@ import pytest
from colossalai.fx.proxy import ColoProxy
from colossalai.fx.tracer.tracer import ColoTracer
from colossalai.tensor.sharding_spec import ShardingSpec, _DimSpec
from colossalai.auto_parallel.solver.conv_handler import ConvHandler, CONV_STRATEGIES_LIST
from colossalai.auto_parallel.solver.sharding_strategy import ShardingStrategy, StrategiesVector
from colossalai.tensor.shape_consistency import ShapeConsistencyManager
from colossalai.device.device_mesh import DeviceMesh
from colossalai.auto_parallel.solver.strategies_constructor import StrategiesConstructor

View File

@ -6,7 +6,7 @@ import pytest
from colossalai.fx.proxy import ColoProxy
from colossalai.fx.tracer.tracer import ColoTracer
from colossalai.tensor.sharding_spec import ShardingSpec, _DimSpec
from colossalai.auto_parallel.solver.dot_handler import DotHandler
from colossalai.auto_parallel.solver.op_handler.dot_handler import DotHandler
from colossalai.auto_parallel.solver.sharding_strategy import ShardingStrategy, StrategiesVector
from colossalai.tensor.shape_consistency import ShapeConsistencyManager
from colossalai.device.device_mesh import DeviceMesh

View File

@ -32,21 +32,21 @@ def test_liveness_analysis():
gm = ColoGraphModule(root=model, graph=graph, class_name=model.__class__.__name__)
graph_analyser = GraphAnalyser(gm)
liveness_dict = graph_analyser.liveness_analysis()
stage_count = len(liveness_dict)
liveness_list = graph_analyser.liveness_analysis()
stage_count = len(liveness_list)
# 8 stages including input and output
assert stage_count == 8
# if a LiveStage is covered by another LiveStage, we just keep the larger one.
assert stage_count == 1
# a variable named `relu` must exist
# and this live var must have inplace = True
assert liveness_dict[5].all_live_vars.exists('relu')
relu_var = liveness_dict[5].all_live_vars.get('relu')
assert liveness_list[0].all_live_vars.exists('relu')
relu_var = liveness_list[0].all_live_vars.get('relu')
assert relu_var.is_inplace
# the unique vars must be fewer than the all vars since in-place ops exist
all_live_vars = liveness_dict[7].all_live_vars
unique_live_vars = liveness_dict[7].unique_live_vars
all_live_vars = liveness_list[0].all_live_vars
unique_live_vars = liveness_list[0].unique_live_vars
assert len(unique_live_vars) + 1 == len(all_live_vars)

View File

@ -0,0 +1,78 @@
import torch
from torch.fx import GraphModule
import torch.nn as nn
import pytest
from colossalai.fx.tracer.tracer import ColoTracer
from colossalai.tensor.shape_consistency import ShapeConsistencyManager
from colossalai.device.device_mesh import DeviceMesh
from colossalai.auto_parallel.solver.strategies_constructor import StrategiesConstructor
from colossalai.auto_parallel.solver.cost_graph import CostGraph
from colossalai.auto_parallel.solver.graph_analysis import GraphAnalyser
from copy import deepcopy
from colossalai.auto_parallel.solver import Solver
class ConvModel(nn.Module):
def __init__(self, c_in, c_out):
super().__init__()
self.conv1 = nn.Conv2d(c_in, c_out, kernel_size=3)
self.conv2 = nn.Conv2d(c_out, c_out, kernel_size=3)
self.conv3 = nn.Conv2d(c_out, c_out, kernel_size=3)
self.relu = nn.ReLU()
def forward(self, x):
x = x * 2
x = self.conv1(x)
x = self.conv2(x)
x = x / 2
x = self.conv3(x)
x = self.relu(x)
return x
@pytest.mark.skip("for higher testing speed")
def test_solver():
physical_mesh_id = torch.arange(0, 4)
mesh_shape = (2, 2)
# [[0, 1]
# [2, 3]]
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape)
entire_shape = torch.Size((4, 16, 64, 64))
shape_consistency_manager = ShapeConsistencyManager()
tracer = ColoTracer()
model = ConvModel(16, 32)
input_sample = {'x': torch.rand(4, 16, 64, 64).to('meta')}
# graph():
# %x : torch.Tensor [#users=1] = placeholder[target=x]
# %mul : [#users=1] = call_function[target=operator.mul](args = (%x, 2), kwargs = {})
# %conv1 : [#users=1] = call_module[target=conv1](args = (%mul,), kwargs = {})
# %conv2 : [#users=1] = call_module[target=conv2](args = (%conv1,), kwargs = {})
# %truediv : [#users=1] = call_function[target=operator.truediv](args = (%conv2, 2), kwargs = {})
# %conv3 : [#users=1] = call_module[target=conv3](args = (%truediv,), kwargs = {})
# %relu : [#users=1] = call_module[target=relu](args = (%conv3,), kwargs = {})
# return relu
graph = tracer.trace(root=model, meta_args=input_sample)
gm = GraphModule(model, graph, model.__class__.__name__)
gm.recompile()
solver_options = {'fast_mode': True}
strategies_constructor = StrategiesConstructor(graph, device_mesh, shape_consistency_manager, solver_options)
strategies_constructor.build_strategies_and_cost()
cost_graph = CostGraph(strategies_constructor.leaf_strategies)
cost_graph.simplify_graph()
graph_analyser = GraphAnalyser(gm)
solver = Solver(gm.graph, strategies_constructor, cost_graph, graph_analyser)
ret = solver.call_solver_serialized_args()
# [ 0 0 13 13 13 13 13 0]
strategies_combination_list = ret[0]
assert solver.leaf_strategies[2][13].name == 'S01R = S01R x RR'
if __name__ == '__main__':
test_solver()

View File

@ -6,7 +6,7 @@ import pytest
from colossalai.fx.proxy import ColoProxy
from colossalai.fx.tracer.tracer import ColoTracer
from colossalai.tensor.sharding_spec import ShardingSpec, _DimSpec
from colossalai.auto_parallel.solver.conv_handler import ConvHandler, CONV_STRATEGIES_LIST
from colossalai.auto_parallel.solver.op_handler.conv_handler import CONV_STRATEGIES_LIST
from colossalai.auto_parallel.solver.sharding_strategy import ShardingStrategy, StrategiesVector
from colossalai.tensor.shape_consistency import ShapeConsistencyManager
from colossalai.device.device_mesh import DeviceMesh