[fx] tested the complete workflow for auto-parallel (#1336)

* [fx] tested the complete workflow for auto-parallel

* polish code

* polish code

* polish code
pull/1344/head
Frank Lee 2022-07-20 10:45:17 +08:00 committed by GitHub
parent 4631fef8a0
commit 2cc1175c76
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 187 additions and 106 deletions

View File

@ -1,9 +1,16 @@
import torch
import torch.nn as nn
import operator
import colossalai
from colossalai.tensor import ProcessGroup
from colossalai.tensor.distspec import shard
from colossalai.tensor.compute_spec import ComputePattern, ComputeSpec
ELEMENTWISE_MODULE_OP = [torch.nn.Dropout, torch.nn.ReLU]
ELEMENTWISE_FUNC_OP = [
torch.add, operator.add, torch.abs, torch.cos, torch.exp, torch.mul, operator.mul, operator.floordiv,
operator.truediv, operator.neg, torch.multiply, torch.nn.functional.relu, torch.nn.functional.dropout
]
ELEMENTWISE_MODULE_OP = [torch.nn.Dropout, torch.nn.ReLU, torch.nn.Conv1d, torch.nn.Conv2d, torch.nn.Conv3d, torch.nn.MaxPool1d, torch.nn.MaxPool2d, torch.nn.AvgPool1d, torch.nn.AvgPool2d]
ELEMENTWISE_FUNC_OP = [torch.add, operator.add, torch.abs, torch.cos, torch.exp, torch.mul, operator.mul, operator.floordiv, operator.truediv, operator.neg, torch.multiply, torch.nn.functional.relu, torch.nn.functional.dropout, torch.nn.functional.conv1d, torch.nn.functional.conv2d, torch.nn.functional.conv3d, torch.nn.functional.avg_pool1d, torch.nn.functional.avg_pool2d, torch.nn.functional.avg_pool3d, torch.nn.functional.max_pool1d, torch.nn.functional.max_pool2d, torch.nn.functional.max_pool3d]
def weight_split(weight: torch.nn.parameter.Parameter, dim: int, col_normal: bool) -> torch.nn.parameter.Parameter:
"""weight_split
@ -21,6 +28,8 @@ def weight_split(weight: torch.nn.parameter.Parameter, dim: int, col_normal: boo
else:
setattr(weight, "fx_attr", (dim, "SHARD", "TP", "col_needs_many_outputs"))
return weight
def column_shard_linear_pass(gm: torch.fx.GraphModule):
# Split all the linear module with column shard. Currently for testing only.
mod_graph = gm.graph
@ -48,43 +57,95 @@ def row_shard_linear_pass(gm: torch.fx.GraphModule):
gm.recompile()
return gm
def transform_mlp_pass(gm: torch.fx.GraphModule):
def transformer_mlp_pass(graph_module: torch.fx.GraphModule, process_group: ProcessGroup):
"""
This IR pass checks for transformer MLP like structure and annotate column and row sharding to the linear layers.
"""
#TODO: Needs to handle special cases, like x = linear(x) + linear(x)
mod_graph = gm.graph
col_shard = True
element_op = []
all_linear_name = []
linear_name = []
# Get the name of element wise module(torch.nn.ReLU)
# Get the name of all the linear modules and repeated linear modules
for name, func in gm.named_children():
if not isinstance(func, torch.nn.Linear):
for i in ELEMENTWISE_MODULE_OP:
if isinstance(func, i):
element_op.append(name)
break
else:
if name in all_linear_name:
if name in linear_name:
linear_name.remove(name)
else:
all_linear_name.append(name)
linear_name.append(name)
# If the linear modules is called multiple times, set the dist spec as col shard
# If the module is element wise or the function/method is element wise, remains col_shard
for node in mod_graph.nodes:
if node.target in linear_name:
target_module = node.graph.owning_module.get_submodule(node.target)
dim = 0 if col_shard else -1
target_module.weight = weight_split(target_module.weight, dim=dim, col_normal=False)
col_shard = not col_shard
elif node.target in all_linear_name:
target_module = node.graph.owning_module.get_submodule(node.target)
dim = 0 if col_shard else -1
target_module.weight = weight_split(target_module.weight, dim=dim, col_normal=True)
col_shard = not col_shard
else:
if node.target not in element_op and all(node.target != i for i in ELEMENTWISE_FUNC_OP):
col_shard = True
gm.recompile()
return gm
graph = graph_module.graph
world_size = process_group.world_size()
def _traverse_and_annotate(node, start_tracking, annotation_record, world_size):
# traverse the graph to look for consecutive linear layers
is_linear_module = False
if node.op == 'call_module':
# look for the linear layer
module = node.graph.owning_module.get_submodule(node.target)
if isinstance(module, nn.Linear):
is_linear_module = True
if start_tracking:
# when start_tracking = True
# it means the first linear has been found and the current module
# is the second linear
# set the current linear module to be row-sharded
annotation_record['row'] = module
for shard_type, module in annotation_record.items():
# add row sharding spec
if shard_type == 'row':
dist_spec = shard(dims=[-1], num_partitions=[world_size])
comp_spec = ComputeSpec(ComputePattern.TP1D)
setattr(module.weight, 'pg', process_group)
setattr(module.weight, 'dist_spec', dist_spec)
setattr(module.weight, 'comp_spec', comp_spec)
elif shard_type == 'col':
weight_dist_spec = shard(dims=[0], num_partitions=[world_size])
weight_comp_spec = ComputeSpec(ComputePattern.TP1D)
weight_comp_spec.output_replicate = False
setattr(module.weight, 'pg', process_group)
setattr(module.weight, 'dist_spec', weight_dist_spec)
setattr(module.weight, 'comp_spec', weight_comp_spec)
if module.bias is not None:
bias_dist_spec = shard(dims=[0], num_partitions=[world_size])
bias_comp_spec = ComputeSpec(ComputePattern.TP1D)
bias_comp_spec.output_replicate = False
setattr(module.bias, 'pg', process_group)
setattr(module.bias, 'dist_spec', bias_dist_spec)
setattr(module.bias, 'comp_spec', bias_comp_spec)
start_tracking = False
annotation_record.clear()
else:
# when start tracking = False
# it means the current layer is the first linear
# set the linear layer to be col-sharded
start_tracking = True
annotation_record['col'] = module
if start_tracking and not is_linear_module:
# check against the white list
# if non-element wise op is found, we reset the tracking
if node.op == 'call_module':
module = node.graph.owning_module.get_submodule(node.target)
if module.__class__ not in ELEMENTWISE_MODULE_OP:
start_tracking = False
elif node.op == 'call_function' or node.op == 'call_method':
if node.target not in ELEMENTWISE_FUNC_OP:
start_tracking = False
elif len(node.users.keys()) > 1:
start_tracking = False
if not start_tracking:
annotation_record.clear()
# stop tracking for consecutive linear when branch is found
# e.g.
# out1 = self.linear1(x)
# out2 = self.linear2(x)
# return out1+out2
next_nodes = list(node.users.keys())
if len(next_nodes) > 1:
start_tracking = False
annotation_record.clear()
# traverse
for node in next_nodes:
_traverse_and_annotate(node, start_tracking, annotation_record, world_size)
placeholder_node = list(graph.nodes)[0]
annotate_record = {}
_traverse_and_annotate(placeholder_node, False, annotate_record, world_size)
return graph_module

View File

@ -175,7 +175,7 @@ class LazyInitContext():
self._unpatch_nn_init_funcs()
self._unpatch_torch_tensor_funcs()
def lazy_init_parameters(self, model: torch.nn.Module, device='cpu', call_back: Callable = None):
def lazy_init_parameters(self, model: torch.nn.Module, device='cpu'):
"""
Initialize the weights of the meta-tensor model.
@ -205,6 +205,7 @@ class LazyInitContext():
# get sharding spec
dist_spec = getattr(tensor, 'dist_spec', None)
pg = getattr(tensor, 'pg', None)
comp_spec = getattr(tensor, 'comp_spec', None)
# convert the tensor from meta to materialized one
if tensor.is_meta:
@ -224,14 +225,15 @@ class LazyInitContext():
else:
tensor = ColoTensor.from_torch_tensor(tensor)
# apply sharding
if dist_spec:
tensor = tensor.redistribute(dist_spec=dist_spec, pg=pg)
# override the original tensor
with torch.no_grad():
setattr(module, name, tensor)
# apply sharding
if dist_spec:
tensor.process_group = pg
tensor.set_tensor_spec(dist_spec, comp_spec)
_init_recursively(model)
return model

View File

@ -0,0 +1,77 @@
import colossalai
import torch
import torch.nn as nn
import pytest
import torch.multiprocessing as mp
import torch.distributed as dist
from colossalai.testing import rerun_if_address_is_in_use
from functools import partial
from colossalai.fx import ColoTracer
from colossalai.utils.model.lazy_init_context import LazyInitContext
from colossalai.fx.passes.shard_1d_pass import transformer_mlp_pass
from colossalai.utils import free_port
from colossalai.tensor import ProcessGroup
class MLP(torch.nn.Module):
def __init__(self, dim: int):
super().__init__()
self.linear1 = torch.nn.Linear(dim, dim)
self.linear2 = torch.nn.Linear(dim, dim)
self.dropout = torch.nn.Dropout(0)
self.relu = torch.nn.ReLU()
def forward(self, x):
x = self.linear1(x)
x = self.dropout(x)
x = self.relu(x)
x = self.linear2(x)
return x
def run_workflow(world_size):
# initailization
with LazyInitContext() as ctx:
model = MLP(16)
# tracing
tracer = ColoTracer()
graph = tracer.trace(model)
gm = torch.fx.GraphModule(model, graph, model.__class__.__name__)
# annotate
annotated_gm = transformer_mlp_pass(gm, process_group=ProcessGroup())
annotated_gm.recompile()
# materialization and sharding
ctx.lazy_init_parameters(annotated_gm)
# # check sharding
assert list(model.linear1.weight.shape) == [16 // world_size, 16]
assert list(model.linear1.bias.shape) == [16 // world_size]
assert list(model.linear2.weight.shape) == [16, 16 // world_size]
# test forward to make sure that IR transform will produce the same results
# like how ColoTensor would do it normally
data = torch.rand(4, 16)
non_fx_out = model(data)
fx_out = annotated_gm(data)
assert torch.equal(non_fx_out, fx_out)
def run_dist(rank, world_size, port):
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
run_workflow(world_size)
@pytest.mark.dist
@pytest.mark.parametrize('world_size', [1, 2])
@rerun_if_address_is_in_use()
def test_complete_workflow(world_size):
run_func = partial(run_dist, world_size=world_size, port=free_port())
mp.spawn(run_func, nprocs=world_size)
if __name__ == '__main__':
test_complete_workflow(2)

View File

@ -1,59 +0,0 @@
import torch
import torch.nn as nn
import pytest
import colossalai
from colossalai.fx import ColoTracer
from colossalai.fx.passes.shard_1d_pass import transform_mlp_pass
CONFIG = dict(parallel=dict(tensor=dict(size=2, mode='1d')))
class MLP(torch.nn.Module):
def __init__(self, dim: int):
super().__init__()
self.linear1 = torch.nn.Linear(dim, dim)
self.linear2 = torch.nn.Linear(dim, dim)
self.linear3 = torch.nn.Linear(dim, dim)
self.linear4 = torch.nn.Linear(dim, dim)
self.dropout = torch.nn.Dropout()
self.relu = torch.nn.ReLU()
def forward(self, x):
x = self.relu(self.linear1(x))
x = self.dropout(self.relu(self.linear2(x)))
x = self.linear3(x)
x = torch.nn.functional.relu(self.linear4(x))
return x
def test_out_acc():
model = MLP(16).cuda()
model.eval()
input_tensor = torch.rand(2, 16).cuda()
output = model(input_tensor)
tracer = ColoTracer()
graph = tracer.trace(model, meta_args={'x': torch.randn((2, 16), device="meta")})
gm = torch.fx.GraphModule(model, graph, model.__class__.__name__)
splitted_gm = transform_mlp_pass(gm)
new_output = splitted_gm(input_tensor)
assert output.equal(new_output)
def test_linear_acc():
input_tensor = torch.rand(2, 16).cuda()
model = MLP(16).cuda()
tracer = ColoTracer()
graph = tracer.trace(model, meta_args={'x': torch.randn((2, 16), device="meta")})
gm = torch.fx.GraphModule(model, graph, model.__class__.__name__)
splitted_gm = transform_mlp_pass(gm)
col_shard = True
for node in splitted_gm.graph.nodes:
if node.op == "call_module" and isinstance(node.graph.owning_module.get_submodule(node.target), torch.nn.Linear):
target_module = node.graph.owning_module.get_submodule(node.target)
dim = 0 if col_shard else -1
assert target_module.weight.fx_attr == (dim, "SHARD", "TP", "col_needs_many_outputs")
col_shard = not col_shard
if __name__ == "__main__":
torch.manual_seed(1)
torch.cuda.manual_seed(1)
# colossalai.launch_from_torch(config=CONFIG)
test_out_acc()
test_linear_acc()