[autoparallel] remove no strategy nodes (#1652)

* [autoparallel] remove no strategy nodes

* fix none object iteration issue
pull/1669/head
YuliangLiu0306 2022-09-29 10:43:25 +08:00 committed by GitHub
parent 50f16a2850
commit c27e701cb2
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 84 additions and 36 deletions

View File

@ -1,6 +1,7 @@
from typing import List from typing import List
import math import math
from torch.fx.node import Node from torch.fx.node import Node
from colossalai.auto_parallel.solver.constants import INFINITY_COST
class CostGraph: class CostGraph:
@ -19,6 +20,7 @@ class CostGraph:
def __init__(self, leaf_strategies, simplify=True): def __init__(self, leaf_strategies, simplify=True):
self.leaf_strategies = leaf_strategies self.leaf_strategies = leaf_strategies
self.nodes = [strategies_vector.node for strategies_vector in self.leaf_strategies]
# stores number of strategies in each node # stores number of strategies in each node
self.node_lens = {strategies_vector.node: len(strategies_vector) for strategies_vector in self.leaf_strategies} self.node_lens = {strategies_vector.node: len(strategies_vector) for strategies_vector in self.leaf_strategies}
# extra_node_costs will store the extra costs introduced by merging nodes # extra_node_costs will store the extra costs introduced by merging nodes
@ -27,6 +29,15 @@ class CostGraph:
self.simplify = simplify self.simplify = simplify
self._build_cost_graph() self._build_cost_graph()
def _remove_invalid_node(self, node, attr_name):
remove_list = []
target_node_list = getattr(node, attr_name, [])
for target_node in target_node_list:
if target_node not in self.nodes:
remove_list.append(target_node)
for element in remove_list:
target_node_list.remove(element)
def _build_cost_graph(self): def _build_cost_graph(self):
''' '''
This method will generate edge_cost for adjacent node pair. Additionally, 'parents' and 'children' attribute will be This method will generate edge_cost for adjacent node pair. Additionally, 'parents' and 'children' attribute will be
@ -39,6 +50,8 @@ class CostGraph:
# build edge_cost # build edge_cost
dst_node = strategies_vector.node dst_node = strategies_vector.node
for src_node in strategies_vector.predecessor_nodes: for src_node in strategies_vector.predecessor_nodes:
if src_node not in self.nodes:
continue
node_pair = (src_node, dst_node) node_pair = (src_node, dst_node)
# src_index = strategies_vector.predecessor_nodes.index(src_node) # src_index = strategies_vector.predecessor_nodes.index(src_node)
edge_cost = {} edge_cost = {}
@ -49,6 +62,8 @@ class CostGraph:
# add parents and children attribute to node # add parents and children attribute to node
setattr(dst_node, 'parents', strategies_vector.predecessor_nodes) setattr(dst_node, 'parents', strategies_vector.predecessor_nodes)
setattr(dst_node, 'children', strategies_vector.successor_nodes) setattr(dst_node, 'children', strategies_vector.successor_nodes)
self._remove_invalid_node(dst_node, 'parents')
self._remove_invalid_node(dst_node, 'children')
if self.simplify and strategies_vector.check_merge(): if self.simplify and strategies_vector.check_merge():
for followed_node in strategies_vector.predecessor_nodes: for followed_node in strategies_vector.predecessor_nodes:
@ -83,11 +98,11 @@ class CostGraph:
# build merge_map # build merge_map
merge_map = {} merge_map = {}
for src_index, strategy in enumerate(src_node.strategies_vector): for src_index, strategy in enumerate(src_node.strategies_vector):
min_cost = math.inf min_cost = INFINITY_COST
lowest_cost_index = -1 lowest_cost_index = -1
for dst_index, dst_strategy in enumerate(dst_node.strategies_vector): for dst_index, dst_strategy in enumerate(dst_node.strategies_vector):
resharding_cost = dst_strategy.resharding_costs[src_node][src_index] resharding_cost = dst_strategy.resharding_costs[src_node][src_index]
if resharding_cost < min_cost: if resharding_cost <= min_cost:
min_cost = resharding_cost min_cost = resharding_cost
lowest_cost_index = dst_index lowest_cost_index = dst_index
merge_map[src_index] = lowest_cost_index merge_map[src_index] = lowest_cost_index

View File

@ -182,6 +182,8 @@ class StrategiesVector(list):
# fetch its input and output nodes # fetch its input and output nodes
# TODO: placeholder input nodes # TODO: placeholder input nodes
self.predecessor_nodes = list(node._input_nodes.keys()) self.predecessor_nodes = list(node._input_nodes.keys())
if self.node.op == 'output':
self.predecessor_nodes = list(node._input_nodes.keys())[:1]
self.successor_nodes = list(node.users.keys()) self.successor_nodes = list(node.users.keys())
def check_merge(self): def check_merge(self):

View File

@ -45,8 +45,8 @@ class Solver:
self.strategies_constructor = strategies_constructor self.strategies_constructor = strategies_constructor
self.cost_graph = cost_graph self.cost_graph = cost_graph
self.graph_analyser = graph_analyser self.graph_analyser = graph_analyser
self.nodes = list(self.graph.nodes)
self.leaf_strategies = self.strategies_constructor.leaf_strategies self.leaf_strategies = self.strategies_constructor.leaf_strategies
self.nodes = [strategies_vector.node for strategies_vector in self.leaf_strategies]
self.strategy_map = self.strategies_constructor.strategy_map self.strategy_map = self.strategies_constructor.strategy_map
self.memory_budget = memory_budget self.memory_budget = memory_budget
self.solution_numbers = solution_numbers self.solution_numbers = solution_numbers
@ -67,7 +67,7 @@ class Solver:
Therefore, the index of those strategies are copied from the previous node. This method is used to recover the strategy index of those merged Therefore, the index of those strategies are copied from the previous node. This method is used to recover the strategy index of those merged
node. node.
''' '''
for node_index, node in enumerate(self.graph.nodes): for node_index, node in enumerate(self.nodes):
if node.strategies_vector.check_merge(): if node.strategies_vector.check_merge():
# the merged node has only one input, and its strategies follow the input sharding strategy # the merged node has only one input, and its strategies follow the input sharding strategy
input_strategies_vector = node.args[0].strategies_vector input_strategies_vector = node.args[0].strategies_vector
@ -297,7 +297,8 @@ class Solver:
num_edges += 1 num_edges += 1
e.append(LpVariable.matrix(f"e[{i},{j}]", (range(len(s[i]) * len(s[j])),), cat="Binary")) e.append(LpVariable.matrix(f"e[{i},{j}]", (range(len(s[i]) * len(s[j])),), cat="Binary"))
assert len(e[idx]) == len(r[idx]) assert len(e[idx]) == len(r[idx])
for element in s:
assert len(element) > 0
# 2. Set initial value # 2. Set initial value
###################################### ######################################
# set a initial value for warm start # # set a initial value for warm start #
@ -317,12 +318,14 @@ class Solver:
################################################################### ###################################################################
obj = 0 obj = 0
for i in range(node_nums): for i in range(node_nums):
assert len(s[i]) == len(c[i])
obj += lpDot(s[i], c[i]) + lpDot(s[i], d[i]) obj += lpDot(s[i], c[i]) + lpDot(s[i], d[i])
############################################# #############################################
# computing the edge cost(resharding cost) # # computing the edge cost(resharding cost) #
############################################# #############################################
for i in range(len(E)): for i in range(len(E)):
assert len(e[i]) == len(r[i])
obj += lpDot(e[i], r[i]) obj += lpDot(e[i], r[i])
prob += obj prob += obj

View File

@ -214,6 +214,21 @@ class StrategiesConstructor:
linear_handler = DotHandler(node, self.device_mesh, strategies_vector) linear_handler = DotHandler(node, self.device_mesh, strategies_vector)
linear_handler.register_strategy() linear_handler.register_strategy()
# where function
elif target == torch.where:
if input_nodes_len == 1:
# both of x and y are scalar
pass
elif input_nodes_len == 2:
# one of x or y is type of scalar
pass
else:
# general case
where_handler = WhereHandler(node, self.device_mesh, strategies_vector)
where_handler.register_strategy()
# reshape function # reshape function
elif target in RESHAPE_FUNC_OP: elif target in RESHAPE_FUNC_OP:
# use ReshapeHandler to create sharding strategies for rehsape node # use ReshapeHandler to create sharding strategies for rehsape node
@ -222,7 +237,6 @@ class StrategiesConstructor:
# element-wise function # element-wise function
elif target in ELEMENTWISE_FUNC_OP or (target in BCAST_FUNC_OP and input_nodes_len == 1): elif target in ELEMENTWISE_FUNC_OP or (target in BCAST_FUNC_OP and input_nodes_len == 1):
if isinstance(node._meta_data, torch.Tensor):
unary_elementwise_handler = UnaryElementwiseHandler(node, self.device_mesh, strategies_vector) unary_elementwise_handler = UnaryElementwiseHandler(node, self.device_mesh, strategies_vector)
unary_elementwise_handler.register_strategy() unary_elementwise_handler.register_strategy()
@ -291,8 +305,10 @@ class StrategiesConstructor:
elif target == operator.getitem: elif target == operator.getitem:
index = node.args[1] index = node.args[1]
input_tensor_node = strategies_vector.predecessor_nodes[0] input_tensor_node = strategies_vector.predecessor_nodes[0]
if isinstance(input_tensor_node, torch.Tensor):
for strategy in input_tensor_node.strategies_vector: for strategy in input_tensor_node.strategies_vector:
if isinstance(strategy.output_sharding_spec, ShardingSpec):
input_sharding_spec = strategy.output_sharding_spec
else:
input_sharding_spec = strategy.output_sharding_spec[index] input_sharding_spec = strategy.output_sharding_spec[index]
assert isinstance(input_sharding_spec, ShardingSpec), f'This assertion is used to debug.' assert isinstance(input_sharding_spec, ShardingSpec), f'This assertion is used to debug.'
dim_partition_dict_for_output = deepcopy(input_sharding_spec.dim_partition_dict) dim_partition_dict_for_output = deepcopy(input_sharding_spec.dim_partition_dict)
@ -304,18 +320,18 @@ class StrategiesConstructor:
compute_cost = 0 compute_cost = 0
memory_cost = 0 memory_cost = 0
resharding_costs = generate_resharding_costs(strategies_vector.predecessor_nodes, resharding_costs = generate_resharding_costs(strategies_vector.predecessor_nodes,
[input_sharding_spec]) [input_sharding_spec],
index=index)
# to prevent the resharding happening, set their resharding cost to inf. # to prevent the resharding happening, set their resharding cost to inf.
resharding_costs[input_tensor_node] = [ resharding_costs[input_tensor_node] = [
cost if cost == 0 else math.inf for cost in resharding_costs[input_tensor_node] cost if cost == 0 else INFINITY_COST for cost in resharding_costs[input_tensor_node]
] ]
sharding_strategy = ShardingStrategy( sharding_strategy = ShardingStrategy(name,
name,
output_sharding_spec, output_sharding_spec,
compute_cost=compute_cost, compute_cost=compute_cost,
memory_cost=memory_cost, memory_cost=memory_cost,
resharding_costs=resharding_costs, resharding_costs=resharding_costs,
input_shardings=[input_tensor_node.output_sharding_spec]) input_shardings=[strategy.output_sharding_spec])
strategies_vector.append(sharding_strategy) strategies_vector.append(sharding_strategy)
# torch.arange function # torch.arange function
@ -334,8 +350,7 @@ class StrategiesConstructor:
strategies_vector.append(sharding_strategy) strategies_vector.append(sharding_strategy)
# op list to be processed to support gpt2 # op list to be processed to support gpt2
elif target in (builtins.getattr, operator.le, torch.addmm, operator.pow, torch.where, torch.softmax, elif target in (builtins.getattr, operator.le, torch.addmm):
torch.nn.functional.softmax, torch.pow, torch.tanh):
pass pass
# other function # other function
else: else:
@ -344,7 +359,7 @@ class StrategiesConstructor:
# call_method node # call_method node
if node.op == 'call_method': if node.op == 'call_method':
method = getattr(node.args[0]._meta_data.__class__, node.target) method = getattr(node.args[0]._meta_data.__class__, node.target)
if method in (torch.Tensor.size, torch.Tensor.contiguous): if method in (torch.Tensor.size,):
pass pass
elif method in ELEMENTWISE_METHOD_OP: elif method in ELEMENTWISE_METHOD_OP:
unary_elementwise_handler = UnaryElementwiseHandler(node, self.device_mesh, strategies_vector) unary_elementwise_handler = UnaryElementwiseHandler(node, self.device_mesh, strategies_vector)
@ -400,6 +415,18 @@ class StrategiesConstructor:
self.strategy_map[node] = strategies_vector self.strategy_map[node] = strategies_vector
# remove no strategy nodes
remove_list = []
for strategies_vector in self.leaf_strategies:
if len(strategies_vector) == 0:
remove_list.append(strategies_vector.node)
for node in remove_list:
if node.strategies_vector in self.leaf_strategies:
self.leaf_strategies.remove(node.strategies_vector)
if node in self.strategy_map:
self.strategy_map.pop(node)
class StrategiesConstructor_V2: class StrategiesConstructor_V2:
""" """
StrategiesConstructor is used to construct the parallelization plan for the model execution. StrategiesConstructor is used to construct the parallelization plan for the model execution.
@ -482,3 +509,4 @@ class StrategiesConstructor_V2:
setattr(node, 'strategies_vector', strategies_vector) setattr(node, 'strategies_vector', strategies_vector)
self.leaf_strategies.append(strategies_vector) self.leaf_strategies.append(strategies_vector)
self.strategy_map[node] = strategies_vector self.strategy_map[node] = strategies_vector