From 189946c5c435a0c814331975021e035dc8509a46 Mon Sep 17 00:00:00 2001 From: YuliangLiu0306 <72588413+YuliangLiu0306@users.noreply.github.com> Date: Wed, 6 Jul 2022 13:48:11 +0800 Subject: [PATCH] [fx]add uniform policy (#1208) * [CLI] add CLI launcher * Revert "[CLI] add CLI launcher" This reverts commit df7e6506d4500af6a9220ef7fe4d3c7b1daebd4c. * [fx]add uniform policy --- .../fx/passes/adding_split_node_pass.py | 29 +++++++++++ tests/test_fx/test_pipeline_passes.py | 48 +++++++++++++++++++ 2 files changed, 77 insertions(+) create mode 100644 tests/test_fx/test_pipeline_passes.py diff --git a/colossalai/fx/passes/adding_split_node_pass.py b/colossalai/fx/passes/adding_split_node_pass.py index 4d34c2b56..91005fe6b 100644 --- a/colossalai/fx/passes/adding_split_node_pass.py +++ b/colossalai/fx/passes/adding_split_node_pass.py @@ -32,6 +32,35 @@ def balanced_split_pass(gm: torch.fx.GraphModule, pp_size: int): return gm +def uniform_split_pass(gm: torch.fx.GraphModule, pp_size: int): + mod_graph = gm.graph + valid_children_size = 0 + valid_children = [] + for module in mod_graph.owning_module.children(): + valid_children_size += 1 + valid_children.append(module) + + if valid_children_size < pp_size: + # If valid children is not enough to shard, we will use balanced policy instead of uniform policy. + return balanced_split_pass(gm, pp_size) + layers_per_partition = valid_children_size // pp_size + accumulate_layer_amount = 0 + for node in mod_graph.nodes: + if pp_size <= 1: + break + if node.op == "call_module": + target_module = node.graph.owning_module.get_submodule(node.target) + if target_module in valid_children: + accumulate_layer_amount += 1 + if accumulate_layer_amount == layers_per_partition: + accumulate_layer_amount = 0 + pp_size -= 1 + with mod_graph.inserting_after(node): + split_node = mod_graph.create_node('call_function', pipe_split) + gm.recompile() + return gm + + def split_with_split_nodes_pass(annotated_gm: torch.fx.GraphModule): part_idx = 0 diff --git a/tests/test_fx/test_pipeline_passes.py b/tests/test_fx/test_pipeline_passes.py new file mode 100644 index 000000000..228fcb880 --- /dev/null +++ b/tests/test_fx/test_pipeline_passes.py @@ -0,0 +1,48 @@ +import torch +import torch.nn as nn +import colossalai +import colossalai.nn as col_nn +from torch.fx import symbolic_trace +from colossalai.fx.passes.adding_split_node_pass import split_with_split_nodes_pass, balanced_split_pass, \ + uniform_split_pass + +MODEL_DIM = 16 +BATCH_SIZE = 8 +PIPELINE_SIZE = 2 + + +class MLP(torch.nn.Module): + + def __init__(self, dim: int): + super().__init__() + self.linear1 = torch.nn.Linear(dim, dim) + self.linear2 = torch.nn.Linear(dim, dim) + self.linear3 = torch.nn.Linear(dim, dim) + self.linear4 = torch.nn.Linear(dim, dim) + + def forward(self, x): + x = self.linear1(x) + x = self.linear2(x) + x = self.linear3(x) + x = self.linear4(x) + return x + + +def pipeline_pass_test_helper(model, data, pass_func): + origin_output = model(data) + symbolic_traced = symbolic_trace(model) + annotated_model = pass_func(symbolic_traced, PIPELINE_SIZE) + split_model, split_submodules = split_with_split_nodes_pass(annotated_model) + output = split_model(data) + assert output.equal(origin_output) + + +def test_pipeline_passes(): + model = MLP(MODEL_DIM) + data = torch.rand(BATCH_SIZE, MODEL_DIM) + pipeline_pass_test_helper(model, data, balanced_split_pass) + pipeline_pass_test_helper(model, data, uniform_split_pass) + + +if __name__ == '__main__': + test_pipeline_passes()