mirror of https://github.com/hpcaitech/ColossalAI
[autoparallel] remove deprecated codes (#2664)
parent
7fa6be49d2
commit
0b2a738393
|
@ -1,6 +0,0 @@
|
|||
from .cost_graph import CostGraph
|
||||
from .graph_analysis import GraphAnalyser
|
||||
from .options import SolverOptions
|
||||
from .sharding_strategy import ShardingStrategy, StrategiesVector
|
||||
from .solver import Solver
|
||||
from .strategies_constructor import StrategiesConstructor
|
|
@ -1,142 +0,0 @@
|
|||
import functools
|
||||
import operator
|
||||
import warnings
|
||||
from functools import reduce
|
||||
from typing import Dict, List, Optional, Union
|
||||
|
||||
import torch
|
||||
from torch.fx.node import Node
|
||||
|
||||
from colossalai.device.device_mesh import DeviceMesh
|
||||
from colossalai.tensor.shape_consistency import ShapeConsistencyManager
|
||||
from colossalai.tensor.sharding_spec import ShardingSpec
|
||||
|
||||
from .constants import INFINITY_COST
|
||||
|
||||
|
||||
def generate_sharding_spec(input_: Union[Node, torch.Tensor], device_mesh: DeviceMesh,
|
||||
dim_partition_dict: Dict[int, List[int]]) -> ShardingSpec:
|
||||
"""
|
||||
Generate the sharding spec of the tensor based on the given dim_partition_dict.
|
||||
|
||||
|
||||
Args:
|
||||
input_ (Union[Node, torch.Tensor]): the input can be a Node object or a PyTorch tensor. If a node is used, it will look for its meta data associated with this node.
|
||||
device_mesh (DeviceMesh): a DeviceMesh object which contains the meta information about the cluster.
|
||||
dim_partition_dict (Dict[int, List[int]]): a dictionary to specify the sharding specs, the key is the tensor dimension and the value is the mesh dimension for sharding.
|
||||
"""
|
||||
|
||||
if isinstance(input_, Node):
|
||||
assert hasattr(input_, '_meta_data'), f'The given node has no attribte _meta_data'
|
||||
meta_tensor = input_._meta_data
|
||||
assert meta_tensor is not None, "The given node's _meta_data attribute is None"
|
||||
shape = meta_tensor.shape
|
||||
elif isinstance(input_, torch.Tensor):
|
||||
shape = input_.shape
|
||||
else:
|
||||
raise TypeError(
|
||||
f'We cannot generate sharding spec for {type(input_)} type, only torch.fx.Node or torch.Tensor is expected.'
|
||||
)
|
||||
for dim_index, sharding_index_list in dim_partition_dict.items():
|
||||
sharding_list = [device_mesh.mesh_shape[sharding_index] for sharding_index in sharding_index_list]
|
||||
sharding_size = reduce(operator.mul, sharding_list, 1)
|
||||
assert shape[
|
||||
dim_index] % sharding_size == 0, f'we cannot shard the {dim_index} dimension of tensor into {sharding_size} partitions.'
|
||||
|
||||
sharding_spec = ShardingSpec(device_mesh=device_mesh, entire_shape=shape, dim_partition_dict=dim_partition_dict)
|
||||
return sharding_spec
|
||||
|
||||
|
||||
def generate_resharding_costs(nodes: List[Node],
|
||||
sharding_specs: List[ShardingSpec],
|
||||
count_backward: Optional[bool] = True,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
index=None):
|
||||
'''
|
||||
Compute the resharding costs with this specific strategy.
|
||||
|
||||
Argument:
|
||||
nodes (List[Node]): a list of nodes
|
||||
sharding_spec_for_input(ShardingSpec): a list of ShardingSpec for the nodes.
|
||||
count_backward (Optional[bool]): whether to include the cost of resharding in the backward pass, default is True. False can be used for inference.
|
||||
dtype (Optional[torch.dtype]): the data type for cost calculation, default is None.
|
||||
'''
|
||||
# The resharding_cost of weight is counted due to sharing weight cases.
|
||||
resharding_costs = {}
|
||||
size_per_elem_bytes = torch.tensor([], dtype=dtype).element_size()
|
||||
|
||||
# shape consistency manager is a singleton class
|
||||
shape_consistency_manager = ShapeConsistencyManager()
|
||||
|
||||
for input_node, input_spec in zip(nodes, sharding_specs):
|
||||
resharding_costs[input_node] = []
|
||||
for strategy in input_node.strategies_vector:
|
||||
input_sharding_spec = strategy.output_sharding_spec
|
||||
if not isinstance(input_sharding_spec, ShardingSpec):
|
||||
assert isinstance(input_sharding_spec, list), 'only ShardingSpec or List[ShardingSpec] is expected.'
|
||||
input_sharding_spec = input_sharding_spec[index]
|
||||
assert isinstance(input_sharding_spec, ShardingSpec), f'The input node should NOT be a tuple of tensor.'
|
||||
try:
|
||||
# compute the resharding cost
|
||||
_, _, total_resharding_cost = shape_consistency_manager.shape_consistency(
|
||||
input_sharding_spec, input_spec)
|
||||
|
||||
# we need multiply the size of elem dtype to get correct communication cost
|
||||
resharding_cost = total_resharding_cost["total"] * size_per_elem_bytes
|
||||
except AssertionError as e:
|
||||
warnings.warn(f'{e}')
|
||||
resharding_cost = INFINITY_COST
|
||||
resharding_costs[input_node].append(resharding_cost)
|
||||
return resharding_costs
|
||||
|
||||
|
||||
def ignore_sharding_exception(func):
|
||||
"""
|
||||
A function wrapper which executes the function with a specified seed.
|
||||
"""
|
||||
|
||||
@functools.wraps(func)
|
||||
def wrapper(*args, **kwargs):
|
||||
try:
|
||||
rst = func(*args, **kwargs)
|
||||
return rst
|
||||
except AssertionError as e:
|
||||
warnings.warn(f'{e}')
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
def enumerate_all_possible_2d_sharding(mesh_dim_0, mesh_dim_1, dim_size):
|
||||
dim_partition_list = []
|
||||
# enumerate all the 2D sharding cases
|
||||
for i in range(dim_size):
|
||||
for j in range(i + 1, dim_size):
|
||||
dim_partition_dict_0 = {i: [mesh_dim_0], j: [mesh_dim_1]}
|
||||
dim_partition_dict_1 = {i: [mesh_dim_1], j: [mesh_dim_0]}
|
||||
dim_partition_list.append(dim_partition_dict_0)
|
||||
dim_partition_list.append(dim_partition_dict_1)
|
||||
for i in range(dim_size):
|
||||
dim_partition_dict_flatten = {i: [mesh_dim_0, mesh_dim_1]}
|
||||
dim_partition_list.append(dim_partition_dict_flatten)
|
||||
|
||||
return dim_partition_list
|
||||
|
||||
|
||||
def enumerate_all_possible_1d_sharding(mesh_dim_0, dim_size):
|
||||
dim_partition_list = []
|
||||
# enumerate all the 1D sharding cases
|
||||
for i in range(dim_size):
|
||||
dim_partition_dict_0 = {i: [mesh_dim_0]}
|
||||
dim_partition_list.append(dim_partition_dict_0)
|
||||
|
||||
return dim_partition_list
|
||||
|
||||
|
||||
def generate_sharding_size(dim_partition_dict, device_mesh):
|
||||
total_sharding_size = 1
|
||||
for mesh_dim_list in dim_partition_dict.values():
|
||||
mesh_dim_sharding_size = [device_mesh.shape[mesh_dim] for mesh_dim in mesh_dim_list]
|
||||
sharding_size = reduce(operator.mul, mesh_dim_sharding_size)
|
||||
total_sharding_size *= sharding_size
|
||||
|
||||
return total_sharding_size
|
|
@ -1,83 +0,0 @@
|
|||
import torch
|
||||
import operator
|
||||
|
||||
__all__ = [
|
||||
'ELEMENTWISE_MODULE_OP', 'ELEMENTWISE_FUNC_OP', 'RESHAPE_FUNC_OP', 'CONV_MODULE_OP', 'CONV_FUNC_OP',
|
||||
'LINEAR_MODULE_OP', 'LINEAR_FUNC_OP', 'BATCHNORM_MODULE_OP', 'POOL_MODULE_OP', 'NON_PARAM_FUNC_OP', 'BCAST_FUNC_OP',
|
||||
'EMBEDDING_MODULE_OP', 'LAYERNORM_MODULE_OP', 'ELEMENTWISE_METHOD_OP', 'RESHAPE_METHOD_OP', 'INFINITY_COST'
|
||||
]
|
||||
|
||||
ELEMENTWISE_MODULE_OP = [torch.nn.Dropout, torch.nn.ReLU]
|
||||
ELEMENTWISE_FUNC_OP = [
|
||||
torch.abs,
|
||||
torch.cos,
|
||||
torch.exp,
|
||||
operator.neg,
|
||||
torch.multiply,
|
||||
torch.nn.functional.relu,
|
||||
torch.nn.functional.dropout,
|
||||
# softmax should not be here
|
||||
torch.nn.functional.softmax
|
||||
]
|
||||
ELEMENTWISE_METHOD_OP = [
|
||||
torch.Tensor.to,
|
||||
torch.Tensor.type,
|
||||
# TODO: contiguous maybe need some extra processes.
|
||||
torch.Tensor.contiguous
|
||||
]
|
||||
RESHAPE_FUNC_OP = [torch.flatten, torch.reshape]
|
||||
RESHAPE_METHOD_OP = [
|
||||
torch.Tensor.view,
|
||||
torch.Tensor.unsqueeze,
|
||||
torch.Tensor.split,
|
||||
torch.Tensor.permute,
|
||||
torch.Tensor.transpose,
|
||||
]
|
||||
BCAST_FUNC_OP = [
|
||||
torch.add, torch.sub, torch.mul, torch.div, torch.floor_divide, torch.true_divide, operator.add, operator.sub,
|
||||
operator.mul, operator.floordiv, operator.truediv, torch.matmul, torch.where, operator.pow, torch.pow, torch.tanh
|
||||
]
|
||||
CONV_MODULE_OP = [
|
||||
torch.nn.Conv1d, torch.nn.Conv2d, torch.nn.Conv3d, torch.nn.ConvTranspose1d, torch.nn.ConvTranspose2d,
|
||||
torch.nn.ConvTranspose3d
|
||||
]
|
||||
CONV_FUNC_OP = [
|
||||
torch.conv1d, torch.conv2d, torch.conv3d, torch.conv_transpose1d, torch.conv_transpose2d, torch.conv_transpose3d
|
||||
]
|
||||
EMBEDDING_MODULE_OP = [torch.nn.modules.sparse.Embedding]
|
||||
LINEAR_MODULE_OP = [torch.nn.Linear]
|
||||
LINEAR_FUNC_OP = [torch.nn.functional.linear, torch.matmul, torch.bmm]
|
||||
BATCHNORM_MODULE_OP = [torch.nn.BatchNorm1d, torch.nn.BatchNorm2d, torch.nn.BatchNorm3d, torch.nn.SyncBatchNorm]
|
||||
LAYERNORM_MODULE_OP = [torch.nn.LayerNorm]
|
||||
POOL_MODULE_OP = [torch.nn.MaxPool1d, torch.nn.MaxPool2d, torch.nn.MaxPool3d, torch.nn.AdaptiveAvgPool2d]
|
||||
NON_PARAM_FUNC_OP = [
|
||||
torch.flatten,
|
||||
torch.reshape,
|
||||
torch.abs,
|
||||
torch.cos,
|
||||
torch.exp,
|
||||
operator.neg,
|
||||
torch.multiply,
|
||||
torch.nn.functional.relu,
|
||||
torch.nn.functional.dropout,
|
||||
torch.flatten,
|
||||
torch.where,
|
||||
operator.pow,
|
||||
torch.pow,
|
||||
torch.tanh,
|
||||
torch.add,
|
||||
torch.sub,
|
||||
torch.mul,
|
||||
torch.div,
|
||||
torch.floor_divide,
|
||||
torch.true_divide,
|
||||
operator.add,
|
||||
operator.sub,
|
||||
operator.mul,
|
||||
operator.floordiv,
|
||||
operator.truediv,
|
||||
# softmax should not be here
|
||||
torch.nn.functional.softmax
|
||||
]
|
||||
|
||||
INFINITY_COST = 1e13
|
|
@ -1,172 +0,0 @@
|
|||
from typing import List
|
||||
import math
|
||||
from torch.fx.node import Node
|
||||
from .constants import INFINITY_COST
|
||||
|
||||
|
||||
class CostGraph:
|
||||
'''
|
||||
A graph data structure to simplify the edge cost graph. It has two main functions:
|
||||
1. To feed the quadratic resharding costs into solver, we need to linearize it. We build edge_cost in
|
||||
CostGraph, and it stored every combinations of strategies for a src-dst node pair in an 1D list.
|
||||
2. To reduce the searching space, we merge computationally-trivial operators, such as
|
||||
element-wise operators, transpose, and reduction, into their following nodes. The merging infomation will
|
||||
be given by the StrategiesVector depending on the type of target node and following nodes.
|
||||
|
||||
Argument:
|
||||
leaf_strategies(List[StrategiesVector]): It stores StrategiesVector of every nodes on the graph.
|
||||
simplify(bool, optional): The generated cost graph will be simplified if it is true. (default to True)
|
||||
'''
|
||||
|
||||
def __init__(self, leaf_strategies, simplify=True):
|
||||
self.leaf_strategies = leaf_strategies
|
||||
self.nodes = [strategies_vector.node for strategies_vector in self.leaf_strategies]
|
||||
# stores number of strategies in each node
|
||||
self.node_lens = {strategies_vector.node: len(strategies_vector) for strategies_vector in self.leaf_strategies}
|
||||
# extra_node_costs will store the extra costs introduced by merging nodes
|
||||
self.extra_node_costs = {}
|
||||
self.following_dict = {}
|
||||
self.simplify = simplify
|
||||
self._build_cost_graph()
|
||||
|
||||
def _remove_invalid_node(self, node, attr_name):
|
||||
remove_list = []
|
||||
target_node_list = getattr(node, attr_name, [])
|
||||
for target_node in target_node_list:
|
||||
if target_node not in self.nodes:
|
||||
remove_list.append(target_node)
|
||||
for element in remove_list:
|
||||
target_node_list.remove(element)
|
||||
|
||||
def _build_cost_graph(self):
|
||||
'''
|
||||
This method will generate edge_cost for adjacent node pair. Additionally, 'parents' and 'children' attribute will be
|
||||
set to node.
|
||||
'''
|
||||
self.edge_costs = {}
|
||||
if self.simplify:
|
||||
self.merge_pair = []
|
||||
for strategies_vector in self.leaf_strategies:
|
||||
# build edge_cost
|
||||
dst_node = strategies_vector.node
|
||||
for src_node in strategies_vector.predecessor_nodes:
|
||||
if src_node not in self.nodes:
|
||||
continue
|
||||
node_pair = (src_node, dst_node)
|
||||
# src_index = strategies_vector.predecessor_nodes.index(src_node)
|
||||
edge_cost = {}
|
||||
for i in range(len(strategies_vector)):
|
||||
for j in range(len(src_node.strategies_vector)):
|
||||
edge_cost[(j, i)] = strategies_vector[i].resharding_costs[src_node][j]
|
||||
self.edge_costs[node_pair] = edge_cost
|
||||
# add parents and children attribute to node
|
||||
setattr(dst_node, 'parents', strategies_vector.predecessor_nodes)
|
||||
setattr(dst_node, 'children', strategies_vector.successor_nodes)
|
||||
self._remove_invalid_node(dst_node, 'parents')
|
||||
self._remove_invalid_node(dst_node, 'children')
|
||||
|
||||
if self.simplify and strategies_vector.check_merge():
|
||||
for followed_node in strategies_vector.predecessor_nodes:
|
||||
self.merge_pair.append((followed_node, dst_node))
|
||||
|
||||
def get_edge_cost(self, src_node, dst_node):
|
||||
return self.edge_costs[(src_node, dst_node)]
|
||||
|
||||
def merge_node(self, src_node, dst_node):
|
||||
'''
|
||||
To merge dst_node into src_node, we need to do it in following steps:
|
||||
|
||||
1. For each strategy in dst_node, we need to pick an appropriate strategy
|
||||
of src_node to merge, it is important because the logical resharding costs
|
||||
between the parents node of src_node and merged node depend on the src_node
|
||||
strategies dispatching. For example, for the graph 0->1->2, after merging node 1
|
||||
into node 2, edge_costs[(node 0, node 2)][(0, 0)] = edge_costs[(node 0, node 1)][(0, x)]
|
||||
x represents the picking strategy of node 1 merged into node 2 strategy 0.
|
||||
|
||||
2. We need to accumulate the extra costs introduced by merging nodes, the extra costs
|
||||
contains two parts, one is resharding costs between src_node strategy and dst_node strategy,
|
||||
another is the origin extra costs in src_node strategy.
|
||||
|
||||
3. Build connections between new node pairs, and remove the src_node after all consumer nodes
|
||||
detached from it.
|
||||
|
||||
Argument:
|
||||
src_node(Node): The node will be merged into dst_node.
|
||||
dst_node(Node): The node to integrate src_node.
|
||||
'''
|
||||
src_node_index = dst_node.parents.index(src_node)
|
||||
# build merge_map
|
||||
merge_map = {}
|
||||
for src_index, strategy in enumerate(src_node.strategies_vector):
|
||||
min_cost = INFINITY_COST
|
||||
lowest_cost_index = -1
|
||||
for dst_index, dst_strategy in enumerate(dst_node.strategies_vector):
|
||||
resharding_cost = dst_strategy.resharding_costs[src_node][src_index]
|
||||
if resharding_cost <= min_cost:
|
||||
min_cost = resharding_cost
|
||||
lowest_cost_index = dst_index
|
||||
merge_map[src_index] = lowest_cost_index
|
||||
|
||||
# extra_node_cost for src node
|
||||
self.extra_node_costs[src_node] = [0.0] * self.node_lens[src_node]
|
||||
for src_index, strategy in enumerate(src_node.strategies_vector):
|
||||
target_strate_index = merge_map[src_index]
|
||||
target_strategy = dst_node.strategies_vector[target_strate_index]
|
||||
self.extra_node_costs[src_node][src_index] += target_strategy.resharding_costs[src_node][src_index]
|
||||
if dst_node in self.extra_node_costs:
|
||||
self.extra_node_costs[src_node][src_index] += self.extra_node_costs[dst_node][target_strate_index]
|
||||
|
||||
# add new node pair to cost graph
|
||||
for child_node in dst_node.children:
|
||||
new_node_pair = (src_node, child_node)
|
||||
old_node_pair = (dst_node, child_node)
|
||||
if new_node_pair in self.edge_costs:
|
||||
continue
|
||||
edge_cost = {}
|
||||
for i in range(self.node_lens[src_node]):
|
||||
for j in range(self.node_lens[child_node]):
|
||||
dst_strate_index = merge_map[i]
|
||||
# dst_strategy = dst_node.strategies_vector[dst_strate_index]
|
||||
edge_cost[(i, j)] = self.edge_costs[old_node_pair][(dst_strate_index, j)]
|
||||
if new_node_pair not in self.edge_costs:
|
||||
self.edge_costs[new_node_pair] = edge_cost
|
||||
else:
|
||||
# we should accumulate the resharding costs if args of child node contain
|
||||
# both src node and dst node.
|
||||
for index_pair, resharding_cost in self.edge_costs[new_node_pair]:
|
||||
self.edge_costs[new_node_pair][index_pair] += edge_cost[index_pair]
|
||||
|
||||
# connect src node and children of dst node
|
||||
dst_node.parents.remove(src_node)
|
||||
src_node.children.remove(dst_node)
|
||||
self.edge_costs.pop((src_node, dst_node))
|
||||
for child_node in dst_node.children:
|
||||
if child_node not in src_node.children:
|
||||
src_node.children.append(child_node)
|
||||
if src_node not in child_node.parents:
|
||||
child_node.parents.append(src_node)
|
||||
# remove dst node from cost graph when dst node has no producer.
|
||||
if len(dst_node.parents) == 0:
|
||||
child_node.parents.remove(dst_node)
|
||||
node_pair = (dst_node, child_node)
|
||||
self.edge_costs.pop(node_pair)
|
||||
if len(dst_node.parents) == 0:
|
||||
self.following_dict[dst_node] = src_node
|
||||
dst_node.children = []
|
||||
|
||||
def _reindexing_src(self, src):
|
||||
if src not in self.following_dict:
|
||||
return src
|
||||
return self._reindexing_src(self.following_dict[src])
|
||||
|
||||
def simplify_graph(self):
|
||||
if not self.simplify:
|
||||
return
|
||||
self.merge_pair.reverse()
|
||||
for (src_node, dst_node) in self.merge_pair:
|
||||
self.merge_node(src_node, dst_node)
|
||||
self.merge_pair.reverse()
|
||||
reindexing_following_dict = {}
|
||||
for dst, src in self.following_dict.items():
|
||||
reindexing_following_dict[dst] = self._reindexing_src(src)
|
||||
self.following_dict = reindexing_following_dict
|
|
@ -1,163 +0,0 @@
|
|||
from dataclasses import dataclass
|
||||
from torch.fx.node import Node
|
||||
from torch.fx.graph import Graph
|
||||
from torch.fx.graph_module import GraphModule
|
||||
from collections import OrderedDict as ODict
|
||||
from typing import List, OrderedDict, Union, Any
|
||||
from colossalai.fx.passes.utils import get_node_module
|
||||
|
||||
__all__ = ['LiveVariable', 'LiveVariableVector', 'LiveStage', 'GraphAnalyser']
|
||||
|
||||
|
||||
@dataclass
|
||||
class LiveVariable:
|
||||
"""
|
||||
LiveVariable is a data structure to store the meta information of a variable for liveness analysis.
|
||||
"""
|
||||
name: str
|
||||
node: Node
|
||||
is_inplace: bool
|
||||
|
||||
|
||||
class LiveVariableVector(list):
|
||||
"""
|
||||
LiveVariableVector is a data structure to store the list of LiveVariable objects.
|
||||
"""
|
||||
|
||||
def exists(self, name) -> bool:
|
||||
"""
|
||||
Check if a variable has already existed in the current list by name.
|
||||
"""
|
||||
for var in self:
|
||||
if name == var.name:
|
||||
return True
|
||||
return False
|
||||
|
||||
def get(self, name) -> LiveVariable:
|
||||
for var in self:
|
||||
if name == var.name:
|
||||
return var
|
||||
raise KeyError(f"Variable {name} is not found")
|
||||
|
||||
def copy(self) -> "LiveVariableVector":
|
||||
"""
|
||||
Create a copy of this vector
|
||||
"""
|
||||
vector = LiveVariableVector()
|
||||
for var in self:
|
||||
vector.append(var)
|
||||
return vector
|
||||
|
||||
|
||||
@dataclass
|
||||
class LiveStage:
|
||||
"""
|
||||
LiveStage is a data structure to record the living variables at this current node.
|
||||
"""
|
||||
name: str
|
||||
node: Node
|
||||
all_live_vars: LiveVariableVector
|
||||
unique_live_vars: LiveVariableVector
|
||||
|
||||
|
||||
class GraphAnalyser:
|
||||
|
||||
def __init__(self, gm: GraphModule):
|
||||
self._gm = gm
|
||||
self._graph = gm.graph
|
||||
|
||||
@property
|
||||
def gm(self) -> GraphModule:
|
||||
"""
|
||||
Return the GraphModule object associated with this analyser.
|
||||
"""
|
||||
return self._gm
|
||||
|
||||
@property
|
||||
def graph(self) -> Graph:
|
||||
"""
|
||||
Return the Graph object associated with this analyser.
|
||||
"""
|
||||
return self._graph
|
||||
|
||||
def liveness_analysis(self) -> List[LiveStage]:
|
||||
"""
|
||||
Analyse the graph to obtain the variable liveness information. This function returns
|
||||
an ordered dictionary where the key is the compute stage ID and the value is a LivenessStage object.
|
||||
"""
|
||||
compute_nodes = self.graph.nodes
|
||||
liveness_list = []
|
||||
|
||||
# checked: record all variables created since the first stage
|
||||
# all: record the live variables only exist until the current stage.
|
||||
# this can be different from the `checked list`` as some varialbes may be destroyed prior to this stage.
|
||||
# unique: record the unique live variables only exist until the current stage.
|
||||
# this is different from `all list` as some variables are duplicated.
|
||||
checked_variables = LiveVariableVector()
|
||||
all_live_variables = LiveVariableVector()
|
||||
unique_live_vars = LiveVariableVector()
|
||||
|
||||
for idx, node in enumerate(compute_nodes):
|
||||
#############################
|
||||
# find new living variables #
|
||||
#############################
|
||||
# detect whether the current op is an in-place op
|
||||
# if it is an in-place op, we would deem it as a duplciate var
|
||||
is_inplace = False
|
||||
if node.op == 'call_function':
|
||||
# check if this is an inplace op such as torch.nn.functional.relu(x, inplace=True)
|
||||
if node.kwargs.get('inplace', False):
|
||||
is_inplace = True
|
||||
elif node.op == 'call_module':
|
||||
# to check if this is an inplace op such as torch.nn.Relu(inplace=True)
|
||||
module = get_node_module(node)
|
||||
if getattr(module, 'inplace', False):
|
||||
is_inplace = True
|
||||
|
||||
# add the output var
|
||||
meta = getattr(node, '_meta_data', None)
|
||||
live_var = LiveVariable(name=node.name, node=node, is_inplace=is_inplace)
|
||||
if not is_inplace:
|
||||
unique_live_vars.append(live_var)
|
||||
checked_variables.append(live_var)
|
||||
all_live_variables.append(live_var)
|
||||
|
||||
# check if any input is not checked yet
|
||||
for arg in node.args:
|
||||
if not isinstance(arg, Node):
|
||||
continue
|
||||
arg_name = arg.name
|
||||
if not checked_variables.exists(arg_name):
|
||||
live_var_from_arg = LiveVariable(name=arg_name, node=node, is_inplace=False)
|
||||
all_live_variables.append(live_var_from_arg)
|
||||
checked_variables.append(live_var_from_arg)
|
||||
unique_live_vars.append(live_var_from_arg)
|
||||
|
||||
# TODO: add the logic to remove live variables
|
||||
# this should be completed if we are able to trace the backward compute graph
|
||||
|
||||
# add this stage to liveness dict
|
||||
stage = LiveStage(name=node.name,
|
||||
node=node,
|
||||
all_live_vars=all_live_variables.copy(),
|
||||
unique_live_vars=unique_live_vars.copy())
|
||||
# if a LiveStage is covered by another LiveStage, we just keep the larger one.
|
||||
replace = False
|
||||
for index, prev_stage in enumerate(liveness_list):
|
||||
all_covered = True
|
||||
for ele in prev_stage.unique_live_vars:
|
||||
if ele not in stage.unique_live_vars:
|
||||
all_covered = False
|
||||
break
|
||||
if all_covered:
|
||||
replace = True
|
||||
break
|
||||
if replace:
|
||||
liveness_list[index] = stage
|
||||
else:
|
||||
liveness_list.append(stage)
|
||||
|
||||
return liveness_list
|
||||
|
||||
def get_alias_set(self):
|
||||
pass
|
|
@ -1,15 +0,0 @@
|
|||
from .batch_norm_handler import BatchNormHandler
|
||||
from .bcast_op_handler import BcastOpHandler
|
||||
from .conv_handler import ConvHandler
|
||||
from .dot_handler import DotHandler
|
||||
from .embedding_handler import EmbeddingHandler
|
||||
from .layer_norm_handler import LayerNormHandler
|
||||
from .operator_handler import OperatorHandler
|
||||
from .reshape_handler import ReshapeHandler
|
||||
from .unary_elementwise_handler import UnaryElementwiseHandler
|
||||
from .where_handler import WhereHandler
|
||||
|
||||
__all__ = [
|
||||
'OperatorHandler', 'DotHandler', 'ConvHandler', 'BatchNormHandler', 'ReshapeHandler', 'BcastOpHandler',
|
||||
'UnaryElementwiseHandler', 'EmbeddingHandler', 'WhereHandler', 'LayerNormHandler'
|
||||
]
|
|
@ -1,492 +0,0 @@
|
|||
import operator
|
||||
from functools import reduce
|
||||
|
||||
import torch
|
||||
from colossalai.auto_parallel.tensor_shard.deprecated._utils import \
|
||||
ignore_sharding_exception
|
||||
from colossalai.auto_parallel.tensor_shard.deprecated.sharding_strategy import (ShardingStrategy, StrategiesVector)
|
||||
|
||||
from .operator_handler import OperatorHandler
|
||||
|
||||
__all__ = ['BatchNormHandler']
|
||||
|
||||
|
||||
class BatchNormHandler(OperatorHandler):
|
||||
"""
|
||||
A OperatorHandler which deals with the sharding strategies of normalization.
|
||||
|
||||
To keep the math consistency, there are two way to do BatchNorm if the input
|
||||
shards on batch dimension:
|
||||
1. We gather the input partitions through batch dimension, then do the normal BatchNorm.
|
||||
2. We do the SyncBatchNorm on the each input partition seperately, the SyncBN op will help
|
||||
us to keep the computing correctness.
|
||||
In this handler, both methods will be considered.
|
||||
"""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.input_data = self.predecessor_node[0]._meta_data
|
||||
self.weight = self.module_named_parameters['weight']
|
||||
self.bias = self.module_named_parameters['bias']
|
||||
self.output_data = self.node._meta_data
|
||||
self._sanity_check()
|
||||
|
||||
def _sanity_check(self):
|
||||
'''
|
||||
In sanity check, we need make sure the input data having correct dimension size.
|
||||
For BatchNorm1d, the dim of input data should be 3([N, C, L]).
|
||||
For BatchNorm2d, the dim of input data should be 4([N, C, H, W]).
|
||||
For BatchNorm3d, the dim of input data should be 5([N, C, H, W, D]).
|
||||
'''
|
||||
assert self.input_data.dim() in (3, 4,
|
||||
5), f'We suppose the dim of input fed into conv op should in range of [3, 5].'
|
||||
|
||||
def _generate_compute_cost(self, bs, channel_in):
|
||||
'''
|
||||
Compute the computation cost per device with this specific strategy.
|
||||
|
||||
Note: compute_cost need to be devided by TFLOPS, now it just shows the computation size.
|
||||
|
||||
Argument:
|
||||
bs(int): Batch size of the input data.
|
||||
channel_in(int): The channel dimension of input data.
|
||||
|
||||
Return:
|
||||
compute_cost(float): Computation cost per device with this specific strategy
|
||||
'''
|
||||
# TODO: compute_cost need to be devided by TFLOPS, now it just shows the computation size.
|
||||
# TODO: a constant coefficient need to be added.
|
||||
# 1D: (L) * N * Cin
|
||||
# 2D: (H * W) * N * Cin
|
||||
# 3D: (H * W * D) * N * Cin
|
||||
|
||||
input_size = self.input_data.shape[2:]
|
||||
input_size_product = reduce(operator.mul, input_size, 1)
|
||||
forward_compute_cost = input_size_product * bs * channel_in
|
||||
backward_activation_compute_cost = input_size_product * bs * channel_in
|
||||
backward_weight_compute_cost = input_size_product * bs * channel_in
|
||||
backward_compute_cost = backward_activation_compute_cost + backward_weight_compute_cost
|
||||
compute_cost = forward_compute_cost + backward_compute_cost
|
||||
return compute_cost
|
||||
|
||||
def _generate_memory_cost(self, sharding_size_forward, sharding_size_backward_activation, sharding_size_weight):
|
||||
'''
|
||||
Compute the memory cost per device with this specific strategy.
|
||||
|
||||
Argument:
|
||||
sharding_size_forward(int): The forward activation will be divided
|
||||
into sharding_size_forward number partions.
|
||||
sharding_size_backward_activation(int): The backward activation will
|
||||
be divided into sharding_size_backward_activation number partions.
|
||||
sharding_size_weight(int): The backward weight will be divided
|
||||
into sharding_size_weight number partions.
|
||||
|
||||
Return:
|
||||
memory_cost(Tuple[float]): Memory cost per device with this
|
||||
specific strategy, the first element of this tuple is forward
|
||||
memory cost, and the second element of this tuple is backward
|
||||
memory cost.
|
||||
memory_cost_forward(float): Memory cost of forward activation per
|
||||
device with this specific strategy.
|
||||
memory_cost_backward_activation(float): Memory cost of backward activation
|
||||
per device with this specific strategy.
|
||||
'''
|
||||
# compute the memory cost of this strategy
|
||||
dtype = self.input_data.dtype
|
||||
numel_output = self.output_data.numel()
|
||||
numel_input = numel_output
|
||||
numel_weight = self.weight.numel()
|
||||
size_per_elem_bytes = torch.tensor([], dtype=dtype).element_size()
|
||||
|
||||
# forward memory_cost
|
||||
memory_cost_forward_activation = numel_output * size_per_elem_bytes / sharding_size_forward
|
||||
memory_cost_forward_weight = numel_weight * size_per_elem_bytes / sharding_size_weight
|
||||
memory_cost_forward = memory_cost_forward_activation + memory_cost_forward_weight
|
||||
|
||||
# backward memory_cost
|
||||
memory_cost_backward_activation = numel_input * size_per_elem_bytes / sharding_size_backward_activation
|
||||
memory_cost_backward_weight = numel_weight * size_per_elem_bytes / sharding_size_weight
|
||||
memory_cost_backward = memory_cost_backward_activation + memory_cost_backward_weight
|
||||
|
||||
# memory_cost pair
|
||||
memory_cost = (memory_cost_forward, memory_cost_backward)
|
||||
|
||||
return memory_cost, memory_cost_forward_activation, memory_cost_backward_activation
|
||||
|
||||
@ignore_sharding_exception
|
||||
def split_input_channel(self, mesh_dim_0, mesh_dim_1):
|
||||
name = f'RS{mesh_dim_0} = RS{mesh_dim_0} x S{mesh_dim_0}'
|
||||
|
||||
dim_partition_dict_for_input = {1: [mesh_dim_0]}
|
||||
sharding_spec_for_input = self._generate_sharding_spec(self.input_data, dim_partition_dict_for_input)
|
||||
|
||||
dim_partition_dict_for_weight = {0: [mesh_dim_0]}
|
||||
sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight)
|
||||
|
||||
dim_partition_dict_for_output = {1: [mesh_dim_0]}
|
||||
sharding_spec_for_output = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output)
|
||||
|
||||
# generate resharding cost for this strategy
|
||||
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input])
|
||||
|
||||
# compute the computation cost of this strategy
|
||||
bs = self.input_data.shape[0]
|
||||
channel_in = self.input_data.shape[1] // self.device_mesh.shape[mesh_dim_0]
|
||||
compute_cost = self._generate_compute_cost(bs, channel_in)
|
||||
|
||||
# compute the memory cost of this strategy
|
||||
sharding_size_forward = self.device_mesh.shape[mesh_dim_0]
|
||||
sharding_size_backward_activation = self.device_mesh.shape[mesh_dim_0]
|
||||
sharding_size_weight = self.device_mesh.shape[mesh_dim_0]
|
||||
memory_cost, _, _ = self._generate_memory_cost(sharding_size_forward, sharding_size_backward_activation,
|
||||
sharding_size_weight)
|
||||
|
||||
# This strategy do not need to do all_reduce operation
|
||||
communication_cost = 0
|
||||
sharding_strategies = ShardingStrategy(name,
|
||||
output_sharding_spec=sharding_spec_for_output,
|
||||
compute_cost=compute_cost,
|
||||
communication_cost=communication_cost,
|
||||
memory_cost=memory_cost,
|
||||
resharding_costs=resharding_costs,
|
||||
input_shardings=(sharding_spec_for_input, sharding_spec_for_weight))
|
||||
|
||||
self.strategies_vector.append(sharding_strategies)
|
||||
|
||||
# shard the output batch dimension to get all possible sharding strategy from this basic strategy
|
||||
new_name = f'S{mesh_dim_1}S{mesh_dim_0} = RS{mesh_dim_0} x S{mesh_dim_0}'
|
||||
|
||||
dim_partition_dict_for_output = {0: [mesh_dim_1], 1: [mesh_dim_0]}
|
||||
new_sharding_spec_for_output = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output)
|
||||
# the computation cost is all the same
|
||||
new_compute_cost = compute_cost
|
||||
|
||||
# the memory cost need to be recomputed
|
||||
# compute the memroy cost of new strategy
|
||||
new_sharding_size_forward = self.device_mesh.shape[mesh_dim_0] * self.device_mesh.shape[mesh_dim_1]
|
||||
sharding_size_backward_activation = self.device_mesh.shape[mesh_dim_0]
|
||||
sharding_size_weight = self.device_mesh.shape[mesh_dim_0]
|
||||
new_memory_cost, _, memory_cost_backward_activation = self._generate_memory_cost(
|
||||
new_sharding_size_forward, sharding_size_backward_activation, sharding_size_weight)
|
||||
|
||||
# the communication cost need to count the sharding cost into this strategy
|
||||
# compute the communication cost of new strategy
|
||||
origin_communication_cost = communication_cost
|
||||
tiny_shard_cost = 10
|
||||
new_forward_communication_cost = tiny_shard_cost
|
||||
# we need to all gather the batch dimension for the basic strategy
|
||||
new_backward_communication_cost = self.device_mesh.all_gather_cost(memory_cost_backward_activation, mesh_dim_1)
|
||||
new_communication_cost = origin_communication_cost + new_forward_communication_cost + new_backward_communication_cost
|
||||
|
||||
sharding_strategies = ShardingStrategy(new_name,
|
||||
output_sharding_spec=new_sharding_spec_for_output,
|
||||
compute_cost=new_compute_cost,
|
||||
communication_cost=new_communication_cost,
|
||||
memory_cost=new_memory_cost,
|
||||
resharding_costs=resharding_costs,
|
||||
input_shardings=(sharding_spec_for_input, sharding_spec_for_weight))
|
||||
|
||||
self.strategies_vector.append(sharding_strategies)
|
||||
|
||||
@ignore_sharding_exception
|
||||
def split_input_channel_1d(self, mesh_dim_0, mesh_dim_1):
|
||||
name = f'RS{mesh_dim_0}{mesh_dim_1} = RS{mesh_dim_0}{mesh_dim_1} x S{mesh_dim_0}{mesh_dim_1}'
|
||||
|
||||
dim_partition_dict_for_input = {1: [mesh_dim_0, mesh_dim_1]}
|
||||
sharding_spec_for_input = self._generate_sharding_spec(self.input_data, dim_partition_dict_for_input)
|
||||
|
||||
dim_partition_dict_for_weight = {0: [mesh_dim_0, mesh_dim_1]}
|
||||
sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight)
|
||||
|
||||
dim_partition_dict_for_output = {1: [mesh_dim_0, mesh_dim_1]}
|
||||
sharding_spec_for_output = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output)
|
||||
|
||||
# generate resharding cost for this strategy
|
||||
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input])
|
||||
|
||||
# compute the computation cost of this strategy
|
||||
bs = self.input_data.shape[0]
|
||||
channel_in = self.input_data.shape[1] // (self.device_mesh.shape[mesh_dim_0] *
|
||||
self.device_mesh.shape[mesh_dim_1])
|
||||
compute_cost = self._generate_compute_cost(bs, channel_in)
|
||||
|
||||
# compute the memory cost of this strategy
|
||||
sharding_size_forward = self.device_mesh.shape[mesh_dim_0] * self.device_mesh.shape[mesh_dim_1]
|
||||
sharding_size_backward_activation = self.device_mesh.shape[mesh_dim_0] * self.device_mesh.shape[mesh_dim_1]
|
||||
sharding_size_weight = self.device_mesh.shape[mesh_dim_0] * self.device_mesh.shape[mesh_dim_1]
|
||||
memory_cost, _, _ = self._generate_memory_cost(sharding_size_forward, sharding_size_backward_activation,
|
||||
sharding_size_weight)
|
||||
|
||||
# This strategy do not need to do all_reduce operation
|
||||
communication_cost = 0
|
||||
sharding_strategies = ShardingStrategy(name,
|
||||
output_sharding_spec=sharding_spec_for_output,
|
||||
compute_cost=compute_cost,
|
||||
communication_cost=communication_cost,
|
||||
memory_cost=memory_cost,
|
||||
resharding_costs=resharding_costs,
|
||||
input_shardings=(sharding_spec_for_input, sharding_spec_for_weight))
|
||||
|
||||
self.strategies_vector.append(sharding_strategies)
|
||||
|
||||
@ignore_sharding_exception
|
||||
def non_split(self, mesh_dim_0, mesh_dim_1):
|
||||
name = f'RR = RR x R'
|
||||
|
||||
dim_partition_dict_for_input = {}
|
||||
sharding_spec_for_input = self._generate_sharding_spec(self.input_data, dim_partition_dict_for_input)
|
||||
|
||||
dim_partition_dict_for_weight = {}
|
||||
sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight)
|
||||
|
||||
dim_partition_dict_for_output = {}
|
||||
sharding_spec_for_output = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output)
|
||||
|
||||
# generate resharding cost for this strategy
|
||||
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input])
|
||||
|
||||
# compute the computation cost of this strategy
|
||||
bs = self.input_data.shape[0]
|
||||
channel_in = self.input_data.shape[1]
|
||||
compute_cost = self._generate_compute_cost(bs, channel_in)
|
||||
|
||||
# compute the memory cost of this strategy
|
||||
sharding_size_forward = 1
|
||||
sharding_size_backward_activation = 1
|
||||
sharding_size_weight = 1
|
||||
memory_cost, _, _ = self._generate_memory_cost(sharding_size_forward, sharding_size_backward_activation,
|
||||
sharding_size_weight)
|
||||
|
||||
# This strategy do not need to do all_reduce operation
|
||||
communication_cost = 0
|
||||
sharding_strategies = ShardingStrategy(name,
|
||||
output_sharding_spec=sharding_spec_for_output,
|
||||
compute_cost=compute_cost,
|
||||
communication_cost=communication_cost,
|
||||
memory_cost=memory_cost,
|
||||
resharding_costs=resharding_costs,
|
||||
input_shardings=(sharding_spec_for_input, sharding_spec_for_weight))
|
||||
|
||||
self.strategies_vector.append(sharding_strategies)
|
||||
|
||||
def _construct_batch_sharding_strategies(mesh_dim_list, new_name):
|
||||
dim_partition_dict_for_output = {0: mesh_dim_list}
|
||||
new_sharding_spec_for_output = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output)
|
||||
|
||||
# the computation cost is all the same
|
||||
new_compute_cost = compute_cost
|
||||
|
||||
# the memory cost need to be recomputed
|
||||
new_sharding_size_input = 1
|
||||
for mesh_dim in mesh_dim_list:
|
||||
new_sharding_size_input = new_sharding_size_input * self.device_mesh.shape[mesh_dim]
|
||||
new_memory_cost, _, memory_cost_backward_activation = self._generate_memory_cost(
|
||||
new_sharding_size_input, sharding_size_backward_activation, sharding_size_weight)
|
||||
|
||||
# the communication cost need to count the sharding cost into this strategy
|
||||
origin_communication_cost = communication_cost
|
||||
tiny_shard_cost = 10
|
||||
new_forward_communication_cost = tiny_shard_cost
|
||||
if len(mesh_dim_list) == 1:
|
||||
new_backward_communication_cost = self.device_mesh.all_gather_cost(memory_cost_backward_activation,
|
||||
mesh_dim_list[0])
|
||||
else:
|
||||
new_backward_communication_cost = self.device_mesh.flatten_device_mesh.all_gather_cost(
|
||||
memory_cost_backward_activation, 0)
|
||||
new_communication_cost = origin_communication_cost + new_forward_communication_cost + new_backward_communication_cost
|
||||
|
||||
new_sharding_strategy = ShardingStrategy(new_name,
|
||||
output_sharding_spec=new_sharding_spec_for_output,
|
||||
compute_cost=new_compute_cost,
|
||||
communication_cost=new_communication_cost,
|
||||
memory_cost=new_memory_cost,
|
||||
resharding_costs=resharding_costs,
|
||||
input_shardings=(sharding_spec_for_input,
|
||||
sharding_spec_for_weight))
|
||||
|
||||
return new_sharding_strategy
|
||||
|
||||
# shard the output batch dimension to get all possible sharding strategy from this basic strategy
|
||||
# shard on mesh_dim_0
|
||||
new_name = f'S{mesh_dim_0}R = RR x R'
|
||||
mesh_dim_list = [mesh_dim_0]
|
||||
new_sharding_strategy = _construct_batch_sharding_strategies(mesh_dim_list, new_name)
|
||||
self.strategies_vector.append(new_sharding_strategy)
|
||||
|
||||
# shard on mesh_dim_1
|
||||
new_name = f'S{mesh_dim_1}R = RR x R'
|
||||
mesh_dim_list = [mesh_dim_1]
|
||||
new_sharding_strategy = _construct_batch_sharding_strategies(mesh_dim_list, new_name)
|
||||
self.strategies_vector.append(new_sharding_strategy)
|
||||
|
||||
# shard on mesh_dim_0, mesh_dim_1
|
||||
new_name = f'S{mesh_dim_0}{mesh_dim_1}R = RR x R'
|
||||
mesh_dim_list = [mesh_dim_0, mesh_dim_1]
|
||||
new_sharding_strategy = _construct_batch_sharding_strategies(mesh_dim_list, new_name)
|
||||
self.strategies_vector.append(new_sharding_strategy)
|
||||
|
||||
@ignore_sharding_exception
|
||||
def split_input_batch(self, mesh_dim_0):
|
||||
name = f'S{mesh_dim_0}R = S{mesh_dim_0}R x R WITH SYNC_BN'
|
||||
|
||||
dim_partition_dict_for_input = {0: [mesh_dim_0]}
|
||||
sharding_spec_for_input = self._generate_sharding_spec(self.input_data, dim_partition_dict_for_input)
|
||||
|
||||
dim_partition_dict_for_weight = {}
|
||||
sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight)
|
||||
|
||||
dim_partition_dict_for_output = {0: [mesh_dim_0]}
|
||||
sharding_spec_for_output = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output)
|
||||
|
||||
# generate resharding cost for this strategy
|
||||
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input])
|
||||
|
||||
# compute the computation cost of this strategy
|
||||
bs = self.input_data.shape[0] // self.device_mesh.shape[mesh_dim_0]
|
||||
channel_in = self.input_data.shape[1]
|
||||
compute_cost = self._generate_compute_cost(bs, channel_in)
|
||||
|
||||
# compute the memory cost of this strategy
|
||||
sharding_size_forward = self.device_mesh.shape[mesh_dim_0]
|
||||
sharding_size_backward_activation = self.device_mesh.shape[mesh_dim_0]
|
||||
sharding_size_weight = 1
|
||||
memory_cost, memory_cost_forward_activation, _ = self._generate_memory_cost(sharding_size_forward,
|
||||
sharding_size_backward_activation,
|
||||
sharding_size_weight)
|
||||
|
||||
# the all reduce communication will happen during the sync bn computing.
|
||||
communication_cost = self.device_mesh.all_reduce_cost(memory_cost_forward_activation, mesh_dim_0)
|
||||
sharding_strategies = ShardingStrategy(name,
|
||||
output_sharding_spec=sharding_spec_for_output,
|
||||
compute_cost=compute_cost,
|
||||
communication_cost=communication_cost,
|
||||
memory_cost=memory_cost,
|
||||
resharding_costs=resharding_costs,
|
||||
input_shardings=(sharding_spec_for_input, sharding_spec_for_weight))
|
||||
|
||||
self.strategies_vector.append(sharding_strategies)
|
||||
|
||||
@ignore_sharding_exception
|
||||
def split_input_batch_1d(self, mesh_dim_0, mesh_dim_1):
|
||||
name = f'S{mesh_dim_0}{mesh_dim_1}R = S{mesh_dim_0}{mesh_dim_1}R x R WITH SYNC_BN'
|
||||
|
||||
dim_partition_dict_for_input = {0: [mesh_dim_0, mesh_dim_1]}
|
||||
sharding_spec_for_input = self._generate_sharding_spec(self.input_data, dim_partition_dict_for_input)
|
||||
|
||||
dim_partition_dict_for_weight = {}
|
||||
sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight)
|
||||
|
||||
dim_partition_dict_for_output = {0: [mesh_dim_0, mesh_dim_1]}
|
||||
sharding_spec_for_output = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output)
|
||||
|
||||
# generate resharding cost for this strategy
|
||||
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input])
|
||||
|
||||
# compute the computation cost of this strategy
|
||||
bs = self.input_data.shape[0] // (self.device_mesh.shape[mesh_dim_0] * self.device_mesh.shape[mesh_dim_1])
|
||||
channel_in = self.input_data.shape[1]
|
||||
compute_cost = self._generate_compute_cost(bs, channel_in)
|
||||
|
||||
# compute the memory cost of this strategy
|
||||
sharding_size_forward = self.device_mesh.shape[mesh_dim_0] * self.device_mesh.shape[mesh_dim_1]
|
||||
sharding_size_backward_activation = self.device_mesh.shape[mesh_dim_0] * self.device_mesh.shape[mesh_dim_1]
|
||||
sharding_size_weight = 1
|
||||
memory_cost, memory_cost_forward_activation, _ = self._generate_memory_cost(sharding_size_forward,
|
||||
sharding_size_backward_activation,
|
||||
sharding_size_weight)
|
||||
|
||||
# the all reduce communication will happen during the sync bn computing.
|
||||
communication_cost = self.device_mesh.flatten_device_mesh.all_reduce_cost(memory_cost_forward_activation, 0)
|
||||
sharding_strategies = ShardingStrategy(name,
|
||||
output_sharding_spec=sharding_spec_for_output,
|
||||
compute_cost=compute_cost,
|
||||
communication_cost=communication_cost,
|
||||
memory_cost=memory_cost,
|
||||
resharding_costs=resharding_costs,
|
||||
input_shardings=(sharding_spec_for_input, sharding_spec_for_weight))
|
||||
|
||||
self.strategies_vector.append(sharding_strategies)
|
||||
|
||||
@ignore_sharding_exception
|
||||
def split_input_both_dim(self, mesh_dim_0, mesh_dim_1):
|
||||
name = f'S{mesh_dim_0}S{mesh_dim_1} = S{mesh_dim_0}S{mesh_dim_1} x S{mesh_dim_1} WITH SYNC_BN'
|
||||
|
||||
dim_partition_dict_for_input = {0: [mesh_dim_0], 1: [mesh_dim_1]}
|
||||
sharding_spec_for_input = self._generate_sharding_spec(self.input_data, dim_partition_dict_for_input)
|
||||
|
||||
dim_partition_dict_for_weight = {0: [mesh_dim_1]}
|
||||
sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight)
|
||||
|
||||
dim_partition_dict_for_output = {0: [mesh_dim_0], 1: [mesh_dim_1]}
|
||||
sharding_spec_for_output = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output)
|
||||
|
||||
# generate resharding cost for this strategy
|
||||
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input])
|
||||
|
||||
# compute the computation cost of this strategy
|
||||
bs = self.input_data.shape[0] // self.device_mesh.shape[mesh_dim_0]
|
||||
channel_in = self.input_data.shape[1] // self.device_mesh.shape[mesh_dim_1]
|
||||
compute_cost = self._generate_compute_cost(bs, channel_in)
|
||||
|
||||
# compute the memory cost of this strategy
|
||||
sharding_size_forward = self.device_mesh.shape[mesh_dim_0] * self.device_mesh.shape[mesh_dim_1]
|
||||
sharding_size_backward_activation = self.device_mesh.shape[mesh_dim_0] * self.device_mesh.shape[mesh_dim_1]
|
||||
sharding_size_weight = self.device_mesh.shape[mesh_dim_1]
|
||||
memory_cost, memory_cost_forward_activation, _ = self._generate_memory_cost(sharding_size_forward,
|
||||
sharding_size_backward_activation,
|
||||
sharding_size_weight)
|
||||
|
||||
# the all reduce communication will happen during the sync bn computing.
|
||||
communication_cost = self.device_mesh.all_reduce_cost(memory_cost_forward_activation, mesh_dim_0)
|
||||
sharding_strategies = ShardingStrategy(name,
|
||||
output_sharding_spec=sharding_spec_for_output,
|
||||
compute_cost=compute_cost,
|
||||
communication_cost=communication_cost,
|
||||
memory_cost=memory_cost,
|
||||
resharding_costs=resharding_costs,
|
||||
input_shardings=(sharding_spec_for_input, sharding_spec_for_weight))
|
||||
|
||||
self.strategies_vector.append(sharding_strategies)
|
||||
|
||||
def register_strategy(self) -> StrategiesVector:
|
||||
'''
|
||||
Generate every possible strategies for a BatchNorm node, and record all strategies into the strategies_vector.
|
||||
|
||||
Example:
|
||||
norm_handler = BatchNormHandler(node, strategies_vector,
|
||||
self.shape_consistency_manager)
|
||||
norm_handler.register_strategy()
|
||||
for strategy in norm_handler.strategies_vector:
|
||||
print(f'{strategy.name}, computation_cost: {strategy.compute_cost}, memory_cost: {strategy.memory_cost}')
|
||||
|
||||
Output:
|
||||
RS0 = RS0 x S0, computation_cost: 131072, memory_cost: 524288.0
|
||||
RS1 = RS1 x S1, computation_cost: 131072, memory_cost: 524288.0
|
||||
RR = RR x R, computation_cost: 262144, memory_cost: 1048576
|
||||
RS01 = RS01 x S01, computation_cost: 65536, memory_cost: 262144.0
|
||||
'''
|
||||
|
||||
# RS = RS x S and strategies based on it, such as
|
||||
# SS = RS x S
|
||||
self.split_input_channel(0, 1)
|
||||
self.split_input_channel(1, 0)
|
||||
|
||||
# RR = RR x R and strategies based on it, such as
|
||||
# SR = SR x R
|
||||
self.non_split(0, 1)
|
||||
|
||||
# RS01 = RS01 x S01
|
||||
self.split_input_channel_1d(0, 1)
|
||||
|
||||
# SR = SR x R WITH SYNC_BN
|
||||
self.split_input_batch(0)
|
||||
self.split_input_batch(1)
|
||||
|
||||
# SS = SS x S WITH SYNC_BN
|
||||
self.split_input_both_dim(0, 1)
|
||||
self.split_input_both_dim(1, 0)
|
||||
|
||||
# S01R = S01R x R WITH SYNC_BN
|
||||
self.split_input_batch_1d(0, 1)
|
||||
|
||||
return self.strategies_vector
|
|
@ -1,552 +0,0 @@
|
|||
import operator
|
||||
import warnings
|
||||
from copy import deepcopy
|
||||
from functools import reduce
|
||||
from typing import Dict, List
|
||||
|
||||
import torch
|
||||
from colossalai.auto_parallel.tensor_shard.deprecated._utils import (enumerate_all_possible_1d_sharding,
|
||||
enumerate_all_possible_2d_sharding,
|
||||
ignore_sharding_exception)
|
||||
from colossalai.auto_parallel.tensor_shard.deprecated.sharding_strategy import (ShardingStrategy, StrategiesVector)
|
||||
from colossalai.tensor.shape_consistency import ShapeConsistencyManager
|
||||
from colossalai.tensor.sharding_spec import ShardingSpec
|
||||
|
||||
from .operator_handler import OperatorHandler
|
||||
|
||||
__all__ = ['BcastOpHandler']
|
||||
|
||||
|
||||
class BcastOpHandler(OperatorHandler):
|
||||
"""
|
||||
An OperatorHandler which deals with the sharding strategies of broadcast operators(such as operator.add).
|
||||
"""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
assert len(self.predecessor_node) == 2
|
||||
self.lhs_data = self.predecessor_node[0]._meta_data
|
||||
self.rhs_data = self.predecessor_node[1]._meta_data
|
||||
self.lhs = self.predecessor_node[0]
|
||||
self.rhs = self.predecessor_node[1]
|
||||
self.output_data = self.node._meta_data
|
||||
|
||||
def _generate_sharding_spec(self, input_: torch.Tensor, dim_partition_dict: Dict[int, List[int]]) -> ShardingSpec:
|
||||
shape = list(input_.shape)
|
||||
|
||||
# padding the shape to the same length as output_data
|
||||
while len(shape) < self.output_data.dim():
|
||||
shape.insert(0, 1)
|
||||
shape = torch.Size(shape)
|
||||
|
||||
# if the sharding happens on a size one dimension, we should record it as R.
|
||||
processed_dim_partition_dict = deepcopy(dim_partition_dict)
|
||||
for dim_index, _ in dim_partition_dict.items():
|
||||
if shape[dim_index] == 1:
|
||||
processed_dim_partition_dict.pop(dim_index)
|
||||
for dim_index, sharding_index_list in processed_dim_partition_dict.items():
|
||||
sharding_list = [self.device_mesh.mesh_shape[sharding_index] for sharding_index in sharding_index_list]
|
||||
sharding_size = reduce(operator.mul, sharding_list, 1)
|
||||
assert shape[
|
||||
dim_index] % sharding_size == 0, f'we cannot shard the {dim_index} dimension of tensor into {sharding_size} partitions.'
|
||||
sharding_spec = ShardingSpec(device_mesh=self.device_mesh,
|
||||
entire_shape=shape,
|
||||
dim_partition_dict=processed_dim_partition_dict)
|
||||
|
||||
return sharding_spec
|
||||
|
||||
def _generate_compute_cost(self, total_sharding_size):
|
||||
lhs_matrix_shape = self.lhs_data.shape[-2:]
|
||||
rhs_matrix_shape = self.rhs_data.shape[-2:]
|
||||
batch_dimensions_shape = self.output_data.shape[:-2]
|
||||
batch_dimensions_product = reduce(operator.mul, batch_dimensions_shape, 1)
|
||||
compute_cost = reduce(
|
||||
operator.mul, lhs_matrix_shape) * rhs_matrix_shape[0] * batch_dimensions_product * 2 / total_sharding_size
|
||||
return compute_cost
|
||||
|
||||
def _generate_resharding_costs(self, sharding_specs):
|
||||
# The resharding_cost of weight is counted due to sharing weight cases.
|
||||
dtype = self.node._meta_data.dtype
|
||||
nodes = self.predecessor_node
|
||||
resharding_costs = {}
|
||||
size_per_elem_bytes = torch.tensor([], dtype=dtype).element_size()
|
||||
|
||||
# shape consistency manager is a singleton class
|
||||
shape_consistency_manager = ShapeConsistencyManager()
|
||||
|
||||
for input_node, input_spec in zip(nodes, sharding_specs):
|
||||
resharding_costs[input_node] = []
|
||||
for strategy in input_node.strategies_vector:
|
||||
input_sharding_spec = strategy.output_sharding_spec
|
||||
assert isinstance(input_sharding_spec, ShardingSpec), f'The input node should NOT be a tuple of tensor.'
|
||||
# if the input shape is smaller than the target input, we will fill the input to the same length as target.
|
||||
# Then, use the padded input sharding spec to compute the resharding cost.
|
||||
if len(input_sharding_spec.entire_shape) < len(input_spec.entire_shape):
|
||||
new_entire_shape = list(input_sharding_spec.entire_shape)
|
||||
while len(new_entire_shape) < len(input_spec.entire_shape):
|
||||
new_entire_shape.insert(0, 1)
|
||||
new_entire_shape = torch.Size(new_entire_shape)
|
||||
new_device_mesh = input_sharding_spec.device_mesh
|
||||
new_dim_partition_dict = input_sharding_spec.dim_partition_dict
|
||||
input_sharding_spec = ShardingSpec(device_mesh=new_device_mesh,
|
||||
entire_shape=new_entire_shape,
|
||||
dim_partition_dict=new_dim_partition_dict)
|
||||
|
||||
# compute the resharding cost
|
||||
_, _, total_resharding_cost = shape_consistency_manager.shape_consistency(
|
||||
input_sharding_spec, input_spec)
|
||||
|
||||
# we need multiply the size of elem dtype to get correct communication cost
|
||||
resharding_cost = total_resharding_cost["total"] * size_per_elem_bytes
|
||||
resharding_costs[input_node].append(resharding_cost)
|
||||
|
||||
return resharding_costs
|
||||
|
||||
def _convert_partition_dict_to_sharding_spec(self, dim_partition_list):
|
||||
|
||||
sharding_spec_list = []
|
||||
check_duplicated_list = []
|
||||
for output_dim_partition_dict in dim_partition_list:
|
||||
try:
|
||||
output_sharding_spec = self._generate_sharding_spec(self.output_data, output_dim_partition_dict)
|
||||
except AssertionError as e:
|
||||
warnings.warn(f'{e}')
|
||||
break
|
||||
sharding_seq = output_sharding_spec.sharding_sequence
|
||||
if sharding_seq not in check_duplicated_list:
|
||||
check_duplicated_list.append(sharding_seq)
|
||||
sharding_spec_list.append(output_sharding_spec)
|
||||
|
||||
return sharding_spec_list
|
||||
|
||||
def _enumerate_all_possible_output(self, mesh_dim_0, mesh_dim_1):
|
||||
# use mesh_dim_0, mesh_dim_1 instead of constant 0, 1 in here for N-D device mesh scaliablity.
|
||||
|
||||
output_dim_partition_list = []
|
||||
dim_size = self.output_data.dim()
|
||||
# enumerate all the 2D sharding cases
|
||||
sharding_list_2d = enumerate_all_possible_2d_sharding(mesh_dim_0, mesh_dim_1, dim_size)
|
||||
output_dim_partition_list.extend(sharding_list_2d)
|
||||
|
||||
# enumerate all the 1D sharding cases
|
||||
sharding_list_1d_on_dim_0 = enumerate_all_possible_1d_sharding(mesh_dim_0, dim_size)
|
||||
output_dim_partition_list.extend(sharding_list_1d_on_dim_0)
|
||||
sharding_list_1d_on_dim_1 = enumerate_all_possible_1d_sharding(mesh_dim_1, dim_size)
|
||||
output_dim_partition_list.extend(sharding_list_1d_on_dim_1)
|
||||
|
||||
# add empty dict for fully replicated case
|
||||
output_dim_partition_list.append({})
|
||||
output_sharding_spec_list = self._convert_partition_dict_to_sharding_spec(output_dim_partition_list)
|
||||
|
||||
return output_sharding_spec_list
|
||||
|
||||
@ignore_sharding_exception
|
||||
def _register_strategy(self, output_sharding_spec):
|
||||
dim_partition_dict_for_input = output_sharding_spec.dim_partition_dict
|
||||
sharding_spec_for_lhs = self._generate_sharding_spec(self.lhs_data, dim_partition_dict_for_input)
|
||||
sharding_spec_for_rhs = self._generate_sharding_spec(self.rhs_data, dim_partition_dict_for_input)
|
||||
|
||||
name = f'{output_sharding_spec.sharding_sequence} = {sharding_spec_for_lhs.sharding_sequence} x {sharding_spec_for_rhs.sharding_sequence}'
|
||||
dim_partition_dict_for_output = output_sharding_spec.dim_partition_dict
|
||||
|
||||
# generate resharding cost for this strategy
|
||||
resharding_costs = self._generate_resharding_costs([sharding_spec_for_lhs, sharding_spec_for_rhs])
|
||||
|
||||
# compute the computation cost of this strategy
|
||||
sharding_dims = []
|
||||
for mesh_dims in dim_partition_dict_for_output.values():
|
||||
for mesh_dim in mesh_dims:
|
||||
sharding_dims.append(self.device_mesh.shape[mesh_dim])
|
||||
sharding_size = reduce(operator.mul, sharding_dims, 1)
|
||||
memory_cost = self.output_data.numel() / sharding_size
|
||||
compute_cost = memory_cost
|
||||
communication_cost = 0
|
||||
|
||||
sharding_strategies = ShardingStrategy(name,
|
||||
output_sharding_spec=output_sharding_spec,
|
||||
compute_cost=compute_cost,
|
||||
communication_cost=communication_cost,
|
||||
memory_cost=memory_cost,
|
||||
resharding_costs=resharding_costs,
|
||||
input_shardings=(sharding_spec_for_lhs, sharding_spec_for_rhs))
|
||||
|
||||
self.strategies_vector.append(sharding_strategies)
|
||||
|
||||
##############################################
|
||||
#used to generate strategies for torch.matmul#
|
||||
##############################################
|
||||
@ignore_sharding_exception
|
||||
def _registry_no_split_strategies_for_matmul(self, dim_partition_dict_for_batch_dim):
|
||||
# this dim partition dict only describes the batch dimensions, but in this scenario,
|
||||
# matrix dimensions are fully replicated, so it do not need extra process.
|
||||
sharding_spec_for_lhs = self._generate_sharding_spec(self.lhs_data, dim_partition_dict_for_batch_dim)
|
||||
sharding_spec_for_rhs = self._generate_sharding_spec(self.rhs_data, dim_partition_dict_for_batch_dim)
|
||||
sharding_spec_for_output = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_batch_dim)
|
||||
|
||||
name = f'{sharding_spec_for_output.sharding_sequence} = {sharding_spec_for_lhs.sharding_sequence} x {sharding_spec_for_rhs.sharding_sequence}'
|
||||
|
||||
# generate resharding cost for this strategy
|
||||
resharding_costs = self._generate_resharding_costs([sharding_spec_for_lhs, sharding_spec_for_rhs])
|
||||
|
||||
# compute the memory cost of this strategy
|
||||
batch_sharding_dims = []
|
||||
for mesh_dims in dim_partition_dict_for_batch_dim.values():
|
||||
for mesh_dim in mesh_dims:
|
||||
batch_sharding_dims.append(self.device_mesh.shape[mesh_dim])
|
||||
batch_sharding_size = reduce(operator.mul, batch_sharding_dims, 1)
|
||||
# in this case, total_sharding_size is equal to the batch sharding size
|
||||
memory_cost = self.output_data.numel() / batch_sharding_size
|
||||
|
||||
# compute the computation cost of this strategy
|
||||
compute_cost = self._generate_compute_cost(batch_sharding_size)
|
||||
|
||||
# in this case, no communication takes place.
|
||||
# TODO: add all-reduce cost if lhs or rhs is type of Parameters.
|
||||
communication_cost = 0
|
||||
|
||||
sharding_strategies = ShardingStrategy(name,
|
||||
output_sharding_spec=sharding_spec_for_output,
|
||||
compute_cost=compute_cost,
|
||||
communication_cost=communication_cost,
|
||||
memory_cost=memory_cost,
|
||||
resharding_costs=resharding_costs,
|
||||
input_shardings=(sharding_spec_for_lhs, sharding_spec_for_rhs))
|
||||
|
||||
self.strategies_vector.append(sharding_strategies)
|
||||
|
||||
@ignore_sharding_exception
|
||||
def _split_dim_i(self, dim_partition_dict_for_batch_dim, mesh_dim_on_matrix):
|
||||
# A batched matrix multiplication can be viewed as [b, i, k] x [b, k, j] -> [b, i, j]
|
||||
# this dim partition dict describe the batch dimensions, so we should append the matrix dimension sharding info on it.
|
||||
# In this scenario, matrix dimensions will be sharded on 'i' dimension.
|
||||
|
||||
# in this case, the matrix dimensions of lhs is sharded on 'i' dimension.
|
||||
dim_partition_dict_for_lhs = deepcopy(dim_partition_dict_for_batch_dim)
|
||||
dim_partition_dict_for_lhs.update({-2: mesh_dim_on_matrix})
|
||||
|
||||
# in this case, the matrix dimensions of rhs is fully replicated.
|
||||
dim_partition_dict_for_rhs = deepcopy(dim_partition_dict_for_batch_dim)
|
||||
|
||||
# in this case, the matrix dimensions of output is sharded on 'i' dimension.
|
||||
|
||||
dim_partition_dict_for_output = deepcopy(dim_partition_dict_for_batch_dim)
|
||||
dim_partition_dict_for_output.update({-2: mesh_dim_on_matrix})
|
||||
|
||||
# generate sharding specs
|
||||
sharding_spec_for_lhs = self._generate_sharding_spec(self.lhs_data, dim_partition_dict_for_lhs)
|
||||
sharding_spec_for_rhs = self._generate_sharding_spec(self.rhs_data, dim_partition_dict_for_rhs)
|
||||
sharding_spec_for_output = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output)
|
||||
|
||||
name = f'{sharding_spec_for_output.sharding_sequence} = {sharding_spec_for_lhs.sharding_sequence} x {sharding_spec_for_rhs.sharding_sequence}'
|
||||
|
||||
# generate resharding cost for this strategy
|
||||
resharding_costs = self._generate_resharding_costs([sharding_spec_for_lhs, sharding_spec_for_rhs])
|
||||
|
||||
# compute the memory cost of this strategy
|
||||
total_sharding_dims = []
|
||||
|
||||
# append batch sharding dims
|
||||
for mesh_dims in dim_partition_dict_for_batch_dim.values():
|
||||
for mesh_dim in mesh_dims:
|
||||
total_sharding_dims.append(self.device_mesh.shape[mesh_dim])
|
||||
|
||||
# append the sharding dims on matrix dimension
|
||||
for mesh_dim in mesh_dim_on_matrix:
|
||||
total_sharding_dims.append(self.device_mesh.shape[mesh_dim])
|
||||
total_sharding_size = reduce(operator.mul, total_sharding_dims, 1)
|
||||
|
||||
# in this case, output_data uses all the sharding dims.
|
||||
memory_cost = self.output_data.numel() / total_sharding_size
|
||||
compute_cost = self._generate_compute_cost(total_sharding_size)
|
||||
|
||||
# TODO: add all-reduce cost if lhs or rhs is type of Parameters.
|
||||
communication_cost = 0
|
||||
|
||||
sharding_strategies = ShardingStrategy(name,
|
||||
output_sharding_spec=sharding_spec_for_output,
|
||||
compute_cost=compute_cost,
|
||||
communication_cost=communication_cost,
|
||||
memory_cost=memory_cost,
|
||||
resharding_costs=resharding_costs,
|
||||
input_shardings=(sharding_spec_for_lhs, sharding_spec_for_rhs))
|
||||
|
||||
self.strategies_vector.append(sharding_strategies)
|
||||
|
||||
@ignore_sharding_exception
|
||||
def _split_dim_k(self, dim_partition_dict_for_batch_dim, mesh_dim_on_matrix):
|
||||
# A batched matrix multiplication can be viewed as [b, i, k] x [b, k, j] -> [b, i, j]
|
||||
# this dim partition dict describe the batch dimensions, so we should append the matrix dimension sharding info on it.
|
||||
# In this scenario, matrix dimensions will be sharded on 'k' dimension.
|
||||
|
||||
# in this case, the matrix dimensions of lhs is sharded on 'k' dimension.
|
||||
dim_partition_dict_for_lhs = deepcopy(dim_partition_dict_for_batch_dim)
|
||||
dim_partition_dict_for_lhs.update({-1: mesh_dim_on_matrix})
|
||||
|
||||
# in this case, the matrix dimensions of rhs is sharded on 'k' dimension.
|
||||
dim_partition_dict_for_rhs = deepcopy(dim_partition_dict_for_batch_dim)
|
||||
dim_partition_dict_for_rhs.update({-2: mesh_dim_on_matrix})
|
||||
|
||||
# in this case, the matrix dimensions of output is fully replicated.
|
||||
dim_partition_dict_for_output = deepcopy(dim_partition_dict_for_batch_dim)
|
||||
|
||||
# generate sharding specs
|
||||
sharding_spec_for_lhs = self._generate_sharding_spec(self.lhs_data, dim_partition_dict_for_lhs)
|
||||
sharding_spec_for_rhs = self._generate_sharding_spec(self.rhs_data, dim_partition_dict_for_rhs)
|
||||
sharding_spec_for_output = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output)
|
||||
|
||||
name = f'{sharding_spec_for_output.sharding_sequence} = {sharding_spec_for_lhs.sharding_sequence} x {sharding_spec_for_rhs.sharding_sequence}'
|
||||
|
||||
# generate resharding cost for this strategy
|
||||
resharding_costs = self._generate_resharding_costs([sharding_spec_for_lhs, sharding_spec_for_rhs])
|
||||
|
||||
# compute the memory cost of this strategy
|
||||
total_sharding_dims = []
|
||||
batch_sharding_dims = []
|
||||
# append batch sharding dims
|
||||
for mesh_dims in dim_partition_dict_for_batch_dim.values():
|
||||
for mesh_dim in mesh_dims:
|
||||
total_sharding_dims.append(self.device_mesh.shape[mesh_dim])
|
||||
batch_sharding_dims.append(self.device_mesh.shape[mesh_dim])
|
||||
|
||||
# append the sharding dims on matrix dimension
|
||||
for mesh_dim in mesh_dim_on_matrix:
|
||||
total_sharding_dims.append(self.device_mesh.shape[mesh_dim])
|
||||
batch_sharding_size = reduce(operator.mul, batch_sharding_dims, 1)
|
||||
total_sharding_size = reduce(operator.mul, total_sharding_dims, 1)
|
||||
|
||||
# in this case, output_data is fully replicated on matrix dimensions.
|
||||
memory_cost = self.output_data.numel() / batch_sharding_size
|
||||
|
||||
compute_cost = self._generate_compute_cost(total_sharding_size)
|
||||
|
||||
# TODO: add all-reduce cost if lhs or rhs is type of Parameters.
|
||||
# The communication takes place during forward activation computation.
|
||||
if len(mesh_dim_on_matrix) == 1:
|
||||
communication_cost = self.device_mesh.all_reduce_cost(memory_cost, mesh_dim_on_matrix[0])
|
||||
else:
|
||||
communication_cost = self.device_mesh.flatten_device_mesh.all_reduce_cost(memory_cost, 0)
|
||||
|
||||
sharding_strategies = ShardingStrategy(name,
|
||||
output_sharding_spec=sharding_spec_for_output,
|
||||
compute_cost=compute_cost,
|
||||
communication_cost=communication_cost,
|
||||
memory_cost=memory_cost,
|
||||
resharding_costs=resharding_costs,
|
||||
input_shardings=(sharding_spec_for_lhs, sharding_spec_for_rhs))
|
||||
|
||||
self.strategies_vector.append(sharding_strategies)
|
||||
|
||||
@ignore_sharding_exception
|
||||
def _split_dim_j(self, dim_partition_dict_for_batch_dim, mesh_dim_on_matrix):
|
||||
# A batched matrix multiplication can be viewed as [b, i, k] x [b, k, j] -> [b, i, j]
|
||||
# this dim partition dict describe the batch dimensions, so we should append the matrix dimension sharding info on it.
|
||||
# In this scenario, matrix dimensions will be is sharded on 'j' dimension.
|
||||
|
||||
# in this case, the matrix dimensions of lhs is fully replicated.
|
||||
dim_partition_dict_for_lhs = deepcopy(dim_partition_dict_for_batch_dim)
|
||||
|
||||
# in this case, the matrix dimensions of rhs is sharded on 'j' dimension.
|
||||
dim_partition_dict_for_rhs = deepcopy(dim_partition_dict_for_batch_dim)
|
||||
dim_partition_dict_for_rhs.update({-1: mesh_dim_on_matrix})
|
||||
|
||||
# in this case, the matrix dimensions of output is sharded on 'j' dimension.
|
||||
dim_partition_dict_for_output = deepcopy(dim_partition_dict_for_batch_dim)
|
||||
dim_partition_dict_for_output.update({-1: mesh_dim_on_matrix})
|
||||
|
||||
# generate sharding specs
|
||||
sharding_spec_for_lhs = self._generate_sharding_spec(self.lhs_data, dim_partition_dict_for_lhs)
|
||||
sharding_spec_for_rhs = self._generate_sharding_spec(self.rhs_data, dim_partition_dict_for_rhs)
|
||||
sharding_spec_for_output = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output)
|
||||
|
||||
name = f'{sharding_spec_for_output.sharding_sequence} = {sharding_spec_for_lhs.sharding_sequence} x {sharding_spec_for_rhs.sharding_sequence}'
|
||||
|
||||
# generate resharding cost for this strategy
|
||||
resharding_costs = self._generate_resharding_costs([sharding_spec_for_lhs, sharding_spec_for_rhs])
|
||||
|
||||
# compute the memory cost of this strategy
|
||||
total_sharding_dims = []
|
||||
|
||||
# append batch sharding dims
|
||||
for mesh_dims in dim_partition_dict_for_batch_dim.values():
|
||||
for mesh_dim in mesh_dims:
|
||||
total_sharding_dims.append(self.device_mesh.shape[mesh_dim])
|
||||
|
||||
# append the sharding dims on matrix dimension
|
||||
for mesh_dim in mesh_dim_on_matrix:
|
||||
total_sharding_dims.append(self.device_mesh.shape[mesh_dim])
|
||||
total_sharding_size = reduce(operator.mul, total_sharding_dims, 1)
|
||||
|
||||
# in this case, output_data uses all the sharding dims.
|
||||
memory_cost = self.output_data.numel() / total_sharding_size
|
||||
compute_cost = self._generate_compute_cost(total_sharding_size)
|
||||
|
||||
# TODO: add all-reduce cost if lhs or rhs is type of Parameters.
|
||||
# The communication takes place during backward activation computation.
|
||||
if len(mesh_dim_on_matrix) == 1:
|
||||
communication_cost = self.device_mesh.all_reduce_cost(memory_cost, mesh_dim_on_matrix[0])
|
||||
else:
|
||||
communication_cost = self.device_mesh.flatten_device_mesh.all_reduce_cost(memory_cost, 0)
|
||||
|
||||
sharding_strategies = ShardingStrategy(name,
|
||||
output_sharding_spec=sharding_spec_for_output,
|
||||
compute_cost=compute_cost,
|
||||
communication_cost=communication_cost,
|
||||
memory_cost=memory_cost,
|
||||
resharding_costs=resharding_costs,
|
||||
input_shardings=(sharding_spec_for_lhs, sharding_spec_for_rhs))
|
||||
|
||||
self.strategies_vector.append(sharding_strategies)
|
||||
|
||||
def _registry_1d_strategies_for_matmul(self, dim_partition_dict, mesh_dim_list):
|
||||
self._split_dim_i(dim_partition_dict, mesh_dim_list)
|
||||
self._split_dim_k(dim_partition_dict, mesh_dim_list)
|
||||
self._split_dim_j(dim_partition_dict, mesh_dim_list)
|
||||
|
||||
@ignore_sharding_exception
|
||||
def _split_lhs_space_both_contract(self, mesh_dim_0, mesh_dim_1):
|
||||
dim_partition_dict_for_lhs = {-2: [mesh_dim_0], -1: [mesh_dim_1]}
|
||||
sharding_spec_for_lhs = self._generate_sharding_spec(self.lhs_data, dim_partition_dict_for_lhs)
|
||||
|
||||
dim_partition_dict_for_rhs = {-2: [mesh_dim_1]}
|
||||
sharding_spec_for_rhs = self._generate_sharding_spec(self.rhs_data, dim_partition_dict_for_rhs)
|
||||
|
||||
dim_partition_dict_for_output = {-2: [mesh_dim_0]}
|
||||
sharding_spec_for_output = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output)
|
||||
|
||||
name = f'{sharding_spec_for_output.sharding_sequence} = {sharding_spec_for_lhs.sharding_sequence} x {sharding_spec_for_rhs.sharding_sequence}'
|
||||
|
||||
# generate resharding cost for this strategy
|
||||
resharding_costs = self._generate_resharding_costs([sharding_spec_for_lhs, sharding_spec_for_rhs])
|
||||
|
||||
# compute the memory cost of this strategy
|
||||
total_sharding_size = reduce(operator.mul, self.device_mesh.shape, 1)
|
||||
output_sharding_size = reduce(operator.mul, self.output_data.shape, 1)
|
||||
# in this case, output_data uses all the sharding dims.
|
||||
memory_cost = self.output_data.numel() / output_sharding_size
|
||||
compute_cost = self._generate_compute_cost(total_sharding_size)
|
||||
|
||||
# TODO: add all-reduce cost if lhs or rhs is type of Parameters.
|
||||
# The communication takes place during forward activation computation.
|
||||
communication_cost = self.device_mesh.all_reduce_cost(memory_cost, mesh_dim_1)
|
||||
|
||||
sharding_strategies = ShardingStrategy(name,
|
||||
output_sharding_spec=sharding_spec_for_output,
|
||||
compute_cost=compute_cost,
|
||||
communication_cost=communication_cost,
|
||||
memory_cost=memory_cost,
|
||||
resharding_costs=resharding_costs,
|
||||
input_shardings=(sharding_spec_for_lhs, sharding_spec_for_rhs))
|
||||
|
||||
self.strategies_vector.append(sharding_strategies)
|
||||
|
||||
@ignore_sharding_exception
|
||||
def _split_rhs_space_both_contract(self, mesh_dim_0, mesh_dim_1):
|
||||
dim_partition_dict_for_lhs = {-1: [mesh_dim_0]}
|
||||
sharding_spec_for_lhs = self._generate_sharding_spec(self.lhs_data, dim_partition_dict_for_lhs)
|
||||
|
||||
dim_partition_dict_for_rhs = {-2: [mesh_dim_0], -1: [mesh_dim_1]}
|
||||
sharding_spec_for_rhs = self._generate_sharding_spec(self.rhs_data, dim_partition_dict_for_rhs)
|
||||
|
||||
dim_partition_dict_for_output = {-1: [mesh_dim_1]}
|
||||
sharding_spec_for_output = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output)
|
||||
|
||||
name = f'{sharding_spec_for_output.sharding_sequence} = {sharding_spec_for_lhs.sharding_sequence} x {sharding_spec_for_rhs.sharding_sequence}'
|
||||
|
||||
# generate resharding cost for this strategy
|
||||
resharding_costs = self._generate_resharding_costs([sharding_spec_for_lhs, sharding_spec_for_rhs])
|
||||
|
||||
# compute the memory cost of this strategy
|
||||
total_sharding_size = reduce(operator.mul, self.device_mesh.shape, 1)
|
||||
output_sharding_size = reduce(operator.mul, self.output_data.shape, 1)
|
||||
# in this case, output_data uses all the sharding dims.
|
||||
memory_cost = self.output_data.numel() / output_sharding_size
|
||||
compute_cost = self._generate_compute_cost(total_sharding_size)
|
||||
|
||||
# TODO: add all-reduce cost if lhs or rhs is type of Parameters.
|
||||
# The communication takes place during forward and backward activation computation.
|
||||
communication_cost_forward_activation = self.device_mesh.all_reduce_cost(memory_cost, mesh_dim_0)
|
||||
communication_cost_backward_activation = self.device_mesh.all_reduce_cost(memory_cost, mesh_dim_1)
|
||||
communication_cost = communication_cost_backward_activation + communication_cost_forward_activation
|
||||
|
||||
sharding_strategies = ShardingStrategy(name,
|
||||
output_sharding_spec=sharding_spec_for_output,
|
||||
compute_cost=compute_cost,
|
||||
communication_cost=communication_cost,
|
||||
memory_cost=memory_cost,
|
||||
resharding_costs=resharding_costs,
|
||||
input_shardings=(sharding_spec_for_lhs, sharding_spec_for_rhs))
|
||||
|
||||
self.strategies_vector.append(sharding_strategies)
|
||||
|
||||
@ignore_sharding_exception
|
||||
def _split_lhs_space_rhs_space(self, mesh_dim_0, mesh_dim_1):
|
||||
dim_partition_dict_for_lhs = {-2: [mesh_dim_0]}
|
||||
sharding_spec_for_lhs = self._generate_sharding_spec(self.lhs_data, dim_partition_dict_for_lhs)
|
||||
|
||||
dim_partition_dict_for_rhs = {-1: [mesh_dim_1]}
|
||||
sharding_spec_for_rhs = self._generate_sharding_spec(self.rhs_data, dim_partition_dict_for_rhs)
|
||||
|
||||
dim_partition_dict_for_output = {-2: [mesh_dim_0], -1: [mesh_dim_1]}
|
||||
sharding_spec_for_output = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output)
|
||||
|
||||
name = f'{sharding_spec_for_output.sharding_sequence} = {sharding_spec_for_lhs.sharding_sequence} x {sharding_spec_for_rhs.sharding_sequence}'
|
||||
|
||||
# generate resharding cost for this strategy
|
||||
resharding_costs = self._generate_resharding_costs([sharding_spec_for_lhs, sharding_spec_for_rhs])
|
||||
|
||||
# compute the memory cost of this strategy
|
||||
total_sharding_size = reduce(operator.mul, self.device_mesh.shape, 1)
|
||||
output_sharding_size = reduce(operator.mul, self.output_data.shape, 1)
|
||||
# in this case, output_data uses all the sharding dims.
|
||||
memory_cost = self.output_data.numel() / output_sharding_size
|
||||
compute_cost = self._generate_compute_cost(total_sharding_size)
|
||||
|
||||
# TODO: add all-reduce cost if lhs or rhs is type of Parameters.
|
||||
# The communication takes place during backward activation computation.
|
||||
communication_cost = self.device_mesh.all_reduce_cost(memory_cost, mesh_dim_1)
|
||||
sharding_strategies = ShardingStrategy(name,
|
||||
output_sharding_spec=sharding_spec_for_output,
|
||||
compute_cost=compute_cost,
|
||||
communication_cost=communication_cost,
|
||||
memory_cost=memory_cost,
|
||||
resharding_costs=resharding_costs,
|
||||
input_shardings=(sharding_spec_for_lhs, sharding_spec_for_rhs))
|
||||
|
||||
self.strategies_vector.append(sharding_strategies)
|
||||
|
||||
def _registry_2d_strategies_for_matmul(self):
|
||||
self._split_lhs_space_both_contract(0, 1)
|
||||
self._split_lhs_space_both_contract(1, 0)
|
||||
self._split_rhs_space_both_contract(0, 1)
|
||||
self._split_rhs_space_both_contract(1, 0)
|
||||
self._split_lhs_space_rhs_space(0, 1)
|
||||
self._split_lhs_space_rhs_space(1, 0)
|
||||
|
||||
def register_strategy(self) -> StrategiesVector:
|
||||
MESH_DIM_LIST = [0, 1]
|
||||
if self.node.target != torch.matmul:
|
||||
output_sharding_specs = self._enumerate_all_possible_output(MESH_DIM_LIST[0], MESH_DIM_LIST[1])
|
||||
for output_sharding_spec in output_sharding_specs:
|
||||
self._register_strategy(output_sharding_spec)
|
||||
else:
|
||||
# we only care about the non-computing dimensions,
|
||||
# therefore, we omit the last two dimensions.
|
||||
dim_size = self.output_data.dim() - 2
|
||||
|
||||
# Both device mesh axises are uesd on batch dimensions
|
||||
dim_partition_dicts_2d = enumerate_all_possible_2d_sharding(MESH_DIM_LIST[0], MESH_DIM_LIST[1], dim_size)
|
||||
for dim_partition_dict in dim_partition_dicts_2d:
|
||||
self._registry_no_split_strategies_for_matmul(dim_partition_dict)
|
||||
|
||||
# Only one device mesh axis is uesd on batch dimensions
|
||||
for mesh_dim_index in [0, 1]:
|
||||
dim_partition_dicts_1d = enumerate_all_possible_1d_sharding(MESH_DIM_LIST[mesh_dim_index], dim_size)
|
||||
for dim_partition_dict in dim_partition_dicts_1d:
|
||||
self._registry_no_split_strategies_for_matmul(dim_partition_dict)
|
||||
self._registry_1d_strategies_for_matmul(dim_partition_dict, [MESH_DIM_LIST[mesh_dim_index - 1]])
|
||||
|
||||
# No device mesh axis is uesd on batch dimensions
|
||||
dim_partition_dict_on_batch_dim = {}
|
||||
self._registry_no_split_strategies_for_matmul(dim_partition_dict_on_batch_dim)
|
||||
self._registry_1d_strategies_for_matmul(dim_partition_dict_on_batch_dim, MESH_DIM_LIST)
|
||||
self._registry_2d_strategies_for_matmul()
|
|
@ -1,609 +0,0 @@
|
|||
import operator
|
||||
import warnings
|
||||
from functools import reduce
|
||||
|
||||
import torch
|
||||
|
||||
from colossalai.auto_parallel.tensor_shard.deprecated._utils import ignore_sharding_exception
|
||||
from colossalai.auto_parallel.tensor_shard.deprecated.sharding_strategy import ShardingStrategy, StrategiesVector
|
||||
|
||||
from .operator_handler import OperatorHandler
|
||||
|
||||
__all__ = ['ConvHandler']
|
||||
|
||||
|
||||
class ConvHandler(OperatorHandler):
|
||||
"""
|
||||
An OperatorHandler which deals with the sharding strategies of Convolution.
|
||||
"""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.input_data = self.predecessor_node[0]._meta_data
|
||||
self.weight = self.module_named_parameters['weight']
|
||||
self.output_data = self.node._meta_data
|
||||
self._sanity_check()
|
||||
|
||||
def _sanity_check(self):
|
||||
'''
|
||||
In sanity check, we need make sure the input data having correct dimension size.
|
||||
For Conv1d, the dim of input data should be 3([N, C, L]).
|
||||
For Conv2d, the dim of input data should be 4([N, C, H, W]).
|
||||
For Conv3d, the dim of input data should be 5([N, C, H, W, D]).
|
||||
'''
|
||||
assert self.input_data.dim() in (3, 4,
|
||||
5), f'We suppose the dim of input fed into conv op should in range of [3, 5].'
|
||||
|
||||
def _generate_compute_cost(self, bs, channel_in, channel_out):
|
||||
'''
|
||||
Compute the computation cost per device with this specific strategy.
|
||||
|
||||
Note: compute_cost need to be devided by TFLOPS, now it just shows the computation size.
|
||||
|
||||
Argument:
|
||||
bs(int): Batch size of the input data.
|
||||
channel_in(int): The channel dimension of input data.
|
||||
channel_out(int): The out channel of the conv weight.
|
||||
|
||||
Return:
|
||||
compute_cost(float): Computation cost per device with this specific strategy
|
||||
'''
|
||||
# TODO: compute_cost need to be devided by TFLOPS, now it just shows the computation size.
|
||||
# 1D: (L) * N * Cout * Cin * kernel
|
||||
# 2D: (H * W) * N * Cout * Cin * kernel
|
||||
# 3D: (H * W * D) * N * Cout * Cin * kernel
|
||||
output_size = self.output_data.shape[2:]
|
||||
output_size_product = reduce(operator.mul, output_size, 1)
|
||||
input_size = self.input_data.shape[2:]
|
||||
input_size_product = reduce(operator.mul, input_size, 1)
|
||||
kernel_size = self.weight.shape[2:]
|
||||
kernel_size_product = reduce(operator.mul, kernel_size, 1)
|
||||
forward_compute_cost = output_size_product * bs * channel_in * channel_out * kernel_size_product
|
||||
backward_activation_cost = input_size_product * bs * channel_in * channel_out * kernel_size_product
|
||||
backward_weight_cost = output_size_product * bs * channel_in * channel_out * kernel_size_product
|
||||
compute_cost = forward_compute_cost + backward_activation_cost + backward_weight_cost
|
||||
return compute_cost
|
||||
|
||||
def _generate_memory_cost(self, sharding_size_forward, sharding_size_backward_activation, sharding_size_weight):
|
||||
'''
|
||||
Compute the memory cost per device with this specific strategy.
|
||||
|
||||
Argument:
|
||||
sharding_size_forward(int): The forward activation will be divided
|
||||
into sharding_size_forward number partions.
|
||||
sharding_size_backward_activation(int): The backward activation will
|
||||
be divided into sharding_size_backward_activation number partions.
|
||||
sharding_size_weight(int): The backward weight will be divided
|
||||
into sharding_size_weight number partions.
|
||||
|
||||
Return:
|
||||
memory_cost(Tuple[float]): Memory cost per device with this
|
||||
specific strategy, the first element of this tuple is forward
|
||||
memory cost, and the second element of this tuple is backward
|
||||
memory cost.
|
||||
memory_cost_forward(float): Memory cost of forward activation per
|
||||
device with this specific strategy.
|
||||
memory_cost_backward_activation(float): Memory cost of backward activation
|
||||
per device with this specific strategy.
|
||||
'''
|
||||
# compute the memory cost of this strategy
|
||||
dtype = self.input_data.dtype
|
||||
numel_output = self.output_data.numel()
|
||||
numel_input = self.input_data.numel()
|
||||
numel_weight = self.weight.numel()
|
||||
size_per_elem_bytes = torch.tensor([], dtype=dtype).element_size()
|
||||
|
||||
# forward memory_cost
|
||||
memory_cost_forward_activation = numel_output * size_per_elem_bytes / sharding_size_forward
|
||||
memory_cost_forward_weight = numel_weight * size_per_elem_bytes / sharding_size_weight
|
||||
memory_cost_forward = memory_cost_forward_activation + memory_cost_forward_weight
|
||||
|
||||
# backward memory_cost
|
||||
memory_cost_backward_activation = numel_input * size_per_elem_bytes / sharding_size_backward_activation
|
||||
memory_cost_backward_weight = numel_weight * size_per_elem_bytes / sharding_size_weight
|
||||
memory_cost_backward = memory_cost_backward_activation + memory_cost_backward_weight
|
||||
|
||||
# memory_cost pair
|
||||
memory_cost = (memory_cost_forward, memory_cost_backward)
|
||||
|
||||
return memory_cost, memory_cost_forward_activation, memory_cost_backward_activation, memory_cost_backward_weight
|
||||
|
||||
@ignore_sharding_exception
|
||||
def split_input_batch_weight_out_channel(self, mesh_dim_0, mesh_dim_1):
|
||||
name = f'S{mesh_dim_0}S{mesh_dim_1} = S{mesh_dim_0}R x RS{mesh_dim_1}'
|
||||
|
||||
dim_partition_dict_for_input = {0: [mesh_dim_0]}
|
||||
sharding_spec_for_input = self._generate_sharding_spec(self.input_data, dim_partition_dict_for_input)
|
||||
|
||||
dim_partition_dict_for_weight = {1: [mesh_dim_1]}
|
||||
sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight)
|
||||
|
||||
dim_partition_dict_for_output = {0: [mesh_dim_0], 1: [mesh_dim_1]}
|
||||
sharding_spec_for_output = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output)
|
||||
|
||||
# generate resharding cost for this strategy
|
||||
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input, sharding_spec_for_weight])
|
||||
|
||||
# compute the computation cost of this strategy
|
||||
bs = self.input_data.shape[0] // self.device_mesh.shape[mesh_dim_0]
|
||||
channel_in = self.input_data.shape[1]
|
||||
channel_out = self.weight.shape[1] // self.device_mesh.shape[mesh_dim_1]
|
||||
compute_cost = self._generate_compute_cost(bs, channel_in, channel_out)
|
||||
|
||||
# compute the memory cost of this strategy
|
||||
sharding_size_forward = self.device_mesh.shape[mesh_dim_0] * self.device_mesh.shape[mesh_dim_1]
|
||||
sharding_size_backward_activation = self.device_mesh.shape[mesh_dim_0]
|
||||
sharding_size_weight = self.device_mesh.shape[mesh_dim_1]
|
||||
memory_cost, _, memory_cost_backward_activation, memory_cost_backward_weight = self._generate_memory_cost(
|
||||
sharding_size_forward, sharding_size_backward_activation, sharding_size_weight)
|
||||
|
||||
# This strategy do not need to do all_reduce operation during forward
|
||||
communication_cost_forward = 0
|
||||
# compute the backward communication cost to all reduce the input activation grad
|
||||
communication_cost_backward_activation = self.device_mesh.all_reduce_cost(memory_cost_backward_activation,
|
||||
mesh_dim_1)
|
||||
# compute the backward communication cost to all reduce the weight due to data parallel
|
||||
communication_cost_backward_weight = self.device_mesh.all_reduce_cost(memory_cost_backward_weight, mesh_dim_0)
|
||||
# total communication cost
|
||||
communication_cost = communication_cost_forward + communication_cost_backward_activation + communication_cost_backward_weight
|
||||
|
||||
sharding_strategies = ShardingStrategy(name,
|
||||
output_sharding_spec=sharding_spec_for_output,
|
||||
compute_cost=compute_cost,
|
||||
communication_cost=communication_cost,
|
||||
memory_cost=memory_cost,
|
||||
resharding_costs=resharding_costs,
|
||||
input_shardings=(sharding_spec_for_input, sharding_spec_for_weight))
|
||||
self.strategies_vector.append(sharding_strategies)
|
||||
|
||||
@ignore_sharding_exception
|
||||
def split_input_batch(self, mesh_dim_0):
|
||||
name = f'S{mesh_dim_0}R = S{mesh_dim_0}R x RR'
|
||||
|
||||
dim_partition_dict_for_input = {0: [mesh_dim_0]}
|
||||
sharding_spec_for_input = self._generate_sharding_spec(self.input_data, dim_partition_dict_for_input)
|
||||
|
||||
dim_partition_dict_for_weight = {}
|
||||
sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight)
|
||||
|
||||
dim_partition_dict_for_output = {0: [mesh_dim_0]}
|
||||
sharding_spec_for_output = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output)
|
||||
|
||||
# generate resharding cost for this strategy
|
||||
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input, sharding_spec_for_weight])
|
||||
|
||||
# compute the computation cost of this strategy
|
||||
bs = self.input_data.shape[0] // self.device_mesh.shape[mesh_dim_0]
|
||||
channel_in = self.input_data.shape[1]
|
||||
channel_out = self.weight.shape[1]
|
||||
compute_cost = self._generate_compute_cost(bs, channel_in, channel_out)
|
||||
|
||||
# compute the memory cost of this strategy
|
||||
sharding_size_forward = self.device_mesh.shape[mesh_dim_0]
|
||||
sharding_size_backward_activation = self.device_mesh.shape[mesh_dim_0]
|
||||
sharding_size_weight = 1
|
||||
memory_cost, _, _, memory_cost_backward_weight = self._generate_memory_cost(sharding_size_forward,
|
||||
sharding_size_backward_activation,
|
||||
sharding_size_weight)
|
||||
|
||||
# This strategy do not need to do all_reduce operation in forward phase.
|
||||
communication_cost_forward = 0
|
||||
# compute the backward communication cost to all reduce the weight due to data parallel
|
||||
communication_cost_backward_weight = self.device_mesh.all_reduce_cost(memory_cost_backward_weight, mesh_dim_0)
|
||||
# compute the total cost
|
||||
communication_cost = communication_cost_forward + communication_cost_backward_weight
|
||||
sharding_strategies = ShardingStrategy(name,
|
||||
output_sharding_spec=sharding_spec_for_output,
|
||||
compute_cost=compute_cost,
|
||||
communication_cost=communication_cost,
|
||||
memory_cost=memory_cost,
|
||||
resharding_costs=resharding_costs,
|
||||
input_shardings=(sharding_spec_for_input, sharding_spec_for_weight))
|
||||
|
||||
self.strategies_vector.append(sharding_strategies)
|
||||
|
||||
@ignore_sharding_exception
|
||||
def split_input_both_dim_weight_in_channel(self, mesh_dim_0, mesh_dim_1):
|
||||
name = f'S{mesh_dim_0}R = S{mesh_dim_0}S{mesh_dim_1} x S{mesh_dim_1}R'
|
||||
|
||||
dim_partition_dict_for_input = {0: [mesh_dim_0], 1: [mesh_dim_1]}
|
||||
sharding_spec_for_input = self._generate_sharding_spec(self.input_data, dim_partition_dict_for_input)
|
||||
|
||||
dim_partition_dict_for_weight = {0: [mesh_dim_0]}
|
||||
sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight)
|
||||
|
||||
dim_partition_dict_for_output = {0: [mesh_dim_0]}
|
||||
sharding_spec_for_output = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output)
|
||||
|
||||
# generate resharding cost for this strategy
|
||||
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input, sharding_spec_for_weight])
|
||||
|
||||
# compute the computation cost of this strategy
|
||||
bs = self.input_data.shape[0] // self.device_mesh.shape[mesh_dim_0]
|
||||
channel_in = self.input_data.shape[1] // self.device_mesh.shape[mesh_dim_1]
|
||||
channel_out = self.weight.shape[1]
|
||||
compute_cost = self._generate_compute_cost(bs, channel_in, channel_out)
|
||||
|
||||
# compute the memory cost of this strategy
|
||||
sharding_size_forward = self.device_mesh.shape[mesh_dim_0]
|
||||
sharding_size_backward_activation = self.device_mesh.shape[mesh_dim_0] * self.device_mesh.shape[mesh_dim_1]
|
||||
sharding_size_weight = self.device_mesh.shape[mesh_dim_1]
|
||||
memory_cost, memory_cost_forward_activation, _, memory_cost_backward_weight = self._generate_memory_cost(
|
||||
sharding_size_forward, sharding_size_backward_activation, sharding_size_weight)
|
||||
|
||||
# compute the communication cost of this strategy during forward phase
|
||||
communication_cost_forward = self.device_mesh.all_reduce_cost(memory_cost_forward_activation, mesh_dim_1)
|
||||
# This strategy do not need to do all_reduce operation to compute the input activation grad
|
||||
communication_cost_backward_activation = 0
|
||||
# compute the backward communication cost to all reduce the weight due to data parallel
|
||||
communication_cost_backward_weight = self.device_mesh.all_reduce_cost(memory_cost_backward_weight, mesh_dim_0)
|
||||
# compute total cost
|
||||
communication_cost = communication_cost_forward + communication_cost_backward_activation + communication_cost_backward_weight
|
||||
sharding_strategies = ShardingStrategy(name,
|
||||
output_sharding_spec=sharding_spec_for_output,
|
||||
compute_cost=compute_cost,
|
||||
communication_cost=communication_cost,
|
||||
memory_cost=memory_cost,
|
||||
resharding_costs=resharding_costs,
|
||||
input_shardings=(sharding_spec_for_input, sharding_spec_for_weight))
|
||||
self.strategies_vector.append(sharding_strategies)
|
||||
|
||||
@ignore_sharding_exception
|
||||
def split_input_in_channel_weight_both_channel(self, mesh_dim_0, mesh_dim_1):
|
||||
name = f'RS{mesh_dim_1} = RS{mesh_dim_0} x S{mesh_dim_0}S{mesh_dim_1}'
|
||||
|
||||
dim_partition_dict_for_input = {1: [mesh_dim_0]}
|
||||
sharding_spec_for_input = self._generate_sharding_spec(self.input_data, dim_partition_dict_for_input)
|
||||
|
||||
dim_partition_dict_for_weight = {0: [mesh_dim_0], 1: [mesh_dim_1]}
|
||||
sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight)
|
||||
|
||||
dim_partition_dict_for_output = {1: [mesh_dim_1]}
|
||||
sharding_spec_for_output = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output)
|
||||
|
||||
# generate resharding cost for this strategy
|
||||
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input, sharding_spec_for_weight])
|
||||
|
||||
# compute the computation cost of this strategy
|
||||
bs = self.input_data.shape[0]
|
||||
channel_in = self.input_data.shape[1] // self.device_mesh.shape[mesh_dim_0]
|
||||
channel_out = self.weight.shape[1] // self.device_mesh.shape[mesh_dim_1]
|
||||
compute_cost = self._generate_compute_cost(bs, channel_in, channel_out)
|
||||
|
||||
# compute the memory cost of this strategy
|
||||
sharding_size_forward = self.device_mesh.shape[mesh_dim_1]
|
||||
sharding_size_backward_activation = self.device_mesh.shape[mesh_dim_0]
|
||||
sharding_size_weight = self.device_mesh.shape[mesh_dim_0] * self.device_mesh.shape[mesh_dim_1]
|
||||
memory_cost, memory_cost_forward_activation, memory_cost_backward_activation, _ = self._generate_memory_cost(
|
||||
sharding_size_forward, sharding_size_backward_activation, sharding_size_weight)
|
||||
|
||||
# compute the communication cost of this strategy during forward phase
|
||||
communication_cost_forward = self.device_mesh.all_reduce_cost(memory_cost_forward_activation, mesh_dim_0)
|
||||
# compute the communication cost of this strategy during backward phase
|
||||
communication_cost_backward = self.device_mesh.all_reduce_cost(memory_cost_backward_activation, mesh_dim_1)
|
||||
communication_cost = communication_cost_forward + communication_cost_backward
|
||||
sharding_strategies = ShardingStrategy(name,
|
||||
output_sharding_spec=sharding_spec_for_output,
|
||||
compute_cost=compute_cost,
|
||||
communication_cost=communication_cost,
|
||||
memory_cost=memory_cost,
|
||||
resharding_costs=resharding_costs,
|
||||
input_shardings=(sharding_spec_for_input, sharding_spec_for_weight))
|
||||
self.strategies_vector.append(sharding_strategies)
|
||||
|
||||
@ignore_sharding_exception
|
||||
def split_input_in_channel_weight_in_channel(self, mesh_dim_0):
|
||||
name = f'RR = RS{mesh_dim_0} x S{mesh_dim_0}R'
|
||||
|
||||
dim_partition_dict_for_input = {1: [mesh_dim_0]}
|
||||
sharding_spec_for_input = self._generate_sharding_spec(self.input_data, dim_partition_dict_for_input)
|
||||
|
||||
dim_partition_dict_for_weight = {0: [mesh_dim_0]}
|
||||
sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight)
|
||||
|
||||
dim_partition_dict_for_output = {}
|
||||
sharding_spec_for_output = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output)
|
||||
|
||||
# generate resharding cost for this strategy
|
||||
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input, sharding_spec_for_weight])
|
||||
|
||||
# compute the computation cost of this strategy
|
||||
bs = self.input_data.shape[0]
|
||||
channel_in = self.input_data.shape[1] // self.device_mesh.shape[mesh_dim_0]
|
||||
channel_out = self.weight.shape[1]
|
||||
compute_cost = self._generate_compute_cost(bs, channel_in, channel_out)
|
||||
|
||||
# compute the memory cost of this strategy
|
||||
sharding_size_forward = 1
|
||||
sharding_size_backward_activation = self.device_mesh.shape[mesh_dim_0]
|
||||
sharding_size_weight = self.device_mesh.shape[mesh_dim_0]
|
||||
memory_cost, memory_cost_forward_activation, _, _ = self._generate_memory_cost(
|
||||
sharding_size_forward, sharding_size_backward_activation, sharding_size_weight)
|
||||
|
||||
# compute the communication cost of this strategy during forward phase
|
||||
communication_cost_forward = self.device_mesh.all_reduce_cost(memory_cost_forward_activation, mesh_dim_0)
|
||||
# This strategy do NOT need all_reduce during forward phase
|
||||
communication_cost_backward = 0
|
||||
communication_cost = communication_cost_forward + communication_cost_backward
|
||||
sharding_strategies = ShardingStrategy(name,
|
||||
output_sharding_spec=sharding_spec_for_output,
|
||||
compute_cost=compute_cost,
|
||||
communication_cost=communication_cost,
|
||||
memory_cost=memory_cost,
|
||||
resharding_costs=resharding_costs,
|
||||
input_shardings=(sharding_spec_for_input, sharding_spec_for_weight))
|
||||
self.strategies_vector.append(sharding_strategies)
|
||||
|
||||
@ignore_sharding_exception
|
||||
def split_weight_out_channel(self, mesh_dim_0):
|
||||
name = f'RS{mesh_dim_0} = RR x RS{mesh_dim_0}'
|
||||
|
||||
dim_partition_dict_for_input = {}
|
||||
sharding_spec_for_input = self._generate_sharding_spec(self.input_data, dim_partition_dict_for_input)
|
||||
|
||||
dim_partition_dict_for_weight = {1: [mesh_dim_0]}
|
||||
sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight)
|
||||
|
||||
dim_partition_dict_for_output = {1: [mesh_dim_0]}
|
||||
sharding_spec_for_output = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output)
|
||||
|
||||
# generate resharding cost for this strategy
|
||||
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input, sharding_spec_for_weight])
|
||||
|
||||
# compute the computation cost of this strategy
|
||||
bs = self.input_data.shape[0]
|
||||
channel_in = self.input_data.shape[1]
|
||||
channel_out = self.weight.shape[1] // self.device_mesh.shape[mesh_dim_0]
|
||||
compute_cost = self._generate_compute_cost(bs, channel_in, channel_out)
|
||||
|
||||
# compute the memory cost of this strategy
|
||||
sharding_size_forward = self.device_mesh.shape[mesh_dim_0]
|
||||
sharding_size_backward_activation = 1
|
||||
sharding_size_weight = self.device_mesh.shape[mesh_dim_0]
|
||||
memory_cost, _, memory_cost_backward_activation, _ = self._generate_memory_cost(
|
||||
sharding_size_forward, sharding_size_backward_activation, sharding_size_weight)
|
||||
|
||||
# This strategy do not need to do all_reduce during forward phase
|
||||
communication_cost_forward = 0
|
||||
# compute the communication cost of this strategy during backward phase
|
||||
communication_cost_backward = self.device_mesh.all_reduce_cost(memory_cost_backward_activation, mesh_dim_0)
|
||||
communication_cost = communication_cost_forward + communication_cost_backward
|
||||
sharding_strategies = ShardingStrategy(name,
|
||||
output_sharding_spec=sharding_spec_for_output,
|
||||
compute_cost=compute_cost,
|
||||
communication_cost=communication_cost,
|
||||
memory_cost=memory_cost,
|
||||
resharding_costs=resharding_costs,
|
||||
input_shardings=(sharding_spec_for_input, sharding_spec_for_weight))
|
||||
self.strategies_vector.append(sharding_strategies)
|
||||
|
||||
@ignore_sharding_exception
|
||||
def non_split(self):
|
||||
name = f'RR = RR x RR'
|
||||
|
||||
dim_partition_dict_for_input = {}
|
||||
sharding_spec_for_input = self._generate_sharding_spec(self.input_data, dim_partition_dict_for_input)
|
||||
|
||||
dim_partition_dict_for_weight = {}
|
||||
sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight)
|
||||
|
||||
dim_partition_dict_for_output = {}
|
||||
sharding_spec_for_output = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output)
|
||||
|
||||
# generate resharding cost for this strategy
|
||||
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input, sharding_spec_for_weight])
|
||||
|
||||
# compute the computation cost of this strategy
|
||||
bs = self.input_data.shape[0]
|
||||
channel_in = self.input_data.shape[1]
|
||||
channel_out = self.weight.shape[1]
|
||||
compute_cost = self._generate_compute_cost(bs, channel_in, channel_out)
|
||||
|
||||
# compute the memory cost of this strategy
|
||||
sharding_size_forward = 1
|
||||
sharding_size_backward_activation = 1
|
||||
sharding_size_weight = 1
|
||||
memory_cost, _, _, _ = self._generate_memory_cost(sharding_size_forward, sharding_size_backward_activation,
|
||||
sharding_size_weight)
|
||||
|
||||
# This strategy do not need to do all_reduce in both forward and backward phase
|
||||
communication_cost = 0
|
||||
|
||||
sharding_strategies = ShardingStrategy(name,
|
||||
output_sharding_spec=sharding_spec_for_output,
|
||||
compute_cost=compute_cost,
|
||||
communication_cost=communication_cost,
|
||||
memory_cost=memory_cost,
|
||||
resharding_costs=resharding_costs,
|
||||
input_shardings=(sharding_spec_for_input, sharding_spec_for_weight))
|
||||
self.strategies_vector.append(sharding_strategies)
|
||||
|
||||
@ignore_sharding_exception
|
||||
def split_1d_parallel_on_input_batch(self, mesh_dim_0, mesh_dim_1):
|
||||
name = f'S{mesh_dim_0}{mesh_dim_1}R = S{mesh_dim_0}{mesh_dim_1}R x RR'
|
||||
|
||||
dim_partition_dict_for_input = {0: [mesh_dim_0, mesh_dim_1]}
|
||||
sharding_spec_for_input = self._generate_sharding_spec(self.input_data, dim_partition_dict_for_input)
|
||||
|
||||
dim_partition_dict_for_weight = {}
|
||||
sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight)
|
||||
|
||||
dim_partition_dict_for_output = {0: [mesh_dim_0, mesh_dim_1]}
|
||||
sharding_spec_for_output = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output)
|
||||
|
||||
# generate resharding cost for this strategy
|
||||
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input, sharding_spec_for_weight])
|
||||
|
||||
# compute the computation cost of this strategy
|
||||
bs = self.input_data.shape[0] // (self.device_mesh.shape[mesh_dim_0] * self.device_mesh.shape[mesh_dim_1])
|
||||
channel_in = self.input_data.shape[1]
|
||||
channel_out = self.weight.shape[1]
|
||||
compute_cost = self._generate_compute_cost(bs, channel_in, channel_out)
|
||||
|
||||
# compute the memory cost of this strategy
|
||||
sharding_size_forward = self.device_mesh.mesh_shape[mesh_dim_0] * self.device_mesh.mesh_shape[mesh_dim_1]
|
||||
sharding_size_backward_activation = self.device_mesh.mesh_shape[mesh_dim_0] * self.device_mesh.mesh_shape[
|
||||
mesh_dim_1]
|
||||
sharding_size_weight = 1
|
||||
memory_cost, _, _, memory_cost_backward_weight = self._generate_memory_cost(sharding_size_forward,
|
||||
sharding_size_backward_activation,
|
||||
sharding_size_weight)
|
||||
|
||||
# This strategy do not need to do all_reduce in forward phase
|
||||
communication_cost_forward = 0
|
||||
# compute the backward communication cost to all reduce the weight due to data parallel
|
||||
communication_cost_backward_weight = self.device_mesh.flatten_device_mesh.all_reduce_cost(
|
||||
memory_cost_backward_weight, 0)
|
||||
# compute the total communication cost
|
||||
communication_cost = communication_cost_backward_weight + communication_cost_forward
|
||||
|
||||
sharding_strategies = ShardingStrategy(name,
|
||||
output_sharding_spec=sharding_spec_for_output,
|
||||
compute_cost=compute_cost,
|
||||
communication_cost=communication_cost,
|
||||
memory_cost=memory_cost,
|
||||
resharding_costs=resharding_costs,
|
||||
input_shardings=(sharding_spec_for_input, sharding_spec_for_weight))
|
||||
self.strategies_vector.append(sharding_strategies)
|
||||
|
||||
@ignore_sharding_exception
|
||||
def split_1d_parallel_on_in_channel(self, mesh_dim_0, mesh_dim_1):
|
||||
name = f'RR = RS{mesh_dim_0}{mesh_dim_1} x S{mesh_dim_0}{mesh_dim_1}R'
|
||||
|
||||
dim_partition_dict_for_input = {1: [mesh_dim_0, mesh_dim_1]}
|
||||
sharding_spec_for_input = self._generate_sharding_spec(self.input_data, dim_partition_dict_for_input)
|
||||
|
||||
dim_partition_dict_for_weight = {0: [mesh_dim_0, mesh_dim_1]}
|
||||
sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight)
|
||||
|
||||
dim_partition_dict_for_output = {}
|
||||
sharding_spec_for_output = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output)
|
||||
|
||||
# generate resharding cost for this strategy
|
||||
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input, sharding_spec_for_weight])
|
||||
|
||||
# compute the computation cost of this strategy
|
||||
bs = self.input_data.shape[0]
|
||||
channel_in = self.input_data.shape[1] // (self.device_mesh.shape[mesh_dim_0] *
|
||||
self.device_mesh.shape[mesh_dim_1])
|
||||
channel_out = self.weight.shape[1]
|
||||
compute_cost = self._generate_compute_cost(bs, channel_in, channel_out)
|
||||
|
||||
# compute the memory cost of this strategy
|
||||
sharding_size_forward = 1
|
||||
sharding_size_backward_activation = self.device_mesh.mesh_shape[mesh_dim_0] * self.device_mesh.mesh_shape[
|
||||
mesh_dim_1]
|
||||
sharding_size_weight = self.device_mesh.mesh_shape[mesh_dim_0] * self.device_mesh.mesh_shape[mesh_dim_1]
|
||||
memory_cost, memory_cost_forward_activation, _, _ = self._generate_memory_cost(
|
||||
sharding_size_forward, sharding_size_backward_activation, sharding_size_weight)
|
||||
|
||||
# compute communication cost during forward phase
|
||||
communication_cost_forward = self.device_mesh.flatten_device_mesh.all_reduce_cost(
|
||||
memory_cost_forward_activation, 0)
|
||||
# This strategy do NOT need do all_reduce during backward phase
|
||||
communication_cost_backward = 0
|
||||
communication_cost = communication_cost_forward + communication_cost_backward
|
||||
|
||||
sharding_strategies = ShardingStrategy(name,
|
||||
output_sharding_spec=sharding_spec_for_output,
|
||||
compute_cost=compute_cost,
|
||||
communication_cost=communication_cost,
|
||||
memory_cost=memory_cost,
|
||||
resharding_costs=resharding_costs,
|
||||
input_shardings=(sharding_spec_for_input, sharding_spec_for_weight))
|
||||
self.strategies_vector.append(sharding_strategies)
|
||||
|
||||
def register_strategy(self) -> StrategiesVector:
|
||||
'''
|
||||
Generate every possible strategies for a Conv node, and record all strategies into the strategies_vector.
|
||||
|
||||
Example:
|
||||
physical_mesh_id = torch.arange(0, 4)
|
||||
mesh_shape = (2, 2)
|
||||
# [[0, 1]
|
||||
# [2, 3]]
|
||||
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape)
|
||||
shape_consistency_manager = ShapeConsistencyManager()
|
||||
|
||||
model = ConvModel(16, 32)
|
||||
input_sample = {'x': torch.rand(4, 16, 64, 64).to('meta')}
|
||||
# graph():
|
||||
# %x : torch.Tensor [#users=1] = placeholder[target=x]
|
||||
# %mul : [#users=1] = call_function[target=operator.mul](args = (%x, 2), kwargs = {})
|
||||
# %conv : [#users=1] = call_module[target=conv](args = (%mul,), kwargs = {})
|
||||
# return conv
|
||||
graph = tracer.trace(root=model, meta_args=input_sample)
|
||||
gm = GraphModule(model, graph, model.__class__.__name__)
|
||||
gm.recompile()
|
||||
# [x, mul, conv, output]
|
||||
nodes = [node for node in gm.graph.nodes]
|
||||
|
||||
# strategies_for_input = [[R, R, R, R], [R, S0, R, R], [R, S1, R, R], [S0, R, R, R], [S0, S1, R, R], [S1, R, R, R], [S1, S0, R, R]]
|
||||
strategies_vector_for_input = StrategiesVector(node=nodes[0], in_nodes=[nodes[1], 2], strategies=strategies_for_input)
|
||||
setattr(nodes[1], 'strategies_vector', strategies_vector_for_input)
|
||||
|
||||
strategies_vector = StrategiesVector(node=nodes[2], in_nodes=[nodes[1], ])
|
||||
conv_handler = ConvHandler(input_node=nodes[1], input_index=0, weight=dict(gm.named_modules())[nodes[2].name].weight, output_node=nodes[2],
|
||||
device_mesh=device_mesh, strategies_vector=strategies_vector, shape_consistency_manager=shape_consistency_manager)
|
||||
conv_handler.register_strategy_into_strategies_vector()
|
||||
for strategy in conv_handler.strategies_vector:
|
||||
print(f'{strategy.name}: compute_cost is {strategy.compute_cost}, communication_cost is {strategy.communication_cost}, memory_cost is {strategy.memory_cost}, resharding_costs is {strategy.resharding_costs}')
|
||||
|
||||
Output:
|
||||
S0S1 = S0R x RS1: compute_cost is 8856576, communication_cost is 0, memory_cost is 492032.0, resharding_costs is {mul: [0, 32769.001, 131074.2, 0, 32769.1, 131074.2, 98307.201]}
|
||||
S1S0 = S1R x RS0: compute_cost is 8856576, communication_cost is 0, memory_cost is 492032.0, resharding_costs is {mul: [0, 131074.2, 32769.001, 131074.2, 98307.201, 0, 32769.1]}
|
||||
S0R = S0R x RR: compute_cost is 17713152, communication_cost is 0, memory_cost is 984064.0, resharding_costs is {mul: [0, 32769.001, 131074.2, 0, 32769.1, 131074.2, 98307.201]}
|
||||
S1R = S1R x RR: compute_cost is 17713152, communication_cost is 0, memory_cost is 984064.0, resharding_costs is {mul: [0, 131074.2, 32769.001, 131074.2, 98307.201, 0, 32769.1]}
|
||||
S0R = S0S1 x S1R: compute_cost is 8856576, communication_cost is 984065.01, memory_cost is 984064.0, resharding_costs is {mul: [0, 65538.002, 0, 0, 0, 65538.002, 196614.402]}
|
||||
S1R = S1S0 x S0R: compute_cost is 8856576, communication_cost is 984065.01, memory_cost is 984064.0, resharding_costs is {mul: [0, 0, 65538.002, 65538.002, 196614.402, 0, 0]}
|
||||
RS1 = RS0 x S0S1: compute_cost is 8856576, communication_cost is 984065.01, memory_cost is 984064.0, resharding_costs is {mul: [0, 0, 131074.2, 32769.001, 98307.201, 131074.2, 32769.1]}
|
||||
RS0 = RS1 x S1S0: compute_cost is 8856576, communication_cost is 984065.01, memory_cost is 984064.0, resharding_costs is {mul: [0, 131074.2, 0, 131074.2, 32769.1, 32769.001, 98307.201]}
|
||||
RR = RS0 x S0R: compute_cost is 17713152, communication_cost is 1968129.01, memory_cost is 1968128, resharding_costs is {mul: [0, 0, 131074.2, 32769.001, 98307.201, 131074.2, 32769.1]}
|
||||
RR = RS1 x S1R: compute_cost is 17713152, communication_cost is 1968129.01, memory_cost is 1968128, resharding_costs is {mul: [0, 131074.2, 0, 131074.2, 32769.1, 32769.001, 98307.201]}
|
||||
RS0 = RR x RS0: compute_cost is 17713152, communication_cost is 0, memory_cost is 984064.0, resharding_costs is {mul: [0, 65537.1, 65537.1, 65537.1, 131075.30000000002, 65537.1, 131075.30000000002]}
|
||||
RS1 = RR x RS1: compute_cost is 17713152, communication_cost is 0, memory_cost is 984064.0, resharding_costs is {mul: [0, 65537.1, 65537.1, 65537.1, 131075.30000000002, 65537.1, 131075.30000000002]}
|
||||
RR = RR x RR: compute_cost is 35426304, communication_cost is 0, memory_cost is 1968128, resharding_costs is {mul: [0, 65537.1, 65537.1, 65537.1, 131075.30000000002, 65537.1, 131075.30000000002]}
|
||||
S01R = S01R x RR: compute_cost is 8856576, communication_cost is 0, memory_cost is 492032.0, resharding_costs is {mul: [0, 65538.002, 262148.4, 0, 16385.001, 262148.4, 196614.402]}
|
||||
RR = RS01 x S01R: compute_cost is 8856576, communication_cost is 0, memory_cost is 1968128, resharding_costs is {mul: [0, 0, 262148.4, 65538.002, 196614.402, 262148.4, 65538.2]}
|
||||
'''
|
||||
# SS = SR x RS
|
||||
self.split_input_batch_weight_out_channel(0, 1)
|
||||
self.split_input_batch_weight_out_channel(1, 0)
|
||||
|
||||
# SR = SR x RR
|
||||
self.split_input_batch(0)
|
||||
self.split_input_batch(1)
|
||||
|
||||
# SR = SS x SR
|
||||
self.split_input_both_dim_weight_in_channel(0, 1)
|
||||
self.split_input_both_dim_weight_in_channel(1, 0)
|
||||
|
||||
# RS = RS x SS
|
||||
self.split_input_in_channel_weight_both_channel(0, 1)
|
||||
self.split_input_in_channel_weight_both_channel(1, 0)
|
||||
|
||||
# RR = RS x SR
|
||||
self.split_input_in_channel_weight_in_channel(0)
|
||||
self.split_input_in_channel_weight_in_channel(1)
|
||||
|
||||
# RS = RR x RS
|
||||
self.split_weight_out_channel(0)
|
||||
self.split_weight_out_channel(1)
|
||||
|
||||
# RR= RR x RR
|
||||
self.non_split()
|
||||
|
||||
# S01R = S01R x RR
|
||||
self.split_1d_parallel_on_input_batch(0, 1)
|
||||
|
||||
# RR = RS01 x S01R
|
||||
self.split_1d_parallel_on_in_channel(0, 1)
|
||||
|
||||
return self.strategies_vector
|
||||
|
||||
|
||||
CONV_STRATEGIES_LIST = [
|
||||
'S0S1 = S0R x RS1', 'S1S0 = S1R x RS0', 'S0R = S0R x RR', 'S1R = S1R x RR', 'S0R = S0S1 x S1R', 'S1R = S1S0 x S0R',
|
||||
'RS1 = RS0 x S0S1', 'RS0 = RS1 x S1S0', 'RR = RS0 x S0R', 'RR = RS1 x S1R', 'RS0 = RR x RS0', 'RS1 = RR x RS1',
|
||||
'RR = RR x RR', 'S01R = S01R x RR', 'RR = RS01 x S01R'
|
||||
]
|
|
@ -1,756 +0,0 @@
|
|||
import operator
|
||||
from enum import Enum
|
||||
from functools import reduce
|
||||
from typing import List
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from colossalai.auto_parallel.tensor_shard.deprecated._utils import ignore_sharding_exception
|
||||
from colossalai.auto_parallel.tensor_shard.deprecated.sharding_strategy import ShardingStrategy, StrategiesVector
|
||||
|
||||
from ..constants import LINEAR_FUNC_OP, LINEAR_MODULE_OP
|
||||
from .operator_handler import OperatorHandler
|
||||
from .strategy_generator import IntermediateStrategy, StrategyGenerator
|
||||
|
||||
__all__ = ['DotHandler']
|
||||
|
||||
|
||||
class DotProductStrategyGenerator(StrategyGenerator):
|
||||
"""
|
||||
DotProductStrategyGenerator is used to generate the sharding strategies for two 1D tensors in dot product computation.
|
||||
This is created for torch.matmul where two tensors are 1D tensors. As torch.matmul does not include a bias argument, so we
|
||||
do not consider bias here.
|
||||
"""
|
||||
|
||||
def validate(self, input, other):
|
||||
assert input.dim() == 1 and other.dim() == 1
|
||||
|
||||
def no_split(self):
|
||||
name = f'R = R dot R'
|
||||
dim_partition_dict = {"input": {}, "other": {}, "output": {}}
|
||||
return IntermediateStrategy(name=name, dim_partition_dict=dim_partition_dict)
|
||||
|
||||
def split_one_dim(self, mesh_dim):
|
||||
name = f'S{mesh_dim} = S{mesh_dim} dot S{mesh_dim}'
|
||||
dim_partition_dict = {"input": {0: [mesh_dim]}, "other": {0: [mesh_dim]}, "output": {}}
|
||||
return IntermediateStrategy(name=name, dim_partition_dict=dim_partition_dict, all_reduce_axis=[mesh_dim])
|
||||
|
||||
def generate(self) -> List[IntermediateStrategy]:
|
||||
strategy_list = []
|
||||
|
||||
# do not split dimensions for dot product
|
||||
# R = R dot R
|
||||
strategy_list.append(self.no_split())
|
||||
|
||||
# split two tensors in the same dimensions
|
||||
# S = S dot S
|
||||
strategy_list.append(self.split_one_dim(0))
|
||||
strategy_list.append(self.split_one_dim(1))
|
||||
|
||||
return strategy_list
|
||||
|
||||
|
||||
class MatVecStrategyGenerator(StrategyGenerator):
|
||||
|
||||
def validate(self, input, other) -> bool:
|
||||
assert input.dim() > 1 and other.dim() == 1
|
||||
|
||||
def no_split(self):
|
||||
name = "R = R x R"
|
||||
dim_partition_dict = {"input": {}, "other": {}, "output": {}}
|
||||
return IntermediateStrategy(name=name, dim_partition_dict=dim_partition_dict)
|
||||
|
||||
def split_input_batch(self, mesh_dim):
|
||||
name = f'S{mesh_dim}R = S{mesh_dim}R x R'
|
||||
dim_partition_dict = {"input": {0: [mesh_dim]}, "other": {}, "output": {0: [mesh_dim]}}
|
||||
return IntermediateStrategy(name=name, dim_partition_dict=dim_partition_dict)
|
||||
|
||||
def generate(self) -> List[IntermediateStrategy]:
|
||||
strategy_list = []
|
||||
|
||||
# no split
|
||||
strategy_list.append(self.no_split())
|
||||
|
||||
# split the batch dim for the first tensor only
|
||||
strategy_list.append(self.split_input_batch(0))
|
||||
strategy_list.append(self.split_input_batch(1))
|
||||
|
||||
return strategy_list
|
||||
|
||||
|
||||
class MatMulStrategyGenerator(StrategyGenerator):
|
||||
"""
|
||||
MatMulStrategyGenerator is used to generate the sharding strategies when the second tensor is
|
||||
a 2D tensor. This is used for nn.Linear, F.linear, torch.matmul and torch.addmm.
|
||||
|
||||
A matmul can be formulated as [n, p] x [p, q] = [n, q]
|
||||
|
||||
Args:
|
||||
is_linear (bool): whether this generator is used for nn.Linear and F.linear.
|
||||
This will incur extra transformation of the dim partitioning as the weight is transposed.
|
||||
"""
|
||||
|
||||
def __init__(self, is_linear: bool, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.is_linear = is_linear
|
||||
|
||||
# as the weight for the linear module is transposed, we can compute
|
||||
# the correponding dimension indexfor convenience
|
||||
if is_linear:
|
||||
self.dim_q = 0
|
||||
self.dim_p = 1
|
||||
else:
|
||||
self.dim_q = 1
|
||||
self.dim_p = 0
|
||||
|
||||
def validate(self, input, other, bias) -> bool:
|
||||
# make sure the second tensor is a 2D tensor
|
||||
assert input.dim() > 0 and other.dim() == 2
|
||||
|
||||
# make sure bias is of the same dimension
|
||||
if self.is_linear:
|
||||
assert bias is None or bias.shape[-1] == other.shape[0]
|
||||
else:
|
||||
assert bias is None or bias.shape[-1] == other.shape[1]
|
||||
|
||||
def split_lhs_space_rhs_space(self, mesh_dim_0, mesh_dim_1):
|
||||
# handle case SS = SR x RS
|
||||
name = f'S{mesh_dim_0}S{mesh_dim_1} = S{mesh_dim_0}R x RS{mesh_dim_1}'
|
||||
|
||||
dim_partition_dict = {
|
||||
"input": {
|
||||
0: [mesh_dim_0]
|
||||
},
|
||||
"other": {
|
||||
self.dim_q: [mesh_dim_1]
|
||||
},
|
||||
"bias": {
|
||||
-1: [mesh_dim_1]
|
||||
},
|
||||
"output": {
|
||||
0: [mesh_dim_0],
|
||||
-1: [mesh_dim_1]
|
||||
},
|
||||
}
|
||||
return IntermediateStrategy(name=name, dim_partition_dict=dim_partition_dict)
|
||||
|
||||
def split_lhs_space_both_contract(self, mesh_dim_0, mesh_dim_1):
|
||||
# handle the case SR = SS x SR
|
||||
name = f'S{mesh_dim_0}R = S{mesh_dim_0}S{mesh_dim_1} x S{mesh_dim_1}R'
|
||||
dim_partition_dict = {
|
||||
"input": {
|
||||
0: [mesh_dim_0],
|
||||
-1: [mesh_dim_1]
|
||||
},
|
||||
"other": {
|
||||
self.dim_p: [mesh_dim_1]
|
||||
},
|
||||
"bias": {},
|
||||
"output": {
|
||||
0: [mesh_dim_0]
|
||||
},
|
||||
}
|
||||
return IntermediateStrategy(name=name, dim_partition_dict=dim_partition_dict, all_reduce_axis=[mesh_dim_1])
|
||||
|
||||
def split_rhs_space_both_contract(self, mesh_dim_0, mesh_dim_1):
|
||||
name = f'RS{mesh_dim_1} = RS{mesh_dim_0} x S{mesh_dim_0}S{mesh_dim_1}'
|
||||
dim_partition_dict = {
|
||||
"input": {
|
||||
-1: [mesh_dim_0]
|
||||
},
|
||||
"other": {
|
||||
self.dim_p: [mesh_dim_0],
|
||||
self.dim_q: [mesh_dim_1]
|
||||
},
|
||||
"bias": {
|
||||
-1: [mesh_dim_1]
|
||||
},
|
||||
"output": {
|
||||
-1: [mesh_dim_1]
|
||||
},
|
||||
}
|
||||
return IntermediateStrategy(name=name, dim_partition_dict=dim_partition_dict)
|
||||
|
||||
def recompute_split_both_contract(self, mesh_dim):
|
||||
name = f'RR = RS{mesh_dim} x S{mesh_dim}R'
|
||||
dim_partition_dict = {
|
||||
"input": {
|
||||
-1: [mesh_dim]
|
||||
},
|
||||
"other": {
|
||||
self.dim_p: [mesh_dim]
|
||||
},
|
||||
"bias": {},
|
||||
"output": {},
|
||||
}
|
||||
return IntermediateStrategy(name=name, dim_partition_dict=dim_partition_dict, all_reduce_axis=[mesh_dim])
|
||||
|
||||
def split_rhs_space_only(self, mesh_dim):
|
||||
name = f'RS{mesh_dim} = RR x RS{mesh_dim}'
|
||||
dim_partition_dict = {
|
||||
"input": {},
|
||||
"other": {
|
||||
self.dim_q: [mesh_dim]
|
||||
},
|
||||
"bias": {
|
||||
-1: [mesh_dim]
|
||||
},
|
||||
"output": {
|
||||
-1: [mesh_dim]
|
||||
},
|
||||
}
|
||||
return IntermediateStrategy(name=name, dim_partition_dict=dim_partition_dict, all_reduce_axis=[mesh_dim])
|
||||
|
||||
def split_lhs_1st_dim_1d(self, mesh_dim_0, mesh_dim_1):
|
||||
name = f'S{mesh_dim_0}{mesh_dim_1}R = S{mesh_dim_0}{mesh_dim_1}R x RR'
|
||||
dim_partition_dict = {
|
||||
"input": {
|
||||
0: [mesh_dim_0, mesh_dim_1]
|
||||
},
|
||||
"other": {},
|
||||
"bias": {},
|
||||
"output": {
|
||||
0: [mesh_dim_0, mesh_dim_1]
|
||||
},
|
||||
}
|
||||
return IntermediateStrategy(name=name, dim_partition_dict=dim_partition_dict)
|
||||
|
||||
def split_lhs_2nd_dim_1d(self, mesh_dim_0, mesh_dim_1):
|
||||
name = f'RR = RS{mesh_dim_0}{mesh_dim_1} x S{mesh_dim_0}{mesh_dim_1}R'
|
||||
dim_partition_dict = {
|
||||
"input": {
|
||||
-1: [mesh_dim_0, mesh_dim_1]
|
||||
},
|
||||
"other": {
|
||||
self.dim_p: [mesh_dim_0, mesh_dim_1]
|
||||
},
|
||||
"bias": {},
|
||||
"output": {},
|
||||
}
|
||||
return IntermediateStrategy(name=name,
|
||||
dim_partition_dict=dim_partition_dict,
|
||||
all_reduce_axis=[mesh_dim_0, mesh_dim_1])
|
||||
|
||||
def split_rhs_2nd_dim_1d(self, mesh_dim_0, mesh_dim_1):
|
||||
name = f'RS{mesh_dim_0}{mesh_dim_1} = RR x RS{mesh_dim_0}{mesh_dim_1}'
|
||||
|
||||
dim_partition_dict = {
|
||||
"input": {},
|
||||
"other": {
|
||||
self.dim_q: [mesh_dim_0, mesh_dim_1]
|
||||
},
|
||||
"bias": {
|
||||
-1: [mesh_dim_0, mesh_dim_1]
|
||||
},
|
||||
"output": {
|
||||
-1: [mesh_dim_0, mesh_dim_1]
|
||||
},
|
||||
}
|
||||
return IntermediateStrategy(name=name, dim_partition_dict=dim_partition_dict)
|
||||
|
||||
|
||||
class BatchedMatMulStrategyGenerator(StrategyGenerator):
|
||||
"""
|
||||
Generate sharding strategies for the batched matrix multiplication.
|
||||
|
||||
A batched matrix multiplication can be viewed as
|
||||
[b, i, k] x [b, k, j] -> [b, i, j]
|
||||
"""
|
||||
|
||||
def __init__(self, is_torch_bmm: bool, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.is_torch_bmm = is_torch_bmm
|
||||
|
||||
def validate(self, input, other, bias) -> bool:
|
||||
if self.is_torch_bmm:
|
||||
assert input.shape == other.shape
|
||||
assert input.dim() > 2
|
||||
assert other.shape[-1] == bias.shape[0]
|
||||
else:
|
||||
# TODO: validate these inputs are broadcastable
|
||||
pass
|
||||
|
||||
def split_one_batch_dim(self):
|
||||
if 1 in self.device_mesh.mesh_shape:
|
||||
mesh_dim = self.device_mesh.mesh_shape.index(1)
|
||||
name = f'Sb{mesh_dim} = Sb{mesh_dim} x Sb{mesh_dim}'
|
||||
dim_partition_dict = {
|
||||
"input": {
|
||||
0: [mesh_dim]
|
||||
},
|
||||
"other": {
|
||||
0: [mesh_dim]
|
||||
},
|
||||
"bias": {},
|
||||
"output": {
|
||||
0: [mesh_dim]
|
||||
}
|
||||
}
|
||||
return IntermediateStrategy(name=name, dim_partition_dict=dim_partition_dict)
|
||||
else:
|
||||
return None
|
||||
|
||||
def split_two_batch_dim(self, mesh_dim_0, mesh_dim_1):
|
||||
name = f'Sb{mesh_dim_0}{mesh_dim_1} = Sb{mesh_dim_0}{mesh_dim_1} x Sb{mesh_dim_0}{mesh_dim_1}'
|
||||
dim_partition_dict = {
|
||||
"input": {
|
||||
0: [mesh_dim_0, mesh_dim_1]
|
||||
},
|
||||
"other": {
|
||||
0: [mesh_dim_0, mesh_dim_1]
|
||||
},
|
||||
"bias": {},
|
||||
"output": {
|
||||
0: [mesh_dim_0, mesh_dim_1]
|
||||
}
|
||||
}
|
||||
return IntermediateStrategy(name=name, dim_partition_dict=dim_partition_dict)
|
||||
|
||||
def split_one_batch_dim(self, mesh_dim):
|
||||
name = f'Sb{mesh_dim} = Sb{mesh_dim} x Sb{mesh_dim}'
|
||||
dim_partition_dict = {"input": {0: [mesh_dim]}, "other": {0: [mesh_dim]}, "bias": {}, "output": {0: [mesh_dim]}}
|
||||
return IntermediateStrategy(name=name, dim_partition_dict=dim_partition_dict)
|
||||
|
||||
def split_batch_dim_lhs_space(self, mesh_dim_0, mesh_dim_1):
|
||||
name = f'Sb{mesh_dim_0}Si{mesh_dim_1} = Sb{mesh_dim_0}Si{mesh_dim_1} x Sb{mesh_dim_0}'
|
||||
dim_partition_dict = {
|
||||
"input": {
|
||||
0: [mesh_dim_0],
|
||||
-2: [mesh_dim_1]
|
||||
},
|
||||
"other": {
|
||||
0: [mesh_dim_0]
|
||||
},
|
||||
"bias": {},
|
||||
"output": {
|
||||
0: mesh_dim_0,
|
||||
-2: [mesh_dim_1]
|
||||
}
|
||||
}
|
||||
return IntermediateStrategy(name=name, dim_partition_dict=dim_partition_dict)
|
||||
|
||||
def split_batch_dim_rhs_space(self, mesh_dim_0, mesh_dim_1):
|
||||
name = f'Sb{mesh_dim_0}Sj{mesh_dim_1} = Sb{mesh_dim_0}R x Sb{mesh_dim_0}Sj{mesh_dim_1}'
|
||||
dim_partition_dict = {
|
||||
"input": {
|
||||
0: [mesh_dim_0]
|
||||
},
|
||||
"other": {
|
||||
0: [mesh_dim_0],
|
||||
-1: [mesh_dim_1]
|
||||
},
|
||||
"bias": {
|
||||
-1: [mesh_dim_1]
|
||||
},
|
||||
"output": {
|
||||
0: [mesh_dim_0],
|
||||
-1: [mesh_dim_1]
|
||||
}
|
||||
}
|
||||
return IntermediateStrategy(name=name, dim_partition_dict=dim_partition_dict)
|
||||
|
||||
def split_batch_dim_both_contract(self, mesh_dim_0, mesh_dim_1):
|
||||
name = f'Sb{mesh_dim_0}R = Sb{mesh_dim_0}Sk{mesh_dim_1} x Sb{mesh_dim_0}Sk{mesh_dim_1}'
|
||||
dim_partition_dict = {
|
||||
"input": {
|
||||
0: [mesh_dim_0],
|
||||
-1: [mesh_dim_1]
|
||||
},
|
||||
"other": {
|
||||
0: [mesh_dim_0],
|
||||
-2: [mesh_dim_1]
|
||||
},
|
||||
"bias": {},
|
||||
"output": {
|
||||
0: [mesh_dim_0],
|
||||
-2: [mesh_dim_1]
|
||||
}
|
||||
}
|
||||
return IntermediateStrategy(name=name, dim_partition_dict=dim_partition_dict, all_reduce_axis=[mesh_dim_1])
|
||||
|
||||
def generate(self) -> List[IntermediateStrategy]:
|
||||
strategy_list = []
|
||||
|
||||
# split only the batch dimension
|
||||
# Sb = Sb x Sb
|
||||
# can be None as it is only for 1D device mesh
|
||||
strategy = self.split_one_batch_dim()
|
||||
if strategy:
|
||||
strategy_list.append(strategy)
|
||||
|
||||
# split batch dim of two inputs and the i dim of the first tensor
|
||||
# SbSi = SbSi x Sb
|
||||
strategy_list.append(self.split_batch_dim_lhs_space(0, 1))
|
||||
strategy_list.append(self.split_batch_dim_lhs_space(1, 0))
|
||||
|
||||
# split batch dim of two inputs and the j of the second tensor
|
||||
# SbSj = Sb x SbSj
|
||||
strategy_list.append(self.split_batch_dim_rhs_space(0, 1))
|
||||
strategy_list.append(self.split_batch_dim_rhs_space(1, 0))
|
||||
|
||||
# split batch dim of two inputs and the k dim of two inputs
|
||||
# Sb = SbSk x SbSk, need to all-reduce by k dim
|
||||
strategy_list.append(self.split_batch_dim_both_contract(0, 1))
|
||||
strategy_list.append(self.split_batch_dim_both_contract(1, 0))
|
||||
|
||||
# split two batch dim
|
||||
strategy_list.append(self.split_two_batch_dim(0, 1))
|
||||
strategy_list.append(self.split_two_batch_dim(1, 0))
|
||||
|
||||
return strategy_list
|
||||
|
||||
|
||||
class DotHandler(OperatorHandler):
|
||||
"""
|
||||
A OperatorHandler which deals with the sharding strategies for nn.Linear and F.linear.
|
||||
"""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.input_data = self.predecessor_node[0]._meta_data
|
||||
self.weight = self.module_named_parameters['weight']
|
||||
self.output_data = self.node._meta_data
|
||||
|
||||
def _generate_compute_cost(self, input_shape, weight_shape, total_sharding_size):
|
||||
# TODO: consider bias addition
|
||||
compute_cost = reduce(operator.mul, input_shape) * weight_shape[0] * 2 // total_sharding_size
|
||||
return compute_cost
|
||||
|
||||
@ignore_sharding_exception
|
||||
def split_lhs_space_rhs_space(self, mesh_dim_0, mesh_dim_1):
|
||||
# handle case SS = SR x RS
|
||||
name = f'S{mesh_dim_0}S{mesh_dim_1} = S{mesh_dim_0}R x RS{mesh_dim_1}'
|
||||
|
||||
dim_partition_dict_for_input = {0: [mesh_dim_0]}
|
||||
sharding_spec_for_input = self._generate_sharding_spec(self.input_data, dim_partition_dict_for_input)
|
||||
|
||||
# linear layer weight is transposed during init
|
||||
dim_partition_dict_for_weight = {0: [mesh_dim_1]}
|
||||
sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight)
|
||||
|
||||
dim_partition_dict_for_output = {0: [mesh_dim_0], 1: [mesh_dim_1]}
|
||||
sharding_spec_for_output = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_input)
|
||||
|
||||
# generate resharding cost for this strategy
|
||||
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input, sharding_spec_for_weight])
|
||||
|
||||
# compute computation cost
|
||||
total_sharding_size = self.device_mesh.shape[mesh_dim_0] * self.device_mesh.shape[mesh_dim_1]
|
||||
compute_cost = self._generate_compute_cost(self.input_data.shape, self.weight.shape, total_sharding_size)
|
||||
|
||||
# compute the memory cost of this strategy
|
||||
toatl_memory_cost, activation_memory_cost, weight_memory_cost, input_grad_memory_cost = self._generate_memory_cost(
|
||||
dim_partition_dict_for_output, dim_partition_dict_for_weight, dim_partition_dict_for_input)
|
||||
|
||||
# compute the communication cost
|
||||
communication_cost_activation_backward = self.device_mesh.all_reduce_cost(activation_memory_cost, mesh_dim_1)
|
||||
communication_cost_weight_backward = self.device_mesh.all_reduce_cost(weight_memory_cost, mesh_dim_0)
|
||||
communication_cost = communication_cost_activation_backward + communication_cost_weight_backward
|
||||
|
||||
# create and register strategy
|
||||
sharding_strategies = ShardingStrategy(name,
|
||||
output_sharding_spec=sharding_spec_for_output,
|
||||
compute_cost=compute_cost,
|
||||
communication_cost=communication_cost,
|
||||
memory_cost=toatl_memory_cost,
|
||||
resharding_costs=resharding_costs,
|
||||
input_shardings=(sharding_spec_for_input, sharding_spec_for_weight))
|
||||
self.strategies_vector.append(sharding_strategies)
|
||||
|
||||
@ignore_sharding_exception
|
||||
def split_lhs_space_both_contract(self, mesh_dim_0, mesh_dim_1):
|
||||
# handle the case SR = SS x SR
|
||||
name = f'S{mesh_dim_0}R = S{mesh_dim_0}S{mesh_dim_1} x S{mesh_dim_1}R'
|
||||
|
||||
dim_partition_dict_for_input = {0: [mesh_dim_0], 1: [mesh_dim_1]}
|
||||
sharding_spec_for_input = self._generate_sharding_spec(self.input_data, dim_partition_dict_for_input)
|
||||
|
||||
# since weight of the linear layer is transposed
|
||||
# the actual dim to be sharded is 1
|
||||
dim_partition_dict_for_weight = {1: [mesh_dim_1]}
|
||||
sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight)
|
||||
|
||||
dim_partition_dict_for_output = {0: [mesh_dim_0]}
|
||||
sharding_spec_for_output = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output)
|
||||
|
||||
# generate resharding cost for this strategy
|
||||
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input, sharding_spec_for_weight])
|
||||
|
||||
# compute the computation cost of this strategy
|
||||
total_sharding_size = self.device_mesh.shape[mesh_dim_0] * self.device_mesh.shape[mesh_dim_1]
|
||||
compute_cost = self._generate_compute_cost(self.input_data.shape, self.weight.shape, total_sharding_size)
|
||||
|
||||
# compute the memory cost of this strategy
|
||||
toatl_memory_cost, activation_memory_cost, weight_memory_cost, input_grad_memory_cost = self._generate_memory_cost(
|
||||
dim_partition_dict_for_output, dim_partition_dict_for_weight, dim_partition_dict_for_input)
|
||||
|
||||
# compute the communication cost of this strategy
|
||||
communication_cost_activation_forward = self.device_mesh.all_reduce_cost(activation_memory_cost, mesh_dim_1)
|
||||
communication_cost_grad_backward = self.device_mesh.all_reduce_cost(weight_memory_cost, mesh_dim_0)
|
||||
communication_cost = communication_cost_activation_forward + communication_cost_grad_backward
|
||||
sharding_strategies = ShardingStrategy(name,
|
||||
output_sharding_spec=sharding_spec_for_output,
|
||||
compute_cost=compute_cost,
|
||||
communication_cost=communication_cost,
|
||||
memory_cost=toatl_memory_cost,
|
||||
resharding_costs=resharding_costs,
|
||||
input_shardings=(sharding_spec_for_input, sharding_spec_for_weight))
|
||||
self.strategies_vector.append(sharding_strategies)
|
||||
|
||||
@ignore_sharding_exception
|
||||
def split_rhs_space_both_contract(self, mesh_dim_0, mesh_dim_1):
|
||||
name = f'RS{mesh_dim_1} = RS{mesh_dim_0} x S{mesh_dim_0}S{mesh_dim_1}'
|
||||
|
||||
dim_partition_dict_for_input = {1: [mesh_dim_0]}
|
||||
sharding_spec_for_input = self._generate_sharding_spec(self.input_data, dim_partition_dict_for_input)
|
||||
|
||||
dim_partition_dict_for_weight = {0: [mesh_dim_0], 1: [mesh_dim_1]}
|
||||
sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight)
|
||||
|
||||
dim_partition_dict_for_output = {1: [mesh_dim_1]}
|
||||
sharding_spec_for_output = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_input)
|
||||
|
||||
# generate resharding cost for this strategy
|
||||
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input, sharding_spec_for_weight])
|
||||
|
||||
# compute the computation cost of this strategy
|
||||
total_sharding_size = self.device_mesh.shape[mesh_dim_0] * self.device_mesh.shape[mesh_dim_1]
|
||||
compute_cost = self._generate_compute_cost(self.input_data.shape, self.weight.shape, total_sharding_size)
|
||||
|
||||
# compute the memory cost of this strategy
|
||||
toatl_memory_cost, activation_memory_cost, weight_memory_cost, input_grad_memory_cost = self._generate_memory_cost(
|
||||
dim_partition_dict_for_output, dim_partition_dict_for_weight, dim_partition_dict_for_input)
|
||||
|
||||
# compute the communication cost of this strategy
|
||||
communication_cost_activation_forward = self.device_mesh.all_reduce_cost(activation_memory_cost, mesh_dim_0)
|
||||
communication_cost_activation_backward = self.device_mesh.all_reduce_cost(input_grad_memory_cost, mesh_dim_1)
|
||||
communication_cost = communication_cost_activation_backward + communication_cost_activation_forward
|
||||
|
||||
sharding_strategies = ShardingStrategy(name,
|
||||
output_sharding_spec=sharding_spec_for_output,
|
||||
compute_cost=compute_cost,
|
||||
communication_cost=communication_cost,
|
||||
memory_cost=toatl_memory_cost,
|
||||
resharding_costs=resharding_costs,
|
||||
input_shardings=(sharding_spec_for_input, sharding_spec_for_weight))
|
||||
self.strategies_vector.append(sharding_strategies)
|
||||
|
||||
@ignore_sharding_exception
|
||||
def recompute_split_both_contract(self, mesh_dim):
|
||||
name = f'RR = RS{mesh_dim} x S{mesh_dim}R'
|
||||
|
||||
dim_partition_dict_for_input = {1: [mesh_dim]}
|
||||
sharding_spec_for_input = self._generate_sharding_spec(self.input_data, dim_partition_dict_for_input)
|
||||
|
||||
dim_partition_dict_for_weight = {1: [mesh_dim]}
|
||||
sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight)
|
||||
|
||||
dim_partition_dict_for_output = {}
|
||||
sharding_spec_for_output = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output)
|
||||
|
||||
# generate resharding cost for this strategy
|
||||
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input, sharding_spec_for_weight])
|
||||
|
||||
# compute the computation cost of this strategy
|
||||
total_sharding_size = self.device_mesh.shape[mesh_dim]
|
||||
compute_cost = self._generate_compute_cost(self.input_data.shape, self.weight.shape, total_sharding_size)
|
||||
|
||||
# compute the memory cost of this strategy
|
||||
toatl_memory_cost, activation_memory_cost, weight_memory_cost, input_grad_memory_cost = self._generate_memory_cost(
|
||||
dim_partition_dict_for_output, dim_partition_dict_for_weight, dim_partition_dict_for_input)
|
||||
|
||||
# compute the communication cost of this strategy
|
||||
communication_cost = self.device_mesh.all_reduce_cost(activation_memory_cost, mesh_dim)
|
||||
sharding_strategies = ShardingStrategy(name,
|
||||
output_sharding_spec=sharding_spec_for_output,
|
||||
compute_cost=compute_cost,
|
||||
communication_cost=communication_cost,
|
||||
memory_cost=toatl_memory_cost,
|
||||
resharding_costs=resharding_costs,
|
||||
input_shardings=(sharding_spec_for_input, sharding_spec_for_weight))
|
||||
self.strategies_vector.append(sharding_strategies)
|
||||
|
||||
@ignore_sharding_exception
|
||||
def split_rhs_space_only(self, mesh_dim):
|
||||
name = f'RS{mesh_dim} = RR x RS{mesh_dim}'
|
||||
|
||||
dim_partition_dict_for_input = {}
|
||||
sharding_spec_for_input = self._generate_sharding_spec(self.input_data, dim_partition_dict_for_input)
|
||||
|
||||
dim_partition_dict_for_weight = {0: [mesh_dim]}
|
||||
sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight)
|
||||
|
||||
dim_partition_dict_for_output = {1: [mesh_dim]}
|
||||
sharding_spec_for_output = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output)
|
||||
|
||||
# generate resharding cost for this strategy
|
||||
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input, sharding_spec_for_weight])
|
||||
|
||||
# compute the computation cost of this strategy
|
||||
total_sharding_size = self.device_mesh.shape[mesh_dim]
|
||||
compute_cost = self._generate_compute_cost(self.input_data.shape, self.weight.shape, total_sharding_size)
|
||||
|
||||
# compute the memory cost of this strategy
|
||||
toatl_memory_cost, activation_memory_cost, weight_memory_cost, input_grad_memory_cost = self._generate_memory_cost(
|
||||
dim_partition_dict_for_output, dim_partition_dict_for_weight, dim_partition_dict_for_input)
|
||||
|
||||
# compute the communication cost of this strategy
|
||||
communication_cost_activation_backward = self.device_mesh.all_reduce_cost(input_grad_memory_cost, mesh_dim)
|
||||
communication_cost = communication_cost_activation_backward
|
||||
sharding_strategies = ShardingStrategy(name,
|
||||
output_sharding_spec=sharding_spec_for_output,
|
||||
compute_cost=compute_cost,
|
||||
communication_cost=communication_cost,
|
||||
memory_cost=toatl_memory_cost,
|
||||
resharding_costs=resharding_costs,
|
||||
input_shardings=(sharding_spec_for_input, sharding_spec_for_weight))
|
||||
self.strategies_vector.append(sharding_strategies)
|
||||
|
||||
@ignore_sharding_exception
|
||||
def split_lhs_1st_dim_1d(self, mesh_dim_0, mesh_dim_1):
|
||||
name = f'S{mesh_dim_0}{mesh_dim_1}R = S{mesh_dim_0}{mesh_dim_1}R x RR'
|
||||
|
||||
dim_partition_dict_for_input = {0: [mesh_dim_0, mesh_dim_1]}
|
||||
sharding_spec_for_input = self._generate_sharding_spec(self.input_data, dim_partition_dict_for_input)
|
||||
|
||||
dim_partition_dict_for_weight = {}
|
||||
sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight)
|
||||
|
||||
dim_partition_dict_for_output = {0: [mesh_dim_0, mesh_dim_1]}
|
||||
sharding_spec_for_output = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output)
|
||||
|
||||
# generate resharding cost for this strategy
|
||||
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input, sharding_spec_for_weight])
|
||||
|
||||
# compute the computation cost of this strategy
|
||||
total_sharding_size = self.device_mesh.shape[mesh_dim_0] * self.device_mesh.shape[mesh_dim_1]
|
||||
compute_cost = self._generate_compute_cost(self.input_data.shape, self.weight.shape, total_sharding_size)
|
||||
|
||||
# compute the memory cost of this strategy
|
||||
toatl_memory_cost, activation_memory_cost, weight_memory_cost, input_grad_memory_cost = self._generate_memory_cost(
|
||||
dim_partition_dict_for_output, dim_partition_dict_for_weight, dim_partition_dict_for_input)
|
||||
|
||||
# compute the communication cost of this strategy
|
||||
communication_cost_weight_backward = self.device_mesh.flatten_device_mesh.all_reduce_cost(weight_memory_cost, 0)
|
||||
communication_cost = communication_cost_weight_backward
|
||||
sharding_strategies = ShardingStrategy(name,
|
||||
output_sharding_spec=sharding_spec_for_output,
|
||||
compute_cost=compute_cost,
|
||||
communication_cost=communication_cost,
|
||||
memory_cost=toatl_memory_cost,
|
||||
resharding_costs=resharding_costs,
|
||||
input_shardings=(sharding_spec_for_input, sharding_spec_for_weight))
|
||||
self.strategies_vector.append(sharding_strategies)
|
||||
|
||||
@ignore_sharding_exception
|
||||
def split_lhs_2nd_dim_1d(self, mesh_dim_0, mesh_dim_1):
|
||||
name = f'RR = RS{mesh_dim_0}{mesh_dim_1} x S{mesh_dim_0}{mesh_dim_1}R'
|
||||
|
||||
dim_partition_dict_for_input = {1: [mesh_dim_0, mesh_dim_1]}
|
||||
sharding_spec_for_input = self._generate_sharding_spec(self.input_data, dim_partition_dict_for_input)
|
||||
|
||||
dim_partition_dict_for_weight = {0: [mesh_dim_0, mesh_dim_1]}
|
||||
sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight)
|
||||
|
||||
dim_partition_dict_for_output = {}
|
||||
sharding_spec_for_output = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output)
|
||||
|
||||
# generate resharding cost for this strategy
|
||||
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input, sharding_spec_for_weight])
|
||||
|
||||
# compute the computation cost of this strategy
|
||||
total_sharding_size = self.device_mesh.shape[mesh_dim_0] * self.device_mesh.shape[mesh_dim_1]
|
||||
compute_cost = self._generate_compute_cost(self.input_data.shape, self.weight.shape, total_sharding_size)
|
||||
|
||||
# compute the memory cost of this strategy
|
||||
toatl_memory_cost, activation_memory_cost, weight_memory_cost, input_grad_memory_cost = self._generate_memory_cost(
|
||||
dim_partition_dict_for_output, dim_partition_dict_for_weight, dim_partition_dict_for_input)
|
||||
|
||||
# compute the communication cost of this strategy
|
||||
communication_cost_forward_activation = self.device_mesh.flatten_device_mesh.all_reduce_cost(
|
||||
activation_memory_cost, 0)
|
||||
communication_cost = communication_cost_forward_activation
|
||||
sharding_strategies = ShardingStrategy(name,
|
||||
output_sharding_spec=sharding_spec_for_output,
|
||||
compute_cost=compute_cost,
|
||||
communication_cost=communication_cost,
|
||||
memory_cost=toatl_memory_cost,
|
||||
resharding_costs=resharding_costs,
|
||||
input_shardings=(sharding_spec_for_input, sharding_spec_for_weight))
|
||||
self.strategies_vector.append(sharding_strategies)
|
||||
|
||||
@ignore_sharding_exception
|
||||
def split_rhs_2nd_dim_1d(self, mesh_dim_0, mesh_dim_1):
|
||||
name = f'RS{mesh_dim_0}{mesh_dim_1} = RR x RS{mesh_dim_0}{mesh_dim_1}'
|
||||
|
||||
dim_partition_dict_for_input = {}
|
||||
sharding_spec_for_input = self._generate_sharding_spec(self.input_data, dim_partition_dict_for_input)
|
||||
|
||||
dim_partition_dict_for_weight = {1: [mesh_dim_0, mesh_dim_1]}
|
||||
sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight)
|
||||
|
||||
dim_partition_dict_for_output = {1: [mesh_dim_0, mesh_dim_1]}
|
||||
sharding_spec_for_output = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output)
|
||||
|
||||
# generate resharding cost for this strategy
|
||||
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input, sharding_spec_for_weight])
|
||||
|
||||
# compute the computation cost of this strategy
|
||||
total_sharding_size = self.device_mesh.shape[mesh_dim_0] * self.device_mesh.shape[mesh_dim_1]
|
||||
compute_cost = self._generate_compute_cost(self.input_data.shape, self.weight.shape, total_sharding_size)
|
||||
|
||||
# compute the memory cost of this strategy
|
||||
toatl_memory_cost, activation_memory_cost, weight_memory_cost, input_grad_memory_cost = self._generate_memory_cost(
|
||||
dim_partition_dict_for_output, dim_partition_dict_for_weight, dim_partition_dict_for_input)
|
||||
# compute the communication cost of this strategy
|
||||
communication_cost_activation_backward = self.device_mesh.flatten_device_mesh.all_reduce_cost(
|
||||
input_grad_memory_cost, 0)
|
||||
communication_cost = communication_cost_activation_backward
|
||||
sharding_strategies = ShardingStrategy(name,
|
||||
output_sharding_spec=sharding_spec_for_output,
|
||||
compute_cost=compute_cost,
|
||||
communication_cost=communication_cost,
|
||||
memory_cost=toatl_memory_cost,
|
||||
resharding_costs=resharding_costs,
|
||||
input_shardings=(sharding_spec_for_input, sharding_spec_for_weight))
|
||||
self.strategies_vector.append(sharding_strategies)
|
||||
|
||||
def register_strategy(self) -> StrategiesVector:
|
||||
'''
|
||||
Generate every possible strategies for a linear node, and record all strategies into the strategies_vector.
|
||||
|
||||
Output:
|
||||
|
||||
'''
|
||||
# SS = SR x RS
|
||||
self.split_lhs_space_rhs_space(0, 1)
|
||||
self.split_lhs_space_rhs_space(1, 0)
|
||||
|
||||
# SR = SS x SR
|
||||
self.split_lhs_space_both_contract(0, 1)
|
||||
self.split_lhs_space_both_contract(1, 0)
|
||||
|
||||
# RS = RS x SS
|
||||
self.split_rhs_space_both_contract(0, 1)
|
||||
self.split_rhs_space_both_contract(1, 0)
|
||||
|
||||
# RR= RS x SR
|
||||
self.recompute_split_both_contract(0)
|
||||
self.recompute_split_both_contract(1)
|
||||
|
||||
# RS = RR x RS
|
||||
self.split_rhs_space_only(0)
|
||||
self.split_rhs_space_only(1)
|
||||
|
||||
# S01R = S01R x RR
|
||||
self.split_lhs_1st_dim_1d(0, 1)
|
||||
|
||||
# RR = RS01 x S01R
|
||||
self.split_lhs_2nd_dim_1d(0, 1)
|
||||
|
||||
# RS01 = RR x RS01
|
||||
self.split_rhs_2nd_dim_1d(0, 1)
|
||||
|
||||
return self.strategies_vector
|
|
@ -1,179 +0,0 @@
|
|||
import operator
|
||||
import warnings
|
||||
from copy import deepcopy
|
||||
from functools import reduce
|
||||
from typing import Dict, List
|
||||
|
||||
import torch
|
||||
from colossalai.auto_parallel.tensor_shard.deprecated._utils import \
|
||||
ignore_sharding_exception
|
||||
from colossalai.auto_parallel.tensor_shard.deprecated.sharding_strategy import (ShardingStrategy, StrategiesVector)
|
||||
from colossalai.tensor.shape_consistency import ShapeConsistencyManager
|
||||
from colossalai.tensor.sharding_spec import ShardingSpec
|
||||
|
||||
from .operator_handler import OperatorHandler
|
||||
|
||||
__all__ = ['EmbeddingHandler']
|
||||
|
||||
|
||||
class EmbeddingHandler(OperatorHandler):
|
||||
"""
|
||||
An OperatorHandler which deals with the sharding strategies of Embedding operators(such as nn.embedding).
|
||||
"""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.input_data = self.predecessor_node[0]._meta_data
|
||||
self.weight = self.module_named_parameters['weight']
|
||||
self.output_data = self.node._meta_data
|
||||
|
||||
def _generate_compute_cost(self, total_sharding_size):
|
||||
input_shape = self.input_data.shape
|
||||
weight_shape = self.weight.shape
|
||||
input_shape_product = reduce(operator.mul, input_shape, 1)
|
||||
weight_shape_product = reduce(operator.mul, weight_shape, 1)
|
||||
compute_cost = input_shape_product * weight_shape_product * 2 / total_sharding_size
|
||||
return compute_cost
|
||||
|
||||
def _generate_memory_cost(self, sharding_size_forward, sharding_size_backward_activation, sharding_size_weight):
|
||||
'''
|
||||
Compute the memory cost per device with this specific strategy.
|
||||
|
||||
Argument:
|
||||
sharding_size_forward(int): The forward activation will be divided
|
||||
into sharding_size_forward number partions.
|
||||
sharding_size_backward_activation(int): The backward activation will
|
||||
be divided into sharding_size_backward_activation number partions.
|
||||
sharding_size_weight(int): The backward weight will be divided
|
||||
into sharding_size_weight number partions.
|
||||
|
||||
Return:
|
||||
memory_cost(Tuple[float]): Memory cost per device with this
|
||||
specific strategy, the first element of this tuple is forward
|
||||
memory cost, and the second element of this tuple is backward
|
||||
memory cost.
|
||||
memory_cost_forward(float): Memory cost of forward activation per
|
||||
device with this specific strategy.
|
||||
memory_cost_backward_activation(float): Memory cost of backward activation
|
||||
per device with this specific strategy.
|
||||
'''
|
||||
# compute the memory cost of this strategy
|
||||
dtype = self.input_data.dtype
|
||||
numel_output = self.output_data.numel()
|
||||
numel_input = self.input_data.numel()
|
||||
numel_weight = self.weight.numel()
|
||||
size_per_elem_bytes = torch.tensor([], dtype=dtype).element_size()
|
||||
|
||||
# forward memory_cost
|
||||
memory_cost_forward_activation = numel_output * size_per_elem_bytes / sharding_size_forward
|
||||
memory_cost_forward_weight = numel_weight * size_per_elem_bytes / sharding_size_weight
|
||||
memory_cost_forward = memory_cost_forward_activation + memory_cost_forward_weight
|
||||
|
||||
# backward memory_cost
|
||||
memory_cost_backward_activation = numel_input * size_per_elem_bytes / sharding_size_backward_activation
|
||||
memory_cost_backward_weight = numel_weight * size_per_elem_bytes / sharding_size_weight
|
||||
memory_cost_backward = memory_cost_backward_activation + memory_cost_backward_weight
|
||||
|
||||
# memory_cost pair
|
||||
memory_cost = (memory_cost_forward, memory_cost_backward)
|
||||
|
||||
return memory_cost, memory_cost_forward_activation, memory_cost_backward_activation, memory_cost_backward_weight
|
||||
|
||||
@ignore_sharding_exception
|
||||
def split_weight_both_dim(self, mesh_dim_0, mesh_dim_1):
|
||||
name = f'RRS{mesh_dim_1} = RR x S{mesh_dim_0}S{mesh_dim_1}'
|
||||
|
||||
dim_partition_dict_for_input = {}
|
||||
sharding_spec_for_input = self._generate_sharding_spec(self.input_data, dim_partition_dict_for_input)
|
||||
|
||||
dim_partition_dict_for_weight = {0: [mesh_dim_0], 1: [mesh_dim_1]}
|
||||
sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight)
|
||||
|
||||
dim_partition_dict_for_output = {2: [mesh_dim_1]}
|
||||
sharding_spec_for_output = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output)
|
||||
|
||||
# generate resharding cost for this strategy
|
||||
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input, sharding_spec_for_weight])
|
||||
|
||||
# compute the computation cost of this strategy
|
||||
total_sharding_size = self.device_mesh.shape[0] * self.device_mesh.shape[1]
|
||||
compute_cost = self._generate_compute_cost(total_sharding_size)
|
||||
|
||||
# compute the memory cost of this strategy
|
||||
sharding_size_forward = self.device_mesh.shape[mesh_dim_1]
|
||||
sharding_size_backward_activation = 1
|
||||
sharding_size_weight = self.device_mesh.shape[mesh_dim_0] * self.device_mesh.shape[mesh_dim_1]
|
||||
memory_cost, memory_cost_forward_activation, memory_cost_backward_activation, _ = self._generate_memory_cost(
|
||||
sharding_size_forward, sharding_size_backward_activation, sharding_size_weight)
|
||||
|
||||
# compute the communication cost of this strategy during forward phase
|
||||
communication_cost_forward = self.device_mesh.all_reduce_cost(memory_cost_forward_activation, mesh_dim_0)
|
||||
# compute the communication cost of this strategy during backward phase
|
||||
communication_cost_backward = self.device_mesh.all_reduce_cost(memory_cost_backward_activation, mesh_dim_1)
|
||||
communication_cost = communication_cost_forward + communication_cost_backward
|
||||
sharding_strategies = ShardingStrategy(name,
|
||||
output_sharding_spec=sharding_spec_for_output,
|
||||
compute_cost=compute_cost,
|
||||
communication_cost=communication_cost,
|
||||
memory_cost=memory_cost,
|
||||
resharding_costs=resharding_costs,
|
||||
input_shardings=(sharding_spec_for_input, sharding_spec_for_weight))
|
||||
self.strategies_vector.append(sharding_strategies)
|
||||
|
||||
@ignore_sharding_exception
|
||||
def split_input_both_dim(self, mesh_dim_0, mesh_dim_1):
|
||||
name = f'S{mesh_dim_0}S{mesh_dim_1}R = S{mesh_dim_0}S{mesh_dim_1} x RR'
|
||||
|
||||
dim_partition_dict_for_input = {0: [mesh_dim_0], 1: [mesh_dim_1]}
|
||||
sharding_spec_for_input = self._generate_sharding_spec(self.input_data, dim_partition_dict_for_input)
|
||||
|
||||
dim_partition_dict_for_weight = {}
|
||||
sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight)
|
||||
|
||||
dim_partition_dict_for_output = {0: [mesh_dim_0], 1: [mesh_dim_1]}
|
||||
sharding_spec_for_output = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output)
|
||||
|
||||
# generate resharding cost for this strategy
|
||||
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input, sharding_spec_for_weight])
|
||||
|
||||
# compute the computation cost of this strategy
|
||||
total_sharding_size = self.device_mesh.shape[0] * self.device_mesh.shape[1]
|
||||
compute_cost = self._generate_compute_cost(total_sharding_size)
|
||||
|
||||
# compute the memory cost of this strategy
|
||||
sharding_size_forward = self.device_mesh.shape[mesh_dim_0] * self.device_mesh.shape[mesh_dim_1]
|
||||
sharding_size_backward_activation = self.device_mesh.shape[mesh_dim_0] * self.device_mesh.shape[mesh_dim_1]
|
||||
sharding_size_weight = 1
|
||||
memory_cost, memory_cost_forward_activation, memory_cost_backward_activation, memory_cost_backward_weight = self._generate_memory_cost(
|
||||
sharding_size_forward, sharding_size_backward_activation, sharding_size_weight)
|
||||
|
||||
# This strategy do not need to do all_reduce during forward phase
|
||||
communication_cost_forward = 0
|
||||
# compute the communication cost of this strategy during backward phase
|
||||
communication_cost_backward_activation = 0
|
||||
communication_cost_backward_weight = self.device_mesh.flatten_device_mesh.all_reduce_cost(
|
||||
memory_cost_backward_weight, 0)
|
||||
communication_cost_backward = communication_cost_backward_activation + communication_cost_backward_weight
|
||||
communication_cost = communication_cost_forward + communication_cost_backward
|
||||
sharding_strategies = ShardingStrategy(name,
|
||||
output_sharding_spec=sharding_spec_for_output,
|
||||
compute_cost=compute_cost,
|
||||
communication_cost=communication_cost,
|
||||
memory_cost=memory_cost,
|
||||
resharding_costs=resharding_costs,
|
||||
input_shardings=(sharding_spec_for_input, sharding_spec_for_weight))
|
||||
self.strategies_vector.append(sharding_strategies)
|
||||
|
||||
def register_strategy(self) -> StrategiesVector:
|
||||
'''
|
||||
Generate every possible strategies for a Conv node, and record all strategies into the strategies_vector.
|
||||
'''
|
||||
# RRS = RR x SS
|
||||
self.split_weight_both_dim(0, 1)
|
||||
self.split_weight_both_dim(1, 0)
|
||||
|
||||
# SSR = SS x RR
|
||||
self.split_input_both_dim(0, 1)
|
||||
self.split_input_both_dim(1, 0)
|
||||
|
||||
return self.strategies_vector
|
|
@ -1,241 +0,0 @@
|
|||
import operator
|
||||
from functools import reduce
|
||||
|
||||
import torch
|
||||
|
||||
from colossalai.auto_parallel.tensor_shard.deprecated._utils import (
|
||||
enumerate_all_possible_1d_sharding,
|
||||
enumerate_all_possible_2d_sharding,
|
||||
generate_sharding_size,
|
||||
ignore_sharding_exception,
|
||||
)
|
||||
from colossalai.auto_parallel.tensor_shard.deprecated.sharding_strategy import ShardingStrategy, StrategiesVector
|
||||
|
||||
from .operator_handler import OperatorHandler
|
||||
|
||||
__all__ = ['LayerNormHandler']
|
||||
|
||||
|
||||
class LayerNormHandler(OperatorHandler):
|
||||
"""
|
||||
A OperatorHandler which deals with the sharding strategies of normalization.
|
||||
|
||||
Note: To keep the math consistency, LayerNorm do not allow shards on hidden dimension.
|
||||
"""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.input_data = self.predecessor_node[0]._meta_data
|
||||
self.weight = self.module_named_parameters['weight']
|
||||
self.bias = self.module_named_parameters['bias']
|
||||
self.output_data = self.node._meta_data
|
||||
|
||||
def _generate_compute_cost(self, total_sharding_size):
|
||||
'''
|
||||
Compute the computation cost per device with this specific strategy.
|
||||
|
||||
Note: compute_cost need to be devided by TFLOPS, now it just shows the computation size.
|
||||
|
||||
Argument:
|
||||
bs(int): Batch size of the input data.
|
||||
channel_in(int): The channel dimension of input data.
|
||||
|
||||
Return:
|
||||
compute_cost(float): Computation cost per device with this specific strategy
|
||||
'''
|
||||
# TODO: compute_cost need to be devided by TFLOPS, now it just shows the computation size.
|
||||
# TODO: a constant coefficient need to be added.
|
||||
|
||||
norm_kernel_size = self.weight.shape
|
||||
# in LayerNorm context, batch dimensions mean all the dimensions do not join the normalization.
|
||||
input_batch_shape = self.input_data.shape[:-len(norm_kernel_size)]
|
||||
input_batch_product = reduce(operator.mul, input_batch_shape, 1)
|
||||
norm_kernel_product = reduce(operator.mul, norm_kernel_size, 1)
|
||||
forward_compute_cost = input_batch_product * norm_kernel_product / total_sharding_size
|
||||
backward_activation_compute_cost = input_batch_product * norm_kernel_product / total_sharding_size
|
||||
# To compute gradient of on norm kernel element requires input_batch_product times computation, so
|
||||
# the total cost is input_batch_product * norm_kernel_product
|
||||
backward_weight_compute_cost = input_batch_product * norm_kernel_product / total_sharding_size
|
||||
backward_compute_cost = backward_activation_compute_cost + backward_weight_compute_cost
|
||||
compute_cost = forward_compute_cost + backward_compute_cost
|
||||
return compute_cost
|
||||
|
||||
def _generate_memory_cost(self, sharding_size_forward, sharding_size_backward_activation, sharding_size_weight):
|
||||
'''
|
||||
Compute the memory cost per device with this specific strategy.
|
||||
|
||||
Argument:
|
||||
sharding_size_forward(int): The forward activation will be divided
|
||||
into sharding_size_forward number partions.
|
||||
sharding_size_backward_activation(int): The backward activation will
|
||||
be divided into sharding_size_backward_activation number partions.
|
||||
sharding_size_weight(int): The backward weight will be divided
|
||||
into sharding_size_weight number partions.
|
||||
|
||||
Return:
|
||||
memory_cost(Tuple[float]): Memory cost per device with this
|
||||
specific strategy, the first element of this tuple is forward
|
||||
memory cost, and the second element of this tuple is backward
|
||||
memory cost.
|
||||
memory_cost_forward(float): Memory cost of forward activation per
|
||||
device with this specific strategy.
|
||||
memory_cost_backward_activation(float): Memory cost of backward activation
|
||||
per device with this specific strategy.
|
||||
'''
|
||||
# compute the memory cost of this strategy
|
||||
dtype = self.input_data.dtype
|
||||
numel_output = self.output_data.numel()
|
||||
# this operation will not change the shape of input
|
||||
numel_input = numel_output
|
||||
numel_weight = self.weight.numel()
|
||||
size_per_elem_bytes = torch.tensor([], dtype=dtype).element_size()
|
||||
|
||||
# forward memory_cost
|
||||
memory_cost_forward_activation = numel_output * size_per_elem_bytes / sharding_size_forward
|
||||
memory_cost_forward_weight = numel_weight * size_per_elem_bytes / sharding_size_weight
|
||||
memory_cost_forward = memory_cost_forward_activation + memory_cost_forward_weight
|
||||
|
||||
# backward memory_cost
|
||||
memory_cost_backward_activation = numel_input * size_per_elem_bytes / sharding_size_backward_activation
|
||||
memory_cost_backward_weight = numel_weight * size_per_elem_bytes / sharding_size_weight
|
||||
memory_cost_backward = memory_cost_backward_activation + memory_cost_backward_weight
|
||||
|
||||
# memory_cost pair
|
||||
memory_cost = (memory_cost_forward, memory_cost_backward)
|
||||
|
||||
return memory_cost, memory_cost_forward_activation, memory_cost_backward_activation, memory_cost_backward_weight
|
||||
|
||||
def _generate_strategy_with_dim_partition(self, dim_partition):
|
||||
dim_partition_dict_for_input = dim_partition
|
||||
sharding_spec_for_input = self._generate_sharding_spec(self.input_data, dim_partition_dict_for_input)
|
||||
|
||||
dim_partition_dict_for_weight = {}
|
||||
sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight)
|
||||
|
||||
dim_partition_dict_for_output = dim_partition
|
||||
sharding_spec_for_output = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output)
|
||||
|
||||
name = f'{sharding_spec_for_output.sharding_sequence} = {sharding_spec_for_input.sharding_sequence} x {sharding_spec_for_weight.sharding_sequence}'
|
||||
# generate resharding cost for this strategy
|
||||
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input])
|
||||
|
||||
total_sharding_size = generate_sharding_size(dim_partition, self.device_mesh)
|
||||
# compute the computation cost of this strategy
|
||||
compute_cost = self._generate_compute_cost(total_sharding_size)
|
||||
|
||||
# compute the memory cost of this strategy
|
||||
sharding_size_forward = generate_sharding_size(dim_partition_dict_for_input, self.device_mesh)
|
||||
sharding_size_backward_activation = generate_sharding_size(dim_partition_dict_for_output, self.device_mesh)
|
||||
sharding_size_weight = generate_sharding_size(dim_partition_dict_for_weight, self.device_mesh)
|
||||
memory_cost, _, _, memory_cost_backward_weight = self._generate_memory_cost(sharding_size_forward,
|
||||
sharding_size_backward_activation,
|
||||
sharding_size_weight)
|
||||
|
||||
total_mesh_dim_list = []
|
||||
for mesh_dim_list in dim_partition.values():
|
||||
total_mesh_dim_list.extend(mesh_dim_list)
|
||||
|
||||
# This strategy do not need to do all_reduce operation for activation
|
||||
communication_cost_forward_activation = 0
|
||||
communication_cost_backward_activation = 0
|
||||
if len(total_mesh_dim_list) == 1:
|
||||
communication_cost_backward_weight = self.device_mesh.all_reduce_cost(memory_cost_backward_weight,
|
||||
total_mesh_dim_list[0])
|
||||
else:
|
||||
assert len(total_mesh_dim_list) == 2, f'temporally we just support 2d device mesh.'
|
||||
communication_cost_backward_weight = self.device_mesh.flatten_device_mesh.all_reduce_cost(
|
||||
memory_cost_backward_weight, 0)
|
||||
communication_cost = communication_cost_forward_activation + communication_cost_backward_activation + communication_cost_backward_weight
|
||||
|
||||
sharding_strategies = ShardingStrategy(name,
|
||||
output_sharding_spec=sharding_spec_for_output,
|
||||
compute_cost=compute_cost,
|
||||
communication_cost=communication_cost,
|
||||
memory_cost=memory_cost,
|
||||
resharding_costs=resharding_costs,
|
||||
input_shardings=(sharding_spec_for_input, sharding_spec_for_weight))
|
||||
|
||||
self.strategies_vector.append(sharding_strategies)
|
||||
|
||||
@ignore_sharding_exception
|
||||
def split_input_batch_single_mesh_dim(self, mesh_dim_0):
|
||||
batch_dimension_length = self.input_data.dim() - self.weight.dim()
|
||||
dim_partition_list = enumerate_all_possible_1d_sharding(mesh_dim_0, batch_dimension_length)
|
||||
for dim_partition in dim_partition_list:
|
||||
self._generate_strategy_with_dim_partition(dim_partition)
|
||||
|
||||
@ignore_sharding_exception
|
||||
def split_input_batch_both_mesh_dim(self, mesh_dim_0, mesh_dim_1):
|
||||
batch_dimension_length = self.input_data.dim() - self.weight.dim()
|
||||
dim_partition_list = enumerate_all_possible_2d_sharding(mesh_dim_0, mesh_dim_1, batch_dimension_length)
|
||||
for dim_partition in dim_partition_list:
|
||||
self._generate_strategy_with_dim_partition(dim_partition)
|
||||
|
||||
@ignore_sharding_exception
|
||||
def non_split(self):
|
||||
name = f'RR = RR x R'
|
||||
|
||||
dim_partition_dict_for_input = {}
|
||||
sharding_spec_for_input = self._generate_sharding_spec(self.input_data, dim_partition_dict_for_input)
|
||||
|
||||
dim_partition_dict_for_weight = {}
|
||||
sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight)
|
||||
|
||||
dim_partition_dict_for_output = {}
|
||||
sharding_spec_for_output = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output)
|
||||
|
||||
# generate resharding cost for this strategy
|
||||
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input])
|
||||
|
||||
total_sharding_size = 1
|
||||
# compute the computation cost of this strategy
|
||||
compute_cost = self._generate_compute_cost(total_sharding_size)
|
||||
|
||||
# compute the memory cost of this strategy
|
||||
sharding_size_forward = 1
|
||||
sharding_size_backward_activation = 1
|
||||
sharding_size_weight = 1
|
||||
memory_cost, _, _, _ = self._generate_memory_cost(sharding_size_forward, sharding_size_backward_activation,
|
||||
sharding_size_weight)
|
||||
|
||||
# This strategy do not need to do all_reduce operation
|
||||
communication_cost = 0
|
||||
sharding_strategies = ShardingStrategy(name,
|
||||
output_sharding_spec=sharding_spec_for_output,
|
||||
compute_cost=compute_cost,
|
||||
communication_cost=communication_cost,
|
||||
memory_cost=memory_cost,
|
||||
resharding_costs=resharding_costs,
|
||||
input_shardings=(sharding_spec_for_input, sharding_spec_for_weight))
|
||||
|
||||
self.strategies_vector.append(sharding_strategies)
|
||||
|
||||
def register_strategy(self) -> StrategiesVector:
|
||||
'''
|
||||
Generate every possible strategies for a BatchNorm node, and record all strategies into the strategies_vector.
|
||||
|
||||
Example:
|
||||
norm_handler = BatchNormHandler(node, strategies_vector,
|
||||
self.shape_consistency_manager)
|
||||
norm_handler.register_strategy()
|
||||
for strategy in norm_handler.strategies_vector:
|
||||
print(f'{strategy.name}, computation_cost: {strategy.compute_cost}, memory_cost: {strategy.memory_cost}')
|
||||
|
||||
Output:
|
||||
RS0 = RS0 x S0, computation_cost: 131072, memory_cost: 524288.0
|
||||
RS1 = RS1 x S1, computation_cost: 131072, memory_cost: 524288.0
|
||||
RR = RR x R, computation_cost: 262144, memory_cost: 1048576
|
||||
RS01 = RS01 x S01, computation_cost: 65536, memory_cost: 262144.0
|
||||
'''
|
||||
|
||||
# SR = SR x R with single mesh dim on batch dimensions
|
||||
self.split_input_batch_single_mesh_dim(0)
|
||||
self.split_input_batch_single_mesh_dim(1)
|
||||
|
||||
# SR = SR x R with both mesh dims on batch dimensions
|
||||
self.split_input_batch_both_mesh_dim(0, 1)
|
||||
|
||||
# RR = RR x R
|
||||
self.non_split()
|
||||
|
||||
return self.strategies_vector
|
|
@ -1,149 +0,0 @@
|
|||
from abc import ABC, abstractmethod
|
||||
from typing import Dict, List
|
||||
from webbrowser import Opera
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.fx.node import Node
|
||||
|
||||
from colossalai.auto_parallel.tensor_shard.deprecated.constants import *
|
||||
from colossalai.device.device_mesh import DeviceMesh
|
||||
from colossalai.tensor.sharding_spec import ShardingSpec
|
||||
|
||||
from .._utils import generate_resharding_costs, generate_sharding_spec
|
||||
from ..sharding_strategy import StrategiesVector
|
||||
|
||||
__all__ = ['OperatorHandler']
|
||||
|
||||
|
||||
class OperatorHandler(ABC):
|
||||
'''
|
||||
The OperatorHandler is an abstract class used to generate every possible strategies for an operator node.
|
||||
|
||||
Args:
|
||||
node (Node): the input node in node argument list.
|
||||
device_mesh (DeviceMesh): A logical view of a physical mesh.
|
||||
strategies_vector (StrategiesVector): all the strategies generated in this handler will be recorded into the strategies_vector.
|
||||
handle_backward (Optional[bool]): whether to consider the backward pass. The default value is True. False can be used for inference.
|
||||
'''
|
||||
|
||||
def __init__(self,
|
||||
node: Node,
|
||||
device_mesh: DeviceMesh,
|
||||
strategies_vector: StrategiesVector,
|
||||
handle_backward: bool = True):
|
||||
self.node = node
|
||||
self.predecessor_node = list(node._input_nodes.keys())
|
||||
self.successor_node = list(node.users.keys())
|
||||
self.device_mesh = device_mesh
|
||||
self.strategies_vector = strategies_vector
|
||||
self.handle_backward = handle_backward
|
||||
|
||||
# find the module and its parameters associated with this node
|
||||
# this can be used to compute the compute/communication/sharding cost
|
||||
if self.node.op == 'call_module':
|
||||
module = node.graph.owning_module.get_submodule(node.target)
|
||||
named_parameters = list(module.named_parameters(recurse=False))
|
||||
# convert named parameters from list to dict
|
||||
named_parameters = {k: v for k, v in named_parameters}
|
||||
elif self.node.op == 'call_function' and self.node.target not in NON_PARAM_FUNC_OP:
|
||||
module = None
|
||||
parameters = list(self.node.args)[1]
|
||||
if isinstance(parameters, Node):
|
||||
named_parameters = {'weight': parameters._meta_data}
|
||||
else:
|
||||
named_parameters = {}
|
||||
else:
|
||||
module = None
|
||||
named_parameters = None
|
||||
self.module = module
|
||||
self.module_named_parameters = named_parameters
|
||||
|
||||
@abstractmethod
|
||||
def register_strategy(self) -> StrategiesVector:
|
||||
"""
|
||||
Register
|
||||
"""
|
||||
pass
|
||||
|
||||
def _generate_memory_cost(self, dim_partition_dict_for_output, dim_partition_dict_for_weight,
|
||||
sharding_spec_for_input):
|
||||
'''
|
||||
Compute the memory cost per device with this specific strategy.
|
||||
|
||||
Argument:
|
||||
dim_partition_dict_for_output(List[int]): The key is the dimension of output to be sharded,
|
||||
and the value of the key decribe which logical axis will be sharded in that dimension.
|
||||
dim_partition_dict_for_weight(List[int]): The key is the dimension of weight to be sharded,
|
||||
and the value of the key decribe which logical axis will be sharded in that dimension.
|
||||
Return:
|
||||
total_memory_cost(float): total memory cost per device with this specific strategy
|
||||
activation_cost(float): the memory cost of activation per device with this specific strategy
|
||||
weight_memory_cost(float): the memory cost of weight per device with this specific strategy
|
||||
'''
|
||||
# compute the size of one element with specific dtype
|
||||
dtype = self.input_data.dtype
|
||||
size_per_elem_bytes = torch.tensor([], dtype=dtype).element_size()
|
||||
|
||||
# compute the memory cost of activation
|
||||
activation_numel = self.output_data.numel()
|
||||
output_mesh_dims = []
|
||||
for sharding_dim, mesh_dims in dim_partition_dict_for_output.items():
|
||||
output_mesh_dims.extend(mesh_dims)
|
||||
activation_sharding_size = 1
|
||||
for mesh_dim in output_mesh_dims:
|
||||
activation_sharding_size *= self.device_mesh.shape[mesh_dim]
|
||||
activation_memory_cost = activation_numel / activation_sharding_size * size_per_elem_bytes
|
||||
|
||||
# compute the memory cost of weight
|
||||
weight_numel = self.weight.numel()
|
||||
weight_sharding_size = 1
|
||||
weight_mesh_dims = []
|
||||
for sharding_dim, mesh_dims in dim_partition_dict_for_weight.items():
|
||||
weight_mesh_dims.extend(mesh_dims)
|
||||
for mesh_dim in weight_mesh_dims:
|
||||
weight_sharding_size *= self.device_mesh.shape[mesh_dim]
|
||||
weight_memory_cost = weight_numel / weight_sharding_size * size_per_elem_bytes
|
||||
|
||||
# compute the memory cost of input grad
|
||||
input_grad_numel = self.input_data.numel()
|
||||
input_grad_sharding_size = 1
|
||||
input_grad_mesh_dims = []
|
||||
for sharding_dim, mesh_dims in sharding_spec_for_input.items():
|
||||
input_grad_mesh_dims.extend(mesh_dims)
|
||||
for mesh_dim in input_grad_mesh_dims:
|
||||
input_grad_sharding_size *= self.device_mesh.shape[mesh_dim]
|
||||
input_grad_memory_cost = input_grad_numel / input_grad_sharding_size * size_per_elem_bytes
|
||||
|
||||
memory_cost_forward = activation_memory_cost + weight_memory_cost
|
||||
memory_cost_backward = input_grad_memory_cost + weight_memory_cost
|
||||
|
||||
return (memory_cost_forward,
|
||||
memory_cost_backward), activation_memory_cost, weight_memory_cost, input_grad_memory_cost
|
||||
|
||||
def _generate_resharding_costs(self, sharding_specs):
|
||||
# The resharding_cost of weight is counted due to sharing weight cases.
|
||||
if hasattr(self.node._meta_data, 'dtype'):
|
||||
dtype = self.node._meta_data.dtype
|
||||
else:
|
||||
assert isinstance(self.node._meta_data,
|
||||
tuple), f'Only torch.Tensor, torch.fx.Node and tuple of torch.Tensor is expected'
|
||||
dtype = self.node._meta_data[0].dtype
|
||||
|
||||
nodes = self.predecessor_node
|
||||
return generate_resharding_costs(nodes=nodes,
|
||||
sharding_specs=sharding_specs,
|
||||
count_backward=self.handle_backward,
|
||||
dtype=dtype)
|
||||
|
||||
def _generate_sharding_spec(self, input_: torch.Tensor, dim_partition_dict: Dict[int, List[int]]) -> ShardingSpec:
|
||||
return generate_sharding_spec(input_=input_,
|
||||
device_mesh=self.device_mesh,
|
||||
dim_partition_dict=dim_partition_dict)
|
||||
|
||||
@abstractmethod
|
||||
def _generate_compute_cost(self, *args, **kwargs):
|
||||
"""
|
||||
Compute the flops involved in the node.
|
||||
"""
|
||||
pass
|
|
@ -1,89 +0,0 @@
|
|||
import colorsys
|
||||
import math
|
||||
import warnings
|
||||
from copy import deepcopy
|
||||
|
||||
import torch
|
||||
|
||||
from colossalai.auto_parallel.tensor_shard.deprecated._utils import ignore_sharding_exception
|
||||
from colossalai.auto_parallel.tensor_shard.deprecated.sharding_strategy import ShardingStrategy, StrategiesVector
|
||||
from colossalai.tensor.shape_consistency import ShapeConsistencyManager
|
||||
from colossalai.tensor.sharding_spec import ShardingSpec
|
||||
|
||||
from ..constants import INFINITY_COST
|
||||
from .operator_handler import OperatorHandler
|
||||
|
||||
|
||||
class ReshapeHandler(OperatorHandler):
|
||||
"""
|
||||
An OperatorHandler which deals with the sharding strategies of Reshape Operator, such as torch.reshape, torch.flatten, etc.
|
||||
"""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.input_data = self.predecessor_node[0]._meta_data
|
||||
self.output_data = self.node._meta_data
|
||||
|
||||
def _generate_compute_cost(self, *args, **kwargs):
|
||||
return super()._generate_compute_cost(*args, **kwargs)
|
||||
|
||||
@ignore_sharding_exception
|
||||
def register_strategy(self):
|
||||
# TODO: add strategies with more output sharding specs other than only fully replicated.
|
||||
input_node = self.strategies_vector.predecessor_nodes[0]
|
||||
# For reshape function, to keep the computing correctness we keep the sharding
|
||||
# spec of input is fully replicated. In addition, we will keep the output in
|
||||
# replica status and let the successor node choose the way to resharding the
|
||||
# output node. Therefore, the different strategies of input node with same
|
||||
# output sharding spec will generate same strategy for reshape function.
|
||||
sharding_spec_checklist = []
|
||||
for strategy in input_node.strategies_vector:
|
||||
# It looks a little bit confusing, the input of the processing node
|
||||
# is the output of the input_node.
|
||||
input_sharding_spec = strategy.output_sharding_spec
|
||||
assert isinstance(input_sharding_spec, ShardingSpec), f'The input node should NOT be a tuple of tensor.'
|
||||
if input_sharding_spec in sharding_spec_checklist:
|
||||
continue
|
||||
sharding_spec_checklist.append(input_sharding_spec)
|
||||
dim_partition_dict_for_output = {}
|
||||
if isinstance(self.output_data, tuple):
|
||||
dim_partition_dict_for_output = [{} for _ in range(len(self.output_data))]
|
||||
try:
|
||||
if isinstance(self.output_data, tuple):
|
||||
output_sharding_spec = []
|
||||
for output, dim_partition_dict in zip(self.output_data, dim_partition_dict_for_output):
|
||||
output_sharding_spec.append(self._generate_sharding_spec(output, dim_partition_dict))
|
||||
else:
|
||||
output_sharding_spec = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output)
|
||||
except AssertionError as e:
|
||||
warnings.warn(f'{e}')
|
||||
continue
|
||||
name = f'{input_sharding_spec.sharding_sequence} -> FULLY REPLICATED'
|
||||
# TODO: use meta_info_prop to profile memory cost and compute cost
|
||||
compute_cost = 0
|
||||
# consider node._meta_data is in type of tuple
|
||||
memory_cost = 0
|
||||
|
||||
# compute the communication cost, in reshape op, the communication happens during casting the input sharding spec to fully replicating.
|
||||
dim_partition_dict_for_replicate_input = {}
|
||||
replicate_input_sharding_spec = self._generate_sharding_spec(self.input_data,
|
||||
dim_partition_dict_for_replicate_input)
|
||||
# shape consistency manager is a singleton class
|
||||
shape_consistency_manager = ShapeConsistencyManager()
|
||||
_, _, communication_cost = shape_consistency_manager.shape_consistency(input_sharding_spec,
|
||||
replicate_input_sharding_spec)
|
||||
communication_cost = communication_cost["total"]
|
||||
|
||||
# generate resharding cost
|
||||
resharding_costs = self._generate_resharding_costs([input_sharding_spec])
|
||||
|
||||
# to prevent the resharding happening, set their resharding cost to inf.
|
||||
resharding_costs[input_node] = [0 if cost == 0 else INFINITY_COST for cost in resharding_costs[input_node]]
|
||||
sharding_strategy = ShardingStrategy(name,
|
||||
output_sharding_spec,
|
||||
compute_cost=compute_cost,
|
||||
communication_cost=communication_cost,
|
||||
memory_cost=memory_cost,
|
||||
resharding_costs=resharding_costs,
|
||||
input_shardings=[input_sharding_spec])
|
||||
self.strategies_vector.append(sharding_strategy)
|
|
@ -1,45 +0,0 @@
|
|||
from dataclasses import dataclass
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import List, Dict
|
||||
from colossalai.device.device_mesh import DeviceMesh
|
||||
|
||||
__all__ = ['IntermediateStrategy', 'StrategyGenerator']
|
||||
|
||||
|
||||
@dataclass
|
||||
class IntermediateStrategy:
|
||||
"""
|
||||
IntermediateStrategy contains the subset of meta information for ShardingStrategy. It is
|
||||
to store the essential information regarding the tensor sharding and leave other meta information to OperatorHandler.
|
||||
|
||||
Args:
|
||||
name (str): name of the sharding strategy.
|
||||
dim_partition_dict (Dict[Dict]): stores the tensor to dim partition dict mapping.
|
||||
all_reduce_dims (List[int]): stores the dimensions which require an all-reduce operation.
|
||||
"""
|
||||
name: str
|
||||
dim_partition_dict: Dict[str, Dict[int, List[int]]]
|
||||
all_reduce_axis: List[int] = None
|
||||
|
||||
|
||||
class StrategyGenerator(ABC):
|
||||
"""
|
||||
StrategyGenerator is used to generate the same group of sharding strategies.
|
||||
"""
|
||||
|
||||
def __init__(self, device_mesh: DeviceMesh):
|
||||
self.device_mesh = device_mesh
|
||||
|
||||
@abstractmethod
|
||||
def generate(self) -> List[IntermediateStrategy]:
|
||||
"""
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def validate(self, *args, **kwargs) -> bool:
|
||||
"""
|
||||
Validate if the operands are of desired shape.
|
||||
If True, means this generator can be used for the current operation.
|
||||
"""
|
||||
pass
|
|
@ -1,88 +0,0 @@
|
|||
import math
|
||||
import operator
|
||||
import warnings
|
||||
from copy import deepcopy
|
||||
from functools import reduce
|
||||
from typing import Dict, List
|
||||
|
||||
import torch
|
||||
from colossalai.auto_parallel.tensor_shard.deprecated._utils import \
|
||||
ignore_sharding_exception
|
||||
from colossalai.auto_parallel.tensor_shard.deprecated.constants import \
|
||||
INFINITY_COST
|
||||
from colossalai.auto_parallel.tensor_shard.deprecated.sharding_strategy import (ShardingStrategy, StrategiesVector)
|
||||
from colossalai.tensor.shape_consistency import ShapeConsistencyManager
|
||||
from colossalai.tensor.sharding_spec import ShardingSpec
|
||||
|
||||
from .operator_handler import OperatorHandler
|
||||
|
||||
__all__ = ['UnaryElementwiseHandler']
|
||||
|
||||
|
||||
class UnaryElementwiseHandler(OperatorHandler):
|
||||
"""
|
||||
An OperatorHandler which deals with the sharding strategies of UnaryElementwiseOp.
|
||||
"""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
if self.node.op == 'call_module':
|
||||
target = self.node.target
|
||||
submod = self.node.graph.owning_module.get_submodule(target)
|
||||
submod_type = type(submod)
|
||||
if submod_type == torch.nn.Dropout:
|
||||
print(f'predecessor nodes of dropout node are {self.predecessor_node}')
|
||||
input_nodes_len = 0
|
||||
for check_node in self.predecessor_node:
|
||||
if isinstance(check_node._meta_data, torch.Tensor):
|
||||
input_nodes_len += 1
|
||||
assert input_nodes_len == 1, f'Temporally, we just support single input element-wise op, node name is {self.node}, node args is {self.node.args}.'
|
||||
self.input_data = self.predecessor_node[0]._meta_data
|
||||
self.input_node = self.predecessor_node[0]
|
||||
self.output_data = self.node._meta_data
|
||||
|
||||
def _generate_compute_cost(self, *args, **kwargs):
|
||||
return super()._generate_compute_cost(*args, **kwargs)
|
||||
|
||||
@ignore_sharding_exception
|
||||
def register_strategy(self):
|
||||
# TODO: integrate element-wise func and module together
|
||||
# create sharding strategy for element-wise function
|
||||
|
||||
# For element-wise function, we keep the sharding spec of output node same as
|
||||
# the input. Therefore, the different strategies of input node with same
|
||||
# output sharding spec will generate same strategy for element-wise function.
|
||||
|
||||
for index, strategy in enumerate(self.input_node.strategies_vector):
|
||||
# It looks a little bit confusing, the input of the processing node
|
||||
# is the output of the input_node.
|
||||
input_sharding_spec = strategy.output_sharding_spec
|
||||
assert isinstance(input_sharding_spec, ShardingSpec), f'The input node should NOT be a tuple of tensor.'
|
||||
|
||||
dim_partition_dict = deepcopy(input_sharding_spec.dim_partition_dict)
|
||||
try:
|
||||
output_sharding_spec = self._generate_sharding_spec(self.output_data, dim_partition_dict)
|
||||
except AssertionError as e:
|
||||
warnings.warn(f'{e}')
|
||||
continue
|
||||
# add index into name to pass the duplicated check
|
||||
# we keep same strategies with different name for node merging, and it will not increase the searching space,
|
||||
# because in solver, this node will be merged into other nodes, and solver will not create a new variable for this node.
|
||||
name = f'{input_sharding_spec.sharding_sequence} -> {output_sharding_spec.sharding_sequence}_{index}'
|
||||
# TODO: use meta_info_prop to profile memory cost and compute cost
|
||||
compute_cost = self.output_data.numel()
|
||||
memory_cost = 0
|
||||
|
||||
resharding_costs = self._generate_resharding_costs([input_sharding_spec])
|
||||
|
||||
# to prevent the resharding happening, set their resharding cost to inf.
|
||||
resharding_costs[self.input_node] = [
|
||||
0 if cost == 0 else INFINITY_COST for cost in resharding_costs[self.input_node]
|
||||
]
|
||||
sharding_strategy = ShardingStrategy(name,
|
||||
output_sharding_spec,
|
||||
compute_cost=compute_cost,
|
||||
memory_cost=memory_cost,
|
||||
resharding_costs=resharding_costs,
|
||||
input_shardings=[input_sharding_spec])
|
||||
self.strategies_vector.append(sharding_strategy)
|
|
@ -1,186 +0,0 @@
|
|||
import operator
|
||||
import warnings
|
||||
from copy import deepcopy
|
||||
from functools import reduce
|
||||
from typing import Dict, List
|
||||
|
||||
import torch
|
||||
|
||||
from colossalai.auto_parallel.tensor_shard.deprecated._utils import (enumerate_all_possible_1d_sharding,
|
||||
enumerate_all_possible_2d_sharding,
|
||||
ignore_sharding_exception)
|
||||
from colossalai.auto_parallel.tensor_shard.deprecated.sharding_strategy import (ShardingStrategy, StrategiesVector)
|
||||
from colossalai.tensor.shape_consistency import ShapeConsistencyManager
|
||||
from colossalai.tensor.sharding_spec import ShardingSpec
|
||||
|
||||
from .operator_handler import OperatorHandler
|
||||
|
||||
__all__ = ['WhereHandler']
|
||||
|
||||
|
||||
class WhereHandler(OperatorHandler):
|
||||
"""
|
||||
An OperatorHandler which deals with the sharding strategies of torch.where.
|
||||
"""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
# TODO: x or y could be scalar
|
||||
super().__init__(*args, **kwargs)
|
||||
assert len(self.predecessor_node) == 3
|
||||
self.condition_data = self.predecessor_node[0]._meta_data
|
||||
self.x_data = self.predecessor_node[1]._meta_data
|
||||
self.y_data = self.predecessor_node[2]._meta_data
|
||||
self.condition = self.predecessor_node[0]
|
||||
self.x = self.predecessor_node[1]
|
||||
self.y = self.predecessor_node[2]
|
||||
self.output_data = self.node._meta_data
|
||||
|
||||
def _generate_sharding_spec(self, input_: torch.Tensor, dim_partition_dict: Dict[int, List[int]]) -> ShardingSpec:
|
||||
shape = list(input_.shape)
|
||||
|
||||
# padding the shape to the same length as output_data
|
||||
while len(shape) < self.output_data.dim():
|
||||
shape.insert(0, 1)
|
||||
shape = torch.Size(shape)
|
||||
|
||||
# if the sharding happens on a size one dimension, we should record it as R.
|
||||
processed_dim_partition_dict = deepcopy(dim_partition_dict)
|
||||
for dim_index, _ in dim_partition_dict.items():
|
||||
if shape[dim_index] == 1:
|
||||
processed_dim_partition_dict.pop(dim_index)
|
||||
for dim_index, sharding_index_list in processed_dim_partition_dict.items():
|
||||
sharding_list = [self.device_mesh.mesh_shape[sharding_index] for sharding_index in sharding_index_list]
|
||||
sharding_size = reduce(operator.mul, sharding_list, 1)
|
||||
assert shape[
|
||||
dim_index] % sharding_size == 0, f'we cannot shard the {dim_index} dimension of tensor into {sharding_size} partitions.'
|
||||
sharding_spec = ShardingSpec(device_mesh=self.device_mesh,
|
||||
entire_shape=shape,
|
||||
dim_partition_dict=processed_dim_partition_dict)
|
||||
|
||||
return sharding_spec
|
||||
|
||||
def _generate_compute_cost(self, total_sharding_size):
|
||||
lhs_matrix_shape = self.lhs_data.shape[-2:]
|
||||
rhs_matrix_shape = self.rhs_data.shape[-2:]
|
||||
batch_dimensions_shape = self.output_data.shape[:-2]
|
||||
batch_dimensions_product = reduce(operator.mul, batch_dimensions_shape, 1)
|
||||
compute_cost = reduce(
|
||||
operator.mul, lhs_matrix_shape) * rhs_matrix_shape[0] * batch_dimensions_product * 2 / total_sharding_size
|
||||
return compute_cost
|
||||
|
||||
def _generate_resharding_costs(self, sharding_specs):
|
||||
# The resharding_cost of weight is counted due to sharing weight cases.
|
||||
dtype = self.node._meta_data.dtype
|
||||
nodes = self.predecessor_node
|
||||
resharding_costs = {}
|
||||
size_per_elem_bytes = torch.tensor([], dtype=dtype).element_size()
|
||||
|
||||
# shape consistency manager is a singleton class
|
||||
shape_consistency_manager = ShapeConsistencyManager()
|
||||
|
||||
for input_node, input_spec in zip(nodes, sharding_specs):
|
||||
resharding_costs[input_node] = []
|
||||
for strategy in input_node.strategies_vector:
|
||||
input_sharding_spec = strategy.output_sharding_spec
|
||||
assert isinstance(input_sharding_spec, ShardingSpec), f'The input node should NOT be a tuple of tensor.'
|
||||
# if the input shape is smaller than the target input, we will fill the input to the same length as target.
|
||||
# Then, use the padded input sharding spec to compute the resharding cost.
|
||||
if len(input_sharding_spec.entire_shape) < len(input_spec.entire_shape):
|
||||
new_entire_shape = list(input_sharding_spec.entire_shape)
|
||||
while len(new_entire_shape) < len(input_spec.entire_shape):
|
||||
new_entire_shape.insert(0, 1)
|
||||
new_entire_shape = torch.Size(new_entire_shape)
|
||||
new_device_mesh = input_sharding_spec.device_mesh
|
||||
new_dim_partition_dict = input_sharding_spec.dim_partition_dict
|
||||
input_sharding_spec = ShardingSpec(device_mesh=new_device_mesh,
|
||||
entire_shape=new_entire_shape,
|
||||
dim_partition_dict=new_dim_partition_dict)
|
||||
|
||||
# compute the resharding cost
|
||||
_, _, total_resharding_cost = shape_consistency_manager.shape_consistency(
|
||||
input_sharding_spec, input_spec)
|
||||
total_resharding_cost = total_resharding_cost['total']
|
||||
# we need multiply the size of elem dtype to get correct communication cost
|
||||
resharding_cost = total_resharding_cost * size_per_elem_bytes
|
||||
resharding_costs[input_node].append(resharding_cost)
|
||||
|
||||
return resharding_costs
|
||||
|
||||
def _convert_partition_dict_to_sharding_spec(self, dim_partition_list):
|
||||
|
||||
sharding_spec_list = []
|
||||
check_duplicated_list = []
|
||||
for output_dim_partition_dict in dim_partition_list:
|
||||
try:
|
||||
output_sharding_spec = self._generate_sharding_spec(self.output_data, output_dim_partition_dict)
|
||||
except AssertionError as e:
|
||||
warnings.warn(f'{e}')
|
||||
break
|
||||
sharding_seq = output_sharding_spec.sharding_sequence
|
||||
if sharding_seq not in check_duplicated_list:
|
||||
check_duplicated_list.append(sharding_seq)
|
||||
sharding_spec_list.append(output_sharding_spec)
|
||||
|
||||
return sharding_spec_list
|
||||
|
||||
def _enumerate_all_possible_output(self, mesh_dim_0, mesh_dim_1):
|
||||
# use mesh_dim_0, mesh_dim_1 instead of constant 0, 1 in here for N-D device mesh scaliablity.
|
||||
|
||||
output_dim_partition_list = []
|
||||
dim_size = self.output_data.dim()
|
||||
# enumerate all the 2D sharding cases
|
||||
sharding_list_2d = enumerate_all_possible_2d_sharding(mesh_dim_0, mesh_dim_1, dim_size)
|
||||
output_dim_partition_list.extend(sharding_list_2d)
|
||||
|
||||
# enumerate all the 1D sharding cases
|
||||
sharding_list_1d_on_dim_0 = enumerate_all_possible_1d_sharding(mesh_dim_0, dim_size)
|
||||
output_dim_partition_list.extend(sharding_list_1d_on_dim_0)
|
||||
sharding_list_1d_on_dim_1 = enumerate_all_possible_1d_sharding(mesh_dim_1, dim_size)
|
||||
output_dim_partition_list.extend(sharding_list_1d_on_dim_1)
|
||||
|
||||
# add empty dict for fully replicated case
|
||||
output_dim_partition_list.append({})
|
||||
output_sharding_spec_list = self._convert_partition_dict_to_sharding_spec(output_dim_partition_list)
|
||||
|
||||
return output_sharding_spec_list
|
||||
|
||||
@ignore_sharding_exception
|
||||
def _register_strategy(self, output_sharding_spec):
|
||||
dim_partition_dict_for_input = output_sharding_spec.dim_partition_dict
|
||||
sharding_spec_for_condition = self._generate_sharding_spec(self.condition_data, dim_partition_dict_for_input)
|
||||
sharding_spec_for_x = self._generate_sharding_spec(self.x_data, dim_partition_dict_for_input)
|
||||
sharding_spec_for_y = self._generate_sharding_spec(self.y_data, dim_partition_dict_for_input)
|
||||
|
||||
name = f'{output_sharding_spec.sharding_sequence} = {sharding_spec_for_condition.sharding_sequence} x {sharding_spec_for_x.sharding_sequence} x {sharding_spec_for_y.sharding_sequence}'
|
||||
dim_partition_dict_for_output = output_sharding_spec.dim_partition_dict
|
||||
|
||||
# generate resharding cost for this strategy
|
||||
resharding_costs = self._generate_resharding_costs(
|
||||
[sharding_spec_for_condition, sharding_spec_for_x, sharding_spec_for_y])
|
||||
|
||||
# compute the computation cost of this strategy
|
||||
sharding_dims = []
|
||||
for mesh_dims in dim_partition_dict_for_output.values():
|
||||
for mesh_dim in mesh_dims:
|
||||
sharding_dims.append(self.device_mesh.shape[mesh_dim])
|
||||
sharding_size = reduce(operator.mul, sharding_dims, 1)
|
||||
memory_cost = self.output_data.numel() / sharding_size
|
||||
compute_cost = memory_cost
|
||||
communication_cost = 0
|
||||
|
||||
sharding_strategies = ShardingStrategy(name,
|
||||
output_sharding_spec=output_sharding_spec,
|
||||
compute_cost=compute_cost,
|
||||
communication_cost=communication_cost,
|
||||
memory_cost=memory_cost,
|
||||
resharding_costs=resharding_costs,
|
||||
input_shardings=(sharding_spec_for_condition, sharding_spec_for_x,
|
||||
sharding_spec_for_y))
|
||||
|
||||
self.strategies_vector.append(sharding_strategies)
|
||||
|
||||
def register_strategy(self) -> StrategiesVector:
|
||||
MESH_DIM_LIST = [0, 1]
|
||||
output_sharding_specs = self._enumerate_all_possible_output(MESH_DIM_LIST[0], MESH_DIM_LIST[1])
|
||||
for output_sharding_spec in output_sharding_specs:
|
||||
self._register_strategy(output_sharding_spec)
|
|
@ -1,11 +0,0 @@
|
|||
from dataclasses import dataclass
|
||||
|
||||
__all__ = ['SolverOptions']
|
||||
|
||||
|
||||
@dataclass
|
||||
class SolverOptions:
|
||||
"""
|
||||
SolverOptions is a dataclass used to configure the preferences for the parallel execution plan search.
|
||||
"""
|
||||
fast: bool = False
|
|
@ -1,91 +0,0 @@
|
|||
from copy import deepcopy
|
||||
from dataclasses import dataclass
|
||||
from abc import ABC, abstractmethod
|
||||
from enum import Enum
|
||||
import operator
|
||||
import torch
|
||||
from functools import reduce
|
||||
|
||||
from colossalai.device.device_mesh import DeviceMesh
|
||||
from colossalai.tensor.sharding_spec import ShardingSpec
|
||||
from colossalai.tensor.shape_consistency import CollectiveCommPattern, CommSpec
|
||||
from typing import Dict, List, Union, Tuple, Any
|
||||
from torch.fx.node import Node
|
||||
from .constants import *
|
||||
|
||||
__all__ = ['ShardingStrategy', 'StrategiesVector']
|
||||
|
||||
|
||||
@dataclass
|
||||
class ShardingStrategy:
|
||||
'''
|
||||
ShardingStrategy is a structure containing sharding strategies of inputs and output of this node
|
||||
and costs information using in solver.
|
||||
|
||||
Argument:
|
||||
name(str): express the sharding strategies in string, such as 'S0S1 = S0R x RS1'.
|
||||
output_sharding_spec(ShardingSpec): ShardingSpec of the output node.
|
||||
compute_cost(float): Computation cost to complete this strategy.(default to 0)
|
||||
communication_cost(float): Communication cost to complete this strategy.(default to 0)
|
||||
memory_cost(float): Memory cost of the output node using this strategy.(default to 0)
|
||||
resharding_costs(Dict[int, List[float]]): resharding_cost[i][j] means the cost of i-th argument in the output node argument list
|
||||
with j-th strategy in its strategies_vector transforms to sharding spec wanted in this
|
||||
strategy.(default to None)
|
||||
input_shardings(List(ShardingSpec)): The ShardingSpecs of the input nodes.
|
||||
'''
|
||||
|
||||
name: str
|
||||
# TODO: output of fx node,such as torch.var_mean, could be a tuple, so we cannot simply suppose it is a tensor.
|
||||
output_sharding_spec: Union[ShardingSpec, Tuple[ShardingSpec]]
|
||||
compute_cost: float = 0.
|
||||
communication_cost: float = 0.
|
||||
memory_cost: float = 0.
|
||||
resharding_costs: Dict[Node, List[float]] = None
|
||||
# sometimes the input node could be a tuple of nodes, but most of op won't accept tuple of node as input.
|
||||
# Therefore, we could process them at the specific op(operator.getitem)
|
||||
input_shardings: List[ShardingSpec] = None
|
||||
|
||||
|
||||
class StrategiesVector(list):
|
||||
'''
|
||||
Each node in fx graph will have a corresponding StrategiesVector, to store all the possible
|
||||
strategies of the node.
|
||||
|
||||
Argument:
|
||||
node (Node): node for which the list of sharding strategies are generated.
|
||||
'''
|
||||
|
||||
def __init__(self, node: Node):
|
||||
super().__init__()
|
||||
self.node = node
|
||||
# fetch its input and output nodes
|
||||
# TODO: placeholder input nodes
|
||||
self.predecessor_nodes = list(node._input_nodes.keys())
|
||||
if self.node.op == 'output':
|
||||
self.predecessor_nodes = list(node._input_nodes.keys())[:1]
|
||||
self.successor_nodes = list(node.users.keys())
|
||||
|
||||
def check_merge(self):
|
||||
merge_label = False
|
||||
if self.node.op == 'call_module':
|
||||
target = self.node.target
|
||||
root_module = self.node.graph.owning_module
|
||||
submod = root_module.get_submodule(target)
|
||||
submod_type = type(submod)
|
||||
# merge elementwise module node into source nodes
|
||||
# we could merge element-wise op, because the output sharding spec is always same as the input sharding spec.
|
||||
if submod_type in ELEMENTWISE_MODULE_OP:
|
||||
merge_label = True
|
||||
|
||||
if self.node.op == 'call_function':
|
||||
# we could merge element-wise op, because the output sharding spec is always same as the input sharding spec.
|
||||
if self.node.target in ELEMENTWISE_FUNC_OP:
|
||||
merge_label = True
|
||||
# we could merge bcast op if the rhs is a scalar, because it will fall back to the element-wise case.
|
||||
if self.node.target in BCAST_FUNC_OP and len(self.predecessor_nodes) == 1:
|
||||
merge_label = True
|
||||
# we could merge reshape op, because the output sharding spec of reshape op is always fully replicated.
|
||||
if self.node.target in RESHAPE_FUNC_OP:
|
||||
merge_label = True
|
||||
|
||||
return merge_label
|
|
@ -1,469 +0,0 @@
|
|||
import multiprocessing
|
||||
import time
|
||||
import warnings
|
||||
from typing import Dict
|
||||
|
||||
import numpy as np
|
||||
from torch.fx.graph import Graph
|
||||
from torch.fx.node import Node
|
||||
|
||||
from .constants import INFINITY_COST
|
||||
from .cost_graph import CostGraph
|
||||
from .graph_analysis import GraphAnalyser
|
||||
from .strategies_constructor import StrategiesConstructor
|
||||
|
||||
try:
|
||||
import pulp
|
||||
from pulp import LpMinimize, LpProblem, LpStatus, LpVariable, lpDot, lpSum
|
||||
except:
|
||||
warnings.warn(f'please install the pulp')
|
||||
|
||||
__all___ = ['Solver']
|
||||
|
||||
|
||||
class Solver:
|
||||
|
||||
def __init__(self,
|
||||
graph: Graph,
|
||||
strategies_constructor: StrategiesConstructor,
|
||||
cost_graph: CostGraph,
|
||||
graph_analyser: GraphAnalyser,
|
||||
memory_budget: float = -1.0,
|
||||
solution_numbers: int = 1,
|
||||
memory_increasing_coefficient: float = 1.3):
|
||||
'''
|
||||
Solver class will integrate information provided by the components and use ILP solver to find a possible optimal strategies combination for target computing graph.
|
||||
|
||||
Argument:
|
||||
graph: The computing graph to be optimized.
|
||||
strategies_constructor: It will provide all the possible strategies for each node in the computing graph.
|
||||
cost_graph: A graph data structure to simplify the edge cost graph.
|
||||
graph_analyser: graph_analyser will analyse the graph to obtain the variable liveness information, which will be used to generate memory constraints.
|
||||
memory_budget: Memory constraint for the solution.
|
||||
solution_numbers: If solution_numbers is larger than one, solver will us a serious of solutions based on different memory budget.
|
||||
memory_increasing_coefficient: If solution_numbers is larger than one, we will use this coefficient to generate new memory budget.
|
||||
'''
|
||||
self.graph = graph
|
||||
self.strategies_constructor = strategies_constructor
|
||||
self.cost_graph = cost_graph
|
||||
self.graph_analyser = graph_analyser
|
||||
self.leaf_strategies = self.strategies_constructor.leaf_strategies
|
||||
self.nodes = [strategies_vector.node for strategies_vector in self.leaf_strategies]
|
||||
self.strategy_map = self.strategies_constructor.strategy_map
|
||||
self.memory_budget = memory_budget
|
||||
self.solution_numbers = solution_numbers
|
||||
if self.solution_numbers > 1:
|
||||
self.memory_increasing_coefficient = memory_increasing_coefficient
|
||||
else:
|
||||
self.memory_increasing_coefficient = 1
|
||||
self.liveness_list = self.graph_analyser.liveness_analysis()
|
||||
self.node_index_dict = self._generate_node_index_dict()
|
||||
# The last solution vector of auto sharding.
|
||||
self.last_s_val = None
|
||||
# The last objective value of the best ILP solution.
|
||||
self.last_objective = None
|
||||
|
||||
def _recover_merged_node_strategy(self):
|
||||
'''
|
||||
During cost graph constructing, some nodes, such as unary element-wise node or ReshapeOp, were merged into the previous node.
|
||||
Therefore, the index of those strategies are copied from the previous node. This method is used to recover the strategy index of those merged
|
||||
node.
|
||||
'''
|
||||
for node_index, node in enumerate(self.nodes):
|
||||
if node.strategies_vector.check_merge():
|
||||
# the merged node has only one input, and its strategies follow the input sharding strategy
|
||||
input_strategies_vector = node.args[0].strategies_vector
|
||||
input_best_strategy_index = self.last_s_val[node_index - 1]
|
||||
input_sharding_spec = input_strategies_vector[input_best_strategy_index].output_sharding_spec
|
||||
for strategy_index, strategy in enumerate(node.strategies_vector):
|
||||
if strategy.input_shardings[0].sharding_sequence == input_sharding_spec.sharding_sequence:
|
||||
self.last_s_val[node_index] = strategy_index
|
||||
break
|
||||
|
||||
def _generate_node_index_dict(self) -> Dict[Node, int]:
|
||||
node_index_dict = {}
|
||||
for index, strategies_vector in enumerate(self.leaf_strategies):
|
||||
node_index_dict[strategies_vector.node] = index
|
||||
return node_index_dict
|
||||
|
||||
def _prepare_data_for_solver(self):
|
||||
'''
|
||||
Extract information from components for solver.
|
||||
'''
|
||||
node_nums = len(self.leaf_strategies)
|
||||
memory_budget = self.memory_budget
|
||||
|
||||
# prepare strategies_len
|
||||
strategies_len = []
|
||||
for node in self.nodes:
|
||||
strategies_len.append(self.cost_graph.node_lens[node])
|
||||
strategies_len = np.array(strategies_len)
|
||||
|
||||
# prepare following_nodes
|
||||
following_nodes = self.cost_graph.following_dict
|
||||
index_following_nodes = {}
|
||||
for src, target in following_nodes.items():
|
||||
src_index = self.node_index_dict[src]
|
||||
target_index = self.node_index_dict[target]
|
||||
index_following_nodes[src_index] = target_index
|
||||
following_nodes = index_following_nodes
|
||||
for index in range(node_nums):
|
||||
if index not in following_nodes:
|
||||
following_nodes[index] = -1
|
||||
|
||||
# prepare edge_pairs and resharding costs
|
||||
edge_pairs = []
|
||||
resharding_costs = []
|
||||
for pairs, edge_cost in self.cost_graph.edge_costs.items():
|
||||
src_node = pairs[0]
|
||||
dst_node = pairs[1]
|
||||
src_node_index = self.node_index_dict[src_node]
|
||||
dst_node_index = self.node_index_dict[dst_node]
|
||||
edge_pairs.append(src_node_index)
|
||||
edge_pairs.append(dst_node_index)
|
||||
|
||||
for i in range(strategies_len[src_node_index]):
|
||||
for j in range(strategies_len[dst_node_index]):
|
||||
resharding_costs.append(edge_cost[(i, j)])
|
||||
edge_pairs = np.array(edge_pairs)
|
||||
resharding_costs = np.array(resharding_costs)
|
||||
|
||||
# prepare liveness_set
|
||||
liveness_set = self.liveness_list
|
||||
|
||||
# omit alias_set now
|
||||
alias_set = None
|
||||
alias_convert_costs = None
|
||||
|
||||
# prepare compute_costs, communication_costs and memory_costs
|
||||
compute_costs = []
|
||||
communication_costs = []
|
||||
memory_costs = []
|
||||
extra_node_costs = self.cost_graph.extra_node_costs
|
||||
for strategies_vector in self.leaf_strategies:
|
||||
node = strategies_vector.node
|
||||
for index, strategy in enumerate(strategies_vector):
|
||||
compute_costs.append(strategy.compute_cost)
|
||||
# node in extra_node_costs means it has some extra communication
|
||||
# cost from node merging, so we need to add those extra communication
|
||||
# cost into
|
||||
if node in extra_node_costs:
|
||||
origin_communication_cost = strategy.communication_cost
|
||||
extra_node_cost = extra_node_costs[node][index]
|
||||
communication_cost = origin_communication_cost + extra_node_cost
|
||||
communication_costs.append(communication_cost)
|
||||
else:
|
||||
communication_costs.append(strategy.communication_cost)
|
||||
# temporarily we just consider the forward memory cost
|
||||
memory_cost = strategy.memory_cost
|
||||
if isinstance(memory_cost, tuple):
|
||||
memory_costs.append(memory_cost[0])
|
||||
else:
|
||||
memory_costs.append(memory_cost)
|
||||
compute_costs = np.array(compute_costs)
|
||||
communication_costs = np.array(communication_costs)
|
||||
memory_costs = np.array(memory_costs)
|
||||
|
||||
# omit initial value for nodes
|
||||
s_init_np = None
|
||||
|
||||
return node_nums, memory_budget, strategies_len, following_nodes, edge_pairs, alias_set, liveness_set, compute_costs, communication_costs, memory_costs, resharding_costs, alias_convert_costs, s_init_np
|
||||
|
||||
def _call_solver_serialized_args(self,
|
||||
node_nums,
|
||||
memory_budget,
|
||||
strategies_len,
|
||||
following_nodes,
|
||||
edge_pairs,
|
||||
alias_set,
|
||||
liveness_set,
|
||||
compute_costs,
|
||||
communication_costs,
|
||||
memory_costs,
|
||||
resharding_costs,
|
||||
alias_convert_costs,
|
||||
s_init_np=None):
|
||||
"""
|
||||
Call the solver with serialized arguments.
|
||||
"""
|
||||
|
||||
tic = time.time()
|
||||
|
||||
for x in [strategies_len, edge_pairs, compute_costs, communication_costs, memory_costs, resharding_costs]:
|
||||
assert isinstance(x, np.ndarray)
|
||||
assert len(strategies_len) == node_nums, "strategies_len"
|
||||
|
||||
def get_non_zero_index(binary_vector):
|
||||
"""
|
||||
Get the index of non-zero item in a vector.
|
||||
"""
|
||||
ct = 0
|
||||
ret = None
|
||||
for i, elem in enumerate(binary_vector):
|
||||
if pulp.value(elem):
|
||||
ret = i
|
||||
ct += 1
|
||||
|
||||
assert ct == 1
|
||||
return ret
|
||||
|
||||
# 0. Unpack flatten numpy arrays
|
||||
s_follow = following_nodes
|
||||
|
||||
E = edge_pairs.reshape((-1, 2)) # noqa
|
||||
r = []
|
||||
pt = 0
|
||||
edge_set = set()
|
||||
for (i, j) in E:
|
||||
prod_length = strategies_len[i] * strategies_len[j]
|
||||
|
||||
if (i, j) in edge_set:
|
||||
raise ValueError(f"Duplicated edges: {(i, j)}")
|
||||
|
||||
edge_set.add((i, j))
|
||||
r.append(resharding_costs[pt:pt + prod_length])
|
||||
pt += prod_length
|
||||
assert pt == len(resharding_costs)
|
||||
|
||||
######################
|
||||
# omit alias set now #
|
||||
######################
|
||||
|
||||
# A = alias_set.reshape((-1, 2)) # noqa
|
||||
# for (i, j) in A:
|
||||
# prod_length = strategies_len[i] * strategies_len[j]
|
||||
# v.append(alias_convert_costs[pt:pt + prod_length])
|
||||
# pt += prod_length
|
||||
# assert pt == len(alias_convert_costs)
|
||||
|
||||
# L = [] # noqa
|
||||
# pt = node_nums
|
||||
# for i in range(node_nums):
|
||||
# length = liveness_set[i]
|
||||
# L.append(liveness_set[pt:pt + length])
|
||||
# pt += length
|
||||
# assert pt == len(liveness_set)
|
||||
v = []
|
||||
pt = 0
|
||||
|
||||
c = []
|
||||
d = []
|
||||
m = []
|
||||
pt = 0
|
||||
for i in range(node_nums):
|
||||
length = strategies_len[i]
|
||||
c.append(compute_costs[pt:pt + length])
|
||||
d.append(communication_costs[pt:pt + length])
|
||||
m.append(memory_costs[pt:pt + length])
|
||||
pt += length
|
||||
assert pt == len(compute_costs), f"{pt} == {len(compute_costs)}"
|
||||
assert pt == len(communication_costs), f"{pt} == {len(communication_costs)}"
|
||||
assert pt == len(memory_costs), f"{pt} == {len(memory_costs)}"
|
||||
|
||||
# 1. Create variables
|
||||
|
||||
#############################
|
||||
# create variables for node #
|
||||
#############################
|
||||
s = []
|
||||
num_nodes = 0
|
||||
reverse_follow_backpatch = []
|
||||
for i in range(node_nums):
|
||||
if s_follow[i] < 0:
|
||||
if strategies_len[i] == 1:
|
||||
s.append([1])
|
||||
else:
|
||||
num_nodes += 1
|
||||
s.append(LpVariable.matrix(f"s[{i}]", (range(strategies_len[i]),), cat="Binary"))
|
||||
else:
|
||||
if s_follow[i] < len(s):
|
||||
s.append(s[s_follow[i]])
|
||||
else:
|
||||
s.append(None)
|
||||
reverse_follow_backpatch.append(i)
|
||||
|
||||
for i in reverse_follow_backpatch:
|
||||
s[i] = s[s_follow[i]]
|
||||
|
||||
#############################
|
||||
# create variables for edge #
|
||||
#############################
|
||||
e = []
|
||||
num_edges = 0
|
||||
for (idx, (i, j)) in enumerate(E):
|
||||
if len(s[i]) == 1:
|
||||
e.append(s[j])
|
||||
elif len(s[j]) == 1:
|
||||
e.append(s[i])
|
||||
else:
|
||||
num_edges += 1
|
||||
e.append(LpVariable.matrix(f"e[{i},{j}]", (range(len(s[i]) * len(s[j])),), cat="Binary"))
|
||||
assert len(e[idx]) == len(r[idx])
|
||||
for element in s:
|
||||
assert len(element) > 0
|
||||
# 2. Set initial value
|
||||
######################################
|
||||
# set a initial value for warm start #
|
||||
######################################
|
||||
if s_init_np is not None:
|
||||
s_init = s_init_np.reshape((-1, 3))
|
||||
for (idx, value, fix) in s_init:
|
||||
for i in range(len(s[idx])):
|
||||
s[idx][i].setInitialValue(i == value)
|
||||
if fix:
|
||||
s[idx][i].fixValue()
|
||||
|
||||
# 3. Objective
|
||||
prob = LpProblem("myProblem", LpMinimize)
|
||||
###################################################################
|
||||
# computing the node cost(computing cost and communication cost) #
|
||||
###################################################################
|
||||
obj = 0
|
||||
for i in range(node_nums):
|
||||
assert len(s[i]) == len(c[i])
|
||||
assert len(s[i]) == len(d[i])
|
||||
|
||||
obj += lpDot(s[i], c[i]) + lpDot(s[i], d[i])
|
||||
|
||||
#############################################
|
||||
# computing the edge cost(resharding cost) #
|
||||
#############################################
|
||||
for i in range(len(E)):
|
||||
assert len(e[i]) == len(r[i])
|
||||
obj += lpDot(e[i], r[i])
|
||||
|
||||
prob += obj
|
||||
|
||||
# 4. Constraints
|
||||
# (a). specified by `cat="Binary"`
|
||||
|
||||
# (b)
|
||||
#################################################
|
||||
# make sure each node only choose one strategy #
|
||||
#################################################
|
||||
for i in range(node_nums):
|
||||
if s_follow[i] < 0:
|
||||
prob += lpSum(s[i]) == 1
|
||||
|
||||
# (c)
|
||||
#################################################
|
||||
# compute memory consumption with liveness set #
|
||||
#################################################
|
||||
if memory_budget > 0:
|
||||
for liveness_stage in liveness_set:
|
||||
mem = 0
|
||||
for live_variable in liveness_stage.unique_live_vars:
|
||||
node_index = self.node_index_dict[live_variable.node]
|
||||
mem += lpSum(s[node_index][j] * m[node_index][j] for j in range(len(s[node_index])))
|
||||
prob += mem <= memory_budget
|
||||
|
||||
# (d). specified by `cat="Binary"`
|
||||
|
||||
for (idx, (i, j)) in enumerate(E):
|
||||
if strategies_len[i] == 1 or strategies_len[j] == 1:
|
||||
continue
|
||||
|
||||
# (e)
|
||||
prob += lpSum(e[idx]) == 1
|
||||
|
||||
# (f)
|
||||
for row in range(len(s[i])):
|
||||
C = len(s[j]) # noqa
|
||||
prob += lpSum(e[idx][row * C + col] for col in range(0, C)) <= s[i][row]
|
||||
|
||||
# (g)
|
||||
for col in range(len(s[j])):
|
||||
R = len(s[i]) # noqa
|
||||
C = len(s[j]) # noqa
|
||||
prob += lpSum(e[idx][row * C + col] for row in range(0, R)) <= s[j][col]
|
||||
|
||||
# (h)
|
||||
######################
|
||||
# omit alias set now #
|
||||
######################
|
||||
|
||||
# alias_set = set()
|
||||
# for (idx, (i, j)) in enumerate(A):
|
||||
# R = len(s[i]) # noqa
|
||||
# C = len(s[j]) # noqa
|
||||
# if (i, j) in alias_set:
|
||||
# raise ValueError(f"Duplicated edges: {(i, j)}")
|
||||
|
||||
# alias_set.add((i, j))
|
||||
# alias_set.add((j, i))
|
||||
|
||||
# for row in range(len(s[i])):
|
||||
# for col in range(len(s[j])):
|
||||
# if v[idx][row * C + col] > 0.5:
|
||||
# prob += s[i][row] + s[j][col] <= 1
|
||||
|
||||
verbose = True
|
||||
|
||||
msg = verbose
|
||||
time_limit = 600
|
||||
assert "COIN_CMD" in pulp.listSolvers(
|
||||
onlyAvailable=True), ("Please install ILP solvers by 'sudo apt install coinor-cbc'")
|
||||
|
||||
solver = pulp.COIN_CMD(mip=True, msg=msg, timeLimit=time_limit, threads=multiprocessing.cpu_count())
|
||||
# solver = pulp.GLPK_CMD(mip=True, msg=msg, timeLimit=time_limit)
|
||||
prob.solve(solver)
|
||||
|
||||
status = prob.status
|
||||
objective = pulp.value(prob.objective)
|
||||
objective = float(objective) if objective is not None else -1.0
|
||||
if verbose:
|
||||
print(f"ILP Status: {LpStatus[status]}\tObjective: {objective}\t"
|
||||
f"Time: {time.time() - tic}")
|
||||
print(f"#nodes: {num_nodes}, #edges: {num_edges}")
|
||||
|
||||
if prob.status in [pulp.LpStatusInfeasible]:
|
||||
raise RuntimeError("Cannot run the function under the given memory budget. "
|
||||
"Please increase the memory budget.")
|
||||
|
||||
# Get and check results
|
||||
s_val = np.full((node_nums,), -1, dtype=np.int32)
|
||||
for i in range(node_nums):
|
||||
s_val[i] = get_non_zero_index(s[i])
|
||||
|
||||
e_val = np.full((len(E),), -1, dtype=np.int32)
|
||||
for (idx, (i, j)) in enumerate(E):
|
||||
e_val[idx] = get_non_zero_index(e[idx])
|
||||
i_spec_index = e_val[idx] // len(s[j])
|
||||
j_spec_index = e_val[idx] % len(s[j])
|
||||
assert i_spec_index == s_val[i], f"e_val[{i}][{j}]"
|
||||
assert j_spec_index == s_val[j], f"e_val[{i}][{j}]"
|
||||
if verbose and r[idx][e_val[idx]] > 0:
|
||||
print(f"Edge cost {(i, j)} : {r[idx][e_val[idx]]}")
|
||||
|
||||
self.last_s_val = list(s_val)
|
||||
self._recover_merged_node_strategy()
|
||||
self.last_objective = objective
|
||||
|
||||
if objective > INFINITY_COST:
|
||||
warnings.warn("Detect unexpected behaviors in the auto-sharding pass.")
|
||||
|
||||
return self.last_s_val, e_val, self.last_objective, status
|
||||
|
||||
def call_solver_serialized_args(self):
|
||||
"""
|
||||
Call the solver with serialized arguments and handle python errors. Additionally,
|
||||
we could give a serious of solutions with different memory budget.
|
||||
"""
|
||||
if self.solution_numbers == 1:
|
||||
args = self._prepare_data_for_solver()
|
||||
ret = self._call_solver_serialized_args(*args)
|
||||
|
||||
return ret
|
||||
|
||||
origin_memory_budget = self.memory_budget
|
||||
memory_budget_list = [
|
||||
origin_memory_budget * self.memory_increasing_coefficient**i for i in range(self.solution_numbers)
|
||||
]
|
||||
ret_list = []
|
||||
for memory_budget in memory_budget_list:
|
||||
self.memory_budget = memory_budget
|
||||
args = self._prepare_data_for_solver()
|
||||
ret = self._call_solver_serialized_args(*args)
|
||||
ret_list.append(ret)
|
||||
|
||||
return ret_list
|
|
@ -1,426 +0,0 @@
|
|||
import builtins
|
||||
import math
|
||||
import operator
|
||||
from copy import deepcopy
|
||||
from typing import Dict, List
|
||||
|
||||
import torch
|
||||
from torch.fx import Graph, Node
|
||||
|
||||
from colossalai.device.device_mesh import DeviceMesh
|
||||
from colossalai.tensor.shape_consistency import ShapeConsistencyManager
|
||||
from colossalai.tensor.sharding_spec import ShardingSpec
|
||||
|
||||
from ._utils import generate_resharding_costs, generate_sharding_spec
|
||||
from .constants import *
|
||||
from .op_handler import *
|
||||
from .options import SolverOptions
|
||||
from .sharding_strategy import ShardingStrategy, StrategiesVector
|
||||
|
||||
__all__ = ['StrategiesConstructor']
|
||||
|
||||
|
||||
class StrategiesConstructor:
|
||||
"""
|
||||
StrategiesConstructor is used to construct the parallelization plan for the model execution.
|
||||
|
||||
Args:
|
||||
graph (Graph): a Graph object used for analysis and strategy generation.
|
||||
device_mesh (DeviceMesh): a DeviceMesh object which contains the meta information about the cluster.
|
||||
solver_options (SolverOptions): a SolverOptions object which specifies the preferences for plan searching.
|
||||
"""
|
||||
|
||||
def __init__(self, graph: Graph, device_mesh: DeviceMesh, solver_options: SolverOptions):
|
||||
self.graph = graph
|
||||
assert graph.owning_module is not None, 'The given graph is not associated with a owning_module'
|
||||
self.root_module = self.graph.owning_module
|
||||
self.nodes = list(graph.nodes)
|
||||
self.device_mesh = device_mesh
|
||||
self.leaf_strategies = []
|
||||
self.strategy_map = {}
|
||||
self.solver_options = solver_options
|
||||
|
||||
def remove_duplicated_strategy(self, strategies_vector):
|
||||
'''
|
||||
In build_strategies_and_cost method, we may produce some duplicated strategies.
|
||||
In this method, we will remove the duplicated strategies depending on the strategies name.
|
||||
'''
|
||||
name_checklist = []
|
||||
remove_list = []
|
||||
for strategy in strategies_vector:
|
||||
if strategy.name not in name_checklist:
|
||||
name_checklist.append(strategy.name)
|
||||
else:
|
||||
remove_list.append(strategy)
|
||||
|
||||
for strategy in remove_list:
|
||||
strategies_vector.remove(strategy)
|
||||
|
||||
def _is_bcast_matmul(self, node):
|
||||
is_bcast_matmul = False
|
||||
if node.target is torch.matmul and len(node.args) == 2:
|
||||
lhs_data = node.args[0]._meta_data
|
||||
rhs_data = node.args[1]._meta_data
|
||||
if lhs_data.dim() >= 3 and rhs_data.dim() >= 3:
|
||||
is_bcast_matmul = True
|
||||
return is_bcast_matmul
|
||||
|
||||
def build_strategies_and_cost(self):
|
||||
for node in self.nodes:
|
||||
strategies_vector = StrategiesVector(node)
|
||||
input_nodes_len = 0
|
||||
for check_node in strategies_vector.predecessor_nodes:
|
||||
if isinstance(check_node._meta_data, torch.Tensor):
|
||||
input_nodes_len += 1
|
||||
# input_nodes_len = len(strategies_vector.predecessor_nodes)
|
||||
# placeholder node
|
||||
if node.op == 'placeholder':
|
||||
# For placeholder nodes, if solver_options.fast is True, we just let them in
|
||||
# fully replicate status, then strategies of following node will be treated equally due
|
||||
# to replicate status has no resharding cost to other status. At the same time, the searching
|
||||
# space is smaller than enumerating all the possible sharding spec for the placeholder node.
|
||||
# Otherwise, all the possible sharding spec for the placeholder node will be enumerated.
|
||||
|
||||
if self.solver_options.fast:
|
||||
# create sharding strategy for placeholder
|
||||
name = 'Replica Placeholder'
|
||||
dim_partition_dict = {}
|
||||
output_sharding_spec = generate_sharding_spec(node, self.device_mesh, dim_partition_dict)
|
||||
# TODO: use meta_info_prop to profile memory cost
|
||||
memory_cost = 0
|
||||
sharding_strategy_placeholder = ShardingStrategy(name,
|
||||
output_sharding_spec,
|
||||
memory_cost=memory_cost)
|
||||
strategies_vector.append(sharding_strategy_placeholder)
|
||||
|
||||
# get_attr node
|
||||
if node.op == 'get_attr':
|
||||
# Same as placeholder nodes, if solver_options.fast is True, we just let them in
|
||||
# fully replicate status, then strategies of following node will be treated equally due
|
||||
# to replicate status has no resharding cost to other status. At the same time, the searching
|
||||
# space is smaller than enumerating all the possible sharding spec for the get_attr node.
|
||||
# Otherwise, all the possible sharding spec for the get_attr node will be enumerated.
|
||||
if self.solver_options.fast:
|
||||
# create sharding strategy for get_attr
|
||||
name = 'Replica Attribute'
|
||||
dim_partition_dict = {}
|
||||
output_sharding_spec = generate_sharding_spec(node, self.device_mesh, dim_partition_dict)
|
||||
# TODO: use meta_info_prop to profile memory cost
|
||||
memory_cost = 0
|
||||
sharding_strategy_attribute = ShardingStrategy(name, output_sharding_spec, memory_cost=memory_cost)
|
||||
strategies_vector.append(sharding_strategy_attribute)
|
||||
|
||||
# call_module node
|
||||
if node.op == 'call_module':
|
||||
|
||||
target = node.target
|
||||
submod = self.root_module.get_submodule(target)
|
||||
submod_type = type(submod)
|
||||
|
||||
# conv module
|
||||
if submod_type in CONV_MODULE_OP:
|
||||
# use ConvHandler to create sharding strategies for conv module node
|
||||
conv_handler = ConvHandler(node, self.device_mesh, strategies_vector)
|
||||
conv_handler.register_strategy()
|
||||
|
||||
# linear module
|
||||
elif submod_type in LINEAR_MODULE_OP:
|
||||
# use DotHandler to create sharding strategies for linear module node
|
||||
dot_handler = DotHandler(node, self.device_mesh, strategies_vector)
|
||||
dot_handler.register_strategy()
|
||||
|
||||
# element-wise module
|
||||
elif submod_type in ELEMENTWISE_MODULE_OP:
|
||||
unary_elementwise_handler = UnaryElementwiseHandler(node, self.device_mesh, strategies_vector)
|
||||
unary_elementwise_handler.register_strategy()
|
||||
|
||||
# BatchNormNd module
|
||||
elif submod_type in BATCHNORM_MODULE_OP:
|
||||
# create sharding strategy for element-wise module
|
||||
norm_handler = BatchNormHandler(node, self.device_mesh, strategies_vector)
|
||||
norm_handler.register_strategy()
|
||||
# for strategy in norm_handler.strategies_vector:
|
||||
# print(f'{strategy.name}, computation_cost: {strategy.compute_cost}, memory_cost: {strategy.memory_cost}')
|
||||
# assert False
|
||||
|
||||
# MaxPool module
|
||||
elif submod_type in POOL_MODULE_OP:
|
||||
# TODO: add sharding constraints on image dimension
|
||||
# e.g.: for a 2D pooling input NCHW, we should promise no sharding happens on H and W dimension
|
||||
|
||||
# create sharding strategy for element-wise module
|
||||
assert input_nodes_len == 1, f'Temporally, we just support single input element-wise op.'
|
||||
input_node = strategies_vector.predecessor_nodes[0]
|
||||
# For element-wise module, we keep the sharding spec of output node same as
|
||||
# the input. Therefore, the different strategies of input node with same
|
||||
# output sharding spec will generate same strategy for element-wise module.
|
||||
sharding_spec_checklist = []
|
||||
for strategy in input_node.strategies_vector:
|
||||
# It looks a little bit confusing, the input of the processing node
|
||||
# is the output of the input_node.
|
||||
input_sharding_spec = strategy.output_sharding_spec
|
||||
assert isinstance(input_sharding_spec,
|
||||
ShardingSpec), f'The input node should NOT be a tuple of tensor.'
|
||||
if input_sharding_spec in sharding_spec_checklist:
|
||||
continue
|
||||
|
||||
sharding_spec_checklist.append(input_sharding_spec)
|
||||
dim_partition_dict = deepcopy(input_sharding_spec.dim_partition_dict)
|
||||
output_sharding_spec = generate_sharding_spec(node, self.device_mesh, dim_partition_dict)
|
||||
|
||||
name = f'{input_sharding_spec.sharding_sequence} -> {output_sharding_spec.sharding_sequence}'
|
||||
|
||||
# TODO: use meta_info_prop to profile memory cost and compute cost
|
||||
compute_cost = node._meta_data.numel()
|
||||
memory_cost = 0
|
||||
resharding_costs = generate_resharding_costs(strategies_vector.predecessor_nodes,
|
||||
[input_sharding_spec])
|
||||
|
||||
sharding_strategy = ShardingStrategy(name,
|
||||
output_sharding_spec,
|
||||
compute_cost=compute_cost,
|
||||
memory_cost=memory_cost,
|
||||
resharding_costs=resharding_costs,
|
||||
input_shardings=[input_sharding_spec])
|
||||
strategies_vector.append(sharding_strategy)
|
||||
|
||||
# embedding module
|
||||
elif submod_type in EMBEDDING_MODULE_OP:
|
||||
embedding_handler = EmbeddingHandler(node, self.device_mesh, strategies_vector)
|
||||
embedding_handler.register_strategy()
|
||||
|
||||
# layernorm module
|
||||
elif submod_type in LAYERNORM_MODULE_OP:
|
||||
layernorm_handler = LayerNormHandler(node, self.device_mesh, strategies_vector)
|
||||
layernorm_handler.register_strategy()
|
||||
# other module
|
||||
else:
|
||||
raise RuntimeError(f'{submod_type} module is NOT supported now.')
|
||||
|
||||
# call_function node
|
||||
if node.op == 'call_function':
|
||||
target = node.target
|
||||
# conv function
|
||||
if target in CONV_FUNC_OP:
|
||||
# use ConvHandler to create sharding strategies for conv node
|
||||
# TODO: the operator_handler does NOT support function node processing now.
|
||||
conv_handler = ConvHandler(node, self.device_mesh, strategies_vector)
|
||||
conv_handler.register_strategy()
|
||||
|
||||
# linear function
|
||||
elif target in LINEAR_FUNC_OP and not self._is_bcast_matmul(node):
|
||||
# use DotHandler to create sharding strategies for linear node
|
||||
# TODO: the operator_handler does NOT support function node processing now.
|
||||
linear_handler = DotHandler(node, self.device_mesh, strategies_vector)
|
||||
linear_handler.register_strategy()
|
||||
|
||||
# where function
|
||||
elif target == torch.where:
|
||||
if input_nodes_len == 1:
|
||||
# both of x and y are scalar
|
||||
pass
|
||||
|
||||
elif input_nodes_len == 2:
|
||||
# one of x or y is type of scalar
|
||||
pass
|
||||
|
||||
else:
|
||||
# general case
|
||||
where_handler = WhereHandler(node, self.device_mesh, strategies_vector)
|
||||
where_handler.register_strategy()
|
||||
|
||||
# reshape function
|
||||
elif target in RESHAPE_FUNC_OP:
|
||||
# use ReshapeHandler to create sharding strategies for rehsape node
|
||||
reshape_handler = ReshapeHandler(node, self.device_mesh, strategies_vector)
|
||||
reshape_handler.register_strategy()
|
||||
|
||||
# element-wise function
|
||||
elif target in ELEMENTWISE_FUNC_OP or (target in BCAST_FUNC_OP and input_nodes_len == 1):
|
||||
unary_elementwise_handler = UnaryElementwiseHandler(node, self.device_mesh, strategies_vector)
|
||||
unary_elementwise_handler.register_strategy()
|
||||
|
||||
# bcast op
|
||||
elif target in BCAST_FUNC_OP:
|
||||
if isinstance(node._meta_data, torch.Tensor):
|
||||
bcast_op_handler = BcastOpHandler(node, self.device_mesh, strategies_vector)
|
||||
bcast_op_handler.register_strategy()
|
||||
|
||||
# torch.var_mean
|
||||
elif target == torch.var_mean:
|
||||
dim = node.kwargs['dim']
|
||||
input_tensor_node = strategies_vector.predecessor_nodes[0]
|
||||
for strategy in input_tensor_node.strategies_vector:
|
||||
input_sharding_spec = strategy.output_sharding_spec
|
||||
assert isinstance(input_sharding_spec,
|
||||
ShardingSpec), f'The input node should NOT be a tuple of tensor.'
|
||||
entire_shape_input = input_sharding_spec.entire_shape
|
||||
dim_partition_dict_input = input_sharding_spec.dim_partition_dict
|
||||
name = f'{new_input_sharding_spec.sharding_sequence} -> ({output_sharding_spec.sharding_sequence}, {output_sharding_spec.sharding_sequence})'
|
||||
if dim in dim_partition_dict_input:
|
||||
# We need to make the action dimension in replicate status
|
||||
dim_partition_dict_for_input = deepcopy(dim_partition_dict_input)
|
||||
dim_partition_dict_for_input.pop(dim)
|
||||
new_input_sharding_spec = ShardingSpec(self.device_mesh,
|
||||
entire_shape_input,
|
||||
dim_partition_dict=dim_partition_dict_for_input)
|
||||
entire_shape_output = deepcopy(entire_shape_input)
|
||||
entire_shape_output.pop(dim)
|
||||
dim_partition_dict_for_output = deepcopy(dim_partition_dict_for_input)
|
||||
output_sharding_spec = ShardingSpec(self.device_mesh,
|
||||
entire_shape_output,
|
||||
dim_partition_dict=dim_partition_dict_for_input)
|
||||
# TODO: use meta_info_prop to profile origin memory cost and compute cost, then divide them depending on sharding spec.
|
||||
compute_cost = 0
|
||||
memory_cost = 0
|
||||
resharding_costs = generate_resharding_costs(strategies_vector.predecessor_nodes,
|
||||
[new_input_sharding_spec])
|
||||
sharding_strategy = ShardingStrategy(name, (output_sharding_spec, output_sharding_spec),
|
||||
compute_cost=compute_cost,
|
||||
memory_cost=memory_cost,
|
||||
resharding_costs=resharding_costs,
|
||||
input_shardings=[new_input_sharding_spec])
|
||||
|
||||
else:
|
||||
entire_shape_output = deepcopy(entire_shape_input)
|
||||
entire_shape_output.pop(dim)
|
||||
dim_partition_dict_for_output = deepcopy(dim_partition_dict_input)
|
||||
output_sharding_spec = ShardingSpec(self.device_mesh,
|
||||
entire_shape_output,
|
||||
dim_partion_dict=dim_partition_dict_input)
|
||||
# TODO: use meta_info_prop to profile origin memory cost and compute cost, then divide them depending on sharding spec.
|
||||
compute_cost = 0
|
||||
memory_cost = 0
|
||||
resharding_costs = generate_resharding_costs(strategies_vector.predecessor_nodes,
|
||||
[input_sharding_spec])
|
||||
sharding_strategy = ShardingStrategy(name, (output_sharding_spec, output_sharding_spec),
|
||||
compute_cost=compute_cost,
|
||||
memory_cost=memory_cost,
|
||||
resharding_costs=resharding_costs,
|
||||
input_shardings=[input_sharding_spec])
|
||||
|
||||
strategies_vector.append(sharding_strategy)
|
||||
|
||||
# operator.getitem
|
||||
elif target == operator.getitem:
|
||||
index = node.args[1]
|
||||
input_tensor_node = strategies_vector.predecessor_nodes[0]
|
||||
for strategy in input_tensor_node.strategies_vector:
|
||||
if isinstance(strategy.output_sharding_spec, ShardingSpec):
|
||||
input_sharding_spec = strategy.output_sharding_spec
|
||||
else:
|
||||
input_sharding_spec = strategy.output_sharding_spec[index]
|
||||
assert isinstance(input_sharding_spec, ShardingSpec), f'This assertion is used to debug.'
|
||||
dim_partition_dict_for_output = deepcopy(input_sharding_spec.dim_partition_dict)
|
||||
entire_shape_output = deepcopy(input_sharding_spec.entire_shape)
|
||||
output_sharding_spec = ShardingSpec(self.device_mesh,
|
||||
entire_shape_output,
|
||||
dim_partition_dict=dim_partition_dict_for_output)
|
||||
# TODO: use meta_info_prop to profile origin memory cost and compute cost, then divide them depending on sharding spec.
|
||||
compute_cost = 0
|
||||
memory_cost = 0
|
||||
resharding_costs = generate_resharding_costs(strategies_vector.predecessor_nodes,
|
||||
[input_sharding_spec],
|
||||
index=index)
|
||||
# to prevent the resharding happening, set their resharding cost to inf.
|
||||
resharding_costs[input_tensor_node] = [
|
||||
cost if cost == 0 else INFINITY_COST for cost in resharding_costs[input_tensor_node]
|
||||
]
|
||||
sharding_strategy = ShardingStrategy(name,
|
||||
output_sharding_spec,
|
||||
compute_cost=compute_cost,
|
||||
memory_cost=memory_cost,
|
||||
resharding_costs=resharding_costs,
|
||||
input_shardings=[strategy.output_sharding_spec])
|
||||
strategies_vector.append(sharding_strategy)
|
||||
|
||||
# torch.arange function
|
||||
elif target == torch.arange:
|
||||
name = f'FULLY REPLICATED ARANGE'
|
||||
entire_shape_output = node._meta_data.shape
|
||||
dim_partition_dict_for_output = {}
|
||||
output_sharding_spec = ShardingSpec(self.device_mesh,
|
||||
entire_shape_output,
|
||||
dim_partition_dict=dim_partition_dict_for_output)
|
||||
memory_cost = node._meta_data.numel()
|
||||
sharding_strategy = ShardingStrategy(name,
|
||||
output_sharding_spec,
|
||||
compute_cost=0,
|
||||
memory_cost=memory_cost)
|
||||
strategies_vector.append(sharding_strategy)
|
||||
|
||||
# op list to be processed to support gpt2
|
||||
elif target in (builtins.getattr, operator.le, torch.addmm):
|
||||
pass
|
||||
# other function
|
||||
else:
|
||||
raise RuntimeError(f'{target} function is NOT supported now.')
|
||||
|
||||
# call_method node
|
||||
if node.op == 'call_method':
|
||||
method = getattr(node.args[0]._meta_data.__class__, node.target)
|
||||
if method in (torch.Tensor.size,):
|
||||
pass
|
||||
elif method in ELEMENTWISE_METHOD_OP:
|
||||
unary_elementwise_handler = UnaryElementwiseHandler(node, self.device_mesh, strategies_vector)
|
||||
unary_elementwise_handler.register_strategy()
|
||||
|
||||
elif method in RESHAPE_METHOD_OP:
|
||||
reshape_handler = ReshapeHandler(node, self.device_mesh, strategies_vector)
|
||||
reshape_handler.register_strategy()
|
||||
# print(strategies_vector)
|
||||
# if len(strategies_vector) == 0:
|
||||
# print(node)
|
||||
# assert False
|
||||
else:
|
||||
raise RuntimeError(f'{method} function is NOT supported now.')
|
||||
|
||||
# output node
|
||||
if node.op == 'output':
|
||||
if self.solver_options.fast:
|
||||
# create sharding strategy for output
|
||||
name = 'Replica Output'
|
||||
input_nodes = strategies_vector.predecessor_nodes
|
||||
input_sharding_specs = []
|
||||
for input_node in input_nodes:
|
||||
dim_partition_dict_for_input = {}
|
||||
entire_shape = input_node._meta_data.shape
|
||||
sharding_spec = ShardingSpec(self.device_mesh,
|
||||
entire_shape,
|
||||
dim_partition_dict=dim_partition_dict_for_input)
|
||||
input_sharding_specs.append(sharding_spec)
|
||||
|
||||
dim_partition_dict = {}
|
||||
output_sharding_spec = input_sharding_specs
|
||||
# TODO: use meta_info_prop to profile memory cost
|
||||
memory_cost = 0
|
||||
resharding_costs = generate_resharding_costs(strategies_vector.predecessor_nodes,
|
||||
input_sharding_specs)
|
||||
|
||||
# clear the resharding cost for the output node
|
||||
# TODO: we may remove this in final version
|
||||
for prev_node, resharding_cost_list in resharding_costs.items():
|
||||
resharding_costs[prev_node] = [0] * len(resharding_cost_list)
|
||||
|
||||
sharding_strategy_attribute = ShardingStrategy(name,
|
||||
output_sharding_spec,
|
||||
memory_cost=memory_cost,
|
||||
resharding_costs=resharding_costs,
|
||||
input_shardings=tuple(input_sharding_specs))
|
||||
strategies_vector.append(sharding_strategy_attribute)
|
||||
|
||||
self.remove_duplicated_strategy(strategies_vector)
|
||||
setattr(node, 'strategies_vector', strategies_vector)
|
||||
self.leaf_strategies.append(strategies_vector)
|
||||
self.strategy_map[node] = strategies_vector
|
||||
|
||||
# remove no strategy nodes
|
||||
remove_list = []
|
||||
for strategies_vector in self.leaf_strategies:
|
||||
if len(strategies_vector) == 0:
|
||||
remove_list.append(strategies_vector.node)
|
||||
for node in remove_list:
|
||||
if node.strategies_vector in self.leaf_strategies:
|
||||
self.leaf_strategies.remove(node.strategies_vector)
|
||||
if node in self.strategy_map:
|
||||
self.strategy_map.pop(node)
|
|
@ -1,96 +0,0 @@
|
|||
from copy import deepcopy
|
||||
from pickletools import optimize
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.fx import GraphModule
|
||||
|
||||
from colossalai.auto_parallel.tensor_shard.deprecated.cost_graph import CostGraph
|
||||
from colossalai.auto_parallel.tensor_shard.deprecated.options import SolverOptions
|
||||
from colossalai.auto_parallel.tensor_shard.deprecated.strategies_constructor import StrategiesConstructor
|
||||
from colossalai.device.device_mesh import DeviceMesh
|
||||
from colossalai.fx.tracer.tracer import ColoTracer
|
||||
|
||||
|
||||
class ConvModel(nn.Module):
|
||||
|
||||
def __init__(self, c_in, c_out):
|
||||
super().__init__()
|
||||
self.conv1 = nn.Conv2d(c_in, c_out, kernel_size=3)
|
||||
self.relu = nn.ReLU()
|
||||
|
||||
def forward(self, x):
|
||||
x = x * 2
|
||||
x = self.conv1(x)
|
||||
x = x / 2
|
||||
x = self.relu(x)
|
||||
return x
|
||||
|
||||
|
||||
def test_cost_graph():
|
||||
physical_mesh_id = torch.arange(0, 4)
|
||||
mesh_shape = (2, 2)
|
||||
# [[0, 1]
|
||||
# [2, 3]]
|
||||
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape)
|
||||
entire_shape = torch.Size((4, 16, 64, 64))
|
||||
|
||||
tracer = ColoTracer()
|
||||
model = ConvModel(16, 32)
|
||||
input_sample = {'x': torch.rand(4, 16, 64, 64).to('meta')}
|
||||
|
||||
# graph():
|
||||
# %x : torch.Tensor [#users=1] = placeholder[target=x]
|
||||
# %mul : [#users=1] = call_function[target=operator.mul](args = (%x, 2), kwargs = {})
|
||||
# %conv1 : [#users=1] = call_module[target=conv1](args = (%mul,), kwargs = {})
|
||||
# %truediv : [#users=1] = call_function[target=operator.truediv](args = (%conv1, 2), kwargs = {})
|
||||
# %relu : [#users=1] = call_module[target=relu](args = (%truediv,), kwargs = {})
|
||||
# return relu
|
||||
graph = tracer.trace(root=model, meta_args=input_sample)
|
||||
gm = GraphModule(model, graph, model.__class__.__name__)
|
||||
gm.recompile()
|
||||
|
||||
solver_options = SolverOptions(fast=True)
|
||||
strategies_constructor = StrategiesConstructor(graph, device_mesh, solver_options)
|
||||
strategies_constructor.build_strategies_and_cost()
|
||||
|
||||
# (x, mul):{(0, 0): 0}
|
||||
# (mul, conv1):{(0, 0): 65547.1, (0, 1): 65547.1, (0, 2): 65547.1, (0, 3): 65547.1, (0, 4): 131105.30000000002, (0, 5): 131105.30000000002, (0, 6): 65547.1, (0, 7): 65547.1, (0, 8): 65547.1, (0, 9): 65547.1, (0, 10): 0, (0, 11): 0, (0, 12): 0, (0, 13): 131105.30000000002, (0, 14): 131105.30000000002}
|
||||
# (conv1, truediv):{(0, 0): 0, (1, 0): inf, (2, 0): inf, (3, 0): inf, (4, 0): 0, (5, 0): inf, (6, 0): inf, (7, 0): inf, (8, 0): inf, (9, 0): inf, (10, 0): inf, (11, 0): inf, (12, 0): inf, (13, 0): inf, (14, 0): inf, (0, 1): inf, (1, 1): 0, (2, 1): inf, (3, 1): inf, (4, 1): inf, (5, 1): 0, (6, 1): inf, (7, 1): inf, (8, 1): inf, (9, 1): inf, (10, 1): inf, (11, 1): inf, (12, 1): inf, (13, 1): inf, (14, 1): inf, (0, 2): inf, (1, 2): inf, (2, 2): 0, (3, 2): inf, (4, 2): inf, (5, 2): inf, (6, 2): inf, (7, 2): inf, (8, 2): inf, (9, 2): inf, (10, 2): inf, (11, 2): inf, (12, 2): inf, (13, 2): inf, (14, 2): inf, (0, 3): inf, (1, 3): inf, (2, 3): inf, (3, 3): 0, (4, 3): inf, (5, 3): inf, (6, 3): inf, (7, 3): inf, (8, 3): inf, (9, 3): inf, (10, 3): inf, (11, 3): inf, (12, 3): inf, (13, 3): inf, (14, 3): inf, (0, 4): inf, (1, 4): inf, (2, 4): inf, (3, 4): inf, (4, 4): inf, (5, 4): inf, (6, 4): 0, (7, 4): inf, (8, 4): 0, (9, 4): inf, (10, 4): inf, (11, 4): inf, (12, 4): inf, (13, 4): inf, (14, 4): inf, (0, 5): inf, (1, 5): inf, (2, 5): inf, (3, 5): inf, (4, 5): inf, (5, 5): inf, (6, 5): inf, (7, 5): 0, (8, 5): inf, (9, 5): 0, (10, 5): inf, (11, 5): inf, (12, 5): inf, (13, 5): inf, (14, 5): inf, (0, 6): inf, (1, 6): inf, (2, 6): inf, (3, 6): inf, (4, 6): inf, (5, 6): inf, (6, 6): inf, (7, 6): inf, (8, 6): inf, (9, 6): inf, (10, 6): 0, (11, 6): 0, (12, 6): 0, (13, 6): inf, (14, 6): inf, (0, 7): inf, (1, 7): inf, (2, 7): inf, (3, 7): inf, (4, 7): inf, (5, 7): inf, (6, 7): inf, (7, 7): inf, (8, 7): inf, (9, 7): inf, (10, 7): inf, (11, 7): inf, (12, 7): inf, (13, 7): 0, (14, 7): inf, (0, 8): inf, (1, 8): inf, (2, 8): inf, (3, 8): inf, (4, 8): inf, (5, 8): inf, (6, 8): inf, (7, 8): inf, (8, 8): inf, (9, 8): inf, (10, 8): inf, (11, 8): inf, (12, 8): inf, (13, 8): inf, (14, 8): 0}
|
||||
# (truediv, relu):{(0, 0): 0, (1, 0): inf, (2, 0): inf, (3, 0): inf, (4, 0): inf, (5, 0): inf, (6, 0): inf, (7, 0): inf, (8, 0): inf, (0, 1): inf, (1, 1): 0, (2, 1): inf, (3, 1): inf, (4, 1): inf, (5, 1): inf, (6, 1): inf, (7, 1): inf, (8, 1): inf, (0, 2): inf, (1, 2): inf, (2, 2): 0, (3, 2): inf, (4, 2): inf, (5, 2): inf, (6, 2): inf, (7, 2): inf, (8, 2): inf, (0, 3): inf, (1, 3): inf, (2, 3): inf, (3, 3): 0, (4, 3): inf, (5, 3): inf, (6, 3): inf, (7, 3): inf, (8, 3): inf, (0, 4): inf, (1, 4): inf, (2, 4): inf, (3, 4): inf, (4, 4): 0, (5, 4): inf, (6, 4): inf, (7, 4): inf, (8, 4): inf, (0, 5): inf, (1, 5): inf, (2, 5): inf, (3, 5): inf, (4, 5): inf, (5, 5): 0, (6, 5): inf, (7, 5): inf, (8, 5): inf, (0, 6): inf, (1, 6): inf, (2, 6): inf, (3, 6): inf, (4, 6): inf, (5, 6): inf, (6, 6): 0, (7, 6): inf, (8, 6): inf, (0, 7): inf, (1, 7): inf, (2, 7): inf, (3, 7): inf, (4, 7): inf, (5, 7): inf, (6, 7): inf, (7, 7): 0, (8, 7): inf, (0, 8): inf, (1, 8): inf, (2, 8): inf, (3, 8): inf, (4, 8): inf, (5, 8): inf, (6, 8): inf, (7, 8): inf, (8, 8): 0}
|
||||
# (relu, output):{(0, 0): 246019.30000000002, (1, 0): 246019.30000000002, (2, 0): 123009.1, (3, 0): 123009.1, (4, 0): 123009.1, (5, 0): 123009.1, (6, 0): 0, (7, 0): 246019.30000000002, (8, 0): 246019.30000000002}
|
||||
cost_graph = CostGraph(strategies_constructor.leaf_strategies)
|
||||
|
||||
# construct all node pairs
|
||||
all_node_pairs = []
|
||||
|
||||
for node in graph.nodes:
|
||||
if node.op == 'output':
|
||||
continue
|
||||
for child in node.users.keys():
|
||||
all_node_pairs.append((node, child))
|
||||
|
||||
for node_pair in all_node_pairs:
|
||||
assert node_pair in cost_graph.edge_costs
|
||||
|
||||
# construct merged node pairs
|
||||
merged_node_pairs = []
|
||||
node_list = list(graph.nodes)
|
||||
# add (conv1_weight, conv2d), (conv1_bias, view), (conv2d, add), (view, add), (add, output), (x, conv2d) into check node pairs
|
||||
merged_node_pairs.append((node_list[0], node_list[4]))
|
||||
merged_node_pairs.append((node_list[2], node_list[4]))
|
||||
merged_node_pairs.append((node_list[3], node_list[5]))
|
||||
merged_node_pairs.append((node_list[5], node_list[6]))
|
||||
merged_node_pairs.append((node_list[4], node_list[6]))
|
||||
merged_node_pairs.append((node_list[6], node_list[-1]))
|
||||
cost_graph.simplify_graph()
|
||||
for node_pair in all_node_pairs:
|
||||
if node_pair in merged_node_pairs:
|
||||
assert node_pair in cost_graph.edge_costs
|
||||
else:
|
||||
assert node_pair not in cost_graph.edge_costs
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_cost_graph()
|
|
@ -1,118 +0,0 @@
|
|||
import torch
|
||||
from torch.fx import GraphModule
|
||||
import torch.nn as nn
|
||||
import pytest
|
||||
|
||||
from colossalai.fx.proxy import ColoProxy
|
||||
from colossalai.fx.tracer.tracer import ColoTracer
|
||||
from colossalai.tensor.sharding_spec import ShardingSpec, _DimSpec
|
||||
from colossalai.auto_parallel.tensor_shard.deprecated.op_handler.batch_norm_handler import BatchNormHandler
|
||||
from colossalai.auto_parallel.tensor_shard.deprecated.sharding_strategy import ShardingStrategy, StrategiesVector
|
||||
from colossalai.device.device_mesh import DeviceMesh
|
||||
|
||||
|
||||
class BNModel(nn.Module):
|
||||
|
||||
def __init__(self, c):
|
||||
super().__init__()
|
||||
self.bn = nn.BatchNorm2d(c)
|
||||
|
||||
def forward(self, x):
|
||||
x = x * 2
|
||||
x = self.bn(x)
|
||||
return x
|
||||
|
||||
|
||||
def test_bn_handler():
|
||||
physical_mesh_id = torch.arange(0, 4)
|
||||
mesh_shape = (2, 2)
|
||||
# [[0, 1]
|
||||
# [2, 3]]
|
||||
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape)
|
||||
entire_shape = torch.Size((4, 16, 64, 64))
|
||||
|
||||
tracer = ColoTracer()
|
||||
model = BNModel(16)
|
||||
input_sample = {'x': torch.rand(4, 16, 64, 64).to('meta')}
|
||||
# graph():
|
||||
# %x : torch.Tensor [#users=1] = placeholder[target=x]
|
||||
# %mul : [#users=1] = call_function[target=operator.mul](args = (%x, 2), kwargs = {})
|
||||
# %bn : [#users=1] = call_module[target=bn](args = (%mul,), kwargs = {})
|
||||
# return bn
|
||||
graph = tracer.trace(root=model, meta_args=input_sample)
|
||||
gm = GraphModule(model, graph, model.__class__.__name__)
|
||||
gm.recompile()
|
||||
# [x, mul, bn, output]
|
||||
nodes = [node for node in gm.graph.nodes]
|
||||
|
||||
# find the sharding strategies for the input node of the bn node
|
||||
# strategies_for_input = [[R, R, R, R], [R, S0, R, R], [R, S1, R, R], [S0, R, R, R], [S0, S1, R, R], [S1, R, R, R], [S1, S0, R, R]]
|
||||
strategies_vector_for_input = StrategiesVector(nodes[1])
|
||||
sharding_option = (None, 0, 1)
|
||||
for first_sharding_index in sharding_option:
|
||||
for second_sharding_index in sharding_option:
|
||||
if first_sharding_index is not None and second_sharding_index == first_sharding_index:
|
||||
continue
|
||||
if first_sharding_index is None:
|
||||
first_dim_spec = _DimSpec([])
|
||||
else:
|
||||
first_dim_spec = _DimSpec([first_sharding_index])
|
||||
|
||||
if second_sharding_index is None:
|
||||
second_dim_spec = _DimSpec([])
|
||||
else:
|
||||
second_dim_spec = _DimSpec([second_sharding_index])
|
||||
|
||||
replica_dim_spec = _DimSpec([])
|
||||
sharding_sequence = [first_dim_spec, second_dim_spec, replica_dim_spec, replica_dim_spec]
|
||||
sharding_spec = ShardingSpec(device_mesh=device_mesh,
|
||||
entire_shape=entire_shape,
|
||||
sharding_sequence=sharding_sequence)
|
||||
strategy_name = str(sharding_spec.sharding_sequence)
|
||||
sharding_strategy = ShardingStrategy(name=strategy_name, output_sharding_spec=sharding_spec)
|
||||
strategies_vector_for_input.append(sharding_strategy)
|
||||
setattr(nodes[1], 'strategies_vector', strategies_vector_for_input)
|
||||
|
||||
# generate bn strategy
|
||||
strategies_vector = StrategiesVector(node=nodes[2])
|
||||
bn_handler = BatchNormHandler(
|
||||
node=nodes[2],
|
||||
device_mesh=device_mesh,
|
||||
strategies_vector=strategies_vector,
|
||||
)
|
||||
bn_handler.register_strategy()
|
||||
# ['RS0 = RS0 x S0', 'S1S0 = RS0 x S0', 'RS1 = RS1 x S1', 'S0S1 = RS1 x S1', 'RR = RR x R', 'S0R = RR x R', 'S1R = RR x R', 'S01R = RR x R', 'RS01 = RS01 x S01',
|
||||
# 'S0R = S0R x R WITH SYNC_BN', 'S1R = S1R x R WITH SYNC_BN', 'S0S1 = S0S1 x S1 WITH SYNC_BN', 'S1S0 = S1S0 x S0 WITH SYNC_BN', 'S01R = S01R x R WITH SYNC_BN']
|
||||
strategy_name_list = [strategy.name for strategy in bn_handler.strategies_vector]
|
||||
|
||||
# RS = RS x S and strategies based on it, such as
|
||||
# SS = RS x S
|
||||
assert 'RS0 = RS0 x S0' in strategy_name_list
|
||||
assert 'S1S0 = RS0 x S0' in strategy_name_list
|
||||
assert 'RS1 = RS1 x S1' in strategy_name_list
|
||||
assert 'S0S1 = RS1 x S1' in strategy_name_list
|
||||
|
||||
# RR = RR x R and strategies based on it, such as
|
||||
# SR = SR x R
|
||||
assert 'RR = RR x R' in strategy_name_list
|
||||
assert 'S0R = RR x R' in strategy_name_list
|
||||
assert 'S1R = RR x R' in strategy_name_list
|
||||
assert 'S01R = RR x R' in strategy_name_list
|
||||
|
||||
# RS01 = RS01 x S01
|
||||
assert 'RS01 = RS01 x S01' in strategy_name_list
|
||||
|
||||
# SR = SR x R WITH SYNC_BN
|
||||
assert 'S0R = S0R x R WITH SYNC_BN' in strategy_name_list
|
||||
assert 'S1R = S1R x R WITH SYNC_BN' in strategy_name_list
|
||||
|
||||
# SS = SS x S WITH SYNC_BN
|
||||
assert 'S0S1 = S0S1 x S1 WITH SYNC_BN' in strategy_name_list
|
||||
assert 'S1S0 = S1S0 x S0 WITH SYNC_BN' in strategy_name_list
|
||||
|
||||
# S01R = S01R x R WITH SYNC_BN
|
||||
assert 'S01R = S01R x R WITH SYNC_BN' in strategy_name_list
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_bn_handler()
|
|
@ -1,75 +0,0 @@
|
|||
from cProfile import run
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.fx import GraphModule
|
||||
|
||||
from colossalai.auto_parallel.tensor_shard.deprecated.options import SolverOptions
|
||||
from colossalai.auto_parallel.tensor_shard.deprecated.strategies_constructor import StrategiesConstructor
|
||||
from colossalai.device.device_mesh import DeviceMesh
|
||||
from colossalai.fx.tracer.tracer import ColoTracer
|
||||
from colossalai.testing.pytest_wrapper import run_on_environment_flag
|
||||
|
||||
|
||||
class ConvModel(nn.Module):
|
||||
|
||||
def __init__(self, c_in, c_out):
|
||||
super().__init__()
|
||||
self.conv1 = nn.Conv2d(c_in, c_out, kernel_size=3, padding=1)
|
||||
self.conv2 = nn.Conv2d(c_in, c_out, kernel_size=3, padding=1, stride=2)
|
||||
|
||||
def forward(self, x):
|
||||
x1 = self.conv1(x)
|
||||
x2 = x1 + 1
|
||||
x1 = torch.reshape(x1, [1, -1, 64, 1])
|
||||
x3 = self.conv2(x1)
|
||||
x3 = torch.reshape(x3, [4, 1, 64, -1])
|
||||
x = x1 + x3
|
||||
|
||||
return x
|
||||
|
||||
|
||||
@run_on_environment_flag(name='AUTO_PARALLEL')
|
||||
def test_conv_handler():
|
||||
physical_mesh_id = torch.arange(0, 4)
|
||||
mesh_shape = (2, 2)
|
||||
# [[0, 1]
|
||||
# [2, 3]]
|
||||
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape)
|
||||
|
||||
tracer = ColoTracer()
|
||||
model = ConvModel(16, 32)
|
||||
input_sample = {'x': torch.rand(4, 16, 64, 64).to('meta')}
|
||||
# graph():
|
||||
# %x : torch.Tensor [#users=1] = placeholder[target=x]
|
||||
# %conv1 : [#users=2] = call_module[target=conv1](args = (%x,), kwargs = {})
|
||||
# %add : [#users=0] = call_function[target=operator.add](args = (%conv1, 1), kwargs = {})
|
||||
# %reshape : [#users=2] = call_function[target=torch.reshape](args = (%conv1, [1, -1, 64, 1]), kwargs = {})
|
||||
# %conv2 : [#users=1] = call_module[target=conv2](args = (%reshape,), kwargs = {})
|
||||
# %reshape_1 : [#users=1] = call_function[target=torch.reshape](args = (%conv2, [4, 1, 64, -1]), kwargs = {})
|
||||
# %add_1 : [#users=1] = call_function[target=operator.add](args = (%reshape, %reshape_1), kwargs = {})
|
||||
# return add_1
|
||||
graph = tracer.trace(root=model, meta_args=input_sample)
|
||||
gm = GraphModule(model, graph, model.__class__.__name__)
|
||||
# [x, conv1, add, reshape, conv2, reshape_1, add_1, output]
|
||||
nodes = [node for node in gm.graph.nodes]
|
||||
solver_options = SolverOptions(fast=True)
|
||||
strategies_constructor = StrategiesConstructor(graph, device_mesh, solver_options)
|
||||
|
||||
strategies_constructor.build_strategies_and_cost()
|
||||
strategy_map = strategies_constructor.strategy_map
|
||||
# check a tensor add with a scalar case
|
||||
conv1_strategies = strategy_map[nodes[1]]
|
||||
add_strategies = strategy_map[nodes[2]]
|
||||
add_strategies_cover_list = [strategy.input_shardings[0].sharding_sequence for strategy in add_strategies]
|
||||
for strategy in conv1_strategies:
|
||||
assert strategy.output_sharding_spec.sharding_sequence in add_strategies_cover_list
|
||||
|
||||
# check two tensors element-wise add case
|
||||
add_1_strategies = strategy_map[nodes[6]]
|
||||
assert len(add_1_strategies) == 25
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_conv_handler()
|
|
@ -1,54 +0,0 @@
|
|||
import pytest
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.fx import GraphModule
|
||||
|
||||
from colossalai.auto_parallel.tensor_shard.deprecated.options import SolverOptions
|
||||
from colossalai.auto_parallel.tensor_shard.deprecated.strategies_constructor import StrategiesConstructor
|
||||
from colossalai.device.device_mesh import DeviceMesh
|
||||
from colossalai.fx.tracer.tracer import ColoTracer
|
||||
from colossalai.testing.pytest_wrapper import run_on_environment_flag
|
||||
|
||||
|
||||
class MatmulModel(nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
def forward(self, x1, x2):
|
||||
x = torch.matmul(x1, x2)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
@run_on_environment_flag(name='AUTO_PARALLEL')
|
||||
def test_conv_handler():
|
||||
physical_mesh_id = torch.arange(0, 4)
|
||||
mesh_shape = (2, 2)
|
||||
# [[0, 1]
|
||||
# [2, 3]]
|
||||
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape)
|
||||
|
||||
tracer = ColoTracer()
|
||||
model = MatmulModel()
|
||||
input_sample = {'x1': torch.rand(4, 4, 8).to('meta'), 'x2': torch.rand(4, 1, 8, 4).to('meta')}
|
||||
# graph():
|
||||
# %x1 : torch.Tensor [#users=1] = placeholder[target=x1]
|
||||
# %x2 : torch.Tensor [#users=1] = placeholder[target=x2]
|
||||
# %matmul : [#users=1] = call_function[target=torch.matmul](args = (%x1, %x2), kwargs = {})
|
||||
# return matmul
|
||||
graph = tracer.trace(root=model, meta_args=input_sample)
|
||||
gm = GraphModule(model, graph, model.__class__.__name__)
|
||||
# [x1, x2, matmul, output]
|
||||
nodes = [node for node in gm.graph.nodes]
|
||||
solver_options = SolverOptions(fast=True)
|
||||
strategies_constructor = StrategiesConstructor(graph, device_mesh, solver_options)
|
||||
|
||||
strategies_constructor.build_strategies_and_cost()
|
||||
strategy_map = strategies_constructor.strategy_map
|
||||
matmul_strategies = strategy_map[nodes[2]]
|
||||
assert len(matmul_strategies) == 30
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_conv_handler()
|
|
@ -1,90 +0,0 @@
|
|||
import pytest
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.fx import GraphModule
|
||||
|
||||
from colossalai.auto_parallel.tensor_shard.deprecated.op_handler.conv_handler import ConvHandler
|
||||
from colossalai.auto_parallel.tensor_shard.deprecated.options import SolverOptions
|
||||
from colossalai.auto_parallel.tensor_shard.deprecated.sharding_strategy import ShardingStrategy, StrategiesVector
|
||||
from colossalai.auto_parallel.tensor_shard.deprecated.strategies_constructor import StrategiesConstructor
|
||||
from colossalai.device.device_mesh import DeviceMesh
|
||||
from colossalai.fx.proxy import ColoProxy
|
||||
from colossalai.fx.tracer.tracer import ColoTracer
|
||||
from colossalai.tensor.sharding_spec import ShardingSpec, _DimSpec
|
||||
|
||||
|
||||
class ConvModel(nn.Module):
|
||||
|
||||
def __init__(self, c_in, c_out):
|
||||
super().__init__()
|
||||
self.conv = nn.Conv2d(c_in, c_out, kernel_size=3)
|
||||
|
||||
def forward(self, x):
|
||||
x = x * 2
|
||||
x = self.conv(x)
|
||||
return x
|
||||
|
||||
|
||||
def test_conv_handler():
|
||||
physical_mesh_id = torch.arange(0, 4)
|
||||
mesh_shape = (2, 2)
|
||||
# [[0, 1]
|
||||
# [2, 3]]
|
||||
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape)
|
||||
entire_shape = torch.Size((4, 16, 64, 64))
|
||||
|
||||
tracer = ColoTracer()
|
||||
model = ConvModel(16, 32)
|
||||
input_sample = {'x': torch.rand(4, 16, 64, 64).to('meta')}
|
||||
# graph():
|
||||
# %x : torch.Tensor [#users=1] = placeholder[target=x]
|
||||
# %mul : [#users=1] = call_function[target=operator.mul](args = (%x, 2), kwargs = {})
|
||||
# %conv_weight : [#users=1] = get_attr[target=conv.weight]
|
||||
# %conv_bias : [#users=1] = get_attr[target=conv.bias]
|
||||
# %conv2d : [#users=1] = call_function[target=torch.conv2d](args = (%mul, %conv_weight), kwargs = {groups: 1, dilation: (1, 1), stride: (1, 1), padding: (0, 0)})
|
||||
# %view : [#users=1] = call_method[target=view](args = (%conv_bias, [1, -1, 1, 1]), kwargs = {})
|
||||
# %add : [#users=1] = call_function[target=operator.add](args = (%conv2d, %view), kwargs = {})
|
||||
# return add
|
||||
graph = tracer.trace(root=model, meta_args=input_sample)
|
||||
gm = GraphModule(model, graph, model.__class__.__name__)
|
||||
gm.recompile()
|
||||
solver_options = SolverOptions(fast=True)
|
||||
strategies_constructor = StrategiesConstructor(graph, device_mesh, solver_options)
|
||||
|
||||
strategies_constructor.build_strategies_and_cost()
|
||||
conv_node = list(graph.nodes)[4]
|
||||
# ['S0S1 = S0R x RS1', 'S1S0 = S1R x RS0', 'S0R = S0R x RR', 'S1R = S1R x RR', 'S0R = S0S1 x S1R', 'S1R = S1S0 x S0R', 'RS1 = RS0 x S0S1', 'RS0 = RS1 x S1S0', 'RR = RS0 x S0R', 'RR = RS1 x S1R', 'RS0 = RR x RS0', 'RS1 = RR x RS1', 'RR = RR x RR', 'S01R = S01R x RR', 'RR = RS01 x S01R']
|
||||
strategy_name_list = [strategy.name for strategy in conv_node.strategies_vector]
|
||||
|
||||
# SS = SR x RS
|
||||
assert 'S0S1 = S0R x RS1' in strategy_name_list
|
||||
assert 'S1S0 = S1R x RS0' in strategy_name_list
|
||||
|
||||
# SR = SS x SR
|
||||
assert 'S0R = S0S1 x S1R' in strategy_name_list
|
||||
assert 'S1R = S1S0 x S0R' in strategy_name_list
|
||||
|
||||
# RS = RS x SS
|
||||
assert 'RS0 = RS1 x S1S0' in strategy_name_list
|
||||
assert 'RS1 = RS0 x S0S1' in strategy_name_list
|
||||
|
||||
# RS = RR x RS
|
||||
assert 'RS0 = RR x RS0' in strategy_name_list
|
||||
assert 'RS1 = RR x RS1' in strategy_name_list
|
||||
|
||||
# RR= RR x RR
|
||||
assert 'RR = RR x RR' in strategy_name_list
|
||||
|
||||
# SR = SR x RR
|
||||
assert 'S0R = S0R x RR' in strategy_name_list
|
||||
assert 'S1R = S1R x RR' in strategy_name_list
|
||||
assert 'S01R = S01R x RR' in strategy_name_list
|
||||
|
||||
# RR = RS x SR
|
||||
assert 'RR = RS0 x S0R' in strategy_name_list
|
||||
assert 'RR = RS1 x S1R' in strategy_name_list
|
||||
assert 'RR = RS01 x S01R' in strategy_name_list
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_conv_handler()
|
|
@ -1,83 +0,0 @@
|
|||
import pytest
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.fx import GraphModule
|
||||
|
||||
from colossalai.auto_parallel.tensor_shard.deprecated.op_handler.dot_handler import DotHandler
|
||||
from colossalai.auto_parallel.tensor_shard.deprecated.options import SolverOptions
|
||||
from colossalai.auto_parallel.tensor_shard.deprecated.sharding_strategy import ShardingStrategy, StrategiesVector
|
||||
from colossalai.auto_parallel.tensor_shard.deprecated.strategies_constructor import StrategiesConstructor
|
||||
from colossalai.device.device_mesh import DeviceMesh
|
||||
from colossalai.fx.proxy import ColoProxy
|
||||
from colossalai.fx.tracer.tracer import ColoTracer
|
||||
from colossalai.tensor.sharding_spec import ShardingSpec, _DimSpec
|
||||
|
||||
|
||||
class LinearModel(nn.Module):
|
||||
|
||||
def __init__(self, in_features, out_features):
|
||||
super().__init__()
|
||||
self.linear = nn.Linear(in_features, out_features)
|
||||
|
||||
def forward(self, x):
|
||||
x = x * 2
|
||||
x = self.linear(x)
|
||||
return x
|
||||
|
||||
|
||||
@pytest.mark.skip('F.linear is not supported in deprecated handler')
|
||||
def test_dot_handler():
|
||||
physical_mesh_id = torch.arange(0, 4)
|
||||
mesh_shape = (2, 2)
|
||||
# [[0, 1]
|
||||
# [2, 3]]
|
||||
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape)
|
||||
entire_shape = torch.Size((4, 8))
|
||||
|
||||
tracer = ColoTracer()
|
||||
model = LinearModel(8, 16)
|
||||
input_sample = {'x': torch.rand(4, 8).to('meta')}
|
||||
# graph():
|
||||
# %x : torch.Tensor [#users=1] = placeholder[target=x]
|
||||
# %mul : [#users=1] = call_function[target=operator.mul](args = (%x, 2), kwargs = {})
|
||||
# %linear_weight : [#users=1] = get_attr[target=linear.weight]
|
||||
# %linear_bias : [#users=1] = get_attr[target=linear.bias]
|
||||
# %linear : [#users=1] = call_function[target=torch._C._nn.linear](args = (%mul, %linear_weight), kwargs = {})
|
||||
# %add : [#users=1] = call_function[target=operator.add](args = (%linear, %linear_bias), kwargs = {})
|
||||
# return add
|
||||
graph = tracer.trace(root=model, meta_args=input_sample)
|
||||
|
||||
gm = GraphModule(model, graph, model.__class__.__name__)
|
||||
gm.recompile()
|
||||
solver_options = SolverOptions(fast=True)
|
||||
strategies_constructor = StrategiesConstructor(graph, device_mesh, solver_options)
|
||||
|
||||
strategies_constructor.build_strategies_and_cost()
|
||||
linear_node = list(graph.nodes)[4]
|
||||
|
||||
# ['S0S1 = S0R x RS1', 'S1S0 = S1R x RS0', 'S0R = S0S1 x S1R', 'S1R = S1S0 x S0R', 'RS1 = RS0 x S0S1', 'RS0 = RS1 x S1S0', 'RS0 = RR x RS0', 'RS1 = RR x RS1', 'RR = RR x RR']
|
||||
strategy_name_list = [strategy.name for strategy in linear_node.strategies_vector]
|
||||
|
||||
# SS = SR x RS
|
||||
assert 'S0S1 = S0R x RS1' in strategy_name_list
|
||||
assert 'S1S0 = S1R x RS0' in strategy_name_list
|
||||
|
||||
# SR = SS x SR
|
||||
assert 'S0R = S0S1 x S1R' in strategy_name_list
|
||||
assert 'S1R = S1S0 x S0R' in strategy_name_list
|
||||
|
||||
# RS = RS x SS
|
||||
assert 'RS0 = RS1 x S1S0' in strategy_name_list
|
||||
assert 'RS1 = RS0 x S0S1' in strategy_name_list
|
||||
|
||||
# RR = RS x SR
|
||||
assert 'RR = RS0 x S0R' in strategy_name_list
|
||||
assert 'RR = RS1 x S1R' in strategy_name_list
|
||||
|
||||
# RS= RR x RS
|
||||
assert 'RS0 = RR x RS0' in strategy_name_list
|
||||
assert 'RS1 = RR x RS1' in strategy_name_list
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_dot_handler()
|
|
@ -1,70 +0,0 @@
|
|||
import torch
|
||||
from torch.fx import GraphModule
|
||||
import torch.nn as nn
|
||||
import pytest
|
||||
from colossalai.auto_parallel.tensor_shard.deprecated import sharding_strategy
|
||||
|
||||
from colossalai.fx.proxy import ColoProxy
|
||||
from colossalai.fx.tracer.tracer import ColoTracer
|
||||
from colossalai.tensor.sharding_spec import ShardingSpec, _DimSpec
|
||||
from colossalai.auto_parallel.tensor_shard.deprecated.op_handler.layer_norm_handler import LayerNormHandler
|
||||
from colossalai.auto_parallel.tensor_shard.deprecated.sharding_strategy import ShardingStrategy, StrategiesVector
|
||||
from colossalai.device.device_mesh import DeviceMesh
|
||||
|
||||
|
||||
class LNModel(nn.Module):
|
||||
|
||||
def __init__(self, c):
|
||||
super().__init__()
|
||||
self.ln = nn.LayerNorm(c)
|
||||
|
||||
def forward(self, x):
|
||||
x = x * 2
|
||||
x = self.ln(x)
|
||||
return x
|
||||
|
||||
|
||||
def test_bn_handler():
|
||||
physical_mesh_id = torch.arange(0, 4)
|
||||
mesh_shape = (2, 2)
|
||||
# [[0, 1]
|
||||
# [2, 3]]
|
||||
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape)
|
||||
entire_shape = torch.Size((4, 4, 128))
|
||||
|
||||
tracer = ColoTracer()
|
||||
model = LNModel(128)
|
||||
input_sample = {'x': torch.rand(4, 4, 128).to('meta')}
|
||||
# graph():
|
||||
# %x : torch.Tensor [#users=1] = placeholder[target=x]
|
||||
# %mul : [#users=1] = call_function[target=operator.mul](args = (%x, 2), kwargs = {})
|
||||
# %ln : [#users=1] = call_module[target=ln](args = (%mul,), kwargs = {})
|
||||
# return ln
|
||||
graph = tracer.trace(root=model, meta_args=input_sample)
|
||||
gm = GraphModule(model, graph, model.__class__.__name__)
|
||||
gm.recompile()
|
||||
# [x, mul, ln, output]
|
||||
nodes = [node for node in gm.graph.nodes]
|
||||
sharding_spec_for_input = ShardingSpec(device_mesh, entire_shape, {})
|
||||
sharding_strategy_for_input = ShardingStrategy('node_1', sharding_spec_for_input)
|
||||
strategies_vector_for_input = StrategiesVector(nodes[1])
|
||||
strategies_vector_for_input.append(sharding_strategy_for_input)
|
||||
setattr(nodes[1], 'strategies_vector', strategies_vector_for_input)
|
||||
|
||||
# generate bn strategy
|
||||
strategies_vector = StrategiesVector(node=nodes[2])
|
||||
ln_handler = LayerNormHandler(
|
||||
node=nodes[2],
|
||||
device_mesh=device_mesh,
|
||||
strategies_vector=strategies_vector,
|
||||
)
|
||||
ln_handler.register_strategy()
|
||||
# ['[S0, R, R] = [S0, R, R] x [R]', '[R, S0, R] = [R, S0, R] x [R]', '[S1, R, R] = [S1, R, R] x [R]', '[R, S1, R] = [R, S1, R] x [R]',
|
||||
# '[S0, S1, R] = [S0, S1, R] x [R]', '[S1, S0, R] = [S1, S0, R] x [R]', '[S01, R, R] = [S01, R, R] x [R]', '[R, S01, R] = [R, S01, R] x [R]', 'RR = RR x R']
|
||||
strategy_name_list = [strategy.name for strategy in ln_handler.strategies_vector]
|
||||
|
||||
assert len(strategy_name_list) == 9
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_bn_handler()
|
|
@ -1,59 +0,0 @@
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.fx import GraphModule
|
||||
|
||||
from colossalai.auto_parallel.tensor_shard.deprecated.options import SolverOptions
|
||||
from colossalai.auto_parallel.tensor_shard.deprecated.strategies_constructor import StrategiesConstructor
|
||||
from colossalai.device.device_mesh import DeviceMesh
|
||||
from colossalai.fx.tracer.tracer import ColoTracer
|
||||
|
||||
|
||||
class ConvModel(nn.Module):
|
||||
|
||||
def __init__(self, c_in, c_out):
|
||||
super().__init__()
|
||||
self.conv = nn.Conv2d(c_in, c_out, kernel_size=3)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.conv(x)
|
||||
x = torch.flatten(x)
|
||||
return x
|
||||
|
||||
|
||||
def test_conv_handler():
|
||||
physical_mesh_id = torch.arange(0, 4)
|
||||
mesh_shape = (2, 2)
|
||||
# [[0, 1]
|
||||
# [2, 3]]
|
||||
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape)
|
||||
|
||||
tracer = ColoTracer()
|
||||
model = ConvModel(16, 32)
|
||||
input_sample = {'x': torch.rand(4, 16, 64, 64).to('meta')}
|
||||
# graph():
|
||||
# %x : torch.Tensor [#users=1] = placeholder[target=x]
|
||||
# %conv_weight : [#users=1] = get_attr[target=conv.weight]
|
||||
# %conv_bias : [#users=1] = get_attr[target=conv.bias]
|
||||
# %conv2d : [#users=1] = call_function[target=torch.conv2d](args = (%x, %conv_weight), kwargs = {groups: 1, dilation: (1, 1), stride: (1, 1), padding: (0, 0)})
|
||||
# %view : [#users=1] = call_method[target=view](args = (%conv_bias, [1, -1, 1, 1]), kwargs = {})
|
||||
# %add : [#users=1] = call_function[target=operator.add](args = (%conv2d, %view), kwargs = {})
|
||||
# %flatten : [#users=1] = call_function[target=torch.flatten](args = (%add,), kwargs = {})
|
||||
# return flatten
|
||||
graph = tracer.trace(root=model, meta_args=input_sample)
|
||||
gm = GraphModule(model, graph, model.__class__.__name__)
|
||||
# [x, conv, flatten, output]
|
||||
nodes = [node for node in gm.graph.nodes]
|
||||
solver_options = SolverOptions(fast=True)
|
||||
strategies_constructor = StrategiesConstructor(graph, device_mesh, solver_options)
|
||||
|
||||
strategies_constructor.build_strategies_and_cost()
|
||||
strategy_map = strategies_constructor.strategy_map
|
||||
add_strategies = strategy_map[nodes[5]]
|
||||
flatten_strategies = strategy_map[nodes[6]]
|
||||
flatten_strategies_cover_list = [strategy.input_shardings[0].sharding_sequence for strategy in flatten_strategies]
|
||||
for strategy in add_strategies:
|
||||
assert strategy.output_sharding_spec.sharding_sequence in flatten_strategies_cover_list
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_conv_handler()
|
|
@ -1,66 +0,0 @@
|
|||
import torch
|
||||
from torch.fx import GraphModule
|
||||
import torch.nn as nn
|
||||
import pytest
|
||||
|
||||
from colossalai.auto_parallel.tensor_shard.deprecated.options import SolverOptions
|
||||
from colossalai.auto_parallel.tensor_shard.deprecated.strategies_constructor import StrategiesConstructor
|
||||
from colossalai.fx.tracer.tracer import ColoTracer
|
||||
from colossalai.device.device_mesh import DeviceMesh
|
||||
from colossalai.testing.pytest_wrapper import run_on_environment_flag
|
||||
|
||||
|
||||
class ConvModel(nn.Module):
|
||||
|
||||
def __init__(self, dim_in, dim_out):
|
||||
super().__init__()
|
||||
self.dim_in = dim_in
|
||||
self.dim_out = dim_out
|
||||
|
||||
def forward(self, condition, x, y):
|
||||
output = torch.where(condition, x, y)
|
||||
|
||||
return output
|
||||
|
||||
|
||||
@run_on_environment_flag(name='AUTO_PARALLEL')
|
||||
def test_where_handler():
|
||||
physical_mesh_id = torch.arange(0, 4)
|
||||
mesh_shape = (2, 2)
|
||||
# [[0, 1]
|
||||
# [2, 3]]
|
||||
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape)
|
||||
|
||||
tracer = ColoTracer()
|
||||
model = ConvModel(16, 32)
|
||||
input_sample = {
|
||||
'condition': torch.rand(16, 32).to('meta'),
|
||||
'x': torch.rand(16, 32).to('meta'),
|
||||
'y': torch.rand(16, 32).to('meta')
|
||||
}
|
||||
# graph():
|
||||
# %condition : torch.Tensor [#users=1] = placeholder[target=condition]
|
||||
# %x : torch.Tensor [#users=1] = placeholder[target=x]
|
||||
# %y : torch.Tensor [#users=1] = placeholder[target=y]
|
||||
# %where : [#users=1] = call_function[target=torch.where](args = (%condition, %x, %y), kwargs = {})
|
||||
# return where
|
||||
graph = tracer.trace(root=model, meta_args=input_sample)
|
||||
gm = GraphModule(model, graph, model.__class__.__name__)
|
||||
|
||||
# [condition, x, y, where, output]
|
||||
nodes = [node for node in gm.graph.nodes]
|
||||
solver_options = SolverOptions(fast=True)
|
||||
strategies_constructor = StrategiesConstructor(graph, device_mesh, solver_options)
|
||||
|
||||
strategies_constructor.build_strategies_and_cost()
|
||||
strategy_map = strategies_constructor.strategy_map
|
||||
# check a tensor add with a scalar case
|
||||
where_node = strategy_map[nodes[3]]
|
||||
# ['[S0, S1] = [S0, S1] x [S0, S1] x [S0, S1]', '[S1, S0] = [S1, S0] x [S1, S0] x [S1, S0]', '[S01, R] = [S01, R] x [S01, R] x [S01, R]',
|
||||
# '[R, S01] = [R, S01] x [R, S01] x [R, S01]', '[S0, R] = [S0, R] x [S0, R] x [S0, R]', '[R, S0] = [R, S0] x [R, S0] x [R, S0]',
|
||||
# '[S1, R] = [S1, R] x [S1, R] x [S1, R]', '[R, S1] = [R, S1] x [R, S1] x [R, S1]', '[R, R] = [R, R] x [R, R] x [R, R]']
|
||||
assert len(where_node) == 9
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_where_handler()
|
|
@ -1,86 +0,0 @@
|
|||
from functools import partial
|
||||
import pytest
|
||||
import torch
|
||||
import torch.multiprocessing as mp
|
||||
from torch.fx import GraphModule
|
||||
import torch.nn as nn
|
||||
import pytest
|
||||
from colossalai.initialize import launch
|
||||
from colossalai.utils import free_port
|
||||
from colossalai.testing import rerun_if_address_is_in_use
|
||||
from colossalai.logging import disable_existing_loggers
|
||||
from colossalai.auto_parallel.tensor_shard.deprecated.cost_graph import CostGraph
|
||||
from colossalai.auto_parallel.tensor_shard.deprecated.graph_analysis import GraphAnalyser
|
||||
from colossalai.auto_parallel.tensor_shard.deprecated.strategies_constructor import StrategiesConstructor
|
||||
|
||||
from colossalai.fx.tracer.tracer import ColoTracer
|
||||
from colossalai.device.device_mesh import DeviceMesh
|
||||
from colossalai.fx.passes.experimental.adding_shape_consistency_pass import shape_consistency_pass, solution_annotatation_pass
|
||||
from colossalai.auto_parallel.tensor_shard.deprecated import Solver
|
||||
from colossalai.auto_parallel.tensor_shard.deprecated.options import SolverOptions
|
||||
from colossalai.testing.pytest_wrapper import run_on_environment_flag
|
||||
|
||||
|
||||
class ConvModel(nn.Module):
|
||||
|
||||
def __init__(self, c_in, c_out):
|
||||
super().__init__()
|
||||
self.conv = nn.Conv2d(c_in, c_out, kernel_size=3, padding=1, bias=False)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.conv(x)
|
||||
return x
|
||||
|
||||
|
||||
def check_apply(rank, world_size, port):
|
||||
disable_existing_loggers()
|
||||
launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||
input = torch.rand(4, 4, 4, 4).cuda()
|
||||
physical_mesh_id = torch.arange(0, 4)
|
||||
mesh_shape = (2, 2)
|
||||
# [[0, 1]
|
||||
# [2, 3]]
|
||||
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True)
|
||||
entire_shape = torch.Size((4, 4, 8, 8))
|
||||
|
||||
tracer = ColoTracer()
|
||||
model = ConvModel(4, 4).cuda()
|
||||
origin_output = model(input)
|
||||
input_sample = {'x': torch.rand(4, 4, 4, 4).to('meta')}
|
||||
# graph():
|
||||
# %x : torch.Tensor [#users=1] = placeholder[target=x]
|
||||
# %conv : [#users=1] = call_module[target=conv](args = (%mul,), kwargs = {})
|
||||
# return conv
|
||||
graph = tracer.trace(root=model, meta_args=input_sample)
|
||||
gm = GraphModule(model, graph, model.__class__.__name__)
|
||||
gm.recompile()
|
||||
solver_options = SolverOptions(fast=True)
|
||||
strategies_constructor = StrategiesConstructor(graph, device_mesh, solver_options)
|
||||
strategies_constructor.build_strategies_and_cost()
|
||||
|
||||
cost_graph = CostGraph(strategies_constructor.leaf_strategies)
|
||||
cost_graph.simplify_graph()
|
||||
graph_analyser = GraphAnalyser(gm)
|
||||
solver = Solver(gm.graph, strategies_constructor, cost_graph, graph_analyser)
|
||||
ret = solver.call_solver_serialized_args()
|
||||
solution = list(ret[0])
|
||||
sharding_spec_dict, origin_spec_dict = solution_annotatation_pass(gm, solution, device_mesh)
|
||||
shape_consistency_pass(gm)
|
||||
gm.recompile()
|
||||
nodes = [node for node in gm.graph.nodes]
|
||||
# TODO: wrap the gm to avoid the influence of the user training code
|
||||
output = gm(input, sharding_spec_dict, origin_spec_dict)
|
||||
assert output.equal(origin_output)
|
||||
|
||||
|
||||
@run_on_environment_flag(name='AUTO_PARALLEL')
|
||||
@pytest.mark.dist
|
||||
@rerun_if_address_is_in_use()
|
||||
def test_apply():
|
||||
world_size = 4
|
||||
run_func = partial(check_apply, world_size=world_size, port=free_port())
|
||||
mp.spawn(run_func, nprocs=world_size)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_apply()
|
|
@ -1,79 +0,0 @@
|
|||
from copy import deepcopy
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.fx import GraphModule
|
||||
|
||||
from colossalai.auto_parallel.tensor_shard.deprecated import Solver
|
||||
from colossalai.auto_parallel.tensor_shard.deprecated.cost_graph import CostGraph
|
||||
from colossalai.auto_parallel.tensor_shard.deprecated.graph_analysis import GraphAnalyser
|
||||
from colossalai.auto_parallel.tensor_shard.deprecated.options import SolverOptions
|
||||
from colossalai.auto_parallel.tensor_shard.deprecated.strategies_constructor import StrategiesConstructor
|
||||
from colossalai.device.device_mesh import DeviceMesh
|
||||
from colossalai.fx.tracer.tracer import ColoTracer
|
||||
from colossalai.tensor.shape_consistency import ShapeConsistencyManager
|
||||
from colossalai.testing.pytest_wrapper import run_on_environment_flag
|
||||
|
||||
|
||||
class ConvModel(nn.Module):
|
||||
|
||||
def __init__(self, c_in, c_out):
|
||||
super().__init__()
|
||||
self.conv1 = nn.Conv2d(c_in, c_out, kernel_size=3)
|
||||
self.conv2 = nn.Conv2d(c_out, c_out, kernel_size=3)
|
||||
self.conv3 = nn.Conv2d(c_out, c_out, kernel_size=3)
|
||||
self.relu = nn.ReLU()
|
||||
|
||||
def forward(self, x):
|
||||
x = x * 2
|
||||
x = self.conv1(x)
|
||||
x = self.conv2(x)
|
||||
x = x / 2
|
||||
x = self.conv3(x)
|
||||
x = self.relu(x)
|
||||
return x
|
||||
|
||||
|
||||
@run_on_environment_flag(name='AUTO_PARALLEL')
|
||||
def test_solver():
|
||||
physical_mesh_id = torch.arange(0, 4)
|
||||
mesh_shape = (2, 2)
|
||||
# [[0, 1]
|
||||
# [2, 3]]
|
||||
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape)
|
||||
shape_consistency_manager = ShapeConsistencyManager()
|
||||
|
||||
tracer = ColoTracer()
|
||||
model = ConvModel(16, 32)
|
||||
input_sample = {'x': torch.rand(4, 16, 64, 64).to('meta')}
|
||||
|
||||
# graph():
|
||||
# %x : torch.Tensor [#users=1] = placeholder[target=x]
|
||||
# %mul : [#users=1] = call_function[target=operator.mul](args = (%x, 2), kwargs = {})
|
||||
# %conv1 : [#users=1] = call_module[target=conv1](args = (%mul,), kwargs = {})
|
||||
# %conv2 : [#users=1] = call_module[target=conv2](args = (%conv1,), kwargs = {})
|
||||
# %truediv : [#users=1] = call_function[target=operator.truediv](args = (%conv2, 2), kwargs = {})
|
||||
# %conv3 : [#users=1] = call_module[target=conv3](args = (%truediv,), kwargs = {})
|
||||
# %relu : [#users=1] = call_module[target=relu](args = (%conv3,), kwargs = {})
|
||||
# return relu
|
||||
graph = tracer.trace(root=model, meta_args=input_sample)
|
||||
gm = GraphModule(model, graph, model.__class__.__name__)
|
||||
|
||||
solver_options = SolverOptions(fast=True)
|
||||
strategies_constructor = StrategiesConstructor(graph, device_mesh, solver_options)
|
||||
strategies_constructor.build_strategies_and_cost()
|
||||
|
||||
cost_graph = CostGraph(strategies_constructor.leaf_strategies)
|
||||
cost_graph.simplify_graph()
|
||||
graph_analyser = GraphAnalyser(gm)
|
||||
solver = Solver(gm.graph, strategies_constructor, cost_graph, graph_analyser)
|
||||
ret = solver.call_solver_serialized_args()
|
||||
|
||||
# [ 0 0 13 13 13 13 13 0]
|
||||
strategies_combination_list = ret[0]
|
||||
assert solver.leaf_strategies[2][13].name == 'S01R = S01R x RR'
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_solver()
|
|
@ -1,81 +0,0 @@
|
|||
import torch
|
||||
from torch.fx import GraphModule
|
||||
import torch.nn as nn
|
||||
import pytest
|
||||
|
||||
from colossalai.fx.tracer.tracer import ColoTracer
|
||||
from colossalai.auto_parallel.tensor_shard.deprecated.sharding_strategy import ShardingStrategy, StrategiesVector
|
||||
from colossalai.tensor.shape_consistency import ShapeConsistencyManager
|
||||
from colossalai.device.device_mesh import DeviceMesh
|
||||
from colossalai.auto_parallel.tensor_shard.deprecated.strategies_constructor import StrategiesConstructor
|
||||
from colossalai.auto_parallel.tensor_shard.deprecated.cost_graph import CostGraph
|
||||
from copy import deepcopy
|
||||
from colossalai.auto_parallel.tensor_shard.deprecated import Solver
|
||||
import transformers
|
||||
from colossalai.auto_parallel.tensor_shard.deprecated.constants import *
|
||||
from colossalai.auto_parallel.tensor_shard.deprecated.graph_analysis import GraphAnalyser
|
||||
from colossalai.auto_parallel.tensor_shard.deprecated.options import SolverOptions
|
||||
from colossalai.testing.pytest_wrapper import run_on_environment_flag
|
||||
|
||||
BATCH_SIZE = 8
|
||||
SEQ_LENGHT = 8
|
||||
|
||||
|
||||
@run_on_environment_flag(name='AUTO_PARALLEL')
|
||||
def test_cost_graph():
|
||||
physical_mesh_id = torch.arange(0, 8)
|
||||
mesh_shape = (2, 4)
|
||||
# [[0, 1]
|
||||
# [2, 3]]
|
||||
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape)
|
||||
shape_consistency_manager = ShapeConsistencyManager()
|
||||
|
||||
tracer = ColoTracer()
|
||||
config = transformers.GPT2Config(n_position=1024, n_layer=1, n_head=12)
|
||||
model = transformers.GPT2LMHeadModel(config=config)
|
||||
input_ids = torch.zeros((BATCH_SIZE, SEQ_LENGHT), dtype=torch.int64)
|
||||
token_type_ids = torch.zeros((BATCH_SIZE, SEQ_LENGHT), dtype=torch.int64)
|
||||
attention_mask = torch.zeros((BATCH_SIZE, SEQ_LENGHT), dtype=torch.int64)
|
||||
kwargs = dict(input_ids=input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask)
|
||||
meta_args = {k: v.to('meta') for k, v in kwargs.items()}
|
||||
|
||||
graph = tracer.trace(root=model, meta_args=meta_args)
|
||||
gm = GraphModule(model, graph, model.__class__.__name__)
|
||||
gm.recompile()
|
||||
graph_analyser = GraphAnalyser(gm)
|
||||
liveness_list = graph_analyser.liveness_analysis()
|
||||
solver_options = SolverOptions(fast=True)
|
||||
strategies_constructor = StrategiesConstructor(graph, device_mesh, solver_options)
|
||||
print(graph)
|
||||
strategies_constructor.build_strategies_and_cost()
|
||||
for check_node, strategies_vector in strategies_constructor.strategy_map.items():
|
||||
print(check_node, len(strategies_vector))
|
||||
cost_graph = CostGraph(strategies_constructor.leaf_strategies)
|
||||
cost_graph.simplify_graph()
|
||||
# solver = Solver(gm.graph, strategies_constructor, cost_graph, graph_analyser, memory_budget=1620017824.0)
|
||||
solver = Solver(gm.graph, strategies_constructor, cost_graph, graph_analyser)
|
||||
|
||||
ret = solver.call_solver_serialized_args()
|
||||
print(ret)
|
||||
strategies_list = list(ret[0])
|
||||
print(strategies_list)
|
||||
computation_cost = 0
|
||||
communication_cost = 0
|
||||
memory_cost = 0
|
||||
nodes = [strategies_vector.node for strategies_vector in strategies_constructor.leaf_strategies]
|
||||
for index, node in enumerate(nodes):
|
||||
print(node.name, node.strategies_vector[strategies_list[index]].name)
|
||||
computation_cost += node.strategies_vector[strategies_list[index]].compute_cost
|
||||
communication_cost += node.strategies_vector[strategies_list[index]].communication_cost
|
||||
node_memory_cost = node.strategies_vector[strategies_list[index]].memory_cost
|
||||
if isinstance(node_memory_cost, tuple):
|
||||
node_memory_cost = node_memory_cost[0]
|
||||
memory_cost += node_memory_cost
|
||||
|
||||
print(f'computation cost is {computation_cost}')
|
||||
print(f'communication cost is {communication_cost}')
|
||||
print(f'memory cost is {memory_cost}')
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_cost_graph()
|
|
@ -1,94 +0,0 @@
|
|||
import torch
|
||||
from torch.fx import GraphModule
|
||||
import torch.nn as nn
|
||||
import pytest
|
||||
|
||||
from colossalai.fx.tracer.tracer import ColoTracer
|
||||
from colossalai.auto_parallel.tensor_shard.deprecated.sharding_strategy import ShardingStrategy, StrategiesVector
|
||||
from colossalai.tensor.shape_consistency import ShapeConsistencyManager
|
||||
from colossalai.device.device_mesh import DeviceMesh
|
||||
from colossalai.auto_parallel.tensor_shard.deprecated.strategies_constructor import StrategiesConstructor
|
||||
from colossalai.auto_parallel.tensor_shard.deprecated.cost_graph import CostGraph
|
||||
from copy import deepcopy
|
||||
from colossalai.auto_parallel.tensor_shard.deprecated import Solver
|
||||
from torchvision.models import resnet34, resnet50
|
||||
from colossalai.auto_parallel.tensor_shard.deprecated.constants import *
|
||||
from colossalai.auto_parallel.tensor_shard.deprecated.graph_analysis import GraphAnalyser
|
||||
from colossalai.auto_parallel.tensor_shard.deprecated.options import SolverOptions
|
||||
from colossalai.testing.pytest_wrapper import run_on_environment_flag
|
||||
|
||||
|
||||
class MLP(torch.nn.Module):
|
||||
|
||||
def __init__(self, dim: int):
|
||||
super().__init__()
|
||||
self.linear1 = torch.nn.Linear(dim, dim * 4)
|
||||
self.linear2 = torch.nn.Linear(dim * 4, dim)
|
||||
self.dropout = torch.nn.Dropout(0)
|
||||
self.relu = torch.nn.ReLU()
|
||||
|
||||
def forward(self, x):
|
||||
x = self.linear1(x)
|
||||
x = self.dropout(x)
|
||||
x = self.relu(x)
|
||||
x = self.linear2(x)
|
||||
return x
|
||||
|
||||
|
||||
@run_on_environment_flag(name='AUTO_PARALLEL')
|
||||
def test_cost_graph():
|
||||
physical_mesh_id = torch.arange(0, 8)
|
||||
mesh_shape = (2, 4)
|
||||
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape)
|
||||
shape_consistency_manager = ShapeConsistencyManager()
|
||||
|
||||
tracer = ColoTracer()
|
||||
model = MLP(32)
|
||||
|
||||
input_sample = {'x': torch.rand(16, 32).to('meta')}
|
||||
|
||||
# graph():
|
||||
# %x : torch.Tensor [#users=1] = placeholder[target=x]
|
||||
# %linear1 : [#users=1] = call_module[target=linear1](args = (%x,), kwargs = {})
|
||||
# %dropout : [#users=1] = call_module[target=dropout](args = (%linear1,), kwargs = {})
|
||||
# %relu : [#users=1] = call_module[target=relu](args = (%dropout,), kwargs = {})
|
||||
# %linear2 : [#users=1] = call_module[target=linear2](args = (%relu,), kwargs = {})
|
||||
# return linear2
|
||||
graph = tracer.trace(root=model, meta_args=input_sample)
|
||||
|
||||
gm = GraphModule(model, graph, model.__class__.__name__)
|
||||
gm.recompile()
|
||||
graph_analyser = GraphAnalyser(gm)
|
||||
liveness_list = graph_analyser.liveness_analysis()
|
||||
solver_options = SolverOptions(fast=True)
|
||||
strategies_constructor = StrategiesConstructor(graph, device_mesh, solver_options)
|
||||
strategies_constructor.build_strategies_and_cost()
|
||||
|
||||
cost_graph = CostGraph(strategies_constructor.leaf_strategies)
|
||||
cost_graph.simplify_graph()
|
||||
# # megatron mode if no memory constraints
|
||||
# solver = Solver(gm.graph, strategies_constructor, cost_graph, graph_analyser)
|
||||
# all sharding on out feature dim if memory budget is not sufficient for megatron mode
|
||||
solver = Solver(gm.graph, strategies_constructor, cost_graph, graph_analyser, memory_budget=5500.0)
|
||||
|
||||
ret = solver.call_solver_serialized_args()
|
||||
strategies_list = list(ret[0])
|
||||
computation_cost = 0
|
||||
communication_cost = 0
|
||||
memory_cost = 0
|
||||
for index, node in enumerate(graph.nodes):
|
||||
print(node.name, node.strategies_vector[strategies_list[index]].name)
|
||||
computation_cost += node.strategies_vector[strategies_list[index]].compute_cost
|
||||
communication_cost += node.strategies_vector[strategies_list[index]].communication_cost
|
||||
node_memory_cost = node.strategies_vector[strategies_list[index]].memory_cost
|
||||
if isinstance(node_memory_cost, tuple):
|
||||
node_memory_cost = node_memory_cost[0]
|
||||
memory_cost += node_memory_cost
|
||||
|
||||
print(f'computation cost is {computation_cost}')
|
||||
print(f'communication cost is {communication_cost}')
|
||||
print(f'memory cost is {memory_cost}')
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_cost_graph()
|
|
@ -1,103 +0,0 @@
|
|||
from copy import deepcopy
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.fx import GraphModule
|
||||
|
||||
from colossalai.auto_parallel.tensor_shard.deprecated.op_handler.conv_handler import CONV_STRATEGIES_LIST
|
||||
from colossalai.auto_parallel.tensor_shard.deprecated.options import SolverOptions
|
||||
from colossalai.auto_parallel.tensor_shard.deprecated.sharding_strategy import ShardingStrategy, StrategiesVector
|
||||
from colossalai.auto_parallel.tensor_shard.deprecated.strategies_constructor import StrategiesConstructor
|
||||
from colossalai.device.device_mesh import DeviceMesh
|
||||
from colossalai.fx.proxy import ColoProxy
|
||||
from colossalai.fx.tracer.tracer import ColoTracer
|
||||
from colossalai.tensor.sharding_spec import ShardingSpec, _DimSpec
|
||||
|
||||
|
||||
class ConvModel(nn.Module):
|
||||
|
||||
def __init__(self, c_in, c_out):
|
||||
super().__init__()
|
||||
self.conv = nn.Conv2d(c_in, c_out, kernel_size=3)
|
||||
|
||||
def forward(self, x):
|
||||
x = x * 2
|
||||
x = self.conv(x)
|
||||
return x
|
||||
|
||||
|
||||
def test_strategies_constructor():
|
||||
physical_mesh_id = torch.arange(0, 4)
|
||||
mesh_shape = (2, 2)
|
||||
# [[0, 1]
|
||||
# [2, 3]]
|
||||
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape)
|
||||
entire_shape = torch.Size((4, 16, 64, 64))
|
||||
|
||||
tracer = ColoTracer()
|
||||
model = ConvModel(16, 32)
|
||||
input_sample = {'x': torch.rand(4, 16, 64, 64).to('meta')}
|
||||
# graph():
|
||||
# %x : torch.Tensor [#users=1] = placeholder[target=x]
|
||||
# %mul : [#users=1] = call_function[target=operator.mul](args = (%x, 2), kwargs = {})
|
||||
# %conv_weight : [#users=1] = get_attr[target=conv.weight]
|
||||
# %conv_bias : [#users=1] = get_attr[target=conv.bias]
|
||||
# %conv2d : [#users=1] = call_function[target=torch.conv2d](args = (%mul, %conv_weight), kwargs = {groups: 1, dilation: (1, 1), stride: (1, 1), padding: (0, 0)})
|
||||
# %view : [#users=1] = call_method[target=view](args = (%conv_bias, [1, -1, 1, 1]), kwargs = {})
|
||||
# %add : [#users=1] = call_function[target=operator.add](args = (%conv2d, %view), kwargs = {})
|
||||
# return add
|
||||
graph = tracer.trace(root=model, meta_args=input_sample)
|
||||
print(graph)
|
||||
gm = GraphModule(model, graph, model.__class__.__name__)
|
||||
gm.recompile()
|
||||
|
||||
solver_options = SolverOptions(fast=True)
|
||||
strategies_constructor = StrategiesConstructor(graph, device_mesh, solver_options)
|
||||
|
||||
assert strategies_constructor.leaf_strategies == []
|
||||
assert strategies_constructor.strategy_map == {}
|
||||
strategies_constructor.build_strategies_and_cost()
|
||||
|
||||
# check leaf_strategies
|
||||
|
||||
# In fast mode, placeholder node only has replica strategy.
|
||||
assert strategies_constructor.leaf_strategies[0][0].name == 'Replica Placeholder'
|
||||
|
||||
# Second node is mul which is a element-wise node, therefore the output sharding spec is same as input sharding spec.
|
||||
assert strategies_constructor.leaf_strategies[1][0].name == '[R, R, R, R] -> [R, R, R, R]_0'
|
||||
|
||||
# Third node is conv.
|
||||
conv_check_list = deepcopy(CONV_STRATEGIES_LIST)
|
||||
for strategy in strategies_constructor.leaf_strategies[4]:
|
||||
conv_check_list.remove(strategy.name)
|
||||
assert len(conv_check_list) == 0
|
||||
|
||||
# In fast mode, output node only has replica strategy.
|
||||
assert strategies_constructor.leaf_strategies[7][0].name == 'Replica Output'
|
||||
|
||||
# check strategy_map
|
||||
|
||||
nodes = [node for node in graph.nodes]
|
||||
# In fast mode, placeholder node only has replica strategy.
|
||||
x = nodes[0]
|
||||
assert strategies_constructor.strategy_map[x][0].name == 'Replica Placeholder'
|
||||
|
||||
# Second node is mul which is a element-wise node, therefore the output sharding spec is same as input sharding spec.
|
||||
mul = nodes[1]
|
||||
assert strategies_constructor.strategy_map[mul][0].name == '[R, R, R, R] -> [R, R, R, R]_0'
|
||||
|
||||
# fifth node is conv.
|
||||
conv = nodes[4]
|
||||
conv_check_list = deepcopy(CONV_STRATEGIES_LIST)
|
||||
for strategy in strategies_constructor.strategy_map[conv]:
|
||||
conv_check_list.remove(strategy.name)
|
||||
assert len(conv_check_list) == 0
|
||||
|
||||
# In fast mode, output node only has replica strategy.
|
||||
output = nodes[-1]
|
||||
assert strategies_constructor.strategy_map[output][0].name == 'Replica Output'
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_strategies_constructor()
|
Loading…
Reference in New Issue