mirror of https://github.com/hpcaitech/ColossalAI
[autoparallel] apply repeat block to reduce solving time (#2912)
parent
a848091141
commit
197d0bf4ed
|
@ -112,11 +112,13 @@ def solve_solution(gm: ColoGraphModule, strategy_constructor: StrategiesConstruc
|
|||
This method is used to solve the best solution for the given graph.
|
||||
The solution is a list of integers, each integer represents the best strategy index of the corresponding node.
|
||||
'''
|
||||
graph_analyser = GraphAnalyser(gm)
|
||||
liveness_list = graph_analyser.liveness_analysis()
|
||||
# temporarily we use all nodes as liveness list, we count the backward memory cost together with
|
||||
# forward memory cost into the node memory cost, and no activation checkpoint is used in this phase.
|
||||
# graph_analyser = GraphAnalyser(gm)
|
||||
# liveness_list = graph_analyser.liveness_analysis()
|
||||
cost_graph = CostGraph(strategy_constructor.leaf_strategies)
|
||||
cost_graph.simplify_graph()
|
||||
solver = Solver(gm.graph, strategy_constructor, cost_graph, graph_analyser, memory_budget=memory_budget)
|
||||
solver = Solver(gm.graph, strategy_constructor, cost_graph, memory_budget=memory_budget)
|
||||
ret = solver.call_solver_serialized_args()
|
||||
solution = list(ret[0])
|
||||
|
||||
|
|
|
@ -32,7 +32,7 @@ class Solver:
|
|||
graph: Graph,
|
||||
strategies_constructor: StrategiesConstructor,
|
||||
cost_graph: CostGraph,
|
||||
graph_analyser: GraphAnalyser,
|
||||
graph_analyser: GraphAnalyser = None,
|
||||
memory_budget: float = -1.0,
|
||||
solution_numbers: int = 1,
|
||||
forward_only: bool = False,
|
||||
|
@ -63,7 +63,10 @@ class Solver:
|
|||
self.memory_increasing_coefficient = memory_increasing_coefficient
|
||||
else:
|
||||
self.memory_increasing_coefficient = 1
|
||||
self.liveness_list = self.graph_analyser.liveness_analysis()
|
||||
# temporarily we use all nodes as liveness list, we count the backward memory cost together with
|
||||
# forward memory cost into the node memory cost, and no activation checkpoint is used in this phase.
|
||||
# self.liveness_list = self.graph_analyser.liveness_analysis()
|
||||
self.liveness_list = self.nodes
|
||||
self.node_index_dict = self._generate_node_index_dict()
|
||||
# The last solution vector of auto sharding.
|
||||
self.last_s_val = None
|
||||
|
@ -140,7 +143,7 @@ class Solver:
|
|||
liveness_set = self.liveness_list
|
||||
|
||||
# omit alias_set now
|
||||
alias_set = None
|
||||
alias_set = self.strategies_constructor.alias_set
|
||||
alias_convert_costs = None
|
||||
|
||||
# prepare compute_costs, communication_costs and memory_costs
|
||||
|
@ -230,6 +233,7 @@ class Solver:
|
|||
|
||||
# 0. Unpack flatten numpy arrays
|
||||
s_follow = following_nodes
|
||||
s_alias = alias_set
|
||||
|
||||
E = edge_pairs.reshape((-1, 2)) # noqa
|
||||
r = []
|
||||
|
@ -294,8 +298,11 @@ class Solver:
|
|||
if strategies_len[i] == 1:
|
||||
s.append([1])
|
||||
else:
|
||||
num_nodes += 1
|
||||
s.append(LpVariable.matrix(f"s[{i}]", (range(strategies_len[i]),), cat="Binary"))
|
||||
if i not in s_alias:
|
||||
num_nodes += 1
|
||||
s.append(LpVariable.matrix(f"s[{i}]", (range(strategies_len[i]),), cat="Binary"))
|
||||
else:
|
||||
s.append(s[s_alias[i]])
|
||||
else:
|
||||
if s_follow[i] < len(s):
|
||||
s.append(s[s_follow[i]])
|
||||
|
@ -311,15 +318,20 @@ class Solver:
|
|||
#############################
|
||||
e = []
|
||||
num_edges = 0
|
||||
map_edge_to_idx = {}
|
||||
for (idx, (i, j)) in enumerate(E):
|
||||
if len(s[i]) == 1:
|
||||
e.append(s[j])
|
||||
elif len(s[j]) == 1:
|
||||
e.append(s[i])
|
||||
else:
|
||||
num_edges += 1
|
||||
e.append(LpVariable.matrix(f"e[{i},{j}]", (range(len(s[i]) * len(s[j])),), cat="Binary"))
|
||||
if i in s_alias and j in s_alias and (s_alias[i], s_alias[j]) in map_edge_to_idx:
|
||||
e.append(e[map_edge_to_idx[(s_alias[i], s_alias[j])]])
|
||||
else:
|
||||
num_edges += 1
|
||||
e.append(LpVariable.matrix(f"e[{i},{j}]", (range(len(s[i]) * len(s[j])),), cat="Binary"))
|
||||
assert len(e[idx]) == len(r[idx])
|
||||
map_edge_to_idx[(i, j)] = idx
|
||||
for element in s:
|
||||
assert len(element) > 0
|
||||
# 2. Set initial value
|
||||
|
@ -371,13 +383,12 @@ class Solver:
|
|||
# compute memory consumption with liveness set #
|
||||
#################################################
|
||||
if memory_budget > 0:
|
||||
for liveness_stage in liveness_set:
|
||||
mem = 0
|
||||
for live_variable in liveness_stage.unique_live_vars:
|
||||
if live_variable.node not in self.node_index_dict:
|
||||
continue
|
||||
node_index = self.node_index_dict[live_variable.node]
|
||||
mem += lpSum(s[node_index][j] * m[node_index][j] for j in range(len(s[node_index])))
|
||||
mem = 0
|
||||
for node in liveness_set:
|
||||
if node not in self.node_index_dict:
|
||||
continue
|
||||
node_index = self.node_index_dict[node]
|
||||
mem += lpSum(s[node_index][j] * m[node_index][j] for j in range(len(s[node_index])))
|
||||
prob += mem <= memory_budget
|
||||
|
||||
# (d). specified by `cat="Binary"`
|
||||
|
|
|
@ -15,6 +15,7 @@ from colossalai.auto_parallel.tensor_shard.node_handler import (
|
|||
)
|
||||
from colossalai.auto_parallel.tensor_shard.sharding_strategy import StrategiesVector
|
||||
from colossalai.auto_parallel.tensor_shard.utils import generate_resharding_costs, generate_sharding_spec
|
||||
from colossalai.auto_parallel.tensor_shard.utils.factory import find_repeat_blocks
|
||||
from colossalai.device.device_mesh import DeviceMesh
|
||||
|
||||
from ..options import DataloaderOption, SolverOptions
|
||||
|
@ -42,6 +43,7 @@ class StrategiesConstructor:
|
|||
self.strategy_map = {}
|
||||
self.solver_options = solver_options
|
||||
self.no_strategy_nodes = []
|
||||
self.alias_set = None
|
||||
|
||||
def remove_duplicated_strategy(self, strategies_vector):
|
||||
'''
|
||||
|
@ -59,6 +61,22 @@ class StrategiesConstructor:
|
|||
for strategy in remove_list:
|
||||
strategies_vector.remove(strategy)
|
||||
|
||||
def generate_alias_set(self):
|
||||
|
||||
node_list = [strategy_vector.node for strategy_vector in self.leaf_strategies]
|
||||
common_blocks = find_repeat_blocks(node_list, self.root_module, common_length_threshold=10)
|
||||
|
||||
repeat_block_nums = len(common_blocks)
|
||||
alias_set = {}
|
||||
|
||||
if repeat_block_nums == 0:
|
||||
return alias_set
|
||||
|
||||
for index, common_node in enumerate(common_blocks[0]):
|
||||
for i in range(1, repeat_block_nums):
|
||||
alias_set[node_list.index(common_blocks[i][index])] = node_list.index(common_node)
|
||||
return alias_set
|
||||
|
||||
def build_strategies_and_cost(self):
|
||||
"""
|
||||
This method is to build the strategy vector for each node in the computation graph.
|
||||
|
@ -175,3 +193,6 @@ class StrategiesConstructor:
|
|||
self.leaf_strategies.remove(node.strategies_vector)
|
||||
if node in self.strategy_map:
|
||||
self.strategy_map.pop(node)
|
||||
|
||||
alias_set = self.generate_alias_set()
|
||||
self.alias_set = alias_set
|
||||
|
|
|
@ -15,13 +15,13 @@ from tests.test_auto_parallel.test_tensor_shard.test_gpt.gpt_modules import GPT2
|
|||
|
||||
BATCH_SIZE = 1
|
||||
SEQ_LENGTH = 32
|
||||
HIDDEN_DIM = 768
|
||||
HIDDEN_DIM = 384
|
||||
|
||||
|
||||
@run_on_environment_flag(name='AUTO_PARALLEL')
|
||||
@parameterize('model_cls', [GPT2Block, GPT2Attention, GPT2MLP, GPT2Model])
|
||||
def test_self_attention_block(model_cls):
|
||||
config = transformers.GPT2Config(n_position=64, n_layer=4, n_head=16, n_embd=HIDDEN_DIM)
|
||||
config = transformers.GPT2Config(n_position=64, n_layer=12, n_head=16, n_embd=HIDDEN_DIM)
|
||||
if model_cls == GPT2MLP:
|
||||
model = model_cls(intermediate_size=4 * config.hidden_size, config=config)
|
||||
else:
|
||||
|
@ -54,15 +54,13 @@ def test_self_attention_block(model_cls):
|
|||
gm = GraphModule(model, graph, model.__class__.__name__)
|
||||
print(gm.graph)
|
||||
gm.recompile()
|
||||
graph_analyser = GraphAnalyser(gm)
|
||||
liveness_list = graph_analyser.liveness_analysis()
|
||||
solver_options = SolverOptions()
|
||||
strategies_constructor = StrategiesConstructor(graph, device_mesh, solver_options)
|
||||
strategies_constructor.build_strategies_and_cost()
|
||||
|
||||
cost_graph = CostGraph(strategies_constructor.leaf_strategies)
|
||||
cost_graph.simplify_graph()
|
||||
solver = Solver(gm.graph, strategies_constructor, cost_graph, graph_analyser, memory_budget=-1)
|
||||
solver = Solver(gm.graph, strategies_constructor, cost_graph, memory_budget=-1)
|
||||
ret = solver.call_solver_serialized_args()
|
||||
strategies_list = solver.last_s_val
|
||||
nodes = [strategies_vector.node for strategies_vector in strategies_constructor.leaf_strategies]
|
||||
|
|
|
@ -9,7 +9,6 @@ from colossalai.auto_parallel.passes.runtime_preparation_pass import runtime_pre
|
|||
from colossalai.auto_parallel.tensor_shard.options import SolverOptions
|
||||
from colossalai.auto_parallel.tensor_shard.solver import StrategiesConstructor
|
||||
from colossalai.auto_parallel.tensor_shard.solver.cost_graph import CostGraph
|
||||
from colossalai.auto_parallel.tensor_shard.solver.graph_analysis import GraphAnalyser
|
||||
from colossalai.auto_parallel.tensor_shard.solver.solver import Solver
|
||||
from colossalai.device.device_mesh import DeviceMesh
|
||||
from colossalai.fx.tracer.tracer import ColoTracer
|
||||
|
@ -109,8 +108,7 @@ def numerical_test_for_node_strategy(model: torch.nn.Module,
|
|||
# solution construction
|
||||
cost_graph = CostGraph(strategies_constructor.leaf_strategies)
|
||||
cost_graph.simplify_graph()
|
||||
graph_analyser = GraphAnalyser(gm)
|
||||
solver = Solver(gm.graph, strategies_constructor, cost_graph, graph_analyser, verbose=False)
|
||||
solver = Solver(gm.graph, strategies_constructor, cost_graph, verbose=False)
|
||||
ret = solver.call_solver_serialized_args()
|
||||
solution = list(ret[0])
|
||||
gm, sharding_spec_dict, origin_spec_dict, comm_actions_dict = runtime_preparation_pass(
|
||||
|
|
|
@ -51,15 +51,14 @@ def test_cost_graph():
|
|||
# return fc
|
||||
gm = GraphModule(model, graph, model.__class__.__name__)
|
||||
gm.recompile()
|
||||
graph_analyser = GraphAnalyser(gm)
|
||||
liveness_list = graph_analyser.liveness_analysis()
|
||||
|
||||
solver_options = SolverOptions()
|
||||
strategies_constructor = StrategiesConstructor(graph, device_mesh, solver_options)
|
||||
strategies_constructor.build_strategies_and_cost()
|
||||
|
||||
cost_graph = CostGraph(strategies_constructor.leaf_strategies)
|
||||
cost_graph.simplify_graph()
|
||||
solver = Solver(gm.graph, strategies_constructor, cost_graph, graph_analyser)
|
||||
solver = Solver(gm.graph, strategies_constructor, cost_graph)
|
||||
|
||||
ret = solver.call_solver_serialized_args()
|
||||
print(ret[0])
|
||||
|
|
Loading…
Reference in New Issue