mirror of https://github.com/hpcaitech/ColossalAI
[autoparallel] remove no strategy nodes (#1652)
* [autoparallel] remove no strategy nodes * fix none object iteration issuepull/1669/head
parent
50f16a2850
commit
c27e701cb2
|
@ -1,6 +1,7 @@
|
|||
from typing import List
|
||||
import math
|
||||
from torch.fx.node import Node
|
||||
from colossalai.auto_parallel.solver.constants import INFINITY_COST
|
||||
|
||||
|
||||
class CostGraph:
|
||||
|
@ -19,6 +20,7 @@ class CostGraph:
|
|||
|
||||
def __init__(self, leaf_strategies, simplify=True):
|
||||
self.leaf_strategies = leaf_strategies
|
||||
self.nodes = [strategies_vector.node for strategies_vector in self.leaf_strategies]
|
||||
# stores number of strategies in each node
|
||||
self.node_lens = {strategies_vector.node: len(strategies_vector) for strategies_vector in self.leaf_strategies}
|
||||
# extra_node_costs will store the extra costs introduced by merging nodes
|
||||
|
@ -27,6 +29,15 @@ class CostGraph:
|
|||
self.simplify = simplify
|
||||
self._build_cost_graph()
|
||||
|
||||
def _remove_invalid_node(self, node, attr_name):
|
||||
remove_list = []
|
||||
target_node_list = getattr(node, attr_name, [])
|
||||
for target_node in target_node_list:
|
||||
if target_node not in self.nodes:
|
||||
remove_list.append(target_node)
|
||||
for element in remove_list:
|
||||
target_node_list.remove(element)
|
||||
|
||||
def _build_cost_graph(self):
|
||||
'''
|
||||
This method will generate edge_cost for adjacent node pair. Additionally, 'parents' and 'children' attribute will be
|
||||
|
@ -39,6 +50,8 @@ class CostGraph:
|
|||
# build edge_cost
|
||||
dst_node = strategies_vector.node
|
||||
for src_node in strategies_vector.predecessor_nodes:
|
||||
if src_node not in self.nodes:
|
||||
continue
|
||||
node_pair = (src_node, dst_node)
|
||||
# src_index = strategies_vector.predecessor_nodes.index(src_node)
|
||||
edge_cost = {}
|
||||
|
@ -49,6 +62,8 @@ class CostGraph:
|
|||
# add parents and children attribute to node
|
||||
setattr(dst_node, 'parents', strategies_vector.predecessor_nodes)
|
||||
setattr(dst_node, 'children', strategies_vector.successor_nodes)
|
||||
self._remove_invalid_node(dst_node, 'parents')
|
||||
self._remove_invalid_node(dst_node, 'children')
|
||||
|
||||
if self.simplify and strategies_vector.check_merge():
|
||||
for followed_node in strategies_vector.predecessor_nodes:
|
||||
|
@ -83,11 +98,11 @@ class CostGraph:
|
|||
# build merge_map
|
||||
merge_map = {}
|
||||
for src_index, strategy in enumerate(src_node.strategies_vector):
|
||||
min_cost = math.inf
|
||||
min_cost = INFINITY_COST
|
||||
lowest_cost_index = -1
|
||||
for dst_index, dst_strategy in enumerate(dst_node.strategies_vector):
|
||||
resharding_cost = dst_strategy.resharding_costs[src_node][src_index]
|
||||
if resharding_cost < min_cost:
|
||||
if resharding_cost <= min_cost:
|
||||
min_cost = resharding_cost
|
||||
lowest_cost_index = dst_index
|
||||
merge_map[src_index] = lowest_cost_index
|
||||
|
|
|
@ -182,6 +182,8 @@ class StrategiesVector(list):
|
|||
# fetch its input and output nodes
|
||||
# TODO: placeholder input nodes
|
||||
self.predecessor_nodes = list(node._input_nodes.keys())
|
||||
if self.node.op == 'output':
|
||||
self.predecessor_nodes = list(node._input_nodes.keys())[:1]
|
||||
self.successor_nodes = list(node.users.keys())
|
||||
|
||||
def check_merge(self):
|
||||
|
|
|
@ -45,8 +45,8 @@ class Solver:
|
|||
self.strategies_constructor = strategies_constructor
|
||||
self.cost_graph = cost_graph
|
||||
self.graph_analyser = graph_analyser
|
||||
self.nodes = list(self.graph.nodes)
|
||||
self.leaf_strategies = self.strategies_constructor.leaf_strategies
|
||||
self.nodes = [strategies_vector.node for strategies_vector in self.leaf_strategies]
|
||||
self.strategy_map = self.strategies_constructor.strategy_map
|
||||
self.memory_budget = memory_budget
|
||||
self.solution_numbers = solution_numbers
|
||||
|
@ -67,7 +67,7 @@ class Solver:
|
|||
Therefore, the index of those strategies are copied from the previous node. This method is used to recover the strategy index of those merged
|
||||
node.
|
||||
'''
|
||||
for node_index, node in enumerate(self.graph.nodes):
|
||||
for node_index, node in enumerate(self.nodes):
|
||||
if node.strategies_vector.check_merge():
|
||||
# the merged node has only one input, and its strategies follow the input sharding strategy
|
||||
input_strategies_vector = node.args[0].strategies_vector
|
||||
|
@ -297,7 +297,8 @@ class Solver:
|
|||
num_edges += 1
|
||||
e.append(LpVariable.matrix(f"e[{i},{j}]", (range(len(s[i]) * len(s[j])),), cat="Binary"))
|
||||
assert len(e[idx]) == len(r[idx])
|
||||
|
||||
for element in s:
|
||||
assert len(element) > 0
|
||||
# 2. Set initial value
|
||||
######################################
|
||||
# set a initial value for warm start #
|
||||
|
@ -317,12 +318,14 @@ class Solver:
|
|||
###################################################################
|
||||
obj = 0
|
||||
for i in range(node_nums):
|
||||
assert len(s[i]) == len(c[i])
|
||||
obj += lpDot(s[i], c[i]) + lpDot(s[i], d[i])
|
||||
|
||||
#############################################
|
||||
# computing the edge cost(resharding cost) #
|
||||
#############################################
|
||||
for i in range(len(E)):
|
||||
assert len(e[i]) == len(r[i])
|
||||
obj += lpDot(e[i], r[i])
|
||||
|
||||
prob += obj
|
||||
|
|
|
@ -214,6 +214,21 @@ class StrategiesConstructor:
|
|||
linear_handler = DotHandler(node, self.device_mesh, strategies_vector)
|
||||
linear_handler.register_strategy()
|
||||
|
||||
# where function
|
||||
elif target == torch.where:
|
||||
if input_nodes_len == 1:
|
||||
# both of x and y are scalar
|
||||
pass
|
||||
|
||||
elif input_nodes_len == 2:
|
||||
# one of x or y is type of scalar
|
||||
pass
|
||||
|
||||
else:
|
||||
# general case
|
||||
where_handler = WhereHandler(node, self.device_mesh, strategies_vector)
|
||||
where_handler.register_strategy()
|
||||
|
||||
# reshape function
|
||||
elif target in RESHAPE_FUNC_OP:
|
||||
# use ReshapeHandler to create sharding strategies for rehsape node
|
||||
|
@ -222,9 +237,8 @@ class StrategiesConstructor:
|
|||
|
||||
# element-wise function
|
||||
elif target in ELEMENTWISE_FUNC_OP or (target in BCAST_FUNC_OP and input_nodes_len == 1):
|
||||
if isinstance(node._meta_data, torch.Tensor):
|
||||
unary_elementwise_handler = UnaryElementwiseHandler(node, self.device_mesh, strategies_vector)
|
||||
unary_elementwise_handler.register_strategy()
|
||||
unary_elementwise_handler = UnaryElementwiseHandler(node, self.device_mesh, strategies_vector)
|
||||
unary_elementwise_handler.register_strategy()
|
||||
|
||||
# bcast op
|
||||
elif target in BCAST_FUNC_OP:
|
||||
|
@ -291,32 +305,34 @@ class StrategiesConstructor:
|
|||
elif target == operator.getitem:
|
||||
index = node.args[1]
|
||||
input_tensor_node = strategies_vector.predecessor_nodes[0]
|
||||
if isinstance(input_tensor_node, torch.Tensor):
|
||||
for strategy in input_tensor_node.strategies_vector:
|
||||
for strategy in input_tensor_node.strategies_vector:
|
||||
if isinstance(strategy.output_sharding_spec, ShardingSpec):
|
||||
input_sharding_spec = strategy.output_sharding_spec
|
||||
else:
|
||||
input_sharding_spec = strategy.output_sharding_spec[index]
|
||||
assert isinstance(input_sharding_spec, ShardingSpec), f'This assertion is used to debug.'
|
||||
dim_partition_dict_for_output = deepcopy(input_sharding_spec.dim_partition_dict)
|
||||
entire_shape_output = deepcopy(input_sharding_spec.entire_shape)
|
||||
output_sharding_spec = ShardingSpec(self.device_mesh,
|
||||
entire_shape_output,
|
||||
dim_partition_dict=dim_partition_dict_for_output)
|
||||
# TODO: use meta_info_prop to profile origin memory cost and compute cost, then divide them depending on sharding spec.
|
||||
compute_cost = 0
|
||||
memory_cost = 0
|
||||
resharding_costs = generate_resharding_costs(strategies_vector.predecessor_nodes,
|
||||
[input_sharding_spec])
|
||||
# to prevent the resharding happening, set their resharding cost to inf.
|
||||
resharding_costs[input_tensor_node] = [
|
||||
cost if cost == 0 else math.inf for cost in resharding_costs[input_tensor_node]
|
||||
]
|
||||
sharding_strategy = ShardingStrategy(
|
||||
name,
|
||||
output_sharding_spec,
|
||||
compute_cost=compute_cost,
|
||||
memory_cost=memory_cost,
|
||||
resharding_costs=resharding_costs,
|
||||
input_shardings=[input_tensor_node.output_sharding_spec])
|
||||
strategies_vector.append(sharding_strategy)
|
||||
assert isinstance(input_sharding_spec, ShardingSpec), f'This assertion is used to debug.'
|
||||
dim_partition_dict_for_output = deepcopy(input_sharding_spec.dim_partition_dict)
|
||||
entire_shape_output = deepcopy(input_sharding_spec.entire_shape)
|
||||
output_sharding_spec = ShardingSpec(self.device_mesh,
|
||||
entire_shape_output,
|
||||
dim_partition_dict=dim_partition_dict_for_output)
|
||||
# TODO: use meta_info_prop to profile origin memory cost and compute cost, then divide them depending on sharding spec.
|
||||
compute_cost = 0
|
||||
memory_cost = 0
|
||||
resharding_costs = generate_resharding_costs(strategies_vector.predecessor_nodes,
|
||||
[input_sharding_spec],
|
||||
index=index)
|
||||
# to prevent the resharding happening, set their resharding cost to inf.
|
||||
resharding_costs[input_tensor_node] = [
|
||||
cost if cost == 0 else INFINITY_COST for cost in resharding_costs[input_tensor_node]
|
||||
]
|
||||
sharding_strategy = ShardingStrategy(name,
|
||||
output_sharding_spec,
|
||||
compute_cost=compute_cost,
|
||||
memory_cost=memory_cost,
|
||||
resharding_costs=resharding_costs,
|
||||
input_shardings=[strategy.output_sharding_spec])
|
||||
strategies_vector.append(sharding_strategy)
|
||||
|
||||
# torch.arange function
|
||||
elif target == torch.arange:
|
||||
|
@ -334,8 +350,7 @@ class StrategiesConstructor:
|
|||
strategies_vector.append(sharding_strategy)
|
||||
|
||||
# op list to be processed to support gpt2
|
||||
elif target in (builtins.getattr, operator.le, torch.addmm, operator.pow, torch.where, torch.softmax,
|
||||
torch.nn.functional.softmax, torch.pow, torch.tanh):
|
||||
elif target in (builtins.getattr, operator.le, torch.addmm):
|
||||
pass
|
||||
# other function
|
||||
else:
|
||||
|
@ -344,7 +359,7 @@ class StrategiesConstructor:
|
|||
# call_method node
|
||||
if node.op == 'call_method':
|
||||
method = getattr(node.args[0]._meta_data.__class__, node.target)
|
||||
if method in (torch.Tensor.size, torch.Tensor.contiguous):
|
||||
if method in (torch.Tensor.size,):
|
||||
pass
|
||||
elif method in ELEMENTWISE_METHOD_OP:
|
||||
unary_elementwise_handler = UnaryElementwiseHandler(node, self.device_mesh, strategies_vector)
|
||||
|
@ -400,6 +415,18 @@ class StrategiesConstructor:
|
|||
self.strategy_map[node] = strategies_vector
|
||||
|
||||
|
||||
# remove no strategy nodes
|
||||
remove_list = []
|
||||
for strategies_vector in self.leaf_strategies:
|
||||
if len(strategies_vector) == 0:
|
||||
remove_list.append(strategies_vector.node)
|
||||
for node in remove_list:
|
||||
if node.strategies_vector in self.leaf_strategies:
|
||||
self.leaf_strategies.remove(node.strategies_vector)
|
||||
if node in self.strategy_map:
|
||||
self.strategy_map.pop(node)
|
||||
|
||||
|
||||
class StrategiesConstructor_V2:
|
||||
"""
|
||||
StrategiesConstructor is used to construct the parallelization plan for the model execution.
|
||||
|
@ -482,3 +509,4 @@ class StrategiesConstructor_V2:
|
|||
setattr(node, 'strategies_vector', strategies_vector)
|
||||
self.leaf_strategies.append(strategies_vector)
|
||||
self.strategy_map[node] = strategies_vector
|
||||
|
||||
|
|
Loading…
Reference in New Issue