mirror of https://github.com/hpcaitech/ColossalAI
[autoparallel] remove no strategy nodes (#1652)
* [autoparallel] remove no strategy nodes * fix none object iteration issuepull/1669/head
parent
50f16a2850
commit
c27e701cb2
|
@ -1,6 +1,7 @@
|
||||||
from typing import List
|
from typing import List
|
||||||
import math
|
import math
|
||||||
from torch.fx.node import Node
|
from torch.fx.node import Node
|
||||||
|
from colossalai.auto_parallel.solver.constants import INFINITY_COST
|
||||||
|
|
||||||
|
|
||||||
class CostGraph:
|
class CostGraph:
|
||||||
|
@ -19,6 +20,7 @@ class CostGraph:
|
||||||
|
|
||||||
def __init__(self, leaf_strategies, simplify=True):
|
def __init__(self, leaf_strategies, simplify=True):
|
||||||
self.leaf_strategies = leaf_strategies
|
self.leaf_strategies = leaf_strategies
|
||||||
|
self.nodes = [strategies_vector.node for strategies_vector in self.leaf_strategies]
|
||||||
# stores number of strategies in each node
|
# stores number of strategies in each node
|
||||||
self.node_lens = {strategies_vector.node: len(strategies_vector) for strategies_vector in self.leaf_strategies}
|
self.node_lens = {strategies_vector.node: len(strategies_vector) for strategies_vector in self.leaf_strategies}
|
||||||
# extra_node_costs will store the extra costs introduced by merging nodes
|
# extra_node_costs will store the extra costs introduced by merging nodes
|
||||||
|
@ -27,6 +29,15 @@ class CostGraph:
|
||||||
self.simplify = simplify
|
self.simplify = simplify
|
||||||
self._build_cost_graph()
|
self._build_cost_graph()
|
||||||
|
|
||||||
|
def _remove_invalid_node(self, node, attr_name):
|
||||||
|
remove_list = []
|
||||||
|
target_node_list = getattr(node, attr_name, [])
|
||||||
|
for target_node in target_node_list:
|
||||||
|
if target_node not in self.nodes:
|
||||||
|
remove_list.append(target_node)
|
||||||
|
for element in remove_list:
|
||||||
|
target_node_list.remove(element)
|
||||||
|
|
||||||
def _build_cost_graph(self):
|
def _build_cost_graph(self):
|
||||||
'''
|
'''
|
||||||
This method will generate edge_cost for adjacent node pair. Additionally, 'parents' and 'children' attribute will be
|
This method will generate edge_cost for adjacent node pair. Additionally, 'parents' and 'children' attribute will be
|
||||||
|
@ -39,6 +50,8 @@ class CostGraph:
|
||||||
# build edge_cost
|
# build edge_cost
|
||||||
dst_node = strategies_vector.node
|
dst_node = strategies_vector.node
|
||||||
for src_node in strategies_vector.predecessor_nodes:
|
for src_node in strategies_vector.predecessor_nodes:
|
||||||
|
if src_node not in self.nodes:
|
||||||
|
continue
|
||||||
node_pair = (src_node, dst_node)
|
node_pair = (src_node, dst_node)
|
||||||
# src_index = strategies_vector.predecessor_nodes.index(src_node)
|
# src_index = strategies_vector.predecessor_nodes.index(src_node)
|
||||||
edge_cost = {}
|
edge_cost = {}
|
||||||
|
@ -49,6 +62,8 @@ class CostGraph:
|
||||||
# add parents and children attribute to node
|
# add parents and children attribute to node
|
||||||
setattr(dst_node, 'parents', strategies_vector.predecessor_nodes)
|
setattr(dst_node, 'parents', strategies_vector.predecessor_nodes)
|
||||||
setattr(dst_node, 'children', strategies_vector.successor_nodes)
|
setattr(dst_node, 'children', strategies_vector.successor_nodes)
|
||||||
|
self._remove_invalid_node(dst_node, 'parents')
|
||||||
|
self._remove_invalid_node(dst_node, 'children')
|
||||||
|
|
||||||
if self.simplify and strategies_vector.check_merge():
|
if self.simplify and strategies_vector.check_merge():
|
||||||
for followed_node in strategies_vector.predecessor_nodes:
|
for followed_node in strategies_vector.predecessor_nodes:
|
||||||
|
@ -83,11 +98,11 @@ class CostGraph:
|
||||||
# build merge_map
|
# build merge_map
|
||||||
merge_map = {}
|
merge_map = {}
|
||||||
for src_index, strategy in enumerate(src_node.strategies_vector):
|
for src_index, strategy in enumerate(src_node.strategies_vector):
|
||||||
min_cost = math.inf
|
min_cost = INFINITY_COST
|
||||||
lowest_cost_index = -1
|
lowest_cost_index = -1
|
||||||
for dst_index, dst_strategy in enumerate(dst_node.strategies_vector):
|
for dst_index, dst_strategy in enumerate(dst_node.strategies_vector):
|
||||||
resharding_cost = dst_strategy.resharding_costs[src_node][src_index]
|
resharding_cost = dst_strategy.resharding_costs[src_node][src_index]
|
||||||
if resharding_cost < min_cost:
|
if resharding_cost <= min_cost:
|
||||||
min_cost = resharding_cost
|
min_cost = resharding_cost
|
||||||
lowest_cost_index = dst_index
|
lowest_cost_index = dst_index
|
||||||
merge_map[src_index] = lowest_cost_index
|
merge_map[src_index] = lowest_cost_index
|
||||||
|
|
|
@ -182,6 +182,8 @@ class StrategiesVector(list):
|
||||||
# fetch its input and output nodes
|
# fetch its input and output nodes
|
||||||
# TODO: placeholder input nodes
|
# TODO: placeholder input nodes
|
||||||
self.predecessor_nodes = list(node._input_nodes.keys())
|
self.predecessor_nodes = list(node._input_nodes.keys())
|
||||||
|
if self.node.op == 'output':
|
||||||
|
self.predecessor_nodes = list(node._input_nodes.keys())[:1]
|
||||||
self.successor_nodes = list(node.users.keys())
|
self.successor_nodes = list(node.users.keys())
|
||||||
|
|
||||||
def check_merge(self):
|
def check_merge(self):
|
||||||
|
|
|
@ -45,8 +45,8 @@ class Solver:
|
||||||
self.strategies_constructor = strategies_constructor
|
self.strategies_constructor = strategies_constructor
|
||||||
self.cost_graph = cost_graph
|
self.cost_graph = cost_graph
|
||||||
self.graph_analyser = graph_analyser
|
self.graph_analyser = graph_analyser
|
||||||
self.nodes = list(self.graph.nodes)
|
|
||||||
self.leaf_strategies = self.strategies_constructor.leaf_strategies
|
self.leaf_strategies = self.strategies_constructor.leaf_strategies
|
||||||
|
self.nodes = [strategies_vector.node for strategies_vector in self.leaf_strategies]
|
||||||
self.strategy_map = self.strategies_constructor.strategy_map
|
self.strategy_map = self.strategies_constructor.strategy_map
|
||||||
self.memory_budget = memory_budget
|
self.memory_budget = memory_budget
|
||||||
self.solution_numbers = solution_numbers
|
self.solution_numbers = solution_numbers
|
||||||
|
@ -67,7 +67,7 @@ class Solver:
|
||||||
Therefore, the index of those strategies are copied from the previous node. This method is used to recover the strategy index of those merged
|
Therefore, the index of those strategies are copied from the previous node. This method is used to recover the strategy index of those merged
|
||||||
node.
|
node.
|
||||||
'''
|
'''
|
||||||
for node_index, node in enumerate(self.graph.nodes):
|
for node_index, node in enumerate(self.nodes):
|
||||||
if node.strategies_vector.check_merge():
|
if node.strategies_vector.check_merge():
|
||||||
# the merged node has only one input, and its strategies follow the input sharding strategy
|
# the merged node has only one input, and its strategies follow the input sharding strategy
|
||||||
input_strategies_vector = node.args[0].strategies_vector
|
input_strategies_vector = node.args[0].strategies_vector
|
||||||
|
@ -297,7 +297,8 @@ class Solver:
|
||||||
num_edges += 1
|
num_edges += 1
|
||||||
e.append(LpVariable.matrix(f"e[{i},{j}]", (range(len(s[i]) * len(s[j])),), cat="Binary"))
|
e.append(LpVariable.matrix(f"e[{i},{j}]", (range(len(s[i]) * len(s[j])),), cat="Binary"))
|
||||||
assert len(e[idx]) == len(r[idx])
|
assert len(e[idx]) == len(r[idx])
|
||||||
|
for element in s:
|
||||||
|
assert len(element) > 0
|
||||||
# 2. Set initial value
|
# 2. Set initial value
|
||||||
######################################
|
######################################
|
||||||
# set a initial value for warm start #
|
# set a initial value for warm start #
|
||||||
|
@ -317,12 +318,14 @@ class Solver:
|
||||||
###################################################################
|
###################################################################
|
||||||
obj = 0
|
obj = 0
|
||||||
for i in range(node_nums):
|
for i in range(node_nums):
|
||||||
|
assert len(s[i]) == len(c[i])
|
||||||
obj += lpDot(s[i], c[i]) + lpDot(s[i], d[i])
|
obj += lpDot(s[i], c[i]) + lpDot(s[i], d[i])
|
||||||
|
|
||||||
#############################################
|
#############################################
|
||||||
# computing the edge cost(resharding cost) #
|
# computing the edge cost(resharding cost) #
|
||||||
#############################################
|
#############################################
|
||||||
for i in range(len(E)):
|
for i in range(len(E)):
|
||||||
|
assert len(e[i]) == len(r[i])
|
||||||
obj += lpDot(e[i], r[i])
|
obj += lpDot(e[i], r[i])
|
||||||
|
|
||||||
prob += obj
|
prob += obj
|
||||||
|
|
|
@ -214,6 +214,21 @@ class StrategiesConstructor:
|
||||||
linear_handler = DotHandler(node, self.device_mesh, strategies_vector)
|
linear_handler = DotHandler(node, self.device_mesh, strategies_vector)
|
||||||
linear_handler.register_strategy()
|
linear_handler.register_strategy()
|
||||||
|
|
||||||
|
# where function
|
||||||
|
elif target == torch.where:
|
||||||
|
if input_nodes_len == 1:
|
||||||
|
# both of x and y are scalar
|
||||||
|
pass
|
||||||
|
|
||||||
|
elif input_nodes_len == 2:
|
||||||
|
# one of x or y is type of scalar
|
||||||
|
pass
|
||||||
|
|
||||||
|
else:
|
||||||
|
# general case
|
||||||
|
where_handler = WhereHandler(node, self.device_mesh, strategies_vector)
|
||||||
|
where_handler.register_strategy()
|
||||||
|
|
||||||
# reshape function
|
# reshape function
|
||||||
elif target in RESHAPE_FUNC_OP:
|
elif target in RESHAPE_FUNC_OP:
|
||||||
# use ReshapeHandler to create sharding strategies for rehsape node
|
# use ReshapeHandler to create sharding strategies for rehsape node
|
||||||
|
@ -222,7 +237,6 @@ class StrategiesConstructor:
|
||||||
|
|
||||||
# element-wise function
|
# element-wise function
|
||||||
elif target in ELEMENTWISE_FUNC_OP or (target in BCAST_FUNC_OP and input_nodes_len == 1):
|
elif target in ELEMENTWISE_FUNC_OP or (target in BCAST_FUNC_OP and input_nodes_len == 1):
|
||||||
if isinstance(node._meta_data, torch.Tensor):
|
|
||||||
unary_elementwise_handler = UnaryElementwiseHandler(node, self.device_mesh, strategies_vector)
|
unary_elementwise_handler = UnaryElementwiseHandler(node, self.device_mesh, strategies_vector)
|
||||||
unary_elementwise_handler.register_strategy()
|
unary_elementwise_handler.register_strategy()
|
||||||
|
|
||||||
|
@ -291,8 +305,10 @@ class StrategiesConstructor:
|
||||||
elif target == operator.getitem:
|
elif target == operator.getitem:
|
||||||
index = node.args[1]
|
index = node.args[1]
|
||||||
input_tensor_node = strategies_vector.predecessor_nodes[0]
|
input_tensor_node = strategies_vector.predecessor_nodes[0]
|
||||||
if isinstance(input_tensor_node, torch.Tensor):
|
|
||||||
for strategy in input_tensor_node.strategies_vector:
|
for strategy in input_tensor_node.strategies_vector:
|
||||||
|
if isinstance(strategy.output_sharding_spec, ShardingSpec):
|
||||||
|
input_sharding_spec = strategy.output_sharding_spec
|
||||||
|
else:
|
||||||
input_sharding_spec = strategy.output_sharding_spec[index]
|
input_sharding_spec = strategy.output_sharding_spec[index]
|
||||||
assert isinstance(input_sharding_spec, ShardingSpec), f'This assertion is used to debug.'
|
assert isinstance(input_sharding_spec, ShardingSpec), f'This assertion is used to debug.'
|
||||||
dim_partition_dict_for_output = deepcopy(input_sharding_spec.dim_partition_dict)
|
dim_partition_dict_for_output = deepcopy(input_sharding_spec.dim_partition_dict)
|
||||||
|
@ -304,18 +320,18 @@ class StrategiesConstructor:
|
||||||
compute_cost = 0
|
compute_cost = 0
|
||||||
memory_cost = 0
|
memory_cost = 0
|
||||||
resharding_costs = generate_resharding_costs(strategies_vector.predecessor_nodes,
|
resharding_costs = generate_resharding_costs(strategies_vector.predecessor_nodes,
|
||||||
[input_sharding_spec])
|
[input_sharding_spec],
|
||||||
|
index=index)
|
||||||
# to prevent the resharding happening, set their resharding cost to inf.
|
# to prevent the resharding happening, set their resharding cost to inf.
|
||||||
resharding_costs[input_tensor_node] = [
|
resharding_costs[input_tensor_node] = [
|
||||||
cost if cost == 0 else math.inf for cost in resharding_costs[input_tensor_node]
|
cost if cost == 0 else INFINITY_COST for cost in resharding_costs[input_tensor_node]
|
||||||
]
|
]
|
||||||
sharding_strategy = ShardingStrategy(
|
sharding_strategy = ShardingStrategy(name,
|
||||||
name,
|
|
||||||
output_sharding_spec,
|
output_sharding_spec,
|
||||||
compute_cost=compute_cost,
|
compute_cost=compute_cost,
|
||||||
memory_cost=memory_cost,
|
memory_cost=memory_cost,
|
||||||
resharding_costs=resharding_costs,
|
resharding_costs=resharding_costs,
|
||||||
input_shardings=[input_tensor_node.output_sharding_spec])
|
input_shardings=[strategy.output_sharding_spec])
|
||||||
strategies_vector.append(sharding_strategy)
|
strategies_vector.append(sharding_strategy)
|
||||||
|
|
||||||
# torch.arange function
|
# torch.arange function
|
||||||
|
@ -334,8 +350,7 @@ class StrategiesConstructor:
|
||||||
strategies_vector.append(sharding_strategy)
|
strategies_vector.append(sharding_strategy)
|
||||||
|
|
||||||
# op list to be processed to support gpt2
|
# op list to be processed to support gpt2
|
||||||
elif target in (builtins.getattr, operator.le, torch.addmm, operator.pow, torch.where, torch.softmax,
|
elif target in (builtins.getattr, operator.le, torch.addmm):
|
||||||
torch.nn.functional.softmax, torch.pow, torch.tanh):
|
|
||||||
pass
|
pass
|
||||||
# other function
|
# other function
|
||||||
else:
|
else:
|
||||||
|
@ -344,7 +359,7 @@ class StrategiesConstructor:
|
||||||
# call_method node
|
# call_method node
|
||||||
if node.op == 'call_method':
|
if node.op == 'call_method':
|
||||||
method = getattr(node.args[0]._meta_data.__class__, node.target)
|
method = getattr(node.args[0]._meta_data.__class__, node.target)
|
||||||
if method in (torch.Tensor.size, torch.Tensor.contiguous):
|
if method in (torch.Tensor.size,):
|
||||||
pass
|
pass
|
||||||
elif method in ELEMENTWISE_METHOD_OP:
|
elif method in ELEMENTWISE_METHOD_OP:
|
||||||
unary_elementwise_handler = UnaryElementwiseHandler(node, self.device_mesh, strategies_vector)
|
unary_elementwise_handler = UnaryElementwiseHandler(node, self.device_mesh, strategies_vector)
|
||||||
|
@ -400,6 +415,18 @@ class StrategiesConstructor:
|
||||||
self.strategy_map[node] = strategies_vector
|
self.strategy_map[node] = strategies_vector
|
||||||
|
|
||||||
|
|
||||||
|
# remove no strategy nodes
|
||||||
|
remove_list = []
|
||||||
|
for strategies_vector in self.leaf_strategies:
|
||||||
|
if len(strategies_vector) == 0:
|
||||||
|
remove_list.append(strategies_vector.node)
|
||||||
|
for node in remove_list:
|
||||||
|
if node.strategies_vector in self.leaf_strategies:
|
||||||
|
self.leaf_strategies.remove(node.strategies_vector)
|
||||||
|
if node in self.strategy_map:
|
||||||
|
self.strategy_map.pop(node)
|
||||||
|
|
||||||
|
|
||||||
class StrategiesConstructor_V2:
|
class StrategiesConstructor_V2:
|
||||||
"""
|
"""
|
||||||
StrategiesConstructor is used to construct the parallelization plan for the model execution.
|
StrategiesConstructor is used to construct the parallelization plan for the model execution.
|
||||||
|
@ -482,3 +509,4 @@ class StrategiesConstructor_V2:
|
||||||
setattr(node, 'strategies_vector', strategies_vector)
|
setattr(node, 'strategies_vector', strategies_vector)
|
||||||
self.leaf_strategies.append(strategies_vector)
|
self.leaf_strategies.append(strategies_vector)
|
||||||
self.strategy_map[node] = strategies_vector
|
self.strategy_map[node] = strategies_vector
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue