From 400f63012eb288b849253efd438622d6898f4233 Mon Sep 17 00:00:00 2001 From: Ziyue Jiang <ziyue.jiang97@gmail.com> Date: Tue, 7 Mar 2023 10:34:31 +0800 Subject: [PATCH] [pipeline] Add Simplified Alpa DP Partition (#2507) * add alpa dp split * add alpa dp split * use fwd+bwd instead of fwd only --------- Co-authored-by: Ziyue Jiang <ziyue.jiang@gmail.com> --- .../fx/passes/adding_split_node_pass.py | 161 ++++++++++++++++++ colossalai/fx/passes/meta_info_prop.py | 1 + colossalai/pipeline/rpc/_pipeline_base.py | 3 +- .../pipeline_parallel/train_gpt_pp.py | 47 +++-- 4 files changed, 197 insertions(+), 15 deletions(-) diff --git a/colossalai/fx/passes/adding_split_node_pass.py b/colossalai/fx/passes/adding_split_node_pass.py index 0499769d8..2c7b842b5 100644 --- a/colossalai/fx/passes/adding_split_node_pass.py +++ b/colossalai/fx/passes/adding_split_node_pass.py @@ -1,4 +1,6 @@ +import numpy as np import torch +import tqdm from torch.fx import symbolic_trace from torch.fx.node import Node @@ -9,6 +11,165 @@ def pipe_split(): pass +def block_split(): + pass + + +# Construct blocks with the condition that (block_flops / total_flops) >= limit. +def construct_blocks(gm: torch.fx.GraphModule, limit=0.01): + total_fwd_flop = 0 + total_bwd_flop = 0 + for node in gm.graph.nodes: + total_fwd_flop += node.fwd_flop + total_bwd_flop += node.bwd_flop + + total_flop = total_fwd_flop + total_bwd_flop + per_block_flop = total_flop * limit + accumulate_fwd_flop = 0 + accumulate_bwd_flop = 0 + block_nodes = [] + for node in gm.graph.nodes: + if 'block_split' in node.name: + continue + accumulate_fwd_flop += node.fwd_flop + accumulate_bwd_flop += node.bwd_flop + if accumulate_fwd_flop + accumulate_bwd_flop >= per_block_flop: + with gm.graph.inserting_after(node): + block_node = gm.graph.create_node('call_function', block_split) + setattr(block_node, 'fwd_flop', accumulate_fwd_flop) + setattr(block_node, 'bwd_flop', accumulate_bwd_flop) + accumulate_fwd_flop = 0 + accumulate_bwd_flop = 0 + block_nodes.append(block_node) + + return block_nodes + + +def remove_blocks(gm: torch.fx.GraphModule): + for node in gm.graph.nodes: + if (node.op, node.target) == ('call_function', block_split): + gm.graph.erase_node(node) + + +def get_compute_costs(node_list): + num_nodes = len(node_list) + all_compute_cost = np.full((num_nodes, num_nodes), np.inf, dtype=np.float64) + + for start in tqdm.tqdm(range(num_nodes), desc='start pos', position=0): + for end in tqdm.tqdm(range(start, num_nodes), desc='end pos', position=1, leave=False): + selected_flops = [(node_list[i].fwd_flop + node_list[i].bwd_flop) for i in range(start, end + 1)] + all_compute_cost[start, end] = sum(selected_flops) + + return all_compute_cost + + +def do_dp_split_gpipe_impl(num_nodes, num_stages, num_microbatches, compute_costs, max_compute_cost): + """The core implementation of the DP algorithm.""" + # Adapted from Alpa DP Formulation. + # For f, node ID start from 0 + # f[number of stages, + # node id that is currently being considered] + + # record time cost(assess by fwd+bwd flop now) + f = np.full((num_stages + 1, num_nodes + 1), np.inf, dtype=np.float32) + + # record max stage compute cost among all stages in this partition. + f_stage_max = np.full((num_stages + 1, num_nodes + 1), 0.0, dtype=np.float32) + # record start node index for next stage in this partition + f_argmin = np.full((num_stages + 1, num_nodes + 1), -1, dtype=np.int32) + f[0, num_nodes] = 0 + for s in tqdm.tqdm(range(1, num_stages + 1), desc='stage', position=2, leave=False): # pylint: disable=too-many-nested-blocks + for i in tqdm.tqdm(range(num_nodes - 1, -1, -1), desc='start node', position=3, leave=False): + for k in tqdm.tqdm(range(num_nodes, i, -1), desc='mid node', position=4, leave=False): + stage_cost = compute_costs[i, k - 1] + new_cost = f[s - 1, k] + stage_cost + if (stage_cost <= max_compute_cost and new_cost < f[s, i]): + f[s, i] = new_cost + f_stage_max[s, i] = max(f_stage_max[s - 1, k], stage_cost) + f_argmin[s, i] = k + + best_total_cost = f[num_stages, 0] + if np.isinf(best_total_cost): + return np.inf, None + + total_cost = f[num_stages, 0] + (num_microbatches - 1) * f_stage_max[num_stages, 0] + + current_s = num_stages + current_node = 0 + + res = [] + while current_s > 0 and current_node < num_nodes: + next_start_node = f_argmin[current_s, current_node] + res.append((current_node, next_start_node)) + current_s -= 1 + current_node = next_start_node + + return total_cost, res + + +def do_dp_split_gpipe(node_list, compute_costs, num_stages: int, num_microbatches: int): + # Ignore the memory cost profiling in Alpa's design for convenience. + max_compute_costs = np.sort(np.unique(compute_costs)) + best_cost = np.inf + best_solution = None + last_max_compute_cost = 0.0 + gap = 1e6 # temporary magic number, unit: flops + + for max_compute_cost in tqdm.tqdm(max_compute_costs): + # Pruning to reduce search space. + if max_compute_cost * num_microbatches >= best_cost: + break + if max_compute_cost - last_max_compute_cost < gap: + continue + + cost, solution = do_dp_split_gpipe_impl(len(node_list), num_stages, num_microbatches, compute_costs, + max_compute_cost) + + if cost < best_cost: + best_cost = cost + best_solution = solution + last_max_compute_cost = max_compute_cost + return best_cost, best_solution + + +# Auto DP partition based on Alpa. +# Adapted to Gpipe Scheduler +# split_mode: +# 'node': fx_node +# 'block': many fx_nodes construct a block +def gpipe_dp_split_pass(gm: torch.fx.GraphModule, pp_size: int, num_microbatches: int, mode='block', block_limit=0.01): + assert mode in ['node', 'block'] + + # nodes or blocks will be used in partition. + node_list = [] + if mode == 'node': + for node in gm.graph.nodes: + node_list.append(node) + elif mode == 'block': + node_list = construct_blocks(gm, limit=block_limit) + else: + pass + + compute_costs = get_compute_costs(node_list) + + best_cost, best_solution = do_dp_split_gpipe(node_list, compute_costs, pp_size, num_microbatches) + + for (_, next_start_node) in best_solution: + if pp_size <= 1: + break + node = node_list[next_start_node] + with gm.graph.inserting_before(node): + split_node = gm.graph.create_node('call_function', pipe_split) + pp_size -= 1 + + # remove block node if possible + if mode == 'block': + remove_blocks(gm) + + gm.recompile() + return gm + + def avgcompute_split_pass(gm: torch.fx.GraphModule, pp_size: int): """ In avgcompute_split_pass, we split module by the fwd flops. diff --git a/colossalai/fx/passes/meta_info_prop.py b/colossalai/fx/passes/meta_info_prop.py index 281cae41f..c2394a13c 100644 --- a/colossalai/fx/passes/meta_info_prop.py +++ b/colossalai/fx/passes/meta_info_prop.py @@ -114,6 +114,7 @@ class MetaInfoProp(torch.fx.Interpreter): # TODO: the attribute node_size should be removed in the future setattr(n, 'node_size', activation_size(n.meta.get('fwd_out', 0)) + activation_size(n.meta.get('fwd_tmp', 0))) setattr(n, 'fwd_flop', n.meta.get('fwd_flop', 0)) + setattr(n, 'bwd_flop', n.meta.get('bwd_flop', 0)) n.meta['type'] = type(result) # retain the autograd graph diff --git a/colossalai/pipeline/rpc/_pipeline_base.py b/colossalai/pipeline/rpc/_pipeline_base.py index 1edc1ac70..2d7e25c82 100644 --- a/colossalai/pipeline/rpc/_pipeline_base.py +++ b/colossalai/pipeline/rpc/_pipeline_base.py @@ -1115,7 +1115,8 @@ class PipelineEngineBase(ABC, nn.Module): # let each worker know global worker rref (include itself) sync_futs = [] for pp_rank in self.pp_rank_to_worker_rref: - fut = self.pp_rank_to_worker_rref[pp_rank].rpc_async().sync_global_worker_rrefs(self.pp_rank_to_worker_rref) + fut = self.pp_rank_to_worker_rref[pp_rank].rpc_async(timeout=0).sync_global_worker_rrefs( + self.pp_rank_to_worker_rref) sync_futs.append(fut) for fut in sync_futs: diff --git a/examples/language/gpt/experiments/pipeline_parallel/train_gpt_pp.py b/examples/language/gpt/experiments/pipeline_parallel/train_gpt_pp.py index c3451c18d..ad69888b8 100644 --- a/examples/language/gpt/experiments/pipeline_parallel/train_gpt_pp.py +++ b/examples/language/gpt/experiments/pipeline_parallel/train_gpt_pp.py @@ -8,11 +8,16 @@ from torch import nn from tqdm import tqdm from colossalai.fx import ColoTracer -from colossalai.fx.passes.adding_split_node_pass import avgnode_split_pass, split_with_split_nodes_pass +from colossalai.fx.passes.adding_split_node_pass import ( + avgnode_split_pass, + gpipe_dp_split_pass, + split_with_split_nodes_pass, +) +from colossalai.fx.passes.meta_info_prop import MetaInfoProp from colossalai.logging import disable_existing_loggers, get_dist_logger from colossalai.nn.optimizer import HybridAdam from colossalai.pipeline.middleware.adaptor import get_fx_topology -from colossalai.pipeline.rpc._pipeline_schedule import OneFOneBPipelineEngine +from colossalai.pipeline.rpc._pipeline_schedule import FillDrainPipelineEngine, OneFOneBPipelineEngine from colossalai.pipeline.rpc.utils import rpc_run @@ -55,13 +60,25 @@ def get_tflops(model_numel, batch_size, seq_len, step_time): return model_numel * batch_size * seq_len * 8 / 1e12 / (step_time + 1e-12) -def create_partition_module(pp_rank: int, stage_num: int, model, data_kwargs): +# Create annotated model which is noted where to be splitted. +def get_annotated_model(model, data_kwargs, num_stages, num_microbatches): tracer = ColoTracer() meta_args = {k: v.to('meta') for k, v in data_kwargs.items()} graph = tracer.trace(root=model, meta_args=meta_args) gm = torch.fx.GraphModule(model, graph, model.__class__.__name__) - annotated_model = avgnode_split_pass(gm, stage_num) + interp_meta_args = tuple([v.to('meta') for k, v in data_kwargs.items()]) + interp = MetaInfoProp(gm) + interp.run(*interp_meta_args) + + #annotated_model = avgnode_split_pass(gm, num_stages) + annotated_model = gpipe_dp_split_pass(gm, num_stages, num_microbatches, mode='block', block_limit=0.01) + + return annotated_model + + +def create_partition_module(pp_rank: int, num_stages: int, model, data_kwargs, num_microbatches): + annotated_model = get_annotated_model(model, data_kwargs, num_stages, num_microbatches) top_module, split_submodules = split_with_split_nodes_pass(annotated_model, merge_output=True) topo = get_fx_topology(top_module) for submodule in split_submodules: @@ -70,8 +87,8 @@ def create_partition_module(pp_rank: int, stage_num: int, model, data_kwargs): return split_submodules[pp_rank + 1] -def partition(model, data_kwargs, pp_rank: int, chunk: int, stage_num: int): - module = create_partition_module(pp_rank, stage_num, model, data_kwargs) +def partition(model, data_kwargs, num_microbatches, pp_rank: int, chunk: int, stage_num: int): + module = create_partition_module(pp_rank, stage_num, model, data_kwargs, num_microbatches) return module @@ -103,17 +120,19 @@ def run_master(args): warmup_data_kwargs = {'input_ids': input_ids, 'attention_mask': attn_mask} # create model + logger.info(f'start model_builder') model = model_builder(model_type)(checkpoint=False) + logger.info(f'end model_builder') # set 1f1b pipeline engine - pp_engine = OneFOneBPipelineEngine(partition_fn=partial(partition, model, warmup_data_kwargs), - stage_num=stage_num, - num_microbatches=num_microbatches, - device=device, - chunk=1, - criterion=criterion, - metric=None, - checkpoint=False) + pp_engine = FillDrainPipelineEngine(partition_fn=partial(partition, model, warmup_data_kwargs, num_microbatches), + stage_num=stage_num, + num_microbatches=num_microbatches, + device=device, + chunk=1, + criterion=criterion, + metric=None, + checkpoint=False) partition_numels = pp_engine.remote_numels() for rank, numel in partition_numels.items():