diff --git a/colossalai/auto_parallel/tensor_shard/initialize.py b/colossalai/auto_parallel/tensor_shard/initialize.py index 4affa3789..60472eee5 100644 --- a/colossalai/auto_parallel/tensor_shard/initialize.py +++ b/colossalai/auto_parallel/tensor_shard/initialize.py @@ -112,11 +112,13 @@ def solve_solution(gm: ColoGraphModule, strategy_constructor: StrategiesConstruc This method is used to solve the best solution for the given graph. The solution is a list of integers, each integer represents the best strategy index of the corresponding node. ''' - graph_analyser = GraphAnalyser(gm) - liveness_list = graph_analyser.liveness_analysis() + # temporarily we use all nodes as liveness list, we count the backward memory cost together with + # forward memory cost into the node memory cost, and no activation checkpoint is used in this phase. + # graph_analyser = GraphAnalyser(gm) + # liveness_list = graph_analyser.liveness_analysis() cost_graph = CostGraph(strategy_constructor.leaf_strategies) cost_graph.simplify_graph() - solver = Solver(gm.graph, strategy_constructor, cost_graph, graph_analyser, memory_budget=memory_budget) + solver = Solver(gm.graph, strategy_constructor, cost_graph, memory_budget=memory_budget) ret = solver.call_solver_serialized_args() solution = list(ret[0]) diff --git a/colossalai/auto_parallel/tensor_shard/solver/solver.py b/colossalai/auto_parallel/tensor_shard/solver/solver.py index 5449fb5a1..f5c6663dc 100644 --- a/colossalai/auto_parallel/tensor_shard/solver/solver.py +++ b/colossalai/auto_parallel/tensor_shard/solver/solver.py @@ -32,7 +32,7 @@ class Solver: graph: Graph, strategies_constructor: StrategiesConstructor, cost_graph: CostGraph, - graph_analyser: GraphAnalyser, + graph_analyser: GraphAnalyser = None, memory_budget: float = -1.0, solution_numbers: int = 1, forward_only: bool = False, @@ -63,7 +63,10 @@ class Solver: self.memory_increasing_coefficient = memory_increasing_coefficient else: self.memory_increasing_coefficient = 1 - self.liveness_list = self.graph_analyser.liveness_analysis() + # temporarily we use all nodes as liveness list, we count the backward memory cost together with + # forward memory cost into the node memory cost, and no activation checkpoint is used in this phase. + # self.liveness_list = self.graph_analyser.liveness_analysis() + self.liveness_list = self.nodes self.node_index_dict = self._generate_node_index_dict() # The last solution vector of auto sharding. self.last_s_val = None @@ -140,7 +143,7 @@ class Solver: liveness_set = self.liveness_list # omit alias_set now - alias_set = None + alias_set = self.strategies_constructor.alias_set alias_convert_costs = None # prepare compute_costs, communication_costs and memory_costs @@ -230,6 +233,7 @@ class Solver: # 0. Unpack flatten numpy arrays s_follow = following_nodes + s_alias = alias_set E = edge_pairs.reshape((-1, 2)) # noqa r = [] @@ -294,8 +298,11 @@ class Solver: if strategies_len[i] == 1: s.append([1]) else: - num_nodes += 1 - s.append(LpVariable.matrix(f"s[{i}]", (range(strategies_len[i]),), cat="Binary")) + if i not in s_alias: + num_nodes += 1 + s.append(LpVariable.matrix(f"s[{i}]", (range(strategies_len[i]),), cat="Binary")) + else: + s.append(s[s_alias[i]]) else: if s_follow[i] < len(s): s.append(s[s_follow[i]]) @@ -311,15 +318,20 @@ class Solver: ############################# e = [] num_edges = 0 + map_edge_to_idx = {} for (idx, (i, j)) in enumerate(E): if len(s[i]) == 1: e.append(s[j]) elif len(s[j]) == 1: e.append(s[i]) else: - num_edges += 1 - e.append(LpVariable.matrix(f"e[{i},{j}]", (range(len(s[i]) * len(s[j])),), cat="Binary")) + if i in s_alias and j in s_alias and (s_alias[i], s_alias[j]) in map_edge_to_idx: + e.append(e[map_edge_to_idx[(s_alias[i], s_alias[j])]]) + else: + num_edges += 1 + e.append(LpVariable.matrix(f"e[{i},{j}]", (range(len(s[i]) * len(s[j])),), cat="Binary")) assert len(e[idx]) == len(r[idx]) + map_edge_to_idx[(i, j)] = idx for element in s: assert len(element) > 0 # 2. Set initial value @@ -371,13 +383,12 @@ class Solver: # compute memory consumption with liveness set # ################################################# if memory_budget > 0: - for liveness_stage in liveness_set: - mem = 0 - for live_variable in liveness_stage.unique_live_vars: - if live_variable.node not in self.node_index_dict: - continue - node_index = self.node_index_dict[live_variable.node] - mem += lpSum(s[node_index][j] * m[node_index][j] for j in range(len(s[node_index]))) + mem = 0 + for node in liveness_set: + if node not in self.node_index_dict: + continue + node_index = self.node_index_dict[node] + mem += lpSum(s[node_index][j] * m[node_index][j] for j in range(len(s[node_index]))) prob += mem <= memory_budget # (d). specified by `cat="Binary"` diff --git a/colossalai/auto_parallel/tensor_shard/solver/strategies_constructor.py b/colossalai/auto_parallel/tensor_shard/solver/strategies_constructor.py index 40741daca..59ead1ca8 100644 --- a/colossalai/auto_parallel/tensor_shard/solver/strategies_constructor.py +++ b/colossalai/auto_parallel/tensor_shard/solver/strategies_constructor.py @@ -15,6 +15,7 @@ from colossalai.auto_parallel.tensor_shard.node_handler import ( ) from colossalai.auto_parallel.tensor_shard.sharding_strategy import StrategiesVector from colossalai.auto_parallel.tensor_shard.utils import generate_resharding_costs, generate_sharding_spec +from colossalai.auto_parallel.tensor_shard.utils.factory import find_repeat_blocks from colossalai.device.device_mesh import DeviceMesh from ..options import DataloaderOption, SolverOptions @@ -42,6 +43,7 @@ class StrategiesConstructor: self.strategy_map = {} self.solver_options = solver_options self.no_strategy_nodes = [] + self.alias_set = None def remove_duplicated_strategy(self, strategies_vector): ''' @@ -59,6 +61,22 @@ class StrategiesConstructor: for strategy in remove_list: strategies_vector.remove(strategy) + def generate_alias_set(self): + + node_list = [strategy_vector.node for strategy_vector in self.leaf_strategies] + common_blocks = find_repeat_blocks(node_list, self.root_module, common_length_threshold=10) + + repeat_block_nums = len(common_blocks) + alias_set = {} + + if repeat_block_nums == 0: + return alias_set + + for index, common_node in enumerate(common_blocks[0]): + for i in range(1, repeat_block_nums): + alias_set[node_list.index(common_blocks[i][index])] = node_list.index(common_node) + return alias_set + def build_strategies_and_cost(self): """ This method is to build the strategy vector for each node in the computation graph. @@ -175,3 +193,6 @@ class StrategiesConstructor: self.leaf_strategies.remove(node.strategies_vector) if node in self.strategy_map: self.strategy_map.pop(node) + + alias_set = self.generate_alias_set() + self.alias_set = alias_set diff --git a/tests/test_auto_parallel/test_tensor_shard/test_gpt/test_solver_with_gpt_module.py b/tests/test_auto_parallel/test_tensor_shard/test_gpt/test_solver_with_gpt_module.py index a6be1928b..4adb4fbaf 100644 --- a/tests/test_auto_parallel/test_tensor_shard/test_gpt/test_solver_with_gpt_module.py +++ b/tests/test_auto_parallel/test_tensor_shard/test_gpt/test_solver_with_gpt_module.py @@ -15,13 +15,13 @@ from tests.test_auto_parallel.test_tensor_shard.test_gpt.gpt_modules import GPT2 BATCH_SIZE = 1 SEQ_LENGTH = 32 -HIDDEN_DIM = 768 +HIDDEN_DIM = 384 @run_on_environment_flag(name='AUTO_PARALLEL') @parameterize('model_cls', [GPT2Block, GPT2Attention, GPT2MLP, GPT2Model]) def test_self_attention_block(model_cls): - config = transformers.GPT2Config(n_position=64, n_layer=4, n_head=16, n_embd=HIDDEN_DIM) + config = transformers.GPT2Config(n_position=64, n_layer=12, n_head=16, n_embd=HIDDEN_DIM) if model_cls == GPT2MLP: model = model_cls(intermediate_size=4 * config.hidden_size, config=config) else: @@ -54,15 +54,13 @@ def test_self_attention_block(model_cls): gm = GraphModule(model, graph, model.__class__.__name__) print(gm.graph) gm.recompile() - graph_analyser = GraphAnalyser(gm) - liveness_list = graph_analyser.liveness_analysis() solver_options = SolverOptions() strategies_constructor = StrategiesConstructor(graph, device_mesh, solver_options) strategies_constructor.build_strategies_and_cost() cost_graph = CostGraph(strategies_constructor.leaf_strategies) cost_graph.simplify_graph() - solver = Solver(gm.graph, strategies_constructor, cost_graph, graph_analyser, memory_budget=-1) + solver = Solver(gm.graph, strategies_constructor, cost_graph, memory_budget=-1) ret = solver.call_solver_serialized_args() strategies_list = solver.last_s_val nodes = [strategies_vector.node for strategies_vector in strategies_constructor.leaf_strategies] diff --git a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/utils.py b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/utils.py index 14c8cb296..0cdfdbc9d 100644 --- a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/utils.py +++ b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/utils.py @@ -9,7 +9,6 @@ from colossalai.auto_parallel.passes.runtime_preparation_pass import runtime_pre from colossalai.auto_parallel.tensor_shard.options import SolverOptions from colossalai.auto_parallel.tensor_shard.solver import StrategiesConstructor from colossalai.auto_parallel.tensor_shard.solver.cost_graph import CostGraph -from colossalai.auto_parallel.tensor_shard.solver.graph_analysis import GraphAnalyser from colossalai.auto_parallel.tensor_shard.solver.solver import Solver from colossalai.device.device_mesh import DeviceMesh from colossalai.fx.tracer.tracer import ColoTracer @@ -109,8 +108,7 @@ def numerical_test_for_node_strategy(model: torch.nn.Module, # solution construction cost_graph = CostGraph(strategies_constructor.leaf_strategies) cost_graph.simplify_graph() - graph_analyser = GraphAnalyser(gm) - solver = Solver(gm.graph, strategies_constructor, cost_graph, graph_analyser, verbose=False) + solver = Solver(gm.graph, strategies_constructor, cost_graph, verbose=False) ret = solver.call_solver_serialized_args() solution = list(ret[0]) gm, sharding_spec_dict, origin_spec_dict, comm_actions_dict = runtime_preparation_pass( diff --git a/tests/test_auto_parallel/test_tensor_shard/test_solver_with_resnet_v2.py b/tests/test_auto_parallel/test_tensor_shard/test_solver_with_resnet_v2.py index 6f64acd52..bbfc3e1fc 100644 --- a/tests/test_auto_parallel/test_tensor_shard/test_solver_with_resnet_v2.py +++ b/tests/test_auto_parallel/test_tensor_shard/test_solver_with_resnet_v2.py @@ -51,15 +51,14 @@ def test_cost_graph(): # return fc gm = GraphModule(model, graph, model.__class__.__name__) gm.recompile() - graph_analyser = GraphAnalyser(gm) - liveness_list = graph_analyser.liveness_analysis() + solver_options = SolverOptions() strategies_constructor = StrategiesConstructor(graph, device_mesh, solver_options) strategies_constructor.build_strategies_and_cost() cost_graph = CostGraph(strategies_constructor.leaf_strategies) cost_graph.simplify_graph() - solver = Solver(gm.graph, strategies_constructor, cost_graph, graph_analyser) + solver = Solver(gm.graph, strategies_constructor, cost_graph) ret = solver.call_solver_serialized_args() print(ret[0])