[fx]add autoparallel passes (#1121)

* [CLI] add CLI launcher

* Revert "[CLI] add CLI launcher"

This reverts commit df7e6506d4.

* feature/add autoparallel passes
pull/1123/head
YuliangLiu0306 2022-06-15 16:36:46 +08:00 committed by GitHub
parent e127b4375b
commit fcf55777dd
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 239 additions and 0 deletions

View File

@ -0,0 +1,2 @@
from .adding_split_node_pass import balanced_split_pass, split_with_split_nodes_pass
from .shard_1d_pass import column_shard_linear_pass, row_shard_linear_pass

View File

@ -0,0 +1,54 @@
import torch
from torch.fx import symbolic_trace
from torch.fx.node import Node
from torch.fx.passes.split_module import split_module
def pipe_split():
pass
def balanced_split_pass(gm: torch.fx.GraphModule, pp_size: int):
mod_graph = gm.graph
total_param_amount = 0
for param in mod_graph.owning_module.parameters():
total_param_amount += param.numel()
params_per_partition = total_param_amount // pp_size
accumulate_param_amount = 0
for node in mod_graph.nodes:
if pp_size <= 1:
break
if node.op == "call_module":
target_module = node.graph.owning_module.get_submodule(node.target)
for param in target_module.parameters():
accumulate_param_amount += param.numel()
if accumulate_param_amount >= params_per_partition:
accumulate_param_amount = 0
pp_size -= 1
with mod_graph.inserting_after(node):
split_node = mod_graph.create_node('call_function', pipe_split)
gm.recompile()
return gm
def split_with_split_nodes_pass(annotated_gm: torch.fx.GraphModule):
part_idx = 0
def split_callback(n: torch.fx.Node):
nonlocal part_idx
if (n.op, n.target) == ('call_function', pipe_split):
part_idx += 1
return part_idx
split_mod = split_module(annotated_gm, None, split_callback)
split_submodules = []
for name, submodule in split_mod.named_modules():
if isinstance(submodule, torch.fx.GraphModule):
for node in submodule.graph.nodes:
if (node.op, node.target) == ('call_function', pipe_split):
submodule.graph.erase_node(node)
submodule.recompile()
split_submodules.append(submodule)
return split_mod, split_submodules

View File

@ -0,0 +1,120 @@
import torch
from torch.fx.node import map_arg
from torch.fx.node import Node
from torch.fx.passes.split_module import split_module
import colossalai
from colossalai.context import ParallelMode
from colossalai.core import global_context as gpc
def all_gather_function(input_):
world_size = gpc.get_world_size(ParallelMode.PARALLEL_1D)
rank = gpc.get_local_rank(ParallelMode.PARALLEL_1D)
tensor_list = [torch.empty_like(input_) for _ in range(world_size)]
tensor_list[rank] = input_
group = gpc.get_group(ParallelMode.PARALLEL_1D)
torch.distributed.all_gather(tensor_list, input_, group=group)
output = torch.cat(tensor_list, dim=-1).contiguous()
return output
def all_reduce_function(input_):
if gpc.get_world_size(ParallelMode.PARALLEL_1D) == 1:
return input_
torch.distributed.all_reduce(input_, group=gpc.get_group(ParallelMode.PARALLEL_1D))
return input_
def weight_split(weight, dim):
#TODO: this function will be refactored by using ColoTensor dist_spec when a stable reshaper feature is ready to use.
num_partition = gpc.get_world_size(ParallelMode.TENSOR)
shape = weight.shape
length = shape[dim] // num_partition
sharded_weight_list = []
for i in range(num_partition):
sharded_weight_list.append(weight.narrow(dim, i * length, length))
return sharded_weight_list[gpc.get_local_rank(ParallelMode.PARALLEL_1D)]
def replace_all_uses_except_replaced(node, replace_node):
"""
Replace all uses of ``node`` in the Graph with the Node ``replace_node``,
except the user of ``node`` is ``replace_node``.
Args:
replace_node (Node): The node to replace all uses of ``node`` with.
Returns:
The list of Nodes on which this change was made.
"""
to_process = list(node.users)
for use_node in to_process:
if use_node == replace_node:
continue
def may_replace_node(n):
if n == node:
return replace_node
else:
return n
new_args = map_arg(use_node.args, may_replace_node)
new_kwargs = map_arg(use_node.kwargs, may_replace_node)
use_node._args = new_args
use_node._kwargs = new_kwargs
for old_use in use_node._input_nodes.keys():
old_use.users.pop(use_node)
use_node._input_nodes = {}
map_arg(use_node._args, lambda n: use_node._input_nodes.setdefault(n))
map_arg(use_node._kwargs, lambda n: use_node._input_nodes.setdefault(n))
for new_use in use_node._input_nodes.keys():
new_use.users.setdefault(use_node)
return to_process
def column_shard_linear_pass(gm: torch.fx.GraphModule):
mod_graph = gm.graph
for node in mod_graph.nodes:
if node.op == "call_module":
target_module = node.graph.owning_module.get_submodule(node.target)
if isinstance(target_module, torch.nn.Linear):
target_module.weight.data = weight_split(target_module.weight.data, dim=0)
if target_module.bias is not None:
target_module.bias.data = weight_split(target_module.bias.data, dim=0)
# inserting communication node after the sharded linear node
with mod_graph.inserting_after(node):
new_node = mod_graph.create_node('call_function', all_gather_function, args=(node,))
replace_all_uses_except_replaced(node, new_node)
gm.recompile()
return gm
def row_shard_linear_pass(gm: torch.fx.GraphModule):
mod_graph = gm.graph
for node in mod_graph.nodes:
if node.op == "call_module":
target_module = node.graph.owning_module.get_submodule(node.target)
if isinstance(target_module, torch.nn.Linear):
target_module.weight.data = weight_split(target_module.weight.data, dim=-1)
# insert input sharding node before the sharded linear node
with mod_graph.inserting_before(node):
input_node_list = list(node._input_nodes.keys())
assert len(input_node_list) == 1, 'linear forward must have and only have one input tensor.'
input_node = input_node_list[0]
new_input_node = mod_graph.create_node('call_function', weight_split, args=(input_node, -1))
replace_all_uses_except_replaced(input_node, new_input_node)
# inserting communication node after the sharded linear node
with mod_graph.inserting_after(node):
new_node = mod_graph.create_node('call_function', all_reduce_function, args=(node,))
replace_all_uses_except_replaced(node, new_node)
gm.recompile()
return gm
#TODO: add elementwise op process pass, then we can try to use column and row mixed strategy.

View File

@ -0,0 +1,63 @@
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
from functools import partial
import pytest
import torch
import torch.multiprocessing as mp
from colossalai.core import global_context as gpc
from colossalai.logging import disable_existing_loggers
from colossalai.initialize import launch
from colossalai.utils import free_port
from colossalai.testing import rerun_if_address_is_in_use
from torch.fx import symbolic_trace
from colossalai.fx.passes import column_shard_linear_pass
class MLP(torch.nn.Module):
def __init__(self, dim: int):
super().__init__()
self.linear1 = torch.nn.Linear(dim, dim)
self.linear2 = torch.nn.Linear(dim, dim)
self.linear3 = torch.nn.Linear(dim, dim)
self.linear4 = torch.nn.Linear(dim, dim)
def forward(self, x):
x = self.linear1(x)
x = self.linear2(x)
x = self.linear3(x)
x = self.linear4(x)
return x
CONFIG = dict(parallel=dict(tensor=dict(mode='1d', size=2)))
def check_layer(rank, world_size, port):
disable_existing_loggers()
launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
input_tensor = torch.rand(2, 16).cuda()
model = MLP(16).cuda()
symbolic_traced = symbolic_trace(model)
output = model(input_tensor)
splitted_gm = column_shard_linear_pass(symbolic_traced)
new_output = splitted_gm(input_tensor)
assert output.equal(new_output)
gpc.destroy()
torch.cuda.empty_cache()
@pytest.mark.dist
@rerun_if_address_is_in_use()
def test_1d():
world_size = 2
run_func = partial(check_layer, world_size=world_size, port=free_port())
mp.spawn(run_func, nprocs=world_size)
if __name__ == '__main__':
test_1d()