mirror of https://github.com/hpcaitech/ColossalAI
[autoparallel] adapt solver and CostGraph with new handler (#1695)
* [autoparallel] adapt solver and CostGraph with new handler * fix test issuepull/1696/head
parent
42b882ef06
commit
81f7530ee7
|
@ -95,7 +95,8 @@ def exception_handler(func):
|
||||||
@functools.wraps(func)
|
@functools.wraps(func)
|
||||||
def wrapper(*args, **kwargs):
|
def wrapper(*args, **kwargs):
|
||||||
try:
|
try:
|
||||||
func(*args, **kwargs)
|
rst = func(*args, **kwargs)
|
||||||
|
return rst
|
||||||
except AssertionError as e:
|
except AssertionError as e:
|
||||||
warnings.warn(f'{e}')
|
warnings.warn(f'{e}')
|
||||||
|
|
||||||
|
|
|
@ -170,3 +170,188 @@ class CostGraph:
|
||||||
for dst, src in self.following_dict.items():
|
for dst, src in self.following_dict.items():
|
||||||
reindexing_following_dict[dst] = self._reindexing_src(src)
|
reindexing_following_dict[dst] = self._reindexing_src(src)
|
||||||
self.following_dict = reindexing_following_dict
|
self.following_dict = reindexing_following_dict
|
||||||
|
|
||||||
|
|
||||||
|
class CostGraph_V2:
|
||||||
|
'''
|
||||||
|
A graph data structure to simplify the edge cost graph. It has two main functions:
|
||||||
|
1. To feed the quadratic resharding costs into solver, we need to linearize it. We build edge_cost in
|
||||||
|
CostGraph, and it stored every combinations of strategies for a src-dst node pair in an 1D list.
|
||||||
|
2. To reduce the searching space, we merge computationally-trivial operators, such as
|
||||||
|
element-wise operators, transpose, and reduction, into their following nodes. The merging infomation will
|
||||||
|
be given by the StrategiesVector depending on the type of target node and following nodes.
|
||||||
|
|
||||||
|
Argument:
|
||||||
|
leaf_strategies(List[StrategiesVector]): It stores StrategiesVector of every nodes on the graph.
|
||||||
|
simplify(bool, optional): The generated cost graph will be simplified if it is true. (default to True)
|
||||||
|
'''
|
||||||
|
|
||||||
|
def __init__(self, leaf_strategies, simplify=True, forward_only=False):
|
||||||
|
self.leaf_strategies = leaf_strategies
|
||||||
|
self.nodes = [strategies_vector.node for strategies_vector in self.leaf_strategies]
|
||||||
|
# stores number of strategies in each node
|
||||||
|
self.node_lens = {strategies_vector.node: len(strategies_vector) for strategies_vector in self.leaf_strategies}
|
||||||
|
# extra_node_costs will store the extra costs introduced by merging nodes
|
||||||
|
self.extra_node_costs = {}
|
||||||
|
self.following_dict = {}
|
||||||
|
self.simplify = simplify
|
||||||
|
self.forward_only = forward_only
|
||||||
|
self._build_cost_graph()
|
||||||
|
|
||||||
|
def _remove_invalid_node(self, node, attr_name):
|
||||||
|
remove_list = []
|
||||||
|
target_node_list = getattr(node, attr_name, [])
|
||||||
|
for target_node in target_node_list:
|
||||||
|
if target_node not in self.nodes:
|
||||||
|
remove_list.append(target_node)
|
||||||
|
for element in remove_list:
|
||||||
|
target_node_list.remove(element)
|
||||||
|
|
||||||
|
def _build_cost_graph(self):
|
||||||
|
'''
|
||||||
|
This method will generate edge_cost for adjacent node pair. Additionally, 'parents' and 'children' attribute will be
|
||||||
|
set to node.
|
||||||
|
'''
|
||||||
|
self.edge_costs = {}
|
||||||
|
if self.simplify:
|
||||||
|
self.merge_pair = []
|
||||||
|
for strategies_vector in self.leaf_strategies:
|
||||||
|
# build edge_cost
|
||||||
|
dst_node = strategies_vector.node
|
||||||
|
for src_node in strategies_vector.predecessor_nodes:
|
||||||
|
if src_node not in self.nodes:
|
||||||
|
continue
|
||||||
|
node_pair = (src_node, dst_node)
|
||||||
|
# src_index = strategies_vector.predecessor_nodes.index(src_node)
|
||||||
|
edge_cost = {}
|
||||||
|
for i in range(len(strategies_vector)):
|
||||||
|
for j in range(len(src_node.strategies_vector)):
|
||||||
|
if strategies_vector[i].resharding_costs is None:
|
||||||
|
print(strategies_vector.node.name)
|
||||||
|
assert False
|
||||||
|
resharding_cost_item = strategies_vector[i].resharding_costs[src_node][j]
|
||||||
|
if self.forward_only:
|
||||||
|
edge_cost[(j, i)] = resharding_cost_item.fwd
|
||||||
|
else:
|
||||||
|
edge_cost[(j, i)] = resharding_cost_item.total
|
||||||
|
self.edge_costs[node_pair] = edge_cost
|
||||||
|
# add parents and children attribute to node
|
||||||
|
setattr(dst_node, 'parents', strategies_vector.predecessor_nodes)
|
||||||
|
setattr(dst_node, 'children', strategies_vector.successor_nodes)
|
||||||
|
self._remove_invalid_node(dst_node, 'parents')
|
||||||
|
self._remove_invalid_node(dst_node, 'children')
|
||||||
|
|
||||||
|
if self.simplify and strategies_vector.check_merge():
|
||||||
|
for followed_node in strategies_vector.predecessor_nodes:
|
||||||
|
self.merge_pair.append((followed_node, dst_node))
|
||||||
|
|
||||||
|
def get_edge_cost(self, src_node, dst_node):
|
||||||
|
return self.edge_costs[(src_node, dst_node)]
|
||||||
|
|
||||||
|
def merge_node(self, src_node, dst_node):
|
||||||
|
'''
|
||||||
|
To merge dst_node into src_node, we need to do it in following steps:
|
||||||
|
|
||||||
|
1. For each strategy in dst_node, we need to pick an appropriate strategy
|
||||||
|
of src_node to merge, it is important because the logical resharding costs
|
||||||
|
between the parents node of src_node and merged node depend on the src_node
|
||||||
|
strategies dispatching. For example, for the graph 0->1->2, after merging node 1
|
||||||
|
into node 2, edge_costs[(node 0, node 2)][(0, 0)] = edge_costs[(node 0, node 1)][(0, x)]
|
||||||
|
x represents the picking strategy of node 1 merged into node 2 strategy 0.
|
||||||
|
|
||||||
|
2. We need to accumulate the extra costs introduced by merging nodes, the extra costs
|
||||||
|
contains two parts, one is resharding costs between src_node strategy and dst_node strategy,
|
||||||
|
another is the origin extra costs in src_node strategy.
|
||||||
|
|
||||||
|
3. Build connections between new node pairs, and remove the src_node after all consumer nodes
|
||||||
|
detached from it.
|
||||||
|
|
||||||
|
Argument:
|
||||||
|
src_node(Node): The node will be merged into dst_node.
|
||||||
|
dst_node(Node): The node to integrate src_node.
|
||||||
|
'''
|
||||||
|
src_node_index = dst_node.parents.index(src_node)
|
||||||
|
# build merge_map
|
||||||
|
merge_map = {}
|
||||||
|
for src_index, strategy in enumerate(src_node.strategies_vector):
|
||||||
|
min_cost = INFINITY_COST
|
||||||
|
lowest_cost_index = -1
|
||||||
|
for dst_index, dst_strategy in enumerate(dst_node.strategies_vector):
|
||||||
|
resharding_cost_item = dst_strategy.resharding_costs[src_node][src_index]
|
||||||
|
if self.forward_only:
|
||||||
|
resharding_cost = resharding_cost_item.fwd
|
||||||
|
else:
|
||||||
|
resharding_cost = resharding_cost_item.total
|
||||||
|
if resharding_cost <= min_cost:
|
||||||
|
min_cost = resharding_cost
|
||||||
|
lowest_cost_index = dst_index
|
||||||
|
merge_map[src_index] = lowest_cost_index
|
||||||
|
|
||||||
|
# extra_node_cost for src node
|
||||||
|
self.extra_node_costs[src_node] = [0.0] * self.node_lens[src_node]
|
||||||
|
for src_index, strategy in enumerate(src_node.strategies_vector):
|
||||||
|
target_strate_index = merge_map[src_index]
|
||||||
|
target_strategy = dst_node.strategies_vector[target_strate_index]
|
||||||
|
resharding_cost_item = target_strategy.resharding_costs[src_node][src_index]
|
||||||
|
if self.forward_only:
|
||||||
|
resharding_cost_to_add = resharding_cost_item.fwd
|
||||||
|
else:
|
||||||
|
resharding_cost_to_add = resharding_cost_item.total
|
||||||
|
self.extra_node_costs[src_node][src_index] += resharding_cost_to_add
|
||||||
|
if dst_node in self.extra_node_costs:
|
||||||
|
self.extra_node_costs[src_node][src_index] += self.extra_node_costs[dst_node][target_strate_index]
|
||||||
|
|
||||||
|
# add new node pair to cost graph
|
||||||
|
for child_node in dst_node.children:
|
||||||
|
new_node_pair = (src_node, child_node)
|
||||||
|
old_node_pair = (dst_node, child_node)
|
||||||
|
if new_node_pair in self.edge_costs:
|
||||||
|
continue
|
||||||
|
edge_cost = {}
|
||||||
|
for i in range(self.node_lens[src_node]):
|
||||||
|
for j in range(self.node_lens[child_node]):
|
||||||
|
dst_strate_index = merge_map[i]
|
||||||
|
# dst_strategy = dst_node.strategies_vector[dst_strate_index]
|
||||||
|
edge_cost[(i, j)] = self.edge_costs[old_node_pair][(dst_strate_index, j)]
|
||||||
|
if new_node_pair not in self.edge_costs:
|
||||||
|
self.edge_costs[new_node_pair] = edge_cost
|
||||||
|
else:
|
||||||
|
# we should accumulate the resharding costs if args of child node contain
|
||||||
|
# both src node and dst node.
|
||||||
|
for index_pair, resharding_cost in self.edge_costs[new_node_pair]:
|
||||||
|
self.edge_costs[new_node_pair][index_pair] += edge_cost[index_pair]
|
||||||
|
|
||||||
|
# connect src node and children of dst node
|
||||||
|
dst_node.parents.remove(src_node)
|
||||||
|
src_node.children.remove(dst_node)
|
||||||
|
self.edge_costs.pop((src_node, dst_node))
|
||||||
|
for child_node in dst_node.children:
|
||||||
|
if child_node not in src_node.children:
|
||||||
|
src_node.children.append(child_node)
|
||||||
|
if src_node not in child_node.parents:
|
||||||
|
child_node.parents.append(src_node)
|
||||||
|
# remove dst node from cost graph when dst node has no producer.
|
||||||
|
if len(dst_node.parents) == 0:
|
||||||
|
child_node.parents.remove(dst_node)
|
||||||
|
node_pair = (dst_node, child_node)
|
||||||
|
self.edge_costs.pop(node_pair)
|
||||||
|
if len(dst_node.parents) == 0:
|
||||||
|
self.following_dict[dst_node] = src_node
|
||||||
|
dst_node.children = []
|
||||||
|
|
||||||
|
def _reindexing_src(self, src):
|
||||||
|
if src not in self.following_dict:
|
||||||
|
return src
|
||||||
|
return self._reindexing_src(self.following_dict[src])
|
||||||
|
|
||||||
|
def simplify_graph(self):
|
||||||
|
if not self.simplify:
|
||||||
|
return
|
||||||
|
self.merge_pair.reverse()
|
||||||
|
for (src_node, dst_node) in self.merge_pair:
|
||||||
|
self.merge_node(src_node, dst_node)
|
||||||
|
self.merge_pair.reverse()
|
||||||
|
reindexing_following_dict = {}
|
||||||
|
for dst, src in self.following_dict.items():
|
||||||
|
reindexing_following_dict[dst] = self._reindexing_src(src)
|
||||||
|
self.following_dict = reindexing_following_dict
|
||||||
|
|
|
@ -9,9 +9,16 @@ from .unary_elementwise_handler import UnaryElementwiseHandler
|
||||||
from .dot_handler_v2 import LinearFunctionHandler, LinearModuleHandler
|
from .dot_handler_v2 import LinearFunctionHandler, LinearModuleHandler
|
||||||
from .layer_norm_handler_v2 import LayerNormModuleHandler
|
from .layer_norm_handler_v2 import LayerNormModuleHandler
|
||||||
from .batch_norm_handler_v2 import BatchNormModuleHandler
|
from .batch_norm_handler_v2 import BatchNormModuleHandler
|
||||||
|
from .conv_handler_v2 import ConvModuleHandler, ConvFunctionHandler
|
||||||
|
from .where_handler_v2 import WhereHandler
|
||||||
|
from .unary_elementwise_handler_v2 import UnaryElementwiseHandler_V2
|
||||||
|
from .reshape_handler_v2 import ReshapeHandler_V2
|
||||||
|
from .placeholder_handler import PlacehodlerHandler
|
||||||
|
from .output_handler import OuputHandler
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
'OperatorHandler', 'DotHandler', 'ConvHandler', 'BatchNormHandler', 'ReshapeHandler', 'BcastOpHandler',
|
'OperatorHandler', 'DotHandler', 'ConvHandler', 'BatchNormHandler', 'ReshapeHandler', 'BcastOpHandler',
|
||||||
'UnaryElementwiseHandler', 'EmbeddingHandler', 'LinearFunctionHandler', 'LinearModuleHandler',
|
'UnaryElementwiseHandler', 'EmbeddingHandler', 'LinearFunctionHandler', 'LinearModuleHandler',
|
||||||
'LayerNormModuleHandler', 'BatchNormModuleHandler'
|
'LayerNormModuleHandler', 'BatchNormModuleHandler', 'ConvModuleHandler', 'ConvFunctionHandler',
|
||||||
|
'UnaryElementwiseHandler_V2', 'ReshapeHandler_V2', 'PlacehodlerHandler', 'OuputHandler', 'WhereHandler'
|
||||||
]
|
]
|
||||||
|
|
|
@ -40,7 +40,7 @@ class ConvModuleHandler(ModuleHandler):
|
||||||
|
|
||||||
mapping = {"input": physical_input_operand, "other": physical_other_operand, "output": physical_output}
|
mapping = {"input": physical_input_operand, "other": physical_other_operand, "output": physical_output}
|
||||||
|
|
||||||
if self.named_parameters['bias'] is not None:
|
if "bias" in self.named_parameters:
|
||||||
physical_bias_operand = OperationData(name="bias",
|
physical_bias_operand = OperationData(name="bias",
|
||||||
type=OperationDataType.PARAM,
|
type=OperationDataType.PARAM,
|
||||||
data=self.named_parameters['bias'])
|
data=self.named_parameters['bias'])
|
||||||
|
@ -53,7 +53,6 @@ class ConvModuleHandler(ModuleHandler):
|
||||||
"""
|
"""
|
||||||
for op_data, sharding_spec in strategy.input_sharding_specs.items():
|
for op_data, sharding_spec in strategy.input_sharding_specs.items():
|
||||||
if op_data.name == "weight":
|
if op_data.name == "weight":
|
||||||
assert op_data.logical_shape != op_data.data.shape
|
|
||||||
dim_partition_dict = sharding_spec.dim_partition_dict
|
dim_partition_dict = sharding_spec.dim_partition_dict
|
||||||
|
|
||||||
# switch first and second dim of the conv module weight
|
# switch first and second dim of the conv module weight
|
||||||
|
|
|
@ -6,12 +6,13 @@ from typing import List, Dict
|
||||||
from .registry import operator_registry
|
from .registry import operator_registry
|
||||||
import operator
|
import operator
|
||||||
|
|
||||||
__all__ = ['ReshapeHandler']
|
__all__ = ['ReshapeHandler_V2']
|
||||||
|
|
||||||
|
|
||||||
@operator_registry.register(torch.reshape)
|
@operator_registry.register(torch.reshape)
|
||||||
|
@operator_registry.register(torch.flatten)
|
||||||
@operator_registry.register(torch.Tensor.permute)
|
@operator_registry.register(torch.Tensor.permute)
|
||||||
class ReshapeHandler(NodeHandler):
|
class ReshapeHandler_V2(NodeHandler):
|
||||||
"""
|
"""
|
||||||
A ReshapeHandler which deals with the sharding strategies for Reshape Op, such as torch.reshape.
|
A ReshapeHandler which deals with the sharding strategies for Reshape Op, such as torch.reshape.
|
||||||
"""
|
"""
|
||||||
|
|
|
@ -6,12 +6,12 @@ from typing import List, Dict
|
||||||
from .registry import operator_registry
|
from .registry import operator_registry
|
||||||
import operator
|
import operator
|
||||||
|
|
||||||
__all__ = ['UnaryElementwiseHandler']
|
__all__ = ['UnaryElementwiseHandler_V2']
|
||||||
|
|
||||||
|
|
||||||
@operator_registry.register(torch.abs)
|
@operator_registry.register(torch.abs)
|
||||||
@operator_registry.register(torch.nn.ReLU)
|
@operator_registry.register(torch.nn.ReLU)
|
||||||
class UnaryElementwiseHandler(NodeHandler):
|
class UnaryElementwiseHandler_V2(NodeHandler):
|
||||||
"""
|
"""
|
||||||
A UnaryElementwiseHandler which deals with the sharding strategies for UnaryElementwise Op.
|
A UnaryElementwiseHandler which deals with the sharding strategies for UnaryElementwise Op.
|
||||||
"""
|
"""
|
||||||
|
|
|
@ -465,3 +465,464 @@ class Solver:
|
||||||
ret_list.append(ret)
|
ret_list.append(ret)
|
||||||
|
|
||||||
return ret_list
|
return ret_list
|
||||||
|
|
||||||
|
|
||||||
|
class Solver_V2:
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
graph: Graph,
|
||||||
|
strategies_constructor: StrategiesConstructor,
|
||||||
|
cost_graph: CostGraph,
|
||||||
|
graph_analyser: GraphAnalyser,
|
||||||
|
memory_budget: float = -1.0,
|
||||||
|
solution_numbers: int = 1,
|
||||||
|
forward_only: bool = False,
|
||||||
|
memory_increasing_coefficient: float = 1.3):
|
||||||
|
'''
|
||||||
|
Solver class will integrate information provided by the components and use ILP solver to find a possible optimal strategies combination for target computing graph.
|
||||||
|
|
||||||
|
Argument:
|
||||||
|
graph: The computing graph to be optimized.
|
||||||
|
strategies_constructor: It will provide all the possible strategies for each node in the computing graph.
|
||||||
|
cost_graph: A graph data structure to simplify the edge cost graph.
|
||||||
|
graph_analyser: graph_analyser will analyse the graph to obtain the variable liveness information, which will be used to generate memory constraints.
|
||||||
|
memory_budget: Memory constraint for the solution.
|
||||||
|
solution_numbers: If solution_numbers is larger than one, solver will us a serious of solutions based on different memory budget.
|
||||||
|
memory_increasing_coefficient: If solution_numbers is larger than one, we will use this coefficient to generate new memory budget.
|
||||||
|
'''
|
||||||
|
self.graph = graph
|
||||||
|
self.strategies_constructor = strategies_constructor
|
||||||
|
self.cost_graph = cost_graph
|
||||||
|
self.graph_analyser = graph_analyser
|
||||||
|
self.leaf_strategies = self.strategies_constructor.leaf_strategies
|
||||||
|
self.nodes = [strategies_vector.node for strategies_vector in self.leaf_strategies]
|
||||||
|
self.strategy_map = self.strategies_constructor.strategy_map
|
||||||
|
self.memory_budget = memory_budget
|
||||||
|
self.solution_numbers = solution_numbers
|
||||||
|
self.forward_only = forward_only
|
||||||
|
if self.solution_numbers > 1:
|
||||||
|
self.memory_increasing_coefficient = memory_increasing_coefficient
|
||||||
|
else:
|
||||||
|
self.memory_increasing_coefficient = 1
|
||||||
|
self.liveness_list = self.graph_analyser.liveness_analysis()
|
||||||
|
self.node_index_dict = self._generate_node_index_dict()
|
||||||
|
# The last solution vector of auto sharding.
|
||||||
|
self.last_s_val = None
|
||||||
|
# The last objective value of the best ILP solution.
|
||||||
|
self.last_objective = None
|
||||||
|
|
||||||
|
def _recover_merged_node_strategy(self):
|
||||||
|
'''
|
||||||
|
During cost graph constructing, some nodes, such as unary element-wise node or ReshapeOp, were merged into the previous node.
|
||||||
|
Therefore, the index of those strategies are copied from the previous node. This method is used to recover the strategy index of those merged
|
||||||
|
node.
|
||||||
|
'''
|
||||||
|
for node_index, node in enumerate(self.nodes):
|
||||||
|
if node.strategies_vector.check_merge():
|
||||||
|
# the merged node has only one input, and its strategies follow the input sharding strategy
|
||||||
|
input_strategies_vector = node.args[0].strategies_vector
|
||||||
|
input_best_strategy_index = self.last_s_val[node_index - 1]
|
||||||
|
input_sharding_spec = input_strategies_vector[input_best_strategy_index].output_sharding_spec
|
||||||
|
for strategy_index, strategy in enumerate(node.strategies_vector):
|
||||||
|
if strategy.input_shardings[0].sharding_sequence == input_sharding_spec.sharding_sequence:
|
||||||
|
self.last_s_val[node_index] = strategy_index
|
||||||
|
break
|
||||||
|
|
||||||
|
def _generate_node_index_dict(self) -> Dict[Node, int]:
|
||||||
|
node_index_dict = {}
|
||||||
|
for index, strategies_vector in enumerate(self.leaf_strategies):
|
||||||
|
node_index_dict[strategies_vector.node] = index
|
||||||
|
return node_index_dict
|
||||||
|
|
||||||
|
def _prepare_data_for_solver(self):
|
||||||
|
'''
|
||||||
|
Extract information from components for solver.
|
||||||
|
'''
|
||||||
|
node_nums = len(self.leaf_strategies)
|
||||||
|
memory_budget = self.memory_budget
|
||||||
|
|
||||||
|
# prepare strategies_len
|
||||||
|
strategies_len = []
|
||||||
|
for node in self.nodes:
|
||||||
|
strategies_len.append(self.cost_graph.node_lens[node])
|
||||||
|
strategies_len = np.array(strategies_len)
|
||||||
|
|
||||||
|
# prepare following_nodes
|
||||||
|
following_nodes = self.cost_graph.following_dict
|
||||||
|
index_following_nodes = {}
|
||||||
|
for src, target in following_nodes.items():
|
||||||
|
src_index = self.node_index_dict[src]
|
||||||
|
target_index = self.node_index_dict[target]
|
||||||
|
index_following_nodes[src_index] = target_index
|
||||||
|
following_nodes = index_following_nodes
|
||||||
|
for index in range(node_nums):
|
||||||
|
if index not in following_nodes:
|
||||||
|
following_nodes[index] = -1
|
||||||
|
|
||||||
|
# prepare edge_pairs and resharding costs
|
||||||
|
edge_pairs = []
|
||||||
|
resharding_costs = []
|
||||||
|
for pairs, edge_cost in self.cost_graph.edge_costs.items():
|
||||||
|
src_node = pairs[0]
|
||||||
|
dst_node = pairs[1]
|
||||||
|
src_node_index = self.node_index_dict[src_node]
|
||||||
|
dst_node_index = self.node_index_dict[dst_node]
|
||||||
|
edge_pairs.append(src_node_index)
|
||||||
|
edge_pairs.append(dst_node_index)
|
||||||
|
|
||||||
|
for i in range(strategies_len[src_node_index]):
|
||||||
|
for j in range(strategies_len[dst_node_index]):
|
||||||
|
resharding_costs.append(edge_cost[(i, j)])
|
||||||
|
edge_pairs = np.array(edge_pairs)
|
||||||
|
resharding_costs = np.array(resharding_costs)
|
||||||
|
|
||||||
|
# prepare liveness_set
|
||||||
|
liveness_set = self.liveness_list
|
||||||
|
|
||||||
|
# omit alias_set now
|
||||||
|
alias_set = None
|
||||||
|
alias_convert_costs = None
|
||||||
|
|
||||||
|
# prepare compute_costs, communication_costs and memory_costs
|
||||||
|
compute_costs = []
|
||||||
|
communication_costs = []
|
||||||
|
memory_costs = []
|
||||||
|
extra_node_costs = self.cost_graph.extra_node_costs
|
||||||
|
for strategies_vector in self.leaf_strategies:
|
||||||
|
node = strategies_vector.node
|
||||||
|
for index, strategy in enumerate(strategies_vector):
|
||||||
|
compute_cost_item = strategy.compute_cost
|
||||||
|
communication_cost_item = strategy.communication_cost
|
||||||
|
memory_cost_item = strategy.memory_cost
|
||||||
|
|
||||||
|
if self.forward_only:
|
||||||
|
origin_communication_cost = communication_cost_item.fwd
|
||||||
|
compute_cost = compute_cost_item.fwd
|
||||||
|
memory_cost = memory_cost_item.fwd
|
||||||
|
else:
|
||||||
|
origin_communication_cost = communication_cost_item.total
|
||||||
|
compute_cost = compute_cost_item.total
|
||||||
|
memory_cost = memory_cost_item.total
|
||||||
|
|
||||||
|
compute_costs.append(compute_cost)
|
||||||
|
# node in extra_node_costs means it has some extra communication
|
||||||
|
# cost from node merging, so we need to add those extra communication
|
||||||
|
# cost into
|
||||||
|
if node in extra_node_costs:
|
||||||
|
extra_node_cost = extra_node_costs[node][index]
|
||||||
|
communication_cost = origin_communication_cost + extra_node_cost
|
||||||
|
communication_costs.append(communication_cost)
|
||||||
|
else:
|
||||||
|
communication_costs.append(origin_communication_cost)
|
||||||
|
memory_costs.append(memory_cost)
|
||||||
|
# if isinstance(memory_cost, tuple):
|
||||||
|
# memory_costs.append(memory_cost[0])
|
||||||
|
# else:
|
||||||
|
# memory_costs.append(memory_cost)
|
||||||
|
compute_costs = np.array(compute_costs)
|
||||||
|
communication_costs = np.array(communication_costs)
|
||||||
|
memory_costs = np.array(memory_costs)
|
||||||
|
|
||||||
|
# omit initial value for nodes
|
||||||
|
s_init_np = None
|
||||||
|
|
||||||
|
return node_nums, memory_budget, strategies_len, following_nodes, edge_pairs, alias_set, liveness_set, compute_costs, communication_costs, memory_costs, resharding_costs, alias_convert_costs, s_init_np
|
||||||
|
|
||||||
|
def _call_solver_serialized_args(self,
|
||||||
|
node_nums,
|
||||||
|
memory_budget,
|
||||||
|
strategies_len,
|
||||||
|
following_nodes,
|
||||||
|
edge_pairs,
|
||||||
|
alias_set,
|
||||||
|
liveness_set,
|
||||||
|
compute_costs,
|
||||||
|
communication_costs,
|
||||||
|
memory_costs,
|
||||||
|
resharding_costs,
|
||||||
|
alias_convert_costs,
|
||||||
|
s_init_np=None):
|
||||||
|
"""
|
||||||
|
Call the solver with serialized arguments.
|
||||||
|
"""
|
||||||
|
|
||||||
|
tic = time.time()
|
||||||
|
|
||||||
|
for x in [strategies_len, edge_pairs, compute_costs, communication_costs, memory_costs, resharding_costs]:
|
||||||
|
assert isinstance(x, np.ndarray)
|
||||||
|
assert len(strategies_len) == node_nums, "strategies_len"
|
||||||
|
|
||||||
|
def get_non_zero_index(binary_vector):
|
||||||
|
"""
|
||||||
|
Get the index of non-zero item in a vector.
|
||||||
|
"""
|
||||||
|
ct = 0
|
||||||
|
ret = None
|
||||||
|
for i, elem in enumerate(binary_vector):
|
||||||
|
if pulp.value(elem):
|
||||||
|
ret = i
|
||||||
|
ct += 1
|
||||||
|
|
||||||
|
assert ct == 1
|
||||||
|
return ret
|
||||||
|
|
||||||
|
# 0. Unpack flatten numpy arrays
|
||||||
|
s_follow = following_nodes
|
||||||
|
|
||||||
|
E = edge_pairs.reshape((-1, 2)) # noqa
|
||||||
|
r = []
|
||||||
|
pt = 0
|
||||||
|
edge_set = set()
|
||||||
|
for (i, j) in E:
|
||||||
|
prod_length = strategies_len[i] * strategies_len[j]
|
||||||
|
|
||||||
|
if (i, j) in edge_set:
|
||||||
|
raise ValueError(f"Duplicated edges: {(i, j)}")
|
||||||
|
|
||||||
|
edge_set.add((i, j))
|
||||||
|
r.append(resharding_costs[pt:pt + prod_length])
|
||||||
|
pt += prod_length
|
||||||
|
assert pt == len(resharding_costs)
|
||||||
|
|
||||||
|
######################
|
||||||
|
# omit alias set now #
|
||||||
|
######################
|
||||||
|
|
||||||
|
# A = alias_set.reshape((-1, 2)) # noqa
|
||||||
|
# for (i, j) in A:
|
||||||
|
# prod_length = strategies_len[i] * strategies_len[j]
|
||||||
|
# v.append(alias_convert_costs[pt:pt + prod_length])
|
||||||
|
# pt += prod_length
|
||||||
|
# assert pt == len(alias_convert_costs)
|
||||||
|
|
||||||
|
# L = [] # noqa
|
||||||
|
# pt = node_nums
|
||||||
|
# for i in range(node_nums):
|
||||||
|
# length = liveness_set[i]
|
||||||
|
# L.append(liveness_set[pt:pt + length])
|
||||||
|
# pt += length
|
||||||
|
# assert pt == len(liveness_set)
|
||||||
|
v = []
|
||||||
|
pt = 0
|
||||||
|
|
||||||
|
c = []
|
||||||
|
d = []
|
||||||
|
m = []
|
||||||
|
pt = 0
|
||||||
|
for i in range(node_nums):
|
||||||
|
length = strategies_len[i]
|
||||||
|
c.append(compute_costs[pt:pt + length])
|
||||||
|
d.append(communication_costs[pt:pt + length])
|
||||||
|
m.append(memory_costs[pt:pt + length])
|
||||||
|
pt += length
|
||||||
|
assert pt == len(compute_costs), f"{pt} == {len(compute_costs)}"
|
||||||
|
assert pt == len(communication_costs), f"{pt} == {len(communication_costs)}"
|
||||||
|
assert pt == len(memory_costs), f"{pt} == {len(memory_costs)}"
|
||||||
|
|
||||||
|
# 1. Create variables
|
||||||
|
|
||||||
|
#############################
|
||||||
|
# create variables for node #
|
||||||
|
#############################
|
||||||
|
s = []
|
||||||
|
num_nodes = 0
|
||||||
|
reverse_follow_backpatch = []
|
||||||
|
for i in range(node_nums):
|
||||||
|
if s_follow[i] < 0:
|
||||||
|
if strategies_len[i] == 1:
|
||||||
|
s.append([1])
|
||||||
|
else:
|
||||||
|
num_nodes += 1
|
||||||
|
s.append(LpVariable.matrix(f"s[{i}]", (range(strategies_len[i]),), cat="Binary"))
|
||||||
|
else:
|
||||||
|
if s_follow[i] < len(s):
|
||||||
|
s.append(s[s_follow[i]])
|
||||||
|
else:
|
||||||
|
s.append(None)
|
||||||
|
reverse_follow_backpatch.append(i)
|
||||||
|
|
||||||
|
for i in reverse_follow_backpatch:
|
||||||
|
s[i] = s[s_follow[i]]
|
||||||
|
|
||||||
|
#############################
|
||||||
|
# create variables for edge #
|
||||||
|
#############################
|
||||||
|
e = []
|
||||||
|
num_edges = 0
|
||||||
|
for (idx, (i, j)) in enumerate(E):
|
||||||
|
if len(s[i]) == 1:
|
||||||
|
e.append(s[j])
|
||||||
|
elif len(s[j]) == 1:
|
||||||
|
e.append(s[i])
|
||||||
|
else:
|
||||||
|
num_edges += 1
|
||||||
|
e.append(LpVariable.matrix(f"e[{i},{j}]", (range(len(s[i]) * len(s[j])),), cat="Binary"))
|
||||||
|
assert len(e[idx]) == len(r[idx])
|
||||||
|
for element in s:
|
||||||
|
assert len(element) > 0
|
||||||
|
# 2. Set initial value
|
||||||
|
######################################
|
||||||
|
# set a initial value for warm start #
|
||||||
|
######################################
|
||||||
|
if s_init_np is not None:
|
||||||
|
s_init = s_init_np.reshape((-1, 3))
|
||||||
|
for (idx, value, fix) in s_init:
|
||||||
|
for i in range(len(s[idx])):
|
||||||
|
s[idx][i].setInitialValue(i == value)
|
||||||
|
if fix:
|
||||||
|
s[idx][i].fixValue()
|
||||||
|
|
||||||
|
# 3. Objective
|
||||||
|
prob = LpProblem("myProblem", LpMinimize)
|
||||||
|
###################################################################
|
||||||
|
# computing the node cost(computing cost and communication cost) #
|
||||||
|
###################################################################
|
||||||
|
obj = 0
|
||||||
|
for i in range(node_nums):
|
||||||
|
assert len(s[i]) == len(c[i])
|
||||||
|
assert len(s[i]) == len(d[i])
|
||||||
|
|
||||||
|
obj += lpDot(s[i], c[i]) + lpDot(s[i], d[i])
|
||||||
|
|
||||||
|
#############################################
|
||||||
|
# computing the edge cost(resharding cost) #
|
||||||
|
#############################################
|
||||||
|
for i in range(len(E)):
|
||||||
|
assert len(e[i]) == len(r[i])
|
||||||
|
obj += lpDot(e[i], r[i])
|
||||||
|
|
||||||
|
prob += obj
|
||||||
|
|
||||||
|
# 4. Constraints
|
||||||
|
# (a). specified by `cat="Binary"`
|
||||||
|
|
||||||
|
# (b)
|
||||||
|
#################################################
|
||||||
|
# make sure each node only choose one strategy #
|
||||||
|
#################################################
|
||||||
|
for i in range(node_nums):
|
||||||
|
if s_follow[i] < 0:
|
||||||
|
prob += lpSum(s[i]) == 1
|
||||||
|
|
||||||
|
# (c)
|
||||||
|
#################################################
|
||||||
|
# compute memory consumption with liveness set #
|
||||||
|
#################################################
|
||||||
|
if memory_budget > 0:
|
||||||
|
for liveness_stage in liveness_set:
|
||||||
|
mem = 0
|
||||||
|
for live_variable in liveness_stage.unique_live_vars:
|
||||||
|
node_index = self.node_index_dict[live_variable.node]
|
||||||
|
mem += lpSum(s[node_index][j] * m[node_index][j] for j in range(len(s[node_index])))
|
||||||
|
prob += mem <= memory_budget
|
||||||
|
|
||||||
|
# (d). specified by `cat="Binary"`
|
||||||
|
|
||||||
|
for (idx, (i, j)) in enumerate(E):
|
||||||
|
if strategies_len[i] == 1 or strategies_len[j] == 1:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# (e)
|
||||||
|
prob += lpSum(e[idx]) == 1
|
||||||
|
|
||||||
|
# (f)
|
||||||
|
for row in range(len(s[i])):
|
||||||
|
C = len(s[j]) # noqa
|
||||||
|
prob += lpSum(e[idx][row * C + col] for col in range(0, C)) <= s[i][row]
|
||||||
|
|
||||||
|
# (g)
|
||||||
|
for col in range(len(s[j])):
|
||||||
|
R = len(s[i]) # noqa
|
||||||
|
C = len(s[j]) # noqa
|
||||||
|
prob += lpSum(e[idx][row * C + col] for row in range(0, R)) <= s[j][col]
|
||||||
|
|
||||||
|
# (h)
|
||||||
|
######################
|
||||||
|
# omit alias set now #
|
||||||
|
######################
|
||||||
|
|
||||||
|
# alias_set = set()
|
||||||
|
# for (idx, (i, j)) in enumerate(A):
|
||||||
|
# R = len(s[i]) # noqa
|
||||||
|
# C = len(s[j]) # noqa
|
||||||
|
# if (i, j) in alias_set:
|
||||||
|
# raise ValueError(f"Duplicated edges: {(i, j)}")
|
||||||
|
|
||||||
|
# alias_set.add((i, j))
|
||||||
|
# alias_set.add((j, i))
|
||||||
|
|
||||||
|
# for row in range(len(s[i])):
|
||||||
|
# for col in range(len(s[j])):
|
||||||
|
# if v[idx][row * C + col] > 0.5:
|
||||||
|
# prob += s[i][row] + s[j][col] <= 1
|
||||||
|
|
||||||
|
verbose = True
|
||||||
|
|
||||||
|
msg = verbose
|
||||||
|
time_limit = 600
|
||||||
|
assert "COIN_CMD" in pulp.listSolvers(
|
||||||
|
onlyAvailable=True), ("Please install ILP solvers by 'sudo apt install coinor-cbc'")
|
||||||
|
|
||||||
|
solver = pulp.COIN_CMD(mip=True, msg=msg, timeLimit=time_limit, threads=multiprocessing.cpu_count())
|
||||||
|
# solver = pulp.GLPK_CMD(mip=True, msg=msg, timeLimit=time_limit)
|
||||||
|
prob.solve(solver)
|
||||||
|
|
||||||
|
status = prob.status
|
||||||
|
objective = pulp.value(prob.objective)
|
||||||
|
objective = float(objective) if objective is not None else -1.0
|
||||||
|
if verbose:
|
||||||
|
print(f"ILP Status: {LpStatus[status]}\tObjective: {objective}\t"
|
||||||
|
f"Time: {time.time() - tic}")
|
||||||
|
print(f"#nodes: {num_nodes}, #edges: {num_edges}")
|
||||||
|
|
||||||
|
if prob.status in [pulp.LpStatusInfeasible]:
|
||||||
|
raise RuntimeError("Cannot run the function under the given memory budget. "
|
||||||
|
"Please increase the memory budget.")
|
||||||
|
|
||||||
|
# Get and check results
|
||||||
|
s_val = np.full((node_nums,), -1, dtype=np.int32)
|
||||||
|
for i in range(node_nums):
|
||||||
|
s_val[i] = get_non_zero_index(s[i])
|
||||||
|
|
||||||
|
e_val = np.full((len(E),), -1, dtype=np.int32)
|
||||||
|
for (idx, (i, j)) in enumerate(E):
|
||||||
|
e_val[idx] = get_non_zero_index(e[idx])
|
||||||
|
i_spec_index = e_val[idx] // len(s[j])
|
||||||
|
j_spec_index = e_val[idx] % len(s[j])
|
||||||
|
assert i_spec_index == s_val[i], f"e_val[{i}][{j}]"
|
||||||
|
assert j_spec_index == s_val[j], f"e_val[{i}][{j}]"
|
||||||
|
if verbose and r[idx][e_val[idx]] > 0:
|
||||||
|
print(f"Edge cost {(i, j)} : {r[idx][e_val[idx]]}")
|
||||||
|
|
||||||
|
self.last_s_val = list(s_val)
|
||||||
|
# self._recover_merged_node_strategy()
|
||||||
|
self.last_objective = objective
|
||||||
|
|
||||||
|
if objective > INFINITY_COST:
|
||||||
|
warnings.warn("Detect unexpected behaviors in the auto-sharding pass.")
|
||||||
|
|
||||||
|
return self.last_s_val, e_val, self.last_objective, status
|
||||||
|
|
||||||
|
def call_solver_serialized_args(self):
|
||||||
|
"""
|
||||||
|
Call the solver with serialized arguments and handle python errors. Additionally,
|
||||||
|
we could give a serious of solutions with different memory budget.
|
||||||
|
"""
|
||||||
|
if self.solution_numbers == 1:
|
||||||
|
args = self._prepare_data_for_solver()
|
||||||
|
ret = self._call_solver_serialized_args(*args)
|
||||||
|
|
||||||
|
return ret
|
||||||
|
|
||||||
|
origin_memory_budget = self.memory_budget
|
||||||
|
memory_budget_list = [
|
||||||
|
origin_memory_budget * self.memory_increasing_coefficient**i for i in range(self.solution_numbers)
|
||||||
|
]
|
||||||
|
ret_list = []
|
||||||
|
for memory_budget in memory_budget_list:
|
||||||
|
self.memory_budget = memory_budget
|
||||||
|
args = self._prepare_data_for_solver()
|
||||||
|
ret = self._call_solver_serialized_args(*args)
|
||||||
|
ret_list.append(ret)
|
||||||
|
|
||||||
|
return ret_list
|
||||||
|
|
|
@ -1,10 +1,13 @@
|
||||||
from torch.fx import Graph, Node
|
from torch.fx import Graph, Node
|
||||||
from colossalai.auto_parallel.solver.op_handler.bcast_op_handler import BcastOpHandler
|
from colossalai.auto_parallel.solver.op_handler.bcast_op_handler import BcastOpHandler
|
||||||
from colossalai.auto_parallel.solver.op_handler.layer_norm_handler import LayerNormHandler
|
from colossalai.auto_parallel.solver.op_handler.layer_norm_handler import LayerNormHandler
|
||||||
|
from colossalai.auto_parallel.solver.sharding_strategy import ShardingStrategy_V2
|
||||||
from colossalai.tensor.sharding_spec import ShardingSpec
|
from colossalai.tensor.sharding_spec import ShardingSpec
|
||||||
from colossalai.device.device_mesh import DeviceMesh
|
from colossalai.device.device_mesh import DeviceMesh
|
||||||
from colossalai.tensor.shape_consistency import ShapeConsistencyManager
|
from colossalai.tensor.shape_consistency import ShapeConsistencyManager
|
||||||
from colossalai.auto_parallel.solver.op_handler.registry import operator_registry
|
from colossalai.auto_parallel.solver.op_handler.registry import operator_registry
|
||||||
|
from colossalai.auto_parallel.solver.op_handler.placeholder_handler import PlacehodlerHandler
|
||||||
|
from colossalai.auto_parallel.solver.op_handler.output_handler import OuputHandler
|
||||||
from .options import SolverOptions
|
from .options import SolverOptions
|
||||||
from . import ShardingStrategy, StrategiesVector
|
from . import ShardingStrategy, StrategiesVector
|
||||||
from .op_handler import *
|
from .op_handler import *
|
||||||
|
@ -414,7 +417,6 @@ class StrategiesConstructor:
|
||||||
self.leaf_strategies.append(strategies_vector)
|
self.leaf_strategies.append(strategies_vector)
|
||||||
self.strategy_map[node] = strategies_vector
|
self.strategy_map[node] = strategies_vector
|
||||||
|
|
||||||
|
|
||||||
# remove no strategy nodes
|
# remove no strategy nodes
|
||||||
remove_list = []
|
remove_list = []
|
||||||
for strategies_vector in self.leaf_strategies:
|
for strategies_vector in self.leaf_strategies:
|
||||||
|
@ -456,6 +458,10 @@ class StrategiesConstructor_V2:
|
||||||
name_checklist = []
|
name_checklist = []
|
||||||
remove_list = []
|
remove_list = []
|
||||||
for strategy in strategies_vector:
|
for strategy in strategies_vector:
|
||||||
|
if strategy is None:
|
||||||
|
print(strategies_vector.node.name)
|
||||||
|
print(strategies_vector)
|
||||||
|
assert False
|
||||||
if strategy.name not in name_checklist:
|
if strategy.name not in name_checklist:
|
||||||
name_checklist.append(strategy.name)
|
name_checklist.append(strategy.name)
|
||||||
else:
|
else:
|
||||||
|
@ -469,16 +475,32 @@ class StrategiesConstructor_V2:
|
||||||
"""
|
"""
|
||||||
for node in self.nodes:
|
for node in self.nodes:
|
||||||
strategies_vector = StrategiesVector(node)
|
strategies_vector = StrategiesVector(node)
|
||||||
|
|
||||||
# placeholder node
|
# placeholder node
|
||||||
if node.op == 'placeholder':
|
if node.op == 'placeholder':
|
||||||
# TODO: implement placeholder node handler
|
placeholder_handler = PlacehodlerHandler(node, self.device_mesh, strategies_vector)
|
||||||
pass
|
placeholder_handler.register_strategy()
|
||||||
|
|
||||||
# get_attr node
|
# get_attr node
|
||||||
elif node.op == 'get_attr':
|
if node.op == 'get_attr':
|
||||||
# TODO: implement getattr node handler
|
# Same as placeholder nodes, if solver_options.fast is True, we just let them in
|
||||||
pass
|
# fully replicate status, then strategies of following node will be treated equally due
|
||||||
|
# to replicate status has no resharding cost to other status. At the same time, the searching
|
||||||
|
# space is smaller than enumerating all the possible sharding spec for the get_attr node.
|
||||||
|
# Otherwise, all the possible sharding spec for the get_attr node will be enumerated.
|
||||||
|
if self.solver_options.fast:
|
||||||
|
# create sharding strategy for get_attr
|
||||||
|
name = 'Replica Attribute'
|
||||||
|
dim_partition_dict = {}
|
||||||
|
output_sharding_spec = generate_sharding_spec(node, self.device_mesh, dim_partition_dict)
|
||||||
|
# TODO: use meta_info_prop to profile memory cost
|
||||||
|
memory_cost = 0
|
||||||
|
sharding_strategy_attribute = ShardingStrategy(name, output_sharding_spec, memory_cost=memory_cost)
|
||||||
|
strategies_vector.append(sharding_strategy_attribute)
|
||||||
|
|
||||||
|
# # get_attr node
|
||||||
|
# elif node.op == 'get_attr':
|
||||||
|
# # TODO: implement getattr node handler
|
||||||
|
# pass
|
||||||
|
|
||||||
# call_module node
|
# call_module node
|
||||||
elif node.op == 'call_module':
|
elif node.op == 'call_module':
|
||||||
|
@ -502,11 +524,13 @@ class StrategiesConstructor_V2:
|
||||||
|
|
||||||
# output node
|
# output node
|
||||||
elif node.op == 'output':
|
elif node.op == 'output':
|
||||||
# TODO: implement output node handler
|
output_handler = OuputHandler(node, self.device_mesh, strategies_vector)
|
||||||
pass
|
output_handler.register_strategy()
|
||||||
|
|
||||||
|
if len(strategies_vector) <= 0:
|
||||||
|
print(node.name)
|
||||||
|
assert len(strategies_vector) > 0
|
||||||
self.remove_duplicated_strategy(strategies_vector)
|
self.remove_duplicated_strategy(strategies_vector)
|
||||||
setattr(node, 'strategies_vector', strategies_vector)
|
setattr(node, 'strategies_vector', strategies_vector)
|
||||||
self.leaf_strategies.append(strategies_vector)
|
self.leaf_strategies.append(strategies_vector)
|
||||||
self.strategy_map[node] = strategies_vector
|
self.strategy_map[node] = strategies_vector
|
||||||
|
|
||||||
|
|
|
@ -8,10 +8,14 @@ from .layer_norm_generator import LayerNormGenerator
|
||||||
from .where_generator import WhereGenerator
|
from .where_generator import WhereGenerator
|
||||||
from .reshape_generator import ReshapeGenerator
|
from .reshape_generator import ReshapeGenerator
|
||||||
from .normal_pooling_generator import NormalPoolStrategyGenerator
|
from .normal_pooling_generator import NormalPoolStrategyGenerator
|
||||||
|
from .placeholder_generator import PlaceholderGenerator
|
||||||
|
from .output_generator import OutputGenerator
|
||||||
|
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
'StrategyGenerator_V2', 'DotProductStrategyGenerator', 'MatVecStrategyGenerator',
|
'StrategyGenerator_V2', 'DotProductStrategyGenerator', 'MatVecStrategyGenerator',
|
||||||
'LinearProjectionStrategyGenerator', 'BatchedMatMulStrategyGenerator', 'ConvStrategyGenerator',
|
'LinearProjectionStrategyGenerator', 'BatchedMatMulStrategyGenerator', 'ConvStrategyGenerator',
|
||||||
'UnaryElementwiseGenerator', 'BatchNormStrategyGenerator', 'GetItemStrategyGenerator', 'TensorStrategyGenerator',
|
'UnaryElementwiseGenerator', 'BatchNormStrategyGenerator', 'GetItemStrategyGenerator', 'TensorStrategyGenerator',
|
||||||
'TensorTupleStrategyGenerator', 'LayerNormGenerator', "WhereGenerator", 'ReshapeGenerator', 'NormalPoolStrategyGenerator'
|
'TensorTupleStrategyGenerator', 'LayerNormGenerator', 'ReshapeGenerator', 'PlaceholderGenerator', 'OutputGenerator',
|
||||||
|
'WhereGenerator', 'ReshapeGenerator', 'NormalPoolStrategyGenerator'
|
||||||
]
|
]
|
||||||
|
|
|
@ -5,6 +5,7 @@ from colossalai.tensor.shape_consistency import CollectiveCommPattern
|
||||||
from .strategy_generator import StrategyGenerator_V2
|
from .strategy_generator import StrategyGenerator_V2
|
||||||
from typing import List
|
from typing import List
|
||||||
from .._utils import exception_handler
|
from .._utils import exception_handler
|
||||||
|
import warnings
|
||||||
import copy
|
import copy
|
||||||
|
|
||||||
|
|
||||||
|
@ -100,6 +101,7 @@ class ConvStrategyGenerator(StrategyGenerator_V2):
|
||||||
memory_cost = TrainCycleItem(fwd=fwd_mem_cost, bwd=bwd_mem_cost, total=total_mem_cost)
|
memory_cost = TrainCycleItem(fwd=fwd_mem_cost, bwd=bwd_mem_cost, total=total_mem_cost)
|
||||||
strategy.memory_cost = memory_cost
|
strategy.memory_cost = memory_cost
|
||||||
|
|
||||||
|
@exception_handler
|
||||||
def split_input_batch_weight_out_channel(self, mesh_dim_0, mesh_dim_1):
|
def split_input_batch_weight_out_channel(self, mesh_dim_0, mesh_dim_1):
|
||||||
name = f'S{mesh_dim_0}S{mesh_dim_1} = S{mesh_dim_0}R x RS{mesh_dim_1}'
|
name = f'S{mesh_dim_0}S{mesh_dim_1} = S{mesh_dim_0}R x RS{mesh_dim_1}'
|
||||||
|
|
||||||
|
@ -146,6 +148,7 @@ class ConvStrategyGenerator(StrategyGenerator_V2):
|
||||||
sharding_spec_mapping=sharding_spec_mapping,
|
sharding_spec_mapping=sharding_spec_mapping,
|
||||||
communication_action_mapping=communication_action_mapping)
|
communication_action_mapping=communication_action_mapping)
|
||||||
|
|
||||||
|
@exception_handler
|
||||||
def split_input_batch(self, mesh_dim_0):
|
def split_input_batch(self, mesh_dim_0):
|
||||||
name = f'S{mesh_dim_0}R = S{mesh_dim_0}R x RR'
|
name = f'S{mesh_dim_0}R = S{mesh_dim_0}R x RR'
|
||||||
|
|
||||||
|
@ -182,6 +185,7 @@ class ConvStrategyGenerator(StrategyGenerator_V2):
|
||||||
sharding_spec_mapping=sharding_spec_mapping,
|
sharding_spec_mapping=sharding_spec_mapping,
|
||||||
communication_action_mapping=communication_action_mapping)
|
communication_action_mapping=communication_action_mapping)
|
||||||
|
|
||||||
|
@exception_handler
|
||||||
def split_input_both_dim_weight_in_channel(self, mesh_dim_0, mesh_dim_1):
|
def split_input_both_dim_weight_in_channel(self, mesh_dim_0, mesh_dim_1):
|
||||||
name = f'S{mesh_dim_0}R = S{mesh_dim_0}S{mesh_dim_1} x S{mesh_dim_1}R'
|
name = f'S{mesh_dim_0}R = S{mesh_dim_0}S{mesh_dim_1} x S{mesh_dim_1}R'
|
||||||
|
|
||||||
|
@ -228,6 +232,7 @@ class ConvStrategyGenerator(StrategyGenerator_V2):
|
||||||
sharding_spec_mapping=sharding_spec_mapping,
|
sharding_spec_mapping=sharding_spec_mapping,
|
||||||
communication_action_mapping=communication_action_mapping)
|
communication_action_mapping=communication_action_mapping)
|
||||||
|
|
||||||
|
@exception_handler
|
||||||
def split_input_in_channel_weight_both_channel(self, mesh_dim_0, mesh_dim_1):
|
def split_input_in_channel_weight_both_channel(self, mesh_dim_0, mesh_dim_1):
|
||||||
name = f'RS{mesh_dim_1} = RS{mesh_dim_0} x S{mesh_dim_0}S{mesh_dim_1}'
|
name = f'RS{mesh_dim_1} = RS{mesh_dim_0} x S{mesh_dim_0}S{mesh_dim_1}'
|
||||||
|
|
||||||
|
@ -267,6 +272,7 @@ class ConvStrategyGenerator(StrategyGenerator_V2):
|
||||||
sharding_spec_mapping=sharding_spec_mapping,
|
sharding_spec_mapping=sharding_spec_mapping,
|
||||||
communication_action_mapping=communication_action_mapping)
|
communication_action_mapping=communication_action_mapping)
|
||||||
|
|
||||||
|
@exception_handler
|
||||||
def split_input_in_channel_weight_in_channel(self, mesh_dim_0):
|
def split_input_in_channel_weight_in_channel(self, mesh_dim_0):
|
||||||
name = f'RR = RS{mesh_dim_0} x S{mesh_dim_0}R'
|
name = f'RR = RS{mesh_dim_0} x S{mesh_dim_0}R'
|
||||||
|
|
||||||
|
@ -297,6 +303,7 @@ class ConvStrategyGenerator(StrategyGenerator_V2):
|
||||||
sharding_spec_mapping=sharding_spec_mapping,
|
sharding_spec_mapping=sharding_spec_mapping,
|
||||||
communication_action_mapping=communication_action_mapping)
|
communication_action_mapping=communication_action_mapping)
|
||||||
|
|
||||||
|
@exception_handler
|
||||||
def split_weight_out_channel(self, mesh_dim_0):
|
def split_weight_out_channel(self, mesh_dim_0):
|
||||||
name = f'RS{mesh_dim_0} = RR x RS{mesh_dim_0}'
|
name = f'RS{mesh_dim_0} = RR x RS{mesh_dim_0}'
|
||||||
|
|
||||||
|
@ -329,6 +336,7 @@ class ConvStrategyGenerator(StrategyGenerator_V2):
|
||||||
sharding_spec_mapping=sharding_spec_mapping,
|
sharding_spec_mapping=sharding_spec_mapping,
|
||||||
communication_action_mapping=communication_action_mapping)
|
communication_action_mapping=communication_action_mapping)
|
||||||
|
|
||||||
|
@exception_handler
|
||||||
def non_split(self):
|
def non_split(self):
|
||||||
name = f'RR = RR x RR'
|
name = f'RR = RR x RR'
|
||||||
|
|
||||||
|
@ -347,6 +355,7 @@ class ConvStrategyGenerator(StrategyGenerator_V2):
|
||||||
sharding_spec_mapping=sharding_spec_mapping,
|
sharding_spec_mapping=sharding_spec_mapping,
|
||||||
communication_action_mapping={})
|
communication_action_mapping={})
|
||||||
|
|
||||||
|
@exception_handler
|
||||||
def split_1d_parallel_on_input_batch(self, mesh_dim_0, mesh_dim_1):
|
def split_1d_parallel_on_input_batch(self, mesh_dim_0, mesh_dim_1):
|
||||||
name = f'S{mesh_dim_0}{mesh_dim_1}R = S{mesh_dim_0}{mesh_dim_1}R x RR'
|
name = f'S{mesh_dim_0}{mesh_dim_1}R = S{mesh_dim_0}{mesh_dim_1}R x RR'
|
||||||
|
|
||||||
|
@ -384,6 +393,7 @@ class ConvStrategyGenerator(StrategyGenerator_V2):
|
||||||
sharding_spec_mapping=sharding_spec_mapping,
|
sharding_spec_mapping=sharding_spec_mapping,
|
||||||
communication_action_mapping=communication_action_mapping)
|
communication_action_mapping=communication_action_mapping)
|
||||||
|
|
||||||
|
@exception_handler
|
||||||
def split_1d_parallel_on_in_channel(self, mesh_dim_0, mesh_dim_1):
|
def split_1d_parallel_on_in_channel(self, mesh_dim_0, mesh_dim_1):
|
||||||
name = f'RR = RS{mesh_dim_0}{mesh_dim_1} x S{mesh_dim_0}{mesh_dim_1}R'
|
name = f'RR = RS{mesh_dim_0}{mesh_dim_1} x S{mesh_dim_0}{mesh_dim_1}R'
|
||||||
dim_partition_dict_mapping = {
|
dim_partition_dict_mapping = {
|
||||||
|
@ -413,6 +423,7 @@ class ConvStrategyGenerator(StrategyGenerator_V2):
|
||||||
sharding_spec_mapping=sharding_spec_mapping,
|
sharding_spec_mapping=sharding_spec_mapping,
|
||||||
communication_action_mapping=communication_action_mapping)
|
communication_action_mapping=communication_action_mapping)
|
||||||
|
|
||||||
|
@exception_handler
|
||||||
def split_1d_parallel_on_out_channel(self, mesh_dim_0, mesh_dim_1):
|
def split_1d_parallel_on_out_channel(self, mesh_dim_0, mesh_dim_1):
|
||||||
name = f'RS{mesh_dim_0}{mesh_dim_1} = RR x RS{mesh_dim_0}{mesh_dim_1}'
|
name = f'RS{mesh_dim_0}{mesh_dim_1} = RR x RS{mesh_dim_0}{mesh_dim_1}'
|
||||||
dim_partition_dict_mapping = {
|
dim_partition_dict_mapping = {
|
||||||
|
@ -482,10 +493,20 @@ class ConvStrategyGenerator(StrategyGenerator_V2):
|
||||||
# RS01 = RR x RS01
|
# RS01 = RR x RS01
|
||||||
strategies.append(self.split_1d_parallel_on_out_channel(0, 1))
|
strategies.append(self.split_1d_parallel_on_out_channel(0, 1))
|
||||||
|
|
||||||
|
rm_list = [strategy for strategy in strategies if strategy is None]
|
||||||
|
for rm_element in rm_list:
|
||||||
|
strategies.remove(rm_element)
|
||||||
|
illegal_strategy_list = []
|
||||||
# update mete info on cost
|
# update mete info on cost
|
||||||
for strategy in strategies:
|
for strategy in strategies:
|
||||||
self.update_communication_cost(strategy)
|
try:
|
||||||
self.update_compute_cost(strategy)
|
self.update_communication_cost(strategy)
|
||||||
self.update_memory_cost(strategy)
|
self.update_compute_cost(strategy)
|
||||||
|
self.update_memory_cost(strategy)
|
||||||
|
except AssertionError as e:
|
||||||
|
illegal_strategy_list.append(strategy)
|
||||||
|
warnings.warn(f'{e}')
|
||||||
|
for strategy in illegal_strategy_list:
|
||||||
|
strategies.remove(strategy)
|
||||||
|
|
||||||
return strategies
|
return strategies
|
||||||
|
|
|
@ -5,8 +5,10 @@ from colossalai.fx import ColoTracer, ColoGraphModule
|
||||||
from colossalai.auto_parallel.solver.op_handler.normal_pooling_handler import NormPoolingHandler
|
from colossalai.auto_parallel.solver.op_handler.normal_pooling_handler import NormPoolingHandler
|
||||||
from colossalai.auto_parallel.solver.sharding_strategy import OperationData, OperationDataType, StrategiesVector
|
from colossalai.auto_parallel.solver.sharding_strategy import OperationData, OperationDataType, StrategiesVector
|
||||||
from colossalai.device.device_mesh import DeviceMesh
|
from colossalai.device.device_mesh import DeviceMesh
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skip("for higher testing speed")
|
||||||
def test_norm_pool_handler():
|
def test_norm_pool_handler():
|
||||||
model = nn.Sequential(nn.MaxPool2d(4, padding=1).to('meta'))
|
model = nn.Sequential(nn.MaxPool2d(4, padding=1).to('meta'))
|
||||||
tracer = ColoTracer()
|
tracer = ColoTracer()
|
||||||
|
|
|
@ -2,7 +2,7 @@ import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from colossalai.fx import ColoTracer, ColoGraphModule
|
from colossalai.fx import ColoTracer, ColoGraphModule
|
||||||
from colossalai.auto_parallel.solver.op_handler.conv_handler_v2 import ConvFunctionHandler
|
from colossalai.auto_parallel.solver.op_handler.conv_handler_v2 import ConvFunctionHandler
|
||||||
from colossalai.auto_parallel.solver.op_handler.reshape_handler_v2 import ReshapeHandler
|
from colossalai.auto_parallel.solver.op_handler.reshape_handler_v2 import ReshapeHandler_V2
|
||||||
from colossalai.auto_parallel.solver.sharding_strategy import OperationData, OperationDataType, StrategiesVector
|
from colossalai.auto_parallel.solver.sharding_strategy import OperationData, OperationDataType, StrategiesVector
|
||||||
from colossalai.device.device_mesh import DeviceMesh
|
from colossalai.device.device_mesh import DeviceMesh
|
||||||
|
|
||||||
|
@ -48,9 +48,9 @@ def test_reshape_handler():
|
||||||
strategies_vector=conv_strategies_vector)
|
strategies_vector=conv_strategies_vector)
|
||||||
conv_handler.register_strategy(compute_resharding_cost=False)
|
conv_handler.register_strategy(compute_resharding_cost=False)
|
||||||
setattr(conv_mod_node, 'strategies_vector', conv_strategies_vector)
|
setattr(conv_mod_node, 'strategies_vector', conv_strategies_vector)
|
||||||
reshape_handler = ReshapeHandler(node=reshape_node,
|
reshape_handler = ReshapeHandler_V2(node=reshape_node,
|
||||||
device_mesh=device_mesh,
|
device_mesh=device_mesh,
|
||||||
strategies_vector=reshape_strategies_vector)
|
strategies_vector=reshape_strategies_vector)
|
||||||
|
|
||||||
reshape_handler.register_strategy(compute_resharding_cost=False)
|
reshape_handler.register_strategy(compute_resharding_cost=False)
|
||||||
|
|
||||||
|
|
|
@ -2,7 +2,7 @@ from colossalai.fx.tracer.meta_patch.patched_module import linear
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from colossalai.fx import ColoTracer, ColoGraphModule
|
from colossalai.fx import ColoTracer, ColoGraphModule
|
||||||
from colossalai.auto_parallel.solver.op_handler.unary_elementwise_handler_v2 import UnaryElementwiseHandler
|
from colossalai.auto_parallel.solver.op_handler.unary_elementwise_handler_v2 import UnaryElementwiseHandler_V2
|
||||||
from colossalai.auto_parallel.solver.op_handler.conv_handler_v2 import ConvFunctionHandler
|
from colossalai.auto_parallel.solver.op_handler.conv_handler_v2 import ConvFunctionHandler
|
||||||
from colossalai.auto_parallel.solver.sharding_strategy import OperationData, OperationDataType, StrategiesVector
|
from colossalai.auto_parallel.solver.sharding_strategy import OperationData, OperationDataType, StrategiesVector
|
||||||
from colossalai.device.device_mesh import DeviceMesh
|
from colossalai.device.device_mesh import DeviceMesh
|
||||||
|
@ -50,9 +50,9 @@ def test_elementwise_handler():
|
||||||
strategies_vector=conv_strategies_vector)
|
strategies_vector=conv_strategies_vector)
|
||||||
conv_handler.register_strategy(compute_resharding_cost=False)
|
conv_handler.register_strategy(compute_resharding_cost=False)
|
||||||
setattr(conv_mod_node, 'strategies_vector', conv_strategies_vector)
|
setattr(conv_mod_node, 'strategies_vector', conv_strategies_vector)
|
||||||
relu_handler = UnaryElementwiseHandler(node=relu_mod_node,
|
relu_handler = UnaryElementwiseHandler_V2(node=relu_mod_node,
|
||||||
device_mesh=device_mesh,
|
device_mesh=device_mesh,
|
||||||
strategies_vector=relu_strategies_vector)
|
strategies_vector=relu_strategies_vector)
|
||||||
|
|
||||||
relu_handler.register_strategy(compute_resharding_cost=False)
|
relu_handler.register_strategy(compute_resharding_cost=False)
|
||||||
|
|
||||||
|
|
|
@ -0,0 +1,99 @@
|
||||||
|
import torch
|
||||||
|
from torch.fx import GraphModule
|
||||||
|
import torch.nn as nn
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from colossalai.fx.tracer.tracer import ColoTracer
|
||||||
|
from colossalai.auto_parallel.solver.sharding_strategy import ShardingStrategy, StrategiesVector
|
||||||
|
from colossalai.tensor.shape_consistency import ShapeConsistencyManager
|
||||||
|
from colossalai.device.device_mesh import DeviceMesh
|
||||||
|
from colossalai.auto_parallel.solver.strategies_constructor import StrategiesConstructor_V2
|
||||||
|
from colossalai.auto_parallel.solver.cost_graph import CostGraph_V2
|
||||||
|
from copy import deepcopy
|
||||||
|
from colossalai.auto_parallel.solver.solver import Solver_V2
|
||||||
|
from torchvision.models import resnet34, resnet50
|
||||||
|
from colossalai.auto_parallel.solver.constants import *
|
||||||
|
from colossalai.auto_parallel.solver.graph_analysis import GraphAnalyser
|
||||||
|
from colossalai.auto_parallel.solver.options import SolverOptions
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skip("for higher testing speed")
|
||||||
|
def test_cost_graph():
|
||||||
|
physical_mesh_id = torch.arange(0, 8)
|
||||||
|
mesh_shape = (2, 4)
|
||||||
|
# [[0, 1]
|
||||||
|
# [2, 3]]
|
||||||
|
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape)
|
||||||
|
shape_consistency_manager = ShapeConsistencyManager()
|
||||||
|
|
||||||
|
tracer = ColoTracer()
|
||||||
|
model = resnet50(num_classes=100000)
|
||||||
|
input_sample = {'x': torch.rand(128, 3, 224, 224).to('meta')}
|
||||||
|
|
||||||
|
graph = tracer.trace(root=model, meta_args=input_sample)
|
||||||
|
# graph():
|
||||||
|
# %x : torch.Tensor [#users=1] = placeholder[target=x]
|
||||||
|
# %conv1 : [#users=1] = call_module[target=conv1](args = (%x,), kwargs = {})
|
||||||
|
# %bn1 : [#users=1] = call_module[target=bn1](args = (%conv1,), kwargs = {})
|
||||||
|
# %relu : [#users=1] = call_module[target=relu](args = (%bn1,), kwargs = {})
|
||||||
|
# %maxpool : [#users=2] = call_module[target=maxpool](args = (%relu,), kwargs = {})
|
||||||
|
# %layer1_0_conv1 : [#users=1] = call_module[target=layer1.0.conv1](args = (%maxpool,), kwargs = {})
|
||||||
|
# %layer1_0_bn1 : [#users=1] = call_module[target=layer1.0.bn1](args = (%layer1_0_conv1,), kwargs = {})
|
||||||
|
# %layer1_0_relu : [#users=1] = call_module[target=layer1.0.relu](args = (%layer1_0_bn1,), kwargs = {})
|
||||||
|
# %layer1_0_conv2 : [#users=1] = call_module[target=layer1.0.conv2](args = (%layer1_0_relu,), kwargs = {})
|
||||||
|
# %layer1_0_bn2 : [#users=1] = call_module[target=layer1.0.bn2](args = (%layer1_0_conv2,), kwargs = {})
|
||||||
|
# %add : [#users=1] = call_function[target=operator.add](args = (%layer1_0_bn2, %maxpool), kwargs = {})
|
||||||
|
# %layer1_0_relu_1 : [#users=2] = call_module[target=layer1.0.relu](args = (%add,), kwargs = {})
|
||||||
|
# %layer1_1_conv1 : [#users=1] = call_module[target=layer1.1.conv1](args = (%layer1_0_relu_1,), kwargs = {})
|
||||||
|
# %layer1_1_bn1 : [#users=1] = call_module[target=layer1.1.bn1](args = (%layer1_1_conv1,), kwargs = {})
|
||||||
|
# %layer1_1_relu : [#users=1] = call_module[target=layer1.1.relu](args = (%layer1_1_bn1,), kwargs = {})
|
||||||
|
# %layer1_1_conv2 : [#users=1] = call_module[target=layer1.1.conv2](args = (%layer1_1_relu,), kwargs = {})
|
||||||
|
# %layer1_1_bn2 : [#users=1] = call_module[target=layer1.1.bn2](args = (%layer1_1_conv2,), kwargs = {})
|
||||||
|
# %add_1 : [#users=1] = call_function[target=operator.add](args = (%layer1_1_bn2, %layer1_0_relu_1), kwargs = {})
|
||||||
|
# ...
|
||||||
|
# %avgpool : [#users=1] = call_module[target=avgpool](args = (%layer4_2_relu_1,), kwargs = {})
|
||||||
|
# %flatten : [#users=1] = call_function[target=torch.flatten](args = (%avgpool, 1), kwargs = {})
|
||||||
|
# %fc : [#users=1] = call_module[target=fc](args = (%flatten,), kwargs = {})
|
||||||
|
# return fc
|
||||||
|
gm = GraphModule(model, graph, model.__class__.__name__)
|
||||||
|
gm.recompile()
|
||||||
|
graph_analyser = GraphAnalyser(gm)
|
||||||
|
liveness_list = graph_analyser.liveness_analysis()
|
||||||
|
solver_options = SolverOptions(fast=True)
|
||||||
|
strategies_constructor = StrategiesConstructor_V2(graph, device_mesh, solver_options)
|
||||||
|
strategies_constructor.build_strategies_and_cost()
|
||||||
|
|
||||||
|
cost_graph = CostGraph_V2(strategies_constructor.leaf_strategies)
|
||||||
|
cost_graph.simplify_graph()
|
||||||
|
solver = Solver_V2(gm.graph, strategies_constructor, cost_graph, graph_analyser)
|
||||||
|
|
||||||
|
ret = solver.call_solver_serialized_args()
|
||||||
|
print(ret[0])
|
||||||
|
print(solver.last_s_val)
|
||||||
|
strategies_list = solver.last_s_val
|
||||||
|
|
||||||
|
computation_cost = 0
|
||||||
|
communication_cost = 0
|
||||||
|
communication_cost_bn = 0
|
||||||
|
memory_cost = 0
|
||||||
|
for index, node in enumerate(graph.nodes):
|
||||||
|
if node.op == 'call_module':
|
||||||
|
submod = node.graph.owning_module.get_submodule(node.target)
|
||||||
|
if type(submod) in BATCHNORM_MODULE_OP:
|
||||||
|
communication_cost_bn += node.strategies_vector[strategies_list[index]].communication_cost.total
|
||||||
|
print(node.name, node.strategies_vector[strategies_list[index]].name)
|
||||||
|
computation_cost += node.strategies_vector[strategies_list[index]].compute_cost.total
|
||||||
|
communication_cost += node.strategies_vector[strategies_list[index]].communication_cost.total
|
||||||
|
node_memory_cost = node.strategies_vector[strategies_list[index]].memory_cost.total
|
||||||
|
if isinstance(node_memory_cost, tuple):
|
||||||
|
node_memory_cost = node_memory_cost[0]
|
||||||
|
memory_cost += node_memory_cost.activation + node_memory_cost.parameter
|
||||||
|
|
||||||
|
print(f'computation cost is {computation_cost}')
|
||||||
|
print(f'communication cost is {communication_cost}')
|
||||||
|
print(f'memory cost is {memory_cost}')
|
||||||
|
print(f'bn communication cost is {communication_cost_bn}')
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
test_cost_graph()
|
Loading…
Reference in New Issue