Browse Source

[autoparallel] apply repeat block to reduce solving time (#2912)

pull/2933/head
YuliangLiu0306 2 years ago committed by GitHub
parent
commit
197d0bf4ed
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
  1. 8
      colossalai/auto_parallel/tensor_shard/initialize.py
  2. 39
      colossalai/auto_parallel/tensor_shard/solver/solver.py
  3. 21
      colossalai/auto_parallel/tensor_shard/solver/strategies_constructor.py
  4. 8
      tests/test_auto_parallel/test_tensor_shard/test_gpt/test_solver_with_gpt_module.py
  5. 4
      tests/test_auto_parallel/test_tensor_shard/test_node_handler/utils.py
  6. 5
      tests/test_auto_parallel/test_tensor_shard/test_solver_with_resnet_v2.py

8
colossalai/auto_parallel/tensor_shard/initialize.py

@ -112,11 +112,13 @@ def solve_solution(gm: ColoGraphModule, strategy_constructor: StrategiesConstruc
This method is used to solve the best solution for the given graph.
The solution is a list of integers, each integer represents the best strategy index of the corresponding node.
'''
graph_analyser = GraphAnalyser(gm)
liveness_list = graph_analyser.liveness_analysis()
# temporarily we use all nodes as liveness list, we count the backward memory cost together with
# forward memory cost into the node memory cost, and no activation checkpoint is used in this phase.
# graph_analyser = GraphAnalyser(gm)
# liveness_list = graph_analyser.liveness_analysis()
cost_graph = CostGraph(strategy_constructor.leaf_strategies)
cost_graph.simplify_graph()
solver = Solver(gm.graph, strategy_constructor, cost_graph, graph_analyser, memory_budget=memory_budget)
solver = Solver(gm.graph, strategy_constructor, cost_graph, memory_budget=memory_budget)
ret = solver.call_solver_serialized_args()
solution = list(ret[0])

39
colossalai/auto_parallel/tensor_shard/solver/solver.py

@ -32,7 +32,7 @@ class Solver:
graph: Graph,
strategies_constructor: StrategiesConstructor,
cost_graph: CostGraph,
graph_analyser: GraphAnalyser,
graph_analyser: GraphAnalyser = None,
memory_budget: float = -1.0,
solution_numbers: int = 1,
forward_only: bool = False,
@ -63,7 +63,10 @@ class Solver:
self.memory_increasing_coefficient = memory_increasing_coefficient
else:
self.memory_increasing_coefficient = 1
self.liveness_list = self.graph_analyser.liveness_analysis()
# temporarily we use all nodes as liveness list, we count the backward memory cost together with
# forward memory cost into the node memory cost, and no activation checkpoint is used in this phase.
# self.liveness_list = self.graph_analyser.liveness_analysis()
self.liveness_list = self.nodes
self.node_index_dict = self._generate_node_index_dict()
# The last solution vector of auto sharding.
self.last_s_val = None
@ -140,7 +143,7 @@ class Solver:
liveness_set = self.liveness_list
# omit alias_set now
alias_set = None
alias_set = self.strategies_constructor.alias_set
alias_convert_costs = None
# prepare compute_costs, communication_costs and memory_costs
@ -230,6 +233,7 @@ class Solver:
# 0. Unpack flatten numpy arrays
s_follow = following_nodes
s_alias = alias_set
E = edge_pairs.reshape((-1, 2)) # noqa
r = []
@ -294,8 +298,11 @@ class Solver:
if strategies_len[i] == 1:
s.append([1])
else:
num_nodes += 1
s.append(LpVariable.matrix(f"s[{i}]", (range(strategies_len[i]),), cat="Binary"))
if i not in s_alias:
num_nodes += 1
s.append(LpVariable.matrix(f"s[{i}]", (range(strategies_len[i]),), cat="Binary"))
else:
s.append(s[s_alias[i]])
else:
if s_follow[i] < len(s):
s.append(s[s_follow[i]])
@ -311,15 +318,20 @@ class Solver:
#############################
e = []
num_edges = 0
map_edge_to_idx = {}
for (idx, (i, j)) in enumerate(E):
if len(s[i]) == 1:
e.append(s[j])
elif len(s[j]) == 1:
e.append(s[i])
else:
num_edges += 1
e.append(LpVariable.matrix(f"e[{i},{j}]", (range(len(s[i]) * len(s[j])),), cat="Binary"))
if i in s_alias and j in s_alias and (s_alias[i], s_alias[j]) in map_edge_to_idx:
e.append(e[map_edge_to_idx[(s_alias[i], s_alias[j])]])
else:
num_edges += 1
e.append(LpVariable.matrix(f"e[{i},{j}]", (range(len(s[i]) * len(s[j])),), cat="Binary"))
assert len(e[idx]) == len(r[idx])
map_edge_to_idx[(i, j)] = idx
for element in s:
assert len(element) > 0
# 2. Set initial value
@ -371,13 +383,12 @@ class Solver:
# compute memory consumption with liveness set #
#################################################
if memory_budget > 0:
for liveness_stage in liveness_set:
mem = 0
for live_variable in liveness_stage.unique_live_vars:
if live_variable.node not in self.node_index_dict:
continue
node_index = self.node_index_dict[live_variable.node]
mem += lpSum(s[node_index][j] * m[node_index][j] for j in range(len(s[node_index])))
mem = 0
for node in liveness_set:
if node not in self.node_index_dict:
continue
node_index = self.node_index_dict[node]
mem += lpSum(s[node_index][j] * m[node_index][j] for j in range(len(s[node_index])))
prob += mem <= memory_budget
# (d). specified by `cat="Binary"`

21
colossalai/auto_parallel/tensor_shard/solver/strategies_constructor.py

@ -15,6 +15,7 @@ from colossalai.auto_parallel.tensor_shard.node_handler import (
)
from colossalai.auto_parallel.tensor_shard.sharding_strategy import StrategiesVector
from colossalai.auto_parallel.tensor_shard.utils import generate_resharding_costs, generate_sharding_spec
from colossalai.auto_parallel.tensor_shard.utils.factory import find_repeat_blocks
from colossalai.device.device_mesh import DeviceMesh
from ..options import DataloaderOption, SolverOptions
@ -42,6 +43,7 @@ class StrategiesConstructor:
self.strategy_map = {}
self.solver_options = solver_options
self.no_strategy_nodes = []
self.alias_set = None
def remove_duplicated_strategy(self, strategies_vector):
'''
@ -59,6 +61,22 @@ class StrategiesConstructor:
for strategy in remove_list:
strategies_vector.remove(strategy)
def generate_alias_set(self):
node_list = [strategy_vector.node for strategy_vector in self.leaf_strategies]
common_blocks = find_repeat_blocks(node_list, self.root_module, common_length_threshold=10)
repeat_block_nums = len(common_blocks)
alias_set = {}
if repeat_block_nums == 0:
return alias_set
for index, common_node in enumerate(common_blocks[0]):
for i in range(1, repeat_block_nums):
alias_set[node_list.index(common_blocks[i][index])] = node_list.index(common_node)
return alias_set
def build_strategies_and_cost(self):
"""
This method is to build the strategy vector for each node in the computation graph.
@ -175,3 +193,6 @@ class StrategiesConstructor:
self.leaf_strategies.remove(node.strategies_vector)
if node in self.strategy_map:
self.strategy_map.pop(node)
alias_set = self.generate_alias_set()
self.alias_set = alias_set

8
tests/test_auto_parallel/test_tensor_shard/test_gpt/test_solver_with_gpt_module.py

@ -15,13 +15,13 @@ from tests.test_auto_parallel.test_tensor_shard.test_gpt.gpt_modules import GPT2
BATCH_SIZE = 1
SEQ_LENGTH = 32
HIDDEN_DIM = 768
HIDDEN_DIM = 384
@run_on_environment_flag(name='AUTO_PARALLEL')
@parameterize('model_cls', [GPT2Block, GPT2Attention, GPT2MLP, GPT2Model])
def test_self_attention_block(model_cls):
config = transformers.GPT2Config(n_position=64, n_layer=4, n_head=16, n_embd=HIDDEN_DIM)
config = transformers.GPT2Config(n_position=64, n_layer=12, n_head=16, n_embd=HIDDEN_DIM)
if model_cls == GPT2MLP:
model = model_cls(intermediate_size=4 * config.hidden_size, config=config)
else:
@ -54,15 +54,13 @@ def test_self_attention_block(model_cls):
gm = GraphModule(model, graph, model.__class__.__name__)
print(gm.graph)
gm.recompile()
graph_analyser = GraphAnalyser(gm)
liveness_list = graph_analyser.liveness_analysis()
solver_options = SolverOptions()
strategies_constructor = StrategiesConstructor(graph, device_mesh, solver_options)
strategies_constructor.build_strategies_and_cost()
cost_graph = CostGraph(strategies_constructor.leaf_strategies)
cost_graph.simplify_graph()
solver = Solver(gm.graph, strategies_constructor, cost_graph, graph_analyser, memory_budget=-1)
solver = Solver(gm.graph, strategies_constructor, cost_graph, memory_budget=-1)
ret = solver.call_solver_serialized_args()
strategies_list = solver.last_s_val
nodes = [strategies_vector.node for strategies_vector in strategies_constructor.leaf_strategies]

4
tests/test_auto_parallel/test_tensor_shard/test_node_handler/utils.py

@ -9,7 +9,6 @@ from colossalai.auto_parallel.passes.runtime_preparation_pass import runtime_pre
from colossalai.auto_parallel.tensor_shard.options import SolverOptions
from colossalai.auto_parallel.tensor_shard.solver import StrategiesConstructor
from colossalai.auto_parallel.tensor_shard.solver.cost_graph import CostGraph
from colossalai.auto_parallel.tensor_shard.solver.graph_analysis import GraphAnalyser
from colossalai.auto_parallel.tensor_shard.solver.solver import Solver
from colossalai.device.device_mesh import DeviceMesh
from colossalai.fx.tracer.tracer import ColoTracer
@ -109,8 +108,7 @@ def numerical_test_for_node_strategy(model: torch.nn.Module,
# solution construction
cost_graph = CostGraph(strategies_constructor.leaf_strategies)
cost_graph.simplify_graph()
graph_analyser = GraphAnalyser(gm)
solver = Solver(gm.graph, strategies_constructor, cost_graph, graph_analyser, verbose=False)
solver = Solver(gm.graph, strategies_constructor, cost_graph, verbose=False)
ret = solver.call_solver_serialized_args()
solution = list(ret[0])
gm, sharding_spec_dict, origin_spec_dict, comm_actions_dict = runtime_preparation_pass(

5
tests/test_auto_parallel/test_tensor_shard/test_solver_with_resnet_v2.py

@ -51,15 +51,14 @@ def test_cost_graph():
# return fc
gm = GraphModule(model, graph, model.__class__.__name__)
gm.recompile()
graph_analyser = GraphAnalyser(gm)
liveness_list = graph_analyser.liveness_analysis()
solver_options = SolverOptions()
strategies_constructor = StrategiesConstructor(graph, device_mesh, solver_options)
strategies_constructor.build_strategies_and_cost()
cost_graph = CostGraph(strategies_constructor.leaf_strategies)
cost_graph.simplify_graph()
solver = Solver(gm.graph, strategies_constructor, cost_graph, graph_analyser)
solver = Solver(gm.graph, strategies_constructor, cost_graph)
ret = solver.call_solver_serialized_args()
print(ret[0])

Loading…
Cancel
Save