[autoparallel] adapt solver and CostGraph with new handler (#1695)

* [autoparallel] adapt solver and CostGraph with new handler

* fix test issue
pull/1696/head
YuliangLiu0306 2022-10-13 14:04:15 +08:00 committed by GitHub
parent 42b882ef06
commit 81f7530ee7
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
14 changed files with 834 additions and 30 deletions

View File

@ -95,7 +95,8 @@ def exception_handler(func):
@functools.wraps(func)
def wrapper(*args, **kwargs):
try:
func(*args, **kwargs)
rst = func(*args, **kwargs)
return rst
except AssertionError as e:
warnings.warn(f'{e}')

View File

@ -170,3 +170,188 @@ class CostGraph:
for dst, src in self.following_dict.items():
reindexing_following_dict[dst] = self._reindexing_src(src)
self.following_dict = reindexing_following_dict
class CostGraph_V2:
'''
A graph data structure to simplify the edge cost graph. It has two main functions:
1. To feed the quadratic resharding costs into solver, we need to linearize it. We build edge_cost in
CostGraph, and it stored every combinations of strategies for a src-dst node pair in an 1D list.
2. To reduce the searching space, we merge computationally-trivial operators, such as
element-wise operators, transpose, and reduction, into their following nodes. The merging infomation will
be given by the StrategiesVector depending on the type of target node and following nodes.
Argument:
leaf_strategies(List[StrategiesVector]): It stores StrategiesVector of every nodes on the graph.
simplify(bool, optional): The generated cost graph will be simplified if it is true. (default to True)
'''
def __init__(self, leaf_strategies, simplify=True, forward_only=False):
self.leaf_strategies = leaf_strategies
self.nodes = [strategies_vector.node for strategies_vector in self.leaf_strategies]
# stores number of strategies in each node
self.node_lens = {strategies_vector.node: len(strategies_vector) for strategies_vector in self.leaf_strategies}
# extra_node_costs will store the extra costs introduced by merging nodes
self.extra_node_costs = {}
self.following_dict = {}
self.simplify = simplify
self.forward_only = forward_only
self._build_cost_graph()
def _remove_invalid_node(self, node, attr_name):
remove_list = []
target_node_list = getattr(node, attr_name, [])
for target_node in target_node_list:
if target_node not in self.nodes:
remove_list.append(target_node)
for element in remove_list:
target_node_list.remove(element)
def _build_cost_graph(self):
'''
This method will generate edge_cost for adjacent node pair. Additionally, 'parents' and 'children' attribute will be
set to node.
'''
self.edge_costs = {}
if self.simplify:
self.merge_pair = []
for strategies_vector in self.leaf_strategies:
# build edge_cost
dst_node = strategies_vector.node
for src_node in strategies_vector.predecessor_nodes:
if src_node not in self.nodes:
continue
node_pair = (src_node, dst_node)
# src_index = strategies_vector.predecessor_nodes.index(src_node)
edge_cost = {}
for i in range(len(strategies_vector)):
for j in range(len(src_node.strategies_vector)):
if strategies_vector[i].resharding_costs is None:
print(strategies_vector.node.name)
assert False
resharding_cost_item = strategies_vector[i].resharding_costs[src_node][j]
if self.forward_only:
edge_cost[(j, i)] = resharding_cost_item.fwd
else:
edge_cost[(j, i)] = resharding_cost_item.total
self.edge_costs[node_pair] = edge_cost
# add parents and children attribute to node
setattr(dst_node, 'parents', strategies_vector.predecessor_nodes)
setattr(dst_node, 'children', strategies_vector.successor_nodes)
self._remove_invalid_node(dst_node, 'parents')
self._remove_invalid_node(dst_node, 'children')
if self.simplify and strategies_vector.check_merge():
for followed_node in strategies_vector.predecessor_nodes:
self.merge_pair.append((followed_node, dst_node))
def get_edge_cost(self, src_node, dst_node):
return self.edge_costs[(src_node, dst_node)]
def merge_node(self, src_node, dst_node):
'''
To merge dst_node into src_node, we need to do it in following steps:
1. For each strategy in dst_node, we need to pick an appropriate strategy
of src_node to merge, it is important because the logical resharding costs
between the parents node of src_node and merged node depend on the src_node
strategies dispatching. For example, for the graph 0->1->2, after merging node 1
into node 2, edge_costs[(node 0, node 2)][(0, 0)] = edge_costs[(node 0, node 1)][(0, x)]
x represents the picking strategy of node 1 merged into node 2 strategy 0.
2. We need to accumulate the extra costs introduced by merging nodes, the extra costs
contains two parts, one is resharding costs between src_node strategy and dst_node strategy,
another is the origin extra costs in src_node strategy.
3. Build connections between new node pairs, and remove the src_node after all consumer nodes
detached from it.
Argument:
src_node(Node): The node will be merged into dst_node.
dst_node(Node): The node to integrate src_node.
'''
src_node_index = dst_node.parents.index(src_node)
# build merge_map
merge_map = {}
for src_index, strategy in enumerate(src_node.strategies_vector):
min_cost = INFINITY_COST
lowest_cost_index = -1
for dst_index, dst_strategy in enumerate(dst_node.strategies_vector):
resharding_cost_item = dst_strategy.resharding_costs[src_node][src_index]
if self.forward_only:
resharding_cost = resharding_cost_item.fwd
else:
resharding_cost = resharding_cost_item.total
if resharding_cost <= min_cost:
min_cost = resharding_cost
lowest_cost_index = dst_index
merge_map[src_index] = lowest_cost_index
# extra_node_cost for src node
self.extra_node_costs[src_node] = [0.0] * self.node_lens[src_node]
for src_index, strategy in enumerate(src_node.strategies_vector):
target_strate_index = merge_map[src_index]
target_strategy = dst_node.strategies_vector[target_strate_index]
resharding_cost_item = target_strategy.resharding_costs[src_node][src_index]
if self.forward_only:
resharding_cost_to_add = resharding_cost_item.fwd
else:
resharding_cost_to_add = resharding_cost_item.total
self.extra_node_costs[src_node][src_index] += resharding_cost_to_add
if dst_node in self.extra_node_costs:
self.extra_node_costs[src_node][src_index] += self.extra_node_costs[dst_node][target_strate_index]
# add new node pair to cost graph
for child_node in dst_node.children:
new_node_pair = (src_node, child_node)
old_node_pair = (dst_node, child_node)
if new_node_pair in self.edge_costs:
continue
edge_cost = {}
for i in range(self.node_lens[src_node]):
for j in range(self.node_lens[child_node]):
dst_strate_index = merge_map[i]
# dst_strategy = dst_node.strategies_vector[dst_strate_index]
edge_cost[(i, j)] = self.edge_costs[old_node_pair][(dst_strate_index, j)]
if new_node_pair not in self.edge_costs:
self.edge_costs[new_node_pair] = edge_cost
else:
# we should accumulate the resharding costs if args of child node contain
# both src node and dst node.
for index_pair, resharding_cost in self.edge_costs[new_node_pair]:
self.edge_costs[new_node_pair][index_pair] += edge_cost[index_pair]
# connect src node and children of dst node
dst_node.parents.remove(src_node)
src_node.children.remove(dst_node)
self.edge_costs.pop((src_node, dst_node))
for child_node in dst_node.children:
if child_node not in src_node.children:
src_node.children.append(child_node)
if src_node not in child_node.parents:
child_node.parents.append(src_node)
# remove dst node from cost graph when dst node has no producer.
if len(dst_node.parents) == 0:
child_node.parents.remove(dst_node)
node_pair = (dst_node, child_node)
self.edge_costs.pop(node_pair)
if len(dst_node.parents) == 0:
self.following_dict[dst_node] = src_node
dst_node.children = []
def _reindexing_src(self, src):
if src not in self.following_dict:
return src
return self._reindexing_src(self.following_dict[src])
def simplify_graph(self):
if not self.simplify:
return
self.merge_pair.reverse()
for (src_node, dst_node) in self.merge_pair:
self.merge_node(src_node, dst_node)
self.merge_pair.reverse()
reindexing_following_dict = {}
for dst, src in self.following_dict.items():
reindexing_following_dict[dst] = self._reindexing_src(src)
self.following_dict = reindexing_following_dict

View File

@ -9,9 +9,16 @@ from .unary_elementwise_handler import UnaryElementwiseHandler
from .dot_handler_v2 import LinearFunctionHandler, LinearModuleHandler
from .layer_norm_handler_v2 import LayerNormModuleHandler
from .batch_norm_handler_v2 import BatchNormModuleHandler
from .conv_handler_v2 import ConvModuleHandler, ConvFunctionHandler
from .where_handler_v2 import WhereHandler
from .unary_elementwise_handler_v2 import UnaryElementwiseHandler_V2
from .reshape_handler_v2 import ReshapeHandler_V2
from .placeholder_handler import PlacehodlerHandler
from .output_handler import OuputHandler
__all__ = [
'OperatorHandler', 'DotHandler', 'ConvHandler', 'BatchNormHandler', 'ReshapeHandler', 'BcastOpHandler',
'UnaryElementwiseHandler', 'EmbeddingHandler', 'LinearFunctionHandler', 'LinearModuleHandler',
'LayerNormModuleHandler', 'BatchNormModuleHandler'
'LayerNormModuleHandler', 'BatchNormModuleHandler', 'ConvModuleHandler', 'ConvFunctionHandler',
'UnaryElementwiseHandler_V2', 'ReshapeHandler_V2', 'PlacehodlerHandler', 'OuputHandler', 'WhereHandler'
]

View File

@ -40,7 +40,7 @@ class ConvModuleHandler(ModuleHandler):
mapping = {"input": physical_input_operand, "other": physical_other_operand, "output": physical_output}
if self.named_parameters['bias'] is not None:
if "bias" in self.named_parameters:
physical_bias_operand = OperationData(name="bias",
type=OperationDataType.PARAM,
data=self.named_parameters['bias'])
@ -53,7 +53,6 @@ class ConvModuleHandler(ModuleHandler):
"""
for op_data, sharding_spec in strategy.input_sharding_specs.items():
if op_data.name == "weight":
assert op_data.logical_shape != op_data.data.shape
dim_partition_dict = sharding_spec.dim_partition_dict
# switch first and second dim of the conv module weight

View File

@ -6,12 +6,13 @@ from typing import List, Dict
from .registry import operator_registry
import operator
__all__ = ['ReshapeHandler']
__all__ = ['ReshapeHandler_V2']
@operator_registry.register(torch.reshape)
@operator_registry.register(torch.flatten)
@operator_registry.register(torch.Tensor.permute)
class ReshapeHandler(NodeHandler):
class ReshapeHandler_V2(NodeHandler):
"""
A ReshapeHandler which deals with the sharding strategies for Reshape Op, such as torch.reshape.
"""

View File

@ -6,12 +6,12 @@ from typing import List, Dict
from .registry import operator_registry
import operator
__all__ = ['UnaryElementwiseHandler']
__all__ = ['UnaryElementwiseHandler_V2']
@operator_registry.register(torch.abs)
@operator_registry.register(torch.nn.ReLU)
class UnaryElementwiseHandler(NodeHandler):
class UnaryElementwiseHandler_V2(NodeHandler):
"""
A UnaryElementwiseHandler which deals with the sharding strategies for UnaryElementwise Op.
"""

View File

@ -465,3 +465,464 @@ class Solver:
ret_list.append(ret)
return ret_list
class Solver_V2:
def __init__(self,
graph: Graph,
strategies_constructor: StrategiesConstructor,
cost_graph: CostGraph,
graph_analyser: GraphAnalyser,
memory_budget: float = -1.0,
solution_numbers: int = 1,
forward_only: bool = False,
memory_increasing_coefficient: float = 1.3):
'''
Solver class will integrate information provided by the components and use ILP solver to find a possible optimal strategies combination for target computing graph.
Argument:
graph: The computing graph to be optimized.
strategies_constructor: It will provide all the possible strategies for each node in the computing graph.
cost_graph: A graph data structure to simplify the edge cost graph.
graph_analyser: graph_analyser will analyse the graph to obtain the variable liveness information, which will be used to generate memory constraints.
memory_budget: Memory constraint for the solution.
solution_numbers: If solution_numbers is larger than one, solver will us a serious of solutions based on different memory budget.
memory_increasing_coefficient: If solution_numbers is larger than one, we will use this coefficient to generate new memory budget.
'''
self.graph = graph
self.strategies_constructor = strategies_constructor
self.cost_graph = cost_graph
self.graph_analyser = graph_analyser
self.leaf_strategies = self.strategies_constructor.leaf_strategies
self.nodes = [strategies_vector.node for strategies_vector in self.leaf_strategies]
self.strategy_map = self.strategies_constructor.strategy_map
self.memory_budget = memory_budget
self.solution_numbers = solution_numbers
self.forward_only = forward_only
if self.solution_numbers > 1:
self.memory_increasing_coefficient = memory_increasing_coefficient
else:
self.memory_increasing_coefficient = 1
self.liveness_list = self.graph_analyser.liveness_analysis()
self.node_index_dict = self._generate_node_index_dict()
# The last solution vector of auto sharding.
self.last_s_val = None
# The last objective value of the best ILP solution.
self.last_objective = None
def _recover_merged_node_strategy(self):
'''
During cost graph constructing, some nodes, such as unary element-wise node or ReshapeOp, were merged into the previous node.
Therefore, the index of those strategies are copied from the previous node. This method is used to recover the strategy index of those merged
node.
'''
for node_index, node in enumerate(self.nodes):
if node.strategies_vector.check_merge():
# the merged node has only one input, and its strategies follow the input sharding strategy
input_strategies_vector = node.args[0].strategies_vector
input_best_strategy_index = self.last_s_val[node_index - 1]
input_sharding_spec = input_strategies_vector[input_best_strategy_index].output_sharding_spec
for strategy_index, strategy in enumerate(node.strategies_vector):
if strategy.input_shardings[0].sharding_sequence == input_sharding_spec.sharding_sequence:
self.last_s_val[node_index] = strategy_index
break
def _generate_node_index_dict(self) -> Dict[Node, int]:
node_index_dict = {}
for index, strategies_vector in enumerate(self.leaf_strategies):
node_index_dict[strategies_vector.node] = index
return node_index_dict
def _prepare_data_for_solver(self):
'''
Extract information from components for solver.
'''
node_nums = len(self.leaf_strategies)
memory_budget = self.memory_budget
# prepare strategies_len
strategies_len = []
for node in self.nodes:
strategies_len.append(self.cost_graph.node_lens[node])
strategies_len = np.array(strategies_len)
# prepare following_nodes
following_nodes = self.cost_graph.following_dict
index_following_nodes = {}
for src, target in following_nodes.items():
src_index = self.node_index_dict[src]
target_index = self.node_index_dict[target]
index_following_nodes[src_index] = target_index
following_nodes = index_following_nodes
for index in range(node_nums):
if index not in following_nodes:
following_nodes[index] = -1
# prepare edge_pairs and resharding costs
edge_pairs = []
resharding_costs = []
for pairs, edge_cost in self.cost_graph.edge_costs.items():
src_node = pairs[0]
dst_node = pairs[1]
src_node_index = self.node_index_dict[src_node]
dst_node_index = self.node_index_dict[dst_node]
edge_pairs.append(src_node_index)
edge_pairs.append(dst_node_index)
for i in range(strategies_len[src_node_index]):
for j in range(strategies_len[dst_node_index]):
resharding_costs.append(edge_cost[(i, j)])
edge_pairs = np.array(edge_pairs)
resharding_costs = np.array(resharding_costs)
# prepare liveness_set
liveness_set = self.liveness_list
# omit alias_set now
alias_set = None
alias_convert_costs = None
# prepare compute_costs, communication_costs and memory_costs
compute_costs = []
communication_costs = []
memory_costs = []
extra_node_costs = self.cost_graph.extra_node_costs
for strategies_vector in self.leaf_strategies:
node = strategies_vector.node
for index, strategy in enumerate(strategies_vector):
compute_cost_item = strategy.compute_cost
communication_cost_item = strategy.communication_cost
memory_cost_item = strategy.memory_cost
if self.forward_only:
origin_communication_cost = communication_cost_item.fwd
compute_cost = compute_cost_item.fwd
memory_cost = memory_cost_item.fwd
else:
origin_communication_cost = communication_cost_item.total
compute_cost = compute_cost_item.total
memory_cost = memory_cost_item.total
compute_costs.append(compute_cost)
# node in extra_node_costs means it has some extra communication
# cost from node merging, so we need to add those extra communication
# cost into
if node in extra_node_costs:
extra_node_cost = extra_node_costs[node][index]
communication_cost = origin_communication_cost + extra_node_cost
communication_costs.append(communication_cost)
else:
communication_costs.append(origin_communication_cost)
memory_costs.append(memory_cost)
# if isinstance(memory_cost, tuple):
# memory_costs.append(memory_cost[0])
# else:
# memory_costs.append(memory_cost)
compute_costs = np.array(compute_costs)
communication_costs = np.array(communication_costs)
memory_costs = np.array(memory_costs)
# omit initial value for nodes
s_init_np = None
return node_nums, memory_budget, strategies_len, following_nodes, edge_pairs, alias_set, liveness_set, compute_costs, communication_costs, memory_costs, resharding_costs, alias_convert_costs, s_init_np
def _call_solver_serialized_args(self,
node_nums,
memory_budget,
strategies_len,
following_nodes,
edge_pairs,
alias_set,
liveness_set,
compute_costs,
communication_costs,
memory_costs,
resharding_costs,
alias_convert_costs,
s_init_np=None):
"""
Call the solver with serialized arguments.
"""
tic = time.time()
for x in [strategies_len, edge_pairs, compute_costs, communication_costs, memory_costs, resharding_costs]:
assert isinstance(x, np.ndarray)
assert len(strategies_len) == node_nums, "strategies_len"
def get_non_zero_index(binary_vector):
"""
Get the index of non-zero item in a vector.
"""
ct = 0
ret = None
for i, elem in enumerate(binary_vector):
if pulp.value(elem):
ret = i
ct += 1
assert ct == 1
return ret
# 0. Unpack flatten numpy arrays
s_follow = following_nodes
E = edge_pairs.reshape((-1, 2)) # noqa
r = []
pt = 0
edge_set = set()
for (i, j) in E:
prod_length = strategies_len[i] * strategies_len[j]
if (i, j) in edge_set:
raise ValueError(f"Duplicated edges: {(i, j)}")
edge_set.add((i, j))
r.append(resharding_costs[pt:pt + prod_length])
pt += prod_length
assert pt == len(resharding_costs)
######################
# omit alias set now #
######################
# A = alias_set.reshape((-1, 2)) # noqa
# for (i, j) in A:
# prod_length = strategies_len[i] * strategies_len[j]
# v.append(alias_convert_costs[pt:pt + prod_length])
# pt += prod_length
# assert pt == len(alias_convert_costs)
# L = [] # noqa
# pt = node_nums
# for i in range(node_nums):
# length = liveness_set[i]
# L.append(liveness_set[pt:pt + length])
# pt += length
# assert pt == len(liveness_set)
v = []
pt = 0
c = []
d = []
m = []
pt = 0
for i in range(node_nums):
length = strategies_len[i]
c.append(compute_costs[pt:pt + length])
d.append(communication_costs[pt:pt + length])
m.append(memory_costs[pt:pt + length])
pt += length
assert pt == len(compute_costs), f"{pt} == {len(compute_costs)}"
assert pt == len(communication_costs), f"{pt} == {len(communication_costs)}"
assert pt == len(memory_costs), f"{pt} == {len(memory_costs)}"
# 1. Create variables
#############################
# create variables for node #
#############################
s = []
num_nodes = 0
reverse_follow_backpatch = []
for i in range(node_nums):
if s_follow[i] < 0:
if strategies_len[i] == 1:
s.append([1])
else:
num_nodes += 1
s.append(LpVariable.matrix(f"s[{i}]", (range(strategies_len[i]),), cat="Binary"))
else:
if s_follow[i] < len(s):
s.append(s[s_follow[i]])
else:
s.append(None)
reverse_follow_backpatch.append(i)
for i in reverse_follow_backpatch:
s[i] = s[s_follow[i]]
#############################
# create variables for edge #
#############################
e = []
num_edges = 0
for (idx, (i, j)) in enumerate(E):
if len(s[i]) == 1:
e.append(s[j])
elif len(s[j]) == 1:
e.append(s[i])
else:
num_edges += 1
e.append(LpVariable.matrix(f"e[{i},{j}]", (range(len(s[i]) * len(s[j])),), cat="Binary"))
assert len(e[idx]) == len(r[idx])
for element in s:
assert len(element) > 0
# 2. Set initial value
######################################
# set a initial value for warm start #
######################################
if s_init_np is not None:
s_init = s_init_np.reshape((-1, 3))
for (idx, value, fix) in s_init:
for i in range(len(s[idx])):
s[idx][i].setInitialValue(i == value)
if fix:
s[idx][i].fixValue()
# 3. Objective
prob = LpProblem("myProblem", LpMinimize)
###################################################################
# computing the node cost(computing cost and communication cost) #
###################################################################
obj = 0
for i in range(node_nums):
assert len(s[i]) == len(c[i])
assert len(s[i]) == len(d[i])
obj += lpDot(s[i], c[i]) + lpDot(s[i], d[i])
#############################################
# computing the edge cost(resharding cost) #
#############################################
for i in range(len(E)):
assert len(e[i]) == len(r[i])
obj += lpDot(e[i], r[i])
prob += obj
# 4. Constraints
# (a). specified by `cat="Binary"`
# (b)
#################################################
# make sure each node only choose one strategy #
#################################################
for i in range(node_nums):
if s_follow[i] < 0:
prob += lpSum(s[i]) == 1
# (c)
#################################################
# compute memory consumption with liveness set #
#################################################
if memory_budget > 0:
for liveness_stage in liveness_set:
mem = 0
for live_variable in liveness_stage.unique_live_vars:
node_index = self.node_index_dict[live_variable.node]
mem += lpSum(s[node_index][j] * m[node_index][j] for j in range(len(s[node_index])))
prob += mem <= memory_budget
# (d). specified by `cat="Binary"`
for (idx, (i, j)) in enumerate(E):
if strategies_len[i] == 1 or strategies_len[j] == 1:
continue
# (e)
prob += lpSum(e[idx]) == 1
# (f)
for row in range(len(s[i])):
C = len(s[j]) # noqa
prob += lpSum(e[idx][row * C + col] for col in range(0, C)) <= s[i][row]
# (g)
for col in range(len(s[j])):
R = len(s[i]) # noqa
C = len(s[j]) # noqa
prob += lpSum(e[idx][row * C + col] for row in range(0, R)) <= s[j][col]
# (h)
######################
# omit alias set now #
######################
# alias_set = set()
# for (idx, (i, j)) in enumerate(A):
# R = len(s[i]) # noqa
# C = len(s[j]) # noqa
# if (i, j) in alias_set:
# raise ValueError(f"Duplicated edges: {(i, j)}")
# alias_set.add((i, j))
# alias_set.add((j, i))
# for row in range(len(s[i])):
# for col in range(len(s[j])):
# if v[idx][row * C + col] > 0.5:
# prob += s[i][row] + s[j][col] <= 1
verbose = True
msg = verbose
time_limit = 600
assert "COIN_CMD" in pulp.listSolvers(
onlyAvailable=True), ("Please install ILP solvers by 'sudo apt install coinor-cbc'")
solver = pulp.COIN_CMD(mip=True, msg=msg, timeLimit=time_limit, threads=multiprocessing.cpu_count())
# solver = pulp.GLPK_CMD(mip=True, msg=msg, timeLimit=time_limit)
prob.solve(solver)
status = prob.status
objective = pulp.value(prob.objective)
objective = float(objective) if objective is not None else -1.0
if verbose:
print(f"ILP Status: {LpStatus[status]}\tObjective: {objective}\t"
f"Time: {time.time() - tic}")
print(f"#nodes: {num_nodes}, #edges: {num_edges}")
if prob.status in [pulp.LpStatusInfeasible]:
raise RuntimeError("Cannot run the function under the given memory budget. "
"Please increase the memory budget.")
# Get and check results
s_val = np.full((node_nums,), -1, dtype=np.int32)
for i in range(node_nums):
s_val[i] = get_non_zero_index(s[i])
e_val = np.full((len(E),), -1, dtype=np.int32)
for (idx, (i, j)) in enumerate(E):
e_val[idx] = get_non_zero_index(e[idx])
i_spec_index = e_val[idx] // len(s[j])
j_spec_index = e_val[idx] % len(s[j])
assert i_spec_index == s_val[i], f"e_val[{i}][{j}]"
assert j_spec_index == s_val[j], f"e_val[{i}][{j}]"
if verbose and r[idx][e_val[idx]] > 0:
print(f"Edge cost {(i, j)} : {r[idx][e_val[idx]]}")
self.last_s_val = list(s_val)
# self._recover_merged_node_strategy()
self.last_objective = objective
if objective > INFINITY_COST:
warnings.warn("Detect unexpected behaviors in the auto-sharding pass.")
return self.last_s_val, e_val, self.last_objective, status
def call_solver_serialized_args(self):
"""
Call the solver with serialized arguments and handle python errors. Additionally,
we could give a serious of solutions with different memory budget.
"""
if self.solution_numbers == 1:
args = self._prepare_data_for_solver()
ret = self._call_solver_serialized_args(*args)
return ret
origin_memory_budget = self.memory_budget
memory_budget_list = [
origin_memory_budget * self.memory_increasing_coefficient**i for i in range(self.solution_numbers)
]
ret_list = []
for memory_budget in memory_budget_list:
self.memory_budget = memory_budget
args = self._prepare_data_for_solver()
ret = self._call_solver_serialized_args(*args)
ret_list.append(ret)
return ret_list

View File

@ -1,10 +1,13 @@
from torch.fx import Graph, Node
from colossalai.auto_parallel.solver.op_handler.bcast_op_handler import BcastOpHandler
from colossalai.auto_parallel.solver.op_handler.layer_norm_handler import LayerNormHandler
from colossalai.auto_parallel.solver.sharding_strategy import ShardingStrategy_V2
from colossalai.tensor.sharding_spec import ShardingSpec
from colossalai.device.device_mesh import DeviceMesh
from colossalai.tensor.shape_consistency import ShapeConsistencyManager
from colossalai.auto_parallel.solver.op_handler.registry import operator_registry
from colossalai.auto_parallel.solver.op_handler.placeholder_handler import PlacehodlerHandler
from colossalai.auto_parallel.solver.op_handler.output_handler import OuputHandler
from .options import SolverOptions
from . import ShardingStrategy, StrategiesVector
from .op_handler import *
@ -414,7 +417,6 @@ class StrategiesConstructor:
self.leaf_strategies.append(strategies_vector)
self.strategy_map[node] = strategies_vector
# remove no strategy nodes
remove_list = []
for strategies_vector in self.leaf_strategies:
@ -456,6 +458,10 @@ class StrategiesConstructor_V2:
name_checklist = []
remove_list = []
for strategy in strategies_vector:
if strategy is None:
print(strategies_vector.node.name)
print(strategies_vector)
assert False
if strategy.name not in name_checklist:
name_checklist.append(strategy.name)
else:
@ -469,16 +475,32 @@ class StrategiesConstructor_V2:
"""
for node in self.nodes:
strategies_vector = StrategiesVector(node)
# placeholder node
if node.op == 'placeholder':
# TODO: implement placeholder node handler
pass
placeholder_handler = PlacehodlerHandler(node, self.device_mesh, strategies_vector)
placeholder_handler.register_strategy()
# get_attr node
elif node.op == 'get_attr':
# TODO: implement getattr node handler
pass
if node.op == 'get_attr':
# Same as placeholder nodes, if solver_options.fast is True, we just let them in
# fully replicate status, then strategies of following node will be treated equally due
# to replicate status has no resharding cost to other status. At the same time, the searching
# space is smaller than enumerating all the possible sharding spec for the get_attr node.
# Otherwise, all the possible sharding spec for the get_attr node will be enumerated.
if self.solver_options.fast:
# create sharding strategy for get_attr
name = 'Replica Attribute'
dim_partition_dict = {}
output_sharding_spec = generate_sharding_spec(node, self.device_mesh, dim_partition_dict)
# TODO: use meta_info_prop to profile memory cost
memory_cost = 0
sharding_strategy_attribute = ShardingStrategy(name, output_sharding_spec, memory_cost=memory_cost)
strategies_vector.append(sharding_strategy_attribute)
# # get_attr node
# elif node.op == 'get_attr':
# # TODO: implement getattr node handler
# pass
# call_module node
elif node.op == 'call_module':
@ -502,11 +524,13 @@ class StrategiesConstructor_V2:
# output node
elif node.op == 'output':
# TODO: implement output node handler
pass
output_handler = OuputHandler(node, self.device_mesh, strategies_vector)
output_handler.register_strategy()
if len(strategies_vector) <= 0:
print(node.name)
assert len(strategies_vector) > 0
self.remove_duplicated_strategy(strategies_vector)
setattr(node, 'strategies_vector', strategies_vector)
self.leaf_strategies.append(strategies_vector)
self.strategy_map[node] = strategies_vector

View File

@ -8,10 +8,14 @@ from .layer_norm_generator import LayerNormGenerator
from .where_generator import WhereGenerator
from .reshape_generator import ReshapeGenerator
from .normal_pooling_generator import NormalPoolStrategyGenerator
from .placeholder_generator import PlaceholderGenerator
from .output_generator import OutputGenerator
__all__ = [
'StrategyGenerator_V2', 'DotProductStrategyGenerator', 'MatVecStrategyGenerator',
'LinearProjectionStrategyGenerator', 'BatchedMatMulStrategyGenerator', 'ConvStrategyGenerator',
'UnaryElementwiseGenerator', 'BatchNormStrategyGenerator', 'GetItemStrategyGenerator', 'TensorStrategyGenerator',
'TensorTupleStrategyGenerator', 'LayerNormGenerator', "WhereGenerator", 'ReshapeGenerator', 'NormalPoolStrategyGenerator'
'TensorTupleStrategyGenerator', 'LayerNormGenerator', 'ReshapeGenerator', 'PlaceholderGenerator', 'OutputGenerator',
'WhereGenerator', 'ReshapeGenerator', 'NormalPoolStrategyGenerator'
]

View File

@ -5,6 +5,7 @@ from colossalai.tensor.shape_consistency import CollectiveCommPattern
from .strategy_generator import StrategyGenerator_V2
from typing import List
from .._utils import exception_handler
import warnings
import copy
@ -100,6 +101,7 @@ class ConvStrategyGenerator(StrategyGenerator_V2):
memory_cost = TrainCycleItem(fwd=fwd_mem_cost, bwd=bwd_mem_cost, total=total_mem_cost)
strategy.memory_cost = memory_cost
@exception_handler
def split_input_batch_weight_out_channel(self, mesh_dim_0, mesh_dim_1):
name = f'S{mesh_dim_0}S{mesh_dim_1} = S{mesh_dim_0}R x RS{mesh_dim_1}'
@ -146,6 +148,7 @@ class ConvStrategyGenerator(StrategyGenerator_V2):
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping)
@exception_handler
def split_input_batch(self, mesh_dim_0):
name = f'S{mesh_dim_0}R = S{mesh_dim_0}R x RR'
@ -182,6 +185,7 @@ class ConvStrategyGenerator(StrategyGenerator_V2):
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping)
@exception_handler
def split_input_both_dim_weight_in_channel(self, mesh_dim_0, mesh_dim_1):
name = f'S{mesh_dim_0}R = S{mesh_dim_0}S{mesh_dim_1} x S{mesh_dim_1}R'
@ -228,6 +232,7 @@ class ConvStrategyGenerator(StrategyGenerator_V2):
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping)
@exception_handler
def split_input_in_channel_weight_both_channel(self, mesh_dim_0, mesh_dim_1):
name = f'RS{mesh_dim_1} = RS{mesh_dim_0} x S{mesh_dim_0}S{mesh_dim_1}'
@ -267,6 +272,7 @@ class ConvStrategyGenerator(StrategyGenerator_V2):
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping)
@exception_handler
def split_input_in_channel_weight_in_channel(self, mesh_dim_0):
name = f'RR = RS{mesh_dim_0} x S{mesh_dim_0}R'
@ -297,6 +303,7 @@ class ConvStrategyGenerator(StrategyGenerator_V2):
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping)
@exception_handler
def split_weight_out_channel(self, mesh_dim_0):
name = f'RS{mesh_dim_0} = RR x RS{mesh_dim_0}'
@ -329,6 +336,7 @@ class ConvStrategyGenerator(StrategyGenerator_V2):
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping)
@exception_handler
def non_split(self):
name = f'RR = RR x RR'
@ -347,6 +355,7 @@ class ConvStrategyGenerator(StrategyGenerator_V2):
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping={})
@exception_handler
def split_1d_parallel_on_input_batch(self, mesh_dim_0, mesh_dim_1):
name = f'S{mesh_dim_0}{mesh_dim_1}R = S{mesh_dim_0}{mesh_dim_1}R x RR'
@ -384,6 +393,7 @@ class ConvStrategyGenerator(StrategyGenerator_V2):
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping)
@exception_handler
def split_1d_parallel_on_in_channel(self, mesh_dim_0, mesh_dim_1):
name = f'RR = RS{mesh_dim_0}{mesh_dim_1} x S{mesh_dim_0}{mesh_dim_1}R'
dim_partition_dict_mapping = {
@ -413,6 +423,7 @@ class ConvStrategyGenerator(StrategyGenerator_V2):
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping)
@exception_handler
def split_1d_parallel_on_out_channel(self, mesh_dim_0, mesh_dim_1):
name = f'RS{mesh_dim_0}{mesh_dim_1} = RR x RS{mesh_dim_0}{mesh_dim_1}'
dim_partition_dict_mapping = {
@ -482,10 +493,20 @@ class ConvStrategyGenerator(StrategyGenerator_V2):
# RS01 = RR x RS01
strategies.append(self.split_1d_parallel_on_out_channel(0, 1))
rm_list = [strategy for strategy in strategies if strategy is None]
for rm_element in rm_list:
strategies.remove(rm_element)
illegal_strategy_list = []
# update mete info on cost
for strategy in strategies:
self.update_communication_cost(strategy)
self.update_compute_cost(strategy)
self.update_memory_cost(strategy)
try:
self.update_communication_cost(strategy)
self.update_compute_cost(strategy)
self.update_memory_cost(strategy)
except AssertionError as e:
illegal_strategy_list.append(strategy)
warnings.warn(f'{e}')
for strategy in illegal_strategy_list:
strategies.remove(strategy)
return strategies

View File

@ -5,8 +5,10 @@ from colossalai.fx import ColoTracer, ColoGraphModule
from colossalai.auto_parallel.solver.op_handler.normal_pooling_handler import NormPoolingHandler
from colossalai.auto_parallel.solver.sharding_strategy import OperationData, OperationDataType, StrategiesVector
from colossalai.device.device_mesh import DeviceMesh
import pytest
@pytest.mark.skip("for higher testing speed")
def test_norm_pool_handler():
model = nn.Sequential(nn.MaxPool2d(4, padding=1).to('meta'))
tracer = ColoTracer()

View File

@ -2,7 +2,7 @@ import torch
import torch.nn as nn
from colossalai.fx import ColoTracer, ColoGraphModule
from colossalai.auto_parallel.solver.op_handler.conv_handler_v2 import ConvFunctionHandler
from colossalai.auto_parallel.solver.op_handler.reshape_handler_v2 import ReshapeHandler
from colossalai.auto_parallel.solver.op_handler.reshape_handler_v2 import ReshapeHandler_V2
from colossalai.auto_parallel.solver.sharding_strategy import OperationData, OperationDataType, StrategiesVector
from colossalai.device.device_mesh import DeviceMesh
@ -48,9 +48,9 @@ def test_reshape_handler():
strategies_vector=conv_strategies_vector)
conv_handler.register_strategy(compute_resharding_cost=False)
setattr(conv_mod_node, 'strategies_vector', conv_strategies_vector)
reshape_handler = ReshapeHandler(node=reshape_node,
device_mesh=device_mesh,
strategies_vector=reshape_strategies_vector)
reshape_handler = ReshapeHandler_V2(node=reshape_node,
device_mesh=device_mesh,
strategies_vector=reshape_strategies_vector)
reshape_handler.register_strategy(compute_resharding_cost=False)

