From e8acf55e8bd53b88f48e183ef360c8691a027e9a Mon Sep 17 00:00:00 2001
From: YuliangLiu0306 <72588413+YuliangLiu0306@users.noreply.github.com>
Date: Fri, 15 Jul 2022 14:54:26 +0800
Subject: [PATCH] [fx] add balanced policy v2 (#1251)

* [CLI] add CLI launcher

* Revert "[CLI] add CLI launcher"

This reverts commit df7e6506d4500af6a9220ef7fe4d3c7b1daebd4c.

* [fx] add balanced policy v2

* add unittest
---
 .../fx/passes/adding_split_node_pass.py       | 34 ++++++++++++++++++-
 colossalai/fx/passes/meta_info_prop.py        | 19 ++++++++++-
 tests/test_fx/test_pipeline_passes.py         |  4 ++-
 3 files changed, 54 insertions(+), 3 deletions(-)

diff --git a/colossalai/fx/passes/adding_split_node_pass.py b/colossalai/fx/passes/adding_split_node_pass.py
index 3a3e5ddbf..e2ea6ec70 100644
--- a/colossalai/fx/passes/adding_split_node_pass.py
+++ b/colossalai/fx/passes/adding_split_node_pass.py
@@ -10,7 +10,9 @@ def pipe_split():
 
 
 def balanced_split_pass(gm: torch.fx.GraphModule, pp_size: int):
-    # TODO(lyl): balanced policy V2, split module by node size(weight+bias+output)
+    """
+    In balanced_split_pass, we split module by the size of parameters(weights+bias).
+    """
     mod_graph = gm.graph
     total_param_amount = 0
     for param in mod_graph.owning_module.parameters():
@@ -39,6 +41,36 @@ def balanced_split_pass(gm: torch.fx.GraphModule, pp_size: int):
     return gm
 
 
+def balanced_split_pass_v2(gm: torch.fx.GraphModule, pp_size: int):
+    """
+    In balanced_split_pass_v12, we split module by the size of nodes(weights+bias+outputs).
+    """
+    mod_graph = gm.graph
+    # To use balanced_split_pass_v2, we need run meta_info_prop interpreter first.
+    # If nodes don't have meta info, this pass will fall back to normal balanced split pass.
+    check_node = list(mod_graph.nodes)[0]
+    if 'tensor_meta' not in check_node.meta:
+        return balanced_split_pass(gm, pp_size)
+
+    total_element_size = 0
+    for node in mod_graph.nodes:
+        total_element_size += node.node_size
+
+    partition_size = total_element_size // pp_size
+    accumulate_node_size = 0
+    for node in mod_graph.nodes:
+        if pp_size <= 1:
+            break
+        accumulate_node_size += node.node_size
+        if accumulate_node_size >= partition_size:
+            accumulate_node_size = 0
+            pp_size -= 1
+            with mod_graph.inserting_after(node):
+                split_node = mod_graph.create_node('call_function', pipe_split)
+    gm.recompile()
+    return gm
+
+
 def uniform_split_pass(gm: torch.fx.GraphModule, pp_size: int):
     mod_graph = gm.graph
     valid_children_size = 0
diff --git a/colossalai/fx/passes/meta_info_prop.py b/colossalai/fx/passes/meta_info_prop.py
index 0eb7f32f4..4033cd72b 100644
--- a/colossalai/fx/passes/meta_info_prop.py
+++ b/colossalai/fx/passes/meta_info_prop.py
@@ -67,7 +67,6 @@ class MetaInfoProp(torch.fx.Interpreter):
 
     def run_node(self, n: Node) -> Any:
         result = super().run_node(n)
-
         found_tensor = False
 
         def extract_tensor_meta(obj):
@@ -83,7 +82,25 @@ class MetaInfoProp(torch.fx.Interpreter):
             n.meta['tensor_meta'] = meta
         else:
             n.meta['tensor_meta'] = TensorMetadata(None, None, False, None, 0)
+        # counting the total size of node outputs
+        total_node_size = 0
+        if isinstance(n.meta['tensor_meta'], TensorMetadata):
+            total_node_size += n.meta['tensor_meta'].numel
+        else:
+            for element in n.meta['tensor_meta']:
+                assert isinstance(
+                    element, TensorMetadata
+                ), f"``n.meta['tensor_meta']`` should be either TensorMetadata or a tuple of TensorMetadata."
+                total_node_size += element.numel
+        # counting the total size of parameters
+        total_param_size = 0
+        if n.op == 'call_module':
+            target_module = n.graph.owning_module.get_submodule(n.target)
+            for param in target_module.parameters():
+                total_param_size += param.numel()
 
+        total_node_size += total_param_size
+        n.node_size = total_node_size
         n.meta['type'] = type(result)
         return result
 
diff --git a/tests/test_fx/test_pipeline_passes.py b/tests/test_fx/test_pipeline_passes.py
index 54619d25c..4d9e63d0d 100644
--- a/tests/test_fx/test_pipeline_passes.py
+++ b/tests/test_fx/test_pipeline_passes.py
@@ -4,7 +4,8 @@ import colossalai
 import colossalai.nn as col_nn
 from torch.fx import symbolic_trace
 from colossalai.fx.passes.adding_split_node_pass import split_with_split_nodes_pass, balanced_split_pass, \
-                                                        uniform_split_pass
+                                                        uniform_split_pass, balanced_split_pass_v2
+
 import pytest
 
 MODEL_DIM = 16
@@ -43,6 +44,7 @@ def test_pipeline_passes():
     model = MLP(MODEL_DIM)
     data = torch.rand(BATCH_SIZE, MODEL_DIM)
     pipeline_pass_test_helper(model, data, balanced_split_pass)
+    pipeline_pass_test_helper(model, data, balanced_split_pass_v2)
     pipeline_pass_test_helper(model, data, uniform_split_pass)