[autoparallel] remove no strategy nodes (#1652)

* [autoparallel] remove no strategy nodes

* fix none object iteration issue
pull/1669/head
YuliangLiu0306 2022-09-29 10:43:25 +08:00 committed by GitHub
parent 50f16a2850
commit c27e701cb2
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 84 additions and 36 deletions

View File

@ -1,6 +1,7 @@
from typing import List
import math
from torch.fx.node import Node
from colossalai.auto_parallel.solver.constants import INFINITY_COST
class CostGraph:
@ -19,6 +20,7 @@ class CostGraph:
def __init__(self, leaf_strategies, simplify=True):
self.leaf_strategies = leaf_strategies
self.nodes = [strategies_vector.node for strategies_vector in self.leaf_strategies]
# stores number of strategies in each node
self.node_lens = {strategies_vector.node: len(strategies_vector) for strategies_vector in self.leaf_strategies}
# extra_node_costs will store the extra costs introduced by merging nodes
@ -27,6 +29,15 @@ class CostGraph:
self.simplify = simplify
self._build_cost_graph()
def _remove_invalid_node(self, node, attr_name):
remove_list = []
target_node_list = getattr(node, attr_name, [])
for target_node in target_node_list:
if target_node not in self.nodes:
remove_list.append(target_node)
for element in remove_list:
target_node_list.remove(element)
def _build_cost_graph(self):
'''
This method will generate edge_cost for adjacent node pair. Additionally, 'parents' and 'children' attribute will be
@ -39,6 +50,8 @@ class CostGraph:
# build edge_cost
dst_node = strategies_vector.node
for src_node in strategies_vector.predecessor_nodes:
if src_node not in self.nodes:
continue
node_pair = (src_node, dst_node)
# src_index = strategies_vector.predecessor_nodes.index(src_node)
edge_cost = {}
@ -49,6 +62,8 @@ class CostGraph:
# add parents and children attribute to node
setattr(dst_node, 'parents', strategies_vector.predecessor_nodes)
setattr(dst_node, 'children', strategies_vector.successor_nodes)
self._remove_invalid_node(dst_node, 'parents')
self._remove_invalid_node(dst_node, 'children')
if self.simplify and strategies_vector.check_merge():
for followed_node in strategies_vector.predecessor_nodes:
@ -83,11 +98,11 @@ class CostGraph:
# build merge_map
merge_map = {}
for src_index, strategy in enumerate(src_node.strategies_vector):
min_cost = math.inf
min_cost = INFINITY_COST
lowest_cost_index = -1
for dst_index, dst_strategy in enumerate(dst_node.strategies_vector):
resharding_cost = dst_strategy.resharding_costs[src_node][src_index]
if resharding_cost < min_cost:
if resharding_cost <= min_cost:
min_cost = resharding_cost
lowest_cost_index = dst_index
merge_map[src_index] = lowest_cost_index

View File

@ -182,6 +182,8 @@ class StrategiesVector(list):
# fetch its input and output nodes
# TODO: placeholder input nodes
self.predecessor_nodes = list(node._input_nodes.keys())
if self.node.op == 'output':
self.predecessor_nodes = list(node._input_nodes.keys())[:1]
self.successor_nodes = list(node.users.keys())
def check_merge(self):

View File

@ -45,8 +45,8 @@ class Solver:
self.strategies_constructor = strategies_constructor
self.cost_graph = cost_graph
self.graph_analyser = graph_analyser
self.nodes = list(self.graph.nodes)
self.leaf_strategies = self.strategies_constructor.leaf_strategies
self.nodes = [strategies_vector.node for strategies_vector in self.leaf_strategies]
self.strategy_map = self.strategies_constructor.strategy_map
self.memory_budget = memory_budget
self.solution_numbers = solution_numbers
@ -67,7 +67,7 @@ class Solver:
Therefore, the index of those strategies are copied from the previous node. This method is used to recover the strategy index of those merged
node.
'''
for node_index, node in enumerate(self.graph.nodes):
for node_index, node in enumerate(self.nodes):
if node.strategies_vector.check_merge():
# the merged node has only one input, and its strategies follow the input sharding strategy
input_strategies_vector = node.args[0].strategies_vector
@ -297,7 +297,8 @@ class Solver:
num_edges += 1
e.append(LpVariable.matrix(f"e[{i},{j}]", (range(len(s[i]) * len(s[j])),), cat="Binary"))
assert len(e[idx]) == len(r[idx])
for element in s:
assert len(element) > 0
# 2. Set initial value
######################################
# set a initial value for warm start #
@ -317,12 +318,14 @@ class Solver:
###################################################################
obj = 0
for i in range(node_nums):
assert len(s[i]) == len(c[i])
obj += lpDot(s[i], c[i]) + lpDot(s[i], d[i])
#############################################
# computing the edge cost(resharding cost) #
#############################################
for i in range(len(E)):
assert len(e[i]) == len(r[i])
obj += lpDot(e[i], r[i])
prob += obj

View File

@ -214,6 +214,21 @@ class StrategiesConstructor:
linear_handler = DotHandler(node, self.device_mesh, strategies_vector)
linear_handler.register_strategy()
# where function
elif target == torch.where:
if input_nodes_len == 1:
# both of x and y are scalar
pass
elif input_nodes_len == 2:
# one of x or y is type of scalar
pass
else:
# general case
where_handler = WhereHandler(node, self.device_mesh, strategies_vector)
where_handler.register_strategy()
# reshape function
elif target in RESHAPE_FUNC_OP:
# use ReshapeHandler to create sharding strategies for rehsape node
@ -222,9 +237,8 @@ class StrategiesConstructor:
# element-wise function
elif target in ELEMENTWISE_FUNC_OP or (target in BCAST_FUNC_OP and input_nodes_len == 1):
if isinstance(node._meta_data, torch.Tensor):
unary_elementwise_handler = UnaryElementwiseHandler(node, self.device_mesh, strategies_vector)
unary_elementwise_handler.register_strategy()
unary_elementwise_handler = UnaryElementwiseHandler(node, self.device_mesh, strategies_vector)
unary_elementwise_handler.register_strategy()
# bcast op
elif target in BCAST_FUNC_OP:
@ -291,32 +305,34 @@ class StrategiesConstructor:
elif target == operator.getitem:
index = node.args[1]
input_tensor_node = strategies_vector.predecessor_nodes[0]
if isinstance(input_tensor_node, torch.Tensor):
for strategy in input_tensor_node.strategies_vector:
for strategy in input_tensor_node.strategies_vector:
if isinstance(strategy.output_sharding_spec, ShardingSpec):
input_sharding_spec = strategy.output_sharding_spec
else:
input_sharding_spec = strategy.output_sharding_spec[index]
assert isinstance(input_sharding_spec, ShardingSpec), f'This assertion is used to debug.'
dim_partition_dict_for_output = deepcopy(input_sharding_spec.dim_partition_dict)
entire_shape_output = deepcopy(input_sharding_spec.entire_shape)
output_sharding_spec = ShardingSpec(self.device_mesh,
entire_shape_output,
dim_partition_dict=dim_partition_dict_for_output)
# TODO: use meta_info_prop to profile origin memory cost and compute cost, then divide them depending on sharding spec.
compute_cost = 0
memory_cost = 0
resharding_costs = generate_resharding_costs(strategies_vector.predecessor_nodes,
[input_sharding_spec])
# to prevent the resharding happening, set their resharding cost to inf.
resharding_costs[input_tensor_node] = [
cost if cost == 0 else math.inf for cost in resharding_costs[input_tensor_node]
]
sharding_strategy = ShardingStrategy(
name,
output_sharding_spec,
compute_cost=compute_cost,
memory_cost=memory_cost,
resharding_costs=resharding_costs,
input_shardings=[input_tensor_node.output_sharding_spec])
strategies_vector.append(sharding_strategy)
assert isinstance(input_sharding_spec, ShardingSpec), f'This assertion is used to debug.'
dim_partition_dict_for_output = deepcopy(input_sharding_spec.dim_partition_dict)
entire_shape_output = deepcopy(input_sharding_spec.entire_shape)
output_sharding_spec = ShardingSpec(self.device_mesh,
entire_shape_output,
dim_partition_dict=dim_partition_dict_for_output)
# TODO: use meta_info_prop to profile origin memory cost and compute cost, then divide them depending on sharding spec.
compute_cost = 0
memory_cost = 0
resharding_costs = generate_resharding_costs(strategies_vector.predecessor_nodes,
[input_sharding_spec],
index=index)
# to prevent the resharding happening, set their resharding cost to inf.
resharding_costs[input_tensor_node] = [
cost if cost == 0 else INFINITY_COST for cost in resharding_costs[input_tensor_node]
]
sharding_strategy = ShardingStrategy(name,
output_sharding_spec,
compute_cost=compute_cost,
memory_cost=memory_cost,
resharding_costs=resharding_costs,
input_shardings=[strategy.output_sharding_spec])
strategies_vector.append(sharding_strategy)
# torch.arange function
elif target == torch.arange:
@ -334,8 +350,7 @@ class StrategiesConstructor:
strategies_vector.append(sharding_strategy)
# op list to be processed to support gpt2
elif target in (builtins.getattr, operator.le, torch.addmm, operator.pow, torch.where, torch.softmax,
torch.nn.functional.softmax, torch.pow, torch.tanh):
elif target in (builtins.getattr, operator.le, torch.addmm):
pass
# other function
else:
@ -344,7 +359,7 @@ class StrategiesConstructor:
# call_method node
if node.op == 'call_method':
method = getattr(node.args[0]._meta_data.__class__, node.target)
if method in (torch.Tensor.size, torch.Tensor.contiguous):
if method in (torch.Tensor.size,):
pass
elif method in ELEMENTWISE_METHOD_OP:
unary_elementwise_handler = UnaryElementwiseHandler(node, self.device_mesh, strategies_vector)
@ -400,6 +415,18 @@ class StrategiesConstructor:
self.strategy_map[node] = strategies_vector
# remove no strategy nodes
remove_list = []
for strategies_vector in self.leaf_strategies:
if len(strategies_vector) == 0:
remove_list.append(strategies_vector.node)
for node in remove_list:
if node.strategies_vector in self.leaf_strategies:
self.leaf_strategies.remove(node.strategies_vector)
if node in self.strategy_map:
self.strategy_map.pop(node)
class StrategiesConstructor_V2:
"""
StrategiesConstructor is used to construct the parallelization plan for the model execution.
@ -482,3 +509,4 @@ class StrategiesConstructor_V2:
setattr(node, 'strategies_vector', strategies_vector)
self.leaf_strategies.append(strategies_vector)
self.strategy_map[node] = strategies_vector