diff --git a/colossalai/fx/passes/adding_split_node_pass.py b/colossalai/fx/passes/adding_split_node_pass.py index 4d34c2b56..91005fe6b 100644 --- a/colossalai/fx/passes/adding_split_node_pass.py +++ b/colossalai/fx/passes/adding_split_node_pass.py @@ -32,6 +32,35 @@ def balanced_split_pass(gm: torch.fx.GraphModule, pp_size: int): return gm +def uniform_split_pass(gm: torch.fx.GraphModule, pp_size: int): + mod_graph = gm.graph + valid_children_size = 0 + valid_children = [] + for module in mod_graph.owning_module.children(): + valid_children_size += 1 + valid_children.append(module) + + if valid_children_size < pp_size: + # If valid children is not enough to shard, we will use balanced policy instead of uniform policy. + return balanced_split_pass(gm, pp_size) + layers_per_partition = valid_children_size // pp_size + accumulate_layer_amount = 0 + for node in mod_graph.nodes: + if pp_size <= 1: + break + if node.op == "call_module": + target_module = node.graph.owning_module.get_submodule(node.target) + if target_module in valid_children: + accumulate_layer_amount += 1 + if accumulate_layer_amount == layers_per_partition: + accumulate_layer_amount = 0 + pp_size -= 1 + with mod_graph.inserting_after(node): + split_node = mod_graph.create_node('call_function', pipe_split) + gm.recompile() + return gm + + def split_with_split_nodes_pass(annotated_gm: torch.fx.GraphModule): part_idx = 0 diff --git a/tests/test_fx/test_pipeline_passes.py b/tests/test_fx/test_pipeline_passes.py new file mode 100644 index 000000000..228fcb880 --- /dev/null +++ b/tests/test_fx/test_pipeline_passes.py @@ -0,0 +1,48 @@ +import torch +import torch.nn as nn +import colossalai +import colossalai.nn as col_nn +from torch.fx import symbolic_trace +from colossalai.fx.passes.adding_split_node_pass import split_with_split_nodes_pass, balanced_split_pass, \ + uniform_split_pass + +MODEL_DIM = 16 +BATCH_SIZE = 8 +PIPELINE_SIZE = 2 + + +class MLP(torch.nn.Module): + + def __init__(self, dim: int): + super().__init__() + self.linear1 = torch.nn.Linear(dim, dim) + self.linear2 = torch.nn.Linear(dim, dim) + self.linear3 = torch.nn.Linear(dim, dim) + self.linear4 = torch.nn.Linear(dim, dim) + + def forward(self, x): + x = self.linear1(x) + x = self.linear2(x) + x = self.linear3(x) + x = self.linear4(x) + return x + + +def pipeline_pass_test_helper(model, data, pass_func): + origin_output = model(data) + symbolic_traced = symbolic_trace(model) + annotated_model = pass_func(symbolic_traced, PIPELINE_SIZE) + split_model, split_submodules = split_with_split_nodes_pass(annotated_model) + output = split_model(data) + assert output.equal(origin_output) + + +def test_pipeline_passes(): + model = MLP(MODEL_DIM) + data = torch.rand(BATCH_SIZE, MODEL_DIM) + pipeline_pass_test_helper(model, data, balanced_split_pass) + pipeline_pass_test_helper(model, data, uniform_split_pass) + + +if __name__ == '__main__': + test_pipeline_passes()