From fcf55777ddf1688f0934700126a84f8bd0dcbabf Mon Sep 17 00:00:00 2001 From: YuliangLiu0306 <72588413+YuliangLiu0306@users.noreply.github.com> Date: Wed, 15 Jun 2022 16:36:46 +0800 Subject: [PATCH] [fx]add autoparallel passes (#1121) * [CLI] add CLI launcher * Revert "[CLI] add CLI launcher" This reverts commit df7e6506d4500af6a9220ef7fe4d3c7b1daebd4c. * feature/add autoparallel passes --- colossalai/fx/passes/__init__.py | 2 + .../fx/passes/adding_split_node_pass.py | 54 ++++++++ colossalai/fx/passes/shard_1d_pass.py | 120 ++++++++++++++++++ tests/test_fx/test_parallel_1d.py | 63 +++++++++ 4 files changed, 239 insertions(+) create mode 100644 colossalai/fx/passes/__init__.py create mode 100644 colossalai/fx/passes/adding_split_node_pass.py create mode 100644 colossalai/fx/passes/shard_1d_pass.py create mode 100644 tests/test_fx/test_parallel_1d.py diff --git a/colossalai/fx/passes/__init__.py b/colossalai/fx/passes/__init__.py new file mode 100644 index 000000000..b1e95b876 --- /dev/null +++ b/colossalai/fx/passes/__init__.py @@ -0,0 +1,2 @@ +from .adding_split_node_pass import balanced_split_pass, split_with_split_nodes_pass +from .shard_1d_pass import column_shard_linear_pass, row_shard_linear_pass \ No newline at end of file diff --git a/colossalai/fx/passes/adding_split_node_pass.py b/colossalai/fx/passes/adding_split_node_pass.py new file mode 100644 index 000000000..4d34c2b56 --- /dev/null +++ b/colossalai/fx/passes/adding_split_node_pass.py @@ -0,0 +1,54 @@ +import torch + +from torch.fx import symbolic_trace +from torch.fx.node import Node +from torch.fx.passes.split_module import split_module + + +def pipe_split(): + pass + + +def balanced_split_pass(gm: torch.fx.GraphModule, pp_size: int): + mod_graph = gm.graph + total_param_amount = 0 + for param in mod_graph.owning_module.parameters(): + total_param_amount += param.numel() + params_per_partition = total_param_amount // pp_size + accumulate_param_amount = 0 + for node in mod_graph.nodes: + if pp_size <= 1: + break + if node.op == "call_module": + target_module = node.graph.owning_module.get_submodule(node.target) + for param in target_module.parameters(): + accumulate_param_amount += param.numel() + if accumulate_param_amount >= params_per_partition: + accumulate_param_amount = 0 + pp_size -= 1 + with mod_graph.inserting_after(node): + split_node = mod_graph.create_node('call_function', pipe_split) + gm.recompile() + return gm + + +def split_with_split_nodes_pass(annotated_gm: torch.fx.GraphModule): + part_idx = 0 + + def split_callback(n: torch.fx.Node): + nonlocal part_idx + if (n.op, n.target) == ('call_function', pipe_split): + part_idx += 1 + return part_idx + + split_mod = split_module(annotated_gm, None, split_callback) + split_submodules = [] + for name, submodule in split_mod.named_modules(): + if isinstance(submodule, torch.fx.GraphModule): + for node in submodule.graph.nodes: + if (node.op, node.target) == ('call_function', pipe_split): + submodule.graph.erase_node(node) + submodule.recompile() + split_submodules.append(submodule) + + return split_mod, split_submodules diff --git a/colossalai/fx/passes/shard_1d_pass.py b/colossalai/fx/passes/shard_1d_pass.py new file mode 100644 index 000000000..b58544c02 --- /dev/null +++ b/colossalai/fx/passes/shard_1d_pass.py @@ -0,0 +1,120 @@ +import torch +from torch.fx.node import map_arg +from torch.fx.node import Node +from torch.fx.passes.split_module import split_module + +import colossalai +from colossalai.context import ParallelMode +from colossalai.core import global_context as gpc + + +def all_gather_function(input_): + world_size = gpc.get_world_size(ParallelMode.PARALLEL_1D) + rank = gpc.get_local_rank(ParallelMode.PARALLEL_1D) + tensor_list = [torch.empty_like(input_) for _ in range(world_size)] + tensor_list[rank] = input_ + group = gpc.get_group(ParallelMode.PARALLEL_1D) + torch.distributed.all_gather(tensor_list, input_, group=group) + output = torch.cat(tensor_list, dim=-1).contiguous() + return output + + +def all_reduce_function(input_): + if gpc.get_world_size(ParallelMode.PARALLEL_1D) == 1: + return input_ + torch.distributed.all_reduce(input_, group=gpc.get_group(ParallelMode.PARALLEL_1D)) + return input_ + + +def weight_split(weight, dim): + #TODO: this function will be refactored by using ColoTensor dist_spec when a stable reshaper feature is ready to use. + num_partition = gpc.get_world_size(ParallelMode.TENSOR) + shape = weight.shape + length = shape[dim] // num_partition + sharded_weight_list = [] + for i in range(num_partition): + sharded_weight_list.append(weight.narrow(dim, i * length, length)) + return sharded_weight_list[gpc.get_local_rank(ParallelMode.PARALLEL_1D)] + + +def replace_all_uses_except_replaced(node, replace_node): + """ + Replace all uses of ``node`` in the Graph with the Node ``replace_node``, + except the user of ``node`` is ``replace_node``. + + Args: + + replace_node (Node): The node to replace all uses of ``node`` with. + + Returns: + + The list of Nodes on which this change was made. + """ + to_process = list(node.users) + for use_node in to_process: + if use_node == replace_node: + continue + + def may_replace_node(n): + if n == node: + return replace_node + else: + return n + + new_args = map_arg(use_node.args, may_replace_node) + new_kwargs = map_arg(use_node.kwargs, may_replace_node) + use_node._args = new_args + use_node._kwargs = new_kwargs + for old_use in use_node._input_nodes.keys(): + old_use.users.pop(use_node) + use_node._input_nodes = {} + map_arg(use_node._args, lambda n: use_node._input_nodes.setdefault(n)) + map_arg(use_node._kwargs, lambda n: use_node._input_nodes.setdefault(n)) + for new_use in use_node._input_nodes.keys(): + new_use.users.setdefault(use_node) + return to_process + + +def column_shard_linear_pass(gm: torch.fx.GraphModule): + mod_graph = gm.graph + for node in mod_graph.nodes: + if node.op == "call_module": + target_module = node.graph.owning_module.get_submodule(node.target) + if isinstance(target_module, torch.nn.Linear): + target_module.weight.data = weight_split(target_module.weight.data, dim=0) + if target_module.bias is not None: + target_module.bias.data = weight_split(target_module.bias.data, dim=0) + + # inserting communication node after the sharded linear node + with mod_graph.inserting_after(node): + new_node = mod_graph.create_node('call_function', all_gather_function, args=(node,)) + replace_all_uses_except_replaced(node, new_node) + gm.recompile() + return gm + + +def row_shard_linear_pass(gm: torch.fx.GraphModule): + mod_graph = gm.graph + for node in mod_graph.nodes: + if node.op == "call_module": + target_module = node.graph.owning_module.get_submodule(node.target) + if isinstance(target_module, torch.nn.Linear): + target_module.weight.data = weight_split(target_module.weight.data, dim=-1) + + # insert input sharding node before the sharded linear node + with mod_graph.inserting_before(node): + input_node_list = list(node._input_nodes.keys()) + assert len(input_node_list) == 1, 'linear forward must have and only have one input tensor.' + input_node = input_node_list[0] + new_input_node = mod_graph.create_node('call_function', weight_split, args=(input_node, -1)) + replace_all_uses_except_replaced(input_node, new_input_node) + + # inserting communication node after the sharded linear node + with mod_graph.inserting_after(node): + new_node = mod_graph.create_node('call_function', all_reduce_function, args=(node,)) + replace_all_uses_except_replaced(node, new_node) + gm.recompile() + return gm + + +#TODO: add elementwise op process pass, then we can try to use column and row mixed strategy. diff --git a/tests/test_fx/test_parallel_1d.py b/tests/test_fx/test_parallel_1d.py new file mode 100644 index 000000000..8963ba29c --- /dev/null +++ b/tests/test_fx/test_parallel_1d.py @@ -0,0 +1,63 @@ +#!/usr/bin/env python +# -*- encoding: utf-8 -*- + +from functools import partial + +import pytest +import torch +import torch.multiprocessing as mp +from colossalai.core import global_context as gpc +from colossalai.logging import disable_existing_loggers +from colossalai.initialize import launch +from colossalai.utils import free_port +from colossalai.testing import rerun_if_address_is_in_use +from torch.fx import symbolic_trace +from colossalai.fx.passes import column_shard_linear_pass + + +class MLP(torch.nn.Module): + + def __init__(self, dim: int): + super().__init__() + self.linear1 = torch.nn.Linear(dim, dim) + self.linear2 = torch.nn.Linear(dim, dim) + self.linear3 = torch.nn.Linear(dim, dim) + self.linear4 = torch.nn.Linear(dim, dim) + + def forward(self, x): + x = self.linear1(x) + x = self.linear2(x) + x = self.linear3(x) + x = self.linear4(x) + return x + + +CONFIG = dict(parallel=dict(tensor=dict(mode='1d', size=2))) + + +def check_layer(rank, world_size, port): + disable_existing_loggers() + launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + input_tensor = torch.rand(2, 16).cuda() + model = MLP(16).cuda() + symbolic_traced = symbolic_trace(model) + output = model(input_tensor) + splitted_gm = column_shard_linear_pass(symbolic_traced) + new_output = splitted_gm(input_tensor) + + assert output.equal(new_output) + + gpc.destroy() + torch.cuda.empty_cache() + + +@pytest.mark.dist +@rerun_if_address_is_in_use() +def test_1d(): + world_size = 2 + run_func = partial(check_layer, world_size=world_size, port=free_port()) + mp.spawn(run_func, nprocs=world_size) + + +if __name__ == '__main__': + test_1d()