View File

@ -2,7 +2,7 @@ from colossalai.fx.tracer.meta_patch.patched_module import linear
import torch
import torch.nn as nn
from colossalai.fx import ColoTracer, ColoGraphModule
from colossalai.auto_parallel.solver.op_handler.unary_elementwise_handler_v2 import UnaryElementwiseHandler
from colossalai.auto_parallel.solver.op_handler.unary_elementwise_handler_v2 import UnaryElementwiseHandler_V2
from colossalai.auto_parallel.solver.op_handler.conv_handler_v2 import ConvFunctionHandler
from colossalai.auto_parallel.solver.sharding_strategy import OperationData, OperationDataType, StrategiesVector
from colossalai.device.device_mesh import DeviceMesh
@ -50,9 +50,9 @@ def test_elementwise_handler():
strategies_vector=conv_strategies_vector)
conv_handler.register_strategy(compute_resharding_cost=False)
setattr(conv_mod_node, 'strategies_vector', conv_strategies_vector)
relu_handler = UnaryElementwiseHandler(node=relu_mod_node,
device_mesh=device_mesh,
strategies_vector=relu_strategies_vector)
relu_handler = UnaryElementwiseHandler_V2(node=relu_mod_node,
device_mesh=device_mesh,
strategies_vector=relu_strategies_vector)
relu_handler.register_strategy(compute_resharding_cost=False)

