From 400f63012eb288b849253efd438622d6898f4233 Mon Sep 17 00:00:00 2001
From: Ziyue Jiang <ziyue.jiang97@gmail.com>
Date: Tue, 7 Mar 2023 10:34:31 +0800
Subject: [PATCH] [pipeline] Add Simplified Alpa DP Partition (#2507)

* add alpa dp split

* add alpa dp split

* use fwd+bwd instead of fwd only

---------

Co-authored-by: Ziyue Jiang <ziyue.jiang@gmail.com>
---
 .../fx/passes/adding_split_node_pass.py       | 161 ++++++++++++++++++
 colossalai/fx/passes/meta_info_prop.py        |   1 +
 colossalai/pipeline/rpc/_pipeline_base.py     |   3 +-
 .../pipeline_parallel/train_gpt_pp.py         |  47 +++--
 4 files changed, 197 insertions(+), 15 deletions(-)

diff --git a/colossalai/fx/passes/adding_split_node_pass.py b/colossalai/fx/passes/adding_split_node_pass.py
index 0499769d8..2c7b842b5 100644
--- a/colossalai/fx/passes/adding_split_node_pass.py
+++ b/colossalai/fx/passes/adding_split_node_pass.py
@@ -1,4 +1,6 @@
+import numpy as np
 import torch
+import tqdm
 from torch.fx import symbolic_trace
 from torch.fx.node import Node
 
@@ -9,6 +11,165 @@ def pipe_split():
     pass
 
 
+def block_split():
+    pass
+
+
+# Construct blocks with the condition that (block_flops / total_flops) >= limit.
+def construct_blocks(gm: torch.fx.GraphModule, limit=0.01):
+    total_fwd_flop = 0
+    total_bwd_flop = 0
+    for node in gm.graph.nodes:
+        total_fwd_flop += node.fwd_flop
+        total_bwd_flop += node.bwd_flop
+
+    total_flop = total_fwd_flop + total_bwd_flop
+    per_block_flop = total_flop * limit
+    accumulate_fwd_flop = 0
+    accumulate_bwd_flop = 0
+    block_nodes = []
+    for node in gm.graph.nodes:
+        if 'block_split' in node.name:
+            continue
+        accumulate_fwd_flop += node.fwd_flop
+        accumulate_bwd_flop += node.bwd_flop
+        if accumulate_fwd_flop + accumulate_bwd_flop >= per_block_flop:
+            with gm.graph.inserting_after(node):
+                block_node = gm.graph.create_node('call_function', block_split)
+                setattr(block_node, 'fwd_flop', accumulate_fwd_flop)
+                setattr(block_node, 'bwd_flop', accumulate_bwd_flop)
+            accumulate_fwd_flop = 0
+            accumulate_bwd_flop = 0
+            block_nodes.append(block_node)
+
+    return block_nodes
+
+
+def remove_blocks(gm: torch.fx.GraphModule):
+    for node in gm.graph.nodes:
+        if (node.op, node.target) == ('call_function', block_split):
+            gm.graph.erase_node(node)
+
+
+def get_compute_costs(node_list):
+    num_nodes = len(node_list)
+    all_compute_cost = np.full((num_nodes, num_nodes), np.inf, dtype=np.float64)
+
+    for start in tqdm.tqdm(range(num_nodes), desc='start pos', position=0):
+        for end in tqdm.tqdm(range(start, num_nodes), desc='end pos', position=1, leave=False):
+            selected_flops = [(node_list[i].fwd_flop + node_list[i].bwd_flop) for i in range(start, end + 1)]
+            all_compute_cost[start, end] = sum(selected_flops)
+
+    return all_compute_cost
+
+
+def do_dp_split_gpipe_impl(num_nodes, num_stages, num_microbatches, compute_costs, max_compute_cost):
+    """The core implementation of the DP algorithm."""
+    # Adapted from Alpa DP Formulation.
+    # For f, node ID start from 0
+    # f[number of stages,
+    #   node id that is currently being considered]
+
+    # record time cost(assess by fwd+bwd flop now)
+    f = np.full((num_stages + 1, num_nodes + 1), np.inf, dtype=np.float32)
+
+    # record max stage compute cost among all stages in this partition.
+    f_stage_max = np.full((num_stages + 1, num_nodes + 1), 0.0, dtype=np.float32)
+    # record start node index for next stage in this partition
+    f_argmin = np.full((num_stages + 1, num_nodes + 1), -1, dtype=np.int32)
+    f[0, num_nodes] = 0
+    for s in tqdm.tqdm(range(1, num_stages + 1), desc='stage', position=2, leave=False):    # pylint: disable=too-many-nested-blocks
+        for i in tqdm.tqdm(range(num_nodes - 1, -1, -1), desc='start node', position=3, leave=False):
+            for k in tqdm.tqdm(range(num_nodes, i, -1), desc='mid node', position=4, leave=False):
+                stage_cost = compute_costs[i, k - 1]
+                new_cost = f[s - 1, k] + stage_cost
+                if (stage_cost <= max_compute_cost and new_cost < f[s, i]):
+                    f[s, i] = new_cost
+                    f_stage_max[s, i] = max(f_stage_max[s - 1, k], stage_cost)
+                    f_argmin[s, i] = k
+
+    best_total_cost = f[num_stages, 0]
+    if np.isinf(best_total_cost):
+        return np.inf, None
+
+    total_cost = f[num_stages, 0] + (num_microbatches - 1) * f_stage_max[num_stages, 0]
+
+    current_s = num_stages
+    current_node = 0
+
+    res = []
+    while current_s > 0 and current_node < num_nodes:
+        next_start_node = f_argmin[current_s, current_node]
+        res.append((current_node, next_start_node))
+        current_s -= 1
+        current_node = next_start_node
+
+    return total_cost, res
+
+
+def do_dp_split_gpipe(node_list, compute_costs, num_stages: int, num_microbatches: int):
+    # Ignore the memory cost profiling in Alpa's design for convenience.
+    max_compute_costs = np.sort(np.unique(compute_costs))
+    best_cost = np.inf
+    best_solution = None
+    last_max_compute_cost = 0.0
+    gap = 1e6    # temporary magic number, unit: flops
+
+    for max_compute_cost in tqdm.tqdm(max_compute_costs):
+        # Pruning to reduce search space.
+        if max_compute_cost * num_microbatches >= best_cost:
+            break
+        if max_compute_cost - last_max_compute_cost < gap:
+            continue
+
+        cost, solution = do_dp_split_gpipe_impl(len(node_list), num_stages, num_microbatches, compute_costs,
+                                                max_compute_cost)
+
+        if cost < best_cost:
+            best_cost = cost
+            best_solution = solution
+        last_max_compute_cost = max_compute_cost
+    return best_cost, best_solution
+
+
+# Auto DP partition based on Alpa.
+# Adapted to Gpipe Scheduler
+# split_mode:
+#   'node': fx_node
+#   'block': many fx_nodes construct a block
+def gpipe_dp_split_pass(gm: torch.fx.GraphModule, pp_size: int, num_microbatches: int, mode='block', block_limit=0.01):
+    assert mode in ['node', 'block']
+
+    # nodes or blocks will be used in partition.
+    node_list = []
+    if mode == 'node':
+        for node in gm.graph.nodes:
+            node_list.append(node)
+    elif mode == 'block':
+        node_list = construct_blocks(gm, limit=block_limit)
+    else:
+        pass
+
+    compute_costs = get_compute_costs(node_list)
+
+    best_cost, best_solution = do_dp_split_gpipe(node_list, compute_costs, pp_size, num_microbatches)
+
+    for (_, next_start_node) in best_solution:
+        if pp_size <= 1:
+            break
+        node = node_list[next_start_node]
+        with gm.graph.inserting_before(node):
+            split_node = gm.graph.create_node('call_function', pipe_split)
+        pp_size -= 1
+
+    # remove block node if possible
+    if mode == 'block':
+        remove_blocks(gm)
+
+    gm.recompile()
+    return gm
+
+
 def avgcompute_split_pass(gm: torch.fx.GraphModule, pp_size: int):
     """
     In avgcompute_split_pass, we split module by the fwd flops.
diff --git a/colossalai/fx/passes/meta_info_prop.py b/colossalai/fx/passes/meta_info_prop.py
index 281cae41f..c2394a13c 100644
--- a/colossalai/fx/passes/meta_info_prop.py
+++ b/colossalai/fx/passes/meta_info_prop.py
@@ -114,6 +114,7 @@ class MetaInfoProp(torch.fx.Interpreter):
         # TODO: the attribute node_size should be removed in the future
         setattr(n, 'node_size', activation_size(n.meta.get('fwd_out', 0)) + activation_size(n.meta.get('fwd_tmp', 0)))
         setattr(n, 'fwd_flop', n.meta.get('fwd_flop', 0))
+        setattr(n, 'bwd_flop', n.meta.get('bwd_flop', 0))
         n.meta['type'] = type(result)
 
         # retain the autograd graph
diff --git a/colossalai/pipeline/rpc/_pipeline_base.py b/colossalai/pipeline/rpc/_pipeline_base.py
index 1edc1ac70..2d7e25c82 100644
--- a/colossalai/pipeline/rpc/_pipeline_base.py
+++ b/colossalai/pipeline/rpc/_pipeline_base.py
@@ -1115,7 +1115,8 @@ class PipelineEngineBase(ABC, nn.Module):
         # let each worker know global worker rref (include itself)
         sync_futs = []
         for pp_rank in self.pp_rank_to_worker_rref:
-            fut = self.pp_rank_to_worker_rref[pp_rank].rpc_async().sync_global_worker_rrefs(self.pp_rank_to_worker_rref)
+            fut = self.pp_rank_to_worker_rref[pp_rank].rpc_async(timeout=0).sync_global_worker_rrefs(
+                self.pp_rank_to_worker_rref)
             sync_futs.append(fut)
 
         for fut in sync_futs:
diff --git a/examples/language/gpt/experiments/pipeline_parallel/train_gpt_pp.py b/examples/language/gpt/experiments/pipeline_parallel/train_gpt_pp.py
index c3451c18d..ad69888b8 100644
--- a/examples/language/gpt/experiments/pipeline_parallel/train_gpt_pp.py
+++ b/examples/language/gpt/experiments/pipeline_parallel/train_gpt_pp.py
@@ -8,11 +8,16 @@ from torch import nn
 from tqdm import tqdm
 
 from colossalai.fx import ColoTracer
-from colossalai.fx.passes.adding_split_node_pass import avgnode_split_pass, split_with_split_nodes_pass
+from colossalai.fx.passes.adding_split_node_pass import (
+    avgnode_split_pass,
+    gpipe_dp_split_pass,
+    split_with_split_nodes_pass,
+)
+from colossalai.fx.passes.meta_info_prop import MetaInfoProp
 from colossalai.logging import disable_existing_loggers, get_dist_logger
 from colossalai.nn.optimizer import HybridAdam
 from colossalai.pipeline.middleware.adaptor import get_fx_topology
-from colossalai.pipeline.rpc._pipeline_schedule import OneFOneBPipelineEngine
+from colossalai.pipeline.rpc._pipeline_schedule import FillDrainPipelineEngine, OneFOneBPipelineEngine
 from colossalai.pipeline.rpc.utils import rpc_run
 
 
@@ -55,13 +60,25 @@ def get_tflops(model_numel, batch_size, seq_len, step_time):
     return model_numel * batch_size * seq_len * 8 / 1e12 / (step_time + 1e-12)
 
 
-def create_partition_module(pp_rank: int, stage_num: int, model, data_kwargs):
+# Create annotated model which is noted where to be splitted.
+def get_annotated_model(model, data_kwargs, num_stages, num_microbatches):
     tracer = ColoTracer()
     meta_args = {k: v.to('meta') for k, v in data_kwargs.items()}
     graph = tracer.trace(root=model, meta_args=meta_args)
     gm = torch.fx.GraphModule(model, graph, model.__class__.__name__)
-    annotated_model = avgnode_split_pass(gm, stage_num)
 
+    interp_meta_args = tuple([v.to('meta') for k, v in data_kwargs.items()])
+    interp = MetaInfoProp(gm)
+    interp.run(*interp_meta_args)
+
+    #annotated_model = avgnode_split_pass(gm, num_stages)
+    annotated_model = gpipe_dp_split_pass(gm, num_stages, num_microbatches, mode='block', block_limit=0.01)
+
+    return annotated_model
+
+
+def create_partition_module(pp_rank: int, num_stages: int, model, data_kwargs, num_microbatches):
+    annotated_model = get_annotated_model(model, data_kwargs, num_stages, num_microbatches)
     top_module, split_submodules = split_with_split_nodes_pass(annotated_model, merge_output=True)
     topo = get_fx_topology(top_module)
     for submodule in split_submodules:
@@ -70,8 +87,8 @@ def create_partition_module(pp_rank: int, stage_num: int, model, data_kwargs):
     return split_submodules[pp_rank + 1]
 
 
-def partition(model, data_kwargs, pp_rank: int, chunk: int, stage_num: int):
-    module = create_partition_module(pp_rank, stage_num, model, data_kwargs)
+def partition(model, data_kwargs, num_microbatches, pp_rank: int, chunk: int, stage_num: int):
+    module = create_partition_module(pp_rank, stage_num, model, data_kwargs, num_microbatches)
     return module
 
 
@@ -103,17 +120,19 @@ def run_master(args):
     warmup_data_kwargs = {'input_ids': input_ids, 'attention_mask': attn_mask}
 
     # create model
+    logger.info(f'start model_builder')
     model = model_builder(model_type)(checkpoint=False)
+    logger.info(f'end model_builder')
 
     # set 1f1b pipeline engine
-    pp_engine = OneFOneBPipelineEngine(partition_fn=partial(partition, model, warmup_data_kwargs),
-                                       stage_num=stage_num,
-                                       num_microbatches=num_microbatches,
-                                       device=device,
-                                       chunk=1,
-                                       criterion=criterion,
-                                       metric=None,
-                                       checkpoint=False)
+    pp_engine = FillDrainPipelineEngine(partition_fn=partial(partition, model, warmup_data_kwargs, num_microbatches),
+                                        stage_num=stage_num,
+                                        num_microbatches=num_microbatches,
+                                        device=device,
+                                        chunk=1,
+                                        criterion=criterion,
+                                        metric=None,
+                                        checkpoint=False)
 
     partition_numels = pp_engine.remote_numels()
     for rank, numel in partition_numels.items():