mirror of https://github.com/hpcaitech/ColossalAI
[fx] tested the complete workflow for auto-parallel (#1336)
* [fx] tested the complete workflow for auto-parallel * polish code * polish code * polish codepull/1344/head
parent
4631fef8a0
commit
2cc1175c76
|
@ -1,9 +1,16 @@
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
import operator
|
||||
import colossalai
|
||||
from colossalai.tensor import ProcessGroup
|
||||
from colossalai.tensor.distspec import shard
|
||||
from colossalai.tensor.compute_spec import ComputePattern, ComputeSpec
|
||||
|
||||
ELEMENTWISE_MODULE_OP = [torch.nn.Dropout, torch.nn.ReLU]
|
||||
ELEMENTWISE_FUNC_OP = [
|
||||
torch.add, operator.add, torch.abs, torch.cos, torch.exp, torch.mul, operator.mul, operator.floordiv,
|
||||
operator.truediv, operator.neg, torch.multiply, torch.nn.functional.relu, torch.nn.functional.dropout
|
||||
]
|
||||
|
||||
ELEMENTWISE_MODULE_OP = [torch.nn.Dropout, torch.nn.ReLU, torch.nn.Conv1d, torch.nn.Conv2d, torch.nn.Conv3d, torch.nn.MaxPool1d, torch.nn.MaxPool2d, torch.nn.AvgPool1d, torch.nn.AvgPool2d]
|
||||
ELEMENTWISE_FUNC_OP = [torch.add, operator.add, torch.abs, torch.cos, torch.exp, torch.mul, operator.mul, operator.floordiv, operator.truediv, operator.neg, torch.multiply, torch.nn.functional.relu, torch.nn.functional.dropout, torch.nn.functional.conv1d, torch.nn.functional.conv2d, torch.nn.functional.conv3d, torch.nn.functional.avg_pool1d, torch.nn.functional.avg_pool2d, torch.nn.functional.avg_pool3d, torch.nn.functional.max_pool1d, torch.nn.functional.max_pool2d, torch.nn.functional.max_pool3d]
|
||||
|
||||
def weight_split(weight: torch.nn.parameter.Parameter, dim: int, col_normal: bool) -> torch.nn.parameter.Parameter:
|
||||
"""weight_split
|
||||
|
@ -21,6 +28,8 @@ def weight_split(weight: torch.nn.parameter.Parameter, dim: int, col_normal: boo
|
|||
else:
|
||||
setattr(weight, "fx_attr", (dim, "SHARD", "TP", "col_needs_many_outputs"))
|
||||
return weight
|
||||
|
||||
|
||||
def column_shard_linear_pass(gm: torch.fx.GraphModule):
|
||||
# Split all the linear module with column shard. Currently for testing only.
|
||||
mod_graph = gm.graph
|
||||
|
@ -48,43 +57,95 @@ def row_shard_linear_pass(gm: torch.fx.GraphModule):
|
|||
gm.recompile()
|
||||
return gm
|
||||
|
||||
def transform_mlp_pass(gm: torch.fx.GraphModule):
|
||||
|
||||
def transformer_mlp_pass(graph_module: torch.fx.GraphModule, process_group: ProcessGroup):
|
||||
"""
|
||||
This IR pass checks for transformer MLP like structure and annotate column and row sharding to the linear layers.
|
||||
"""
|
||||
#TODO: Needs to handle special cases, like x = linear(x) + linear(x)
|
||||
mod_graph = gm.graph
|
||||
col_shard = True
|
||||
element_op = []
|
||||
all_linear_name = []
|
||||
linear_name = []
|
||||
# Get the name of element wise module(torch.nn.ReLU)
|
||||
# Get the name of all the linear modules and repeated linear modules
|
||||
for name, func in gm.named_children():
|
||||
if not isinstance(func, torch.nn.Linear):
|
||||
for i in ELEMENTWISE_MODULE_OP:
|
||||
if isinstance(func, i):
|
||||
element_op.append(name)
|
||||
break
|
||||
else:
|
||||
if name in all_linear_name:
|
||||
if name in linear_name:
|
||||
linear_name.remove(name)
|
||||
else:
|
||||
all_linear_name.append(name)
|
||||
linear_name.append(name)
|
||||
# If the linear modules is called multiple times, set the dist spec as col shard
|
||||
# If the module is element wise or the function/method is element wise, remains col_shard
|
||||
for node in mod_graph.nodes:
|
||||
if node.target in linear_name:
|
||||
target_module = node.graph.owning_module.get_submodule(node.target)
|
||||
dim = 0 if col_shard else -1
|
||||
target_module.weight = weight_split(target_module.weight, dim=dim, col_normal=False)
|
||||
col_shard = not col_shard
|
||||
elif node.target in all_linear_name:
|
||||
target_module = node.graph.owning_module.get_submodule(node.target)
|
||||
dim = 0 if col_shard else -1
|
||||
target_module.weight = weight_split(target_module.weight, dim=dim, col_normal=True)
|
||||
col_shard = not col_shard
|
||||
else:
|
||||
if node.target not in element_op and all(node.target != i for i in ELEMENTWISE_FUNC_OP):
|
||||
col_shard = True
|
||||
gm.recompile()
|
||||
return gm
|
||||
graph = graph_module.graph
|
||||
world_size = process_group.world_size()
|
||||
|
||||
def _traverse_and_annotate(node, start_tracking, annotation_record, world_size):
|
||||
# traverse the graph to look for consecutive linear layers
|
||||
is_linear_module = False
|
||||
|
||||
if node.op == 'call_module':
|
||||
# look for the linear layer
|
||||
module = node.graph.owning_module.get_submodule(node.target)
|
||||
if isinstance(module, nn.Linear):
|
||||
is_linear_module = True
|
||||
if start_tracking:
|
||||
# when start_tracking = True
|
||||
# it means the first linear has been found and the current module
|
||||
# is the second linear
|
||||
# set the current linear module to be row-sharded
|
||||
annotation_record['row'] = module
|
||||
|
||||
for shard_type, module in annotation_record.items():
|
||||
# add row sharding spec
|
||||
if shard_type == 'row':
|
||||
dist_spec = shard(dims=[-1], num_partitions=[world_size])
|
||||
comp_spec = ComputeSpec(ComputePattern.TP1D)
|
||||
setattr(module.weight, 'pg', process_group)
|
||||
setattr(module.weight, 'dist_spec', dist_spec)
|
||||
setattr(module.weight, 'comp_spec', comp_spec)
|
||||
elif shard_type == 'col':
|
||||
weight_dist_spec = shard(dims=[0], num_partitions=[world_size])
|
||||
weight_comp_spec = ComputeSpec(ComputePattern.TP1D)
|
||||
weight_comp_spec.output_replicate = False
|
||||
setattr(module.weight, 'pg', process_group)
|
||||
setattr(module.weight, 'dist_spec', weight_dist_spec)
|
||||
setattr(module.weight, 'comp_spec', weight_comp_spec)
|
||||
|
||||
if module.bias is not None:
|
||||
bias_dist_spec = shard(dims=[0], num_partitions=[world_size])
|
||||
bias_comp_spec = ComputeSpec(ComputePattern.TP1D)
|
||||
bias_comp_spec.output_replicate = False
|
||||
setattr(module.bias, 'pg', process_group)
|
||||
setattr(module.bias, 'dist_spec', bias_dist_spec)
|
||||
setattr(module.bias, 'comp_spec', bias_comp_spec)
|
||||
start_tracking = False
|
||||
annotation_record.clear()
|
||||
else:
|
||||
# when start tracking = False
|
||||
# it means the current layer is the first linear
|
||||
# set the linear layer to be col-sharded
|
||||
start_tracking = True
|
||||
annotation_record['col'] = module
|
||||
|
||||
if start_tracking and not is_linear_module:
|
||||
# check against the white list
|
||||
# if non-element wise op is found, we reset the tracking
|
||||
if node.op == 'call_module':
|
||||
module = node.graph.owning_module.get_submodule(node.target)
|
||||
if module.__class__ not in ELEMENTWISE_MODULE_OP:
|
||||
start_tracking = False
|
||||
elif node.op == 'call_function' or node.op == 'call_method':
|
||||
if node.target not in ELEMENTWISE_FUNC_OP:
|
||||
start_tracking = False
|
||||
elif len(node.users.keys()) > 1:
|
||||
start_tracking = False
|
||||
|
||||
if not start_tracking:
|
||||
annotation_record.clear()
|
||||
|
||||
# stop tracking for consecutive linear when branch is found
|
||||
# e.g.
|
||||
# out1 = self.linear1(x)
|
||||
# out2 = self.linear2(x)
|
||||
# return out1+out2
|
||||
next_nodes = list(node.users.keys())
|
||||
if len(next_nodes) > 1:
|
||||
start_tracking = False
|
||||
annotation_record.clear()
|
||||
|
||||
# traverse
|
||||
for node in next_nodes:
|
||||
_traverse_and_annotate(node, start_tracking, annotation_record, world_size)
|
||||
|
||||
placeholder_node = list(graph.nodes)[0]
|
||||
annotate_record = {}
|
||||
_traverse_and_annotate(placeholder_node, False, annotate_record, world_size)
|
||||
|
||||
return graph_module
|
||||
|
|
|
@ -175,7 +175,7 @@ class LazyInitContext():
|
|||
self._unpatch_nn_init_funcs()
|
||||
self._unpatch_torch_tensor_funcs()
|
||||
|
||||
def lazy_init_parameters(self, model: torch.nn.Module, device='cpu', call_back: Callable = None):
|
||||
def lazy_init_parameters(self, model: torch.nn.Module, device='cpu'):
|
||||
"""
|
||||
Initialize the weights of the meta-tensor model.
|
||||
|
||||
|
@ -205,6 +205,7 @@ class LazyInitContext():
|
|||
# get sharding spec
|
||||
dist_spec = getattr(tensor, 'dist_spec', None)
|
||||
pg = getattr(tensor, 'pg', None)
|
||||
comp_spec = getattr(tensor, 'comp_spec', None)
|
||||
|
||||
# convert the tensor from meta to materialized one
|
||||
if tensor.is_meta:
|
||||
|
@ -224,14 +225,15 @@ class LazyInitContext():
|
|||
else:
|
||||
tensor = ColoTensor.from_torch_tensor(tensor)
|
||||
|
||||
# apply sharding
|
||||
if dist_spec:
|
||||
tensor = tensor.redistribute(dist_spec=dist_spec, pg=pg)
|
||||
|
||||
# override the original tensor
|
||||
with torch.no_grad():
|
||||
setattr(module, name, tensor)
|
||||
|
||||
# apply sharding
|
||||
if dist_spec:
|
||||
tensor.process_group = pg
|
||||
tensor.set_tensor_spec(dist_spec, comp_spec)
|
||||
|
||||
_init_recursively(model)
|
||||
|
||||
return model
|
||||
|
|
|
@ -0,0 +1,77 @@
|
|||
import colossalai
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import pytest
|
||||
import torch.multiprocessing as mp
|
||||
import torch.distributed as dist
|
||||
from colossalai.testing import rerun_if_address_is_in_use
|
||||
from functools import partial
|
||||
from colossalai.fx import ColoTracer
|
||||
from colossalai.utils.model.lazy_init_context import LazyInitContext
|
||||
from colossalai.fx.passes.shard_1d_pass import transformer_mlp_pass
|
||||
from colossalai.utils import free_port
|
||||
from colossalai.tensor import ProcessGroup
|
||||
|
||||
|
||||
class MLP(torch.nn.Module):
|
||||
|
||||
def __init__(self, dim: int):
|
||||
super().__init__()
|
||||
self.linear1 = torch.nn.Linear(dim, dim)
|
||||
self.linear2 = torch.nn.Linear(dim, dim)
|
||||
self.dropout = torch.nn.Dropout(0)
|
||||
self.relu = torch.nn.ReLU()
|
||||
|
||||
def forward(self, x):
|
||||
x = self.linear1(x)
|
||||
x = self.dropout(x)
|
||||
x = self.relu(x)
|
||||
x = self.linear2(x)
|
||||
return x
|
||||
|
||||
|
||||
def run_workflow(world_size):
|
||||
# initailization
|
||||
with LazyInitContext() as ctx:
|
||||
model = MLP(16)
|
||||
|
||||
# tracing
|
||||
tracer = ColoTracer()
|
||||
graph = tracer.trace(model)
|
||||
gm = torch.fx.GraphModule(model, graph, model.__class__.__name__)
|
||||
|
||||
# annotate
|
||||
annotated_gm = transformer_mlp_pass(gm, process_group=ProcessGroup())
|
||||
annotated_gm.recompile()
|
||||
|
||||
# materialization and sharding
|
||||
ctx.lazy_init_parameters(annotated_gm)
|
||||
|
||||
# # check sharding
|
||||
assert list(model.linear1.weight.shape) == [16 // world_size, 16]
|
||||
assert list(model.linear1.bias.shape) == [16 // world_size]
|
||||
assert list(model.linear2.weight.shape) == [16, 16 // world_size]
|
||||
|
||||
# test forward to make sure that IR transform will produce the same results
|
||||
# like how ColoTensor would do it normally
|
||||
data = torch.rand(4, 16)
|
||||
non_fx_out = model(data)
|
||||
fx_out = annotated_gm(data)
|
||||
assert torch.equal(non_fx_out, fx_out)
|
||||
|
||||
|
||||
def run_dist(rank, world_size, port):
|
||||
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||
run_workflow(world_size)
|
||||
|
||||
|
||||
@pytest.mark.dist
|
||||
@pytest.mark.parametrize('world_size', [1, 2])
|
||||
@rerun_if_address_is_in_use()
|
||||
def test_complete_workflow(world_size):
|
||||
run_func = partial(run_dist, world_size=world_size, port=free_port())
|
||||
mp.spawn(run_func, nprocs=world_size)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_complete_workflow(2)
|
|
@ -1,59 +0,0 @@
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
import pytest
|
||||
import colossalai
|
||||
from colossalai.fx import ColoTracer
|
||||
from colossalai.fx.passes.shard_1d_pass import transform_mlp_pass
|
||||
CONFIG = dict(parallel=dict(tensor=dict(size=2, mode='1d')))
|
||||
|
||||
class MLP(torch.nn.Module):
|
||||
|
||||
def __init__(self, dim: int):
|
||||
super().__init__()
|
||||
self.linear1 = torch.nn.Linear(dim, dim)
|
||||
self.linear2 = torch.nn.Linear(dim, dim)
|
||||
self.linear3 = torch.nn.Linear(dim, dim)
|
||||
self.linear4 = torch.nn.Linear(dim, dim)
|
||||
self.dropout = torch.nn.Dropout()
|
||||
self.relu = torch.nn.ReLU()
|
||||
|
||||
def forward(self, x):
|
||||
x = self.relu(self.linear1(x))
|
||||
x = self.dropout(self.relu(self.linear2(x)))
|
||||
x = self.linear3(x)
|
||||
x = torch.nn.functional.relu(self.linear4(x))
|
||||
return x
|
||||
|
||||
def test_out_acc():
|
||||
model = MLP(16).cuda()
|
||||
model.eval()
|
||||
input_tensor = torch.rand(2, 16).cuda()
|
||||
output = model(input_tensor)
|
||||
tracer = ColoTracer()
|
||||
graph = tracer.trace(model, meta_args={'x': torch.randn((2, 16), device="meta")})
|
||||
gm = torch.fx.GraphModule(model, graph, model.__class__.__name__)
|
||||
splitted_gm = transform_mlp_pass(gm)
|
||||
new_output = splitted_gm(input_tensor)
|
||||
assert output.equal(new_output)
|
||||
|
||||
def test_linear_acc():
|
||||
input_tensor = torch.rand(2, 16).cuda()
|
||||
model = MLP(16).cuda()
|
||||
tracer = ColoTracer()
|
||||
graph = tracer.trace(model, meta_args={'x': torch.randn((2, 16), device="meta")})
|
||||
gm = torch.fx.GraphModule(model, graph, model.__class__.__name__)
|
||||
splitted_gm = transform_mlp_pass(gm)
|
||||
col_shard = True
|
||||
for node in splitted_gm.graph.nodes:
|
||||
if node.op == "call_module" and isinstance(node.graph.owning_module.get_submodule(node.target), torch.nn.Linear):
|
||||
target_module = node.graph.owning_module.get_submodule(node.target)
|
||||
dim = 0 if col_shard else -1
|
||||
assert target_module.weight.fx_attr == (dim, "SHARD", "TP", "col_needs_many_outputs")
|
||||
col_shard = not col_shard
|
||||
|
||||
if __name__ == "__main__":
|
||||
torch.manual_seed(1)
|
||||
torch.cuda.manual_seed(1)
|
||||
# colossalai.launch_from_torch(config=CONFIG)
|
||||
test_out_acc()
|
||||
test_linear_acc()
|
Loading…
Reference in New Issue