View File

@ -0,0 +1,99 @@
import torch
from torch.fx import GraphModule
import torch.nn as nn
import pytest
from colossalai.fx.tracer.tracer import ColoTracer
from colossalai.auto_parallel.solver.sharding_strategy import ShardingStrategy, StrategiesVector
from colossalai.tensor.shape_consistency import ShapeConsistencyManager
from colossalai.device.device_mesh import DeviceMesh
from colossalai.auto_parallel.solver.strategies_constructor import StrategiesConstructor_V2
from colossalai.auto_parallel.solver.cost_graph import CostGraph_V2
from copy import deepcopy
from colossalai.auto_parallel.solver.solver import Solver_V2
from torchvision.models import resnet34, resnet50
from colossalai.auto_parallel.solver.constants import *
from colossalai.auto_parallel.solver.graph_analysis import GraphAnalyser
from colossalai.auto_parallel.solver.options import SolverOptions
@pytest.mark.skip("for higher testing speed")
def test_cost_graph():
physical_mesh_id = torch.arange(0, 8)
mesh_shape = (2, 4)
# [[0, 1]
# [2, 3]]
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape)
shape_consistency_manager = ShapeConsistencyManager()
tracer = ColoTracer()
model = resnet50(num_classes=100000)
input_sample = {'x': torch.rand(128, 3, 224, 224).to('meta')}
graph = tracer.trace(root=model, meta_args=input_sample)
# graph():
# %x : torch.Tensor [#users=1] = placeholder[target=x]
# %conv1 : [#users=1] = call_module[target=conv1](args = (%x,), kwargs = {})
# %bn1 : [#users=1] = call_module[target=bn1](args = (%conv1,), kwargs = {})
# %relu : [#users=1] = call_module[target=relu](args = (%bn1,), kwargs = {})
# %maxpool : [#users=2] = call_module[target=maxpool](args = (%relu,), kwargs = {})
# %layer1_0_conv1 : [#users=1] = call_module[target=layer1.0.conv1](args = (%maxpool,), kwargs = {})
# %layer1_0_bn1 : [#users=1] = call_module[target=layer1.0.bn1](args = (%layer1_0_conv1,), kwargs = {})
# %layer1_0_relu : [#users=1] = call_module[target=layer1.0.relu](args = (%layer1_0_bn1,), kwargs = {})
# %layer1_0_conv2 : [#users=1] = call_module[target=layer1.0.conv2](args = (%layer1_0_relu,), kwargs = {})
# %layer1_0_bn2 : [#users=1] = call_module[target=layer1.0.bn2](args = (%layer1_0_conv2,), kwargs = {})
# %add : [#users=1] = call_function[target=operator.add](args = (%layer1_0_bn2, %maxpool), kwargs = {})
# %layer1_0_relu_1 : [#users=2] = call_module[target=layer1.0.relu](args = (%add,), kwargs = {})
# %layer1_1_conv1 : [#users=1] = call_module[target=layer1.1.conv1](args = (%layer1_0_relu_1,), kwargs = {})
# %layer1_1_bn1 : [#users=1] = call_module[target=layer1.1.bn1](args = (%layer1_1_conv1,), kwargs = {})
# %layer1_1_relu : [#users=1] = call_module[target=layer1.1.relu](args = (%layer1_1_bn1,), kwargs = {})
# %layer1_1_conv2 : [#users=1] = call_module[target=layer1.1.conv2](args = (%layer1_1_relu,), kwargs = {})
# %layer1_1_bn2 : [#users=1] = call_module[target=layer1.1.bn2](args = (%layer1_1_conv2,), kwargs = {})
# %add_1 : [#users=1] = call_function[target=operator.add](args = (%layer1_1_bn2, %layer1_0_relu_1), kwargs = {})
# ...
# %avgpool : [#users=1] = call_module[target=avgpool](args = (%layer4_2_relu_1,), kwargs = {})
# %flatten : [#users=1] = call_function[target=torch.flatten](args = (%avgpool, 1), kwargs = {})
# %fc : [#users=1] = call_module[target=fc](args = (%flatten,), kwargs = {})
# return fc
gm = GraphModule(model, graph, model.__class__.__name__)
gm.recompile()
graph_analyser = GraphAnalyser(gm)
liveness_list = graph_analyser.liveness_analysis()
solver_options = SolverOptions(fast=True)
strategies_constructor = StrategiesConstructor_V2(graph, device_mesh, solver_options)
strategies_constructor.build_strategies_and_cost()
cost_graph = CostGraph_V2(strategies_constructor.leaf_strategies)
cost_graph.simplify_graph()
solver = Solver_V2(gm.graph, strategies_constructor, cost_graph, graph_analyser)
ret = solver.call_solver_serialized_args()
print(ret[0])
print(solver.last_s_val)
strategies_list = solver.last_s_val
computation_cost = 0
communication_cost = 0
communication_cost_bn = 0
memory_cost = 0
for index, node in enumerate(graph.nodes):
if node.op == 'call_module':
submod = node.graph.owning_module.get_submodule(node.target)
if type(submod) in BATCHNORM_MODULE_OP:
communication_cost_bn += node.strategies_vector[strategies_list[index]].communication_cost.total
print(node.name, node.strategies_vector[strategies_list[index]].name)
computation_cost += node.strategies_vector[strategies_list[index]].compute_cost.total
communication_cost += node.strategies_vector[strategies_list[index]].communication_cost.total
node_memory_cost = node.strategies_vector[strategies_list[index]].memory_cost.total
if isinstance(node_memory_cost, tuple):
node_memory_cost = node_memory_cost[0]
memory_cost += node_memory_cost.activation + node_memory_cost.parameter
print(f'computation cost is {computation_cost}')
print(f'communication cost is {communication_cost}')
print(f'memory cost is {memory_cost}')
print(f'bn communication cost is {communication_cost_bn}')
if __name__ == '__main__':
test_cost_graph()