From ca2d3f284fafe692f96150f337bee848110f593b Mon Sep 17 00:00:00 2001 From: XYE <92607131+Itok2000u@users.noreply.github.com> Date: Fri, 15 Jul 2022 14:37:58 +0800 Subject: [PATCH] [fx] Add unit test and fix bugs for transform_mlp_pass (#1299) * add test and fix bugs * add functions back * add comments --- colossalai/fx/passes/shard_1d_pass.py | 79 +++++++++++++++++------- tests/test_fx/test_transform_mlp_pass.py | 59 ++++++++++++++++++ 2 files changed, 114 insertions(+), 24 deletions(-) create mode 100644 tests/test_fx/test_transform_mlp_pass.py diff --git a/colossalai/fx/passes/shard_1d_pass.py b/colossalai/fx/passes/shard_1d_pass.py index 49a823076..44449ff8e 100644 --- a/colossalai/fx/passes/shard_1d_pass.py +++ b/colossalai/fx/passes/shard_1d_pass.py @@ -1,59 +1,90 @@ import torch -from colossalai.tensor import ColoTensorSpec, distspec, ProcessGroup, ComputeSpec, ComputePattern, ShardSpec +import operator +import colossalai +ELEMENTWISE_MODULE_OP = [torch.nn.Dropout, torch.nn.ReLU, torch.nn.Conv1d, torch.nn.Conv2d, torch.nn.Conv3d, torch.nn.MaxPool1d, torch.nn.MaxPool2d, torch.nn.AvgPool1d, torch.nn.AvgPool2d] +ELEMENTWISE_FUNC_OP = [torch.add, operator.add, torch.abs, torch.cos, torch.exp, torch.mul, operator.mul, operator.floordiv, operator.truediv, operator.neg, torch.multiply, torch.nn.functional.relu, torch.nn.functional.dropout, torch.nn.functional.conv1d, torch.nn.functional.conv2d, torch.nn.functional.conv3d, torch.nn.functional.avg_pool1d, torch.nn.functional.avg_pool2d, torch.nn.functional.avg_pool3d, torch.nn.functional.max_pool1d, torch.nn.functional.max_pool2d, torch.nn.functional.max_pool3d] -def weight_split(weight: torch.Tensor, dim: int) -> torch.nn.parameter.Parameter: +def weight_split(weight: torch.nn.parameter.Parameter, dim: int, col_normal: bool) -> torch.nn.parameter.Parameter: """weight_split split a nn.Parameter Args: weight (torch.nn.parameter.Parameter): a torch Parameter instance dim (int): the dimension to be sharded along with - + col_normal(bool): col shard with gather or not Returns: _type_: _description_ """ - # Append a Tensor spec to target_module.weight.shard - # Convert to ColoTensor: colo_tensor = ColoTensor.from_torch_tensor(tensor, spec) - assert isinstance(weight, torch.Tensor), \ - f'The type of the input tensor should be torch.nn.parameter' \ - f'Your Input tensor is {type(weight)}' - - # FIXME() I initialized a PG for this tensor. Only has TP comm group. - # we only consider the TP-only caes. - world_size = torch.distributed.get_world_size() - pg = ProcessGroup(tp_degree=world_size) - - spec = ColoTensorSpec(pg, ShardSpec([dim], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D)) - # As you has constructed a Spec, why not directly convert the tensor to ColoTensor. - setattr(weight, "fx_attr", spec) + if col_normal: + setattr(weight, "fx_attr", (dim, "SHARD", "TP", "col_normal")) + else: + setattr(weight, "fx_attr", (dim, "SHARD", "TP", "col_needs_many_outputs")) return weight - - def column_shard_linear_pass(gm: torch.fx.GraphModule): + # Split all the linear module with column shard. Currently for testing only. mod_graph = gm.graph for node in mod_graph.nodes: if node.op == "call_module": target_module = node.graph.owning_module.get_submodule(node.target) if isinstance(target_module, torch.nn.Linear): - target_module.weight = weight_split(target_module.weight, dim=0) + target_module.weight = weight_split(target_module.weight, dim=0, col_normal=False) if target_module.bias is not None: - target_module.bias.data = weight_split(target_module.bias.data, dim=0) + target_module.bias.data = weight_split(target_module.bias.data, dim=0, col_normal=False) gm.recompile() return gm def row_shard_linear_pass(gm: torch.fx.GraphModule): + # Split all the linear module with row shard. Currently for testing only. mod_graph = gm.graph for node in mod_graph.nodes: if node.op == "call_module": target_module = node.graph.owning_module.get_submodule(node.target) if isinstance(target_module, torch.nn.Linear): - target_module.weight = weight_split(target_module.weight, dim=-1) + target_module.weight = weight_split(target_module.weight, dim=-1, col_normal=False) gm.recompile() return gm - -#TODO: add elementwise op process pass, then we can try to use column and row mixed strategy. +def transform_mlp_pass(gm: torch.fx.GraphModule): + #TODO: Needs to handle special cases, like x = linear(x) + linear(x) + mod_graph = gm.graph + col_shard = True + element_op = [] + all_linear_name = [] + linear_name = [] + # Get the name of element wise module(torch.nn.ReLU) + # Get the name of all the linear modules and repeated linear modules + for name, func in gm.named_children(): + if not isinstance(func, torch.nn.Linear): + for i in ELEMENTWISE_MODULE_OP: + if isinstance(func, i): + element_op.append(name) + break + else: + if name in all_linear_name: + if name in linear_name: + linear_name.remove(name) + else: + all_linear_name.append(name) + linear_name.append(name) + # If the linear modules is called multiple times, set the dist spec as col shard + # If the module is element wise or the function/method is element wise, remains col_shard + for node in mod_graph.nodes: + if node.target in linear_name: + target_module = node.graph.owning_module.get_submodule(node.target) + dim = 0 if col_shard else -1 + target_module.weight = weight_split(target_module.weight, dim=dim, col_normal=False) + col_shard = not col_shard + elif node.target in all_linear_name: + target_module = node.graph.owning_module.get_submodule(node.target) + dim = 0 if col_shard else -1 + target_module.weight = weight_split(target_module.weight, dim=dim, col_normal=True) + col_shard = not col_shard + else: + if node.target not in element_op and all(node.target != i for i in ELEMENTWISE_FUNC_OP): + col_shard = True + gm.recompile() + return gm \ No newline at end of file diff --git a/tests/test_fx/test_transform_mlp_pass.py b/tests/test_fx/test_transform_mlp_pass.py new file mode 100644 index 000000000..202c8ce0e --- /dev/null +++ b/tests/test_fx/test_transform_mlp_pass.py @@ -0,0 +1,59 @@ +import torch +import torch.nn as nn +import pytest +import colossalai +from colossalai.fx import ColoTracer +from colossalai.fx.passes.shard_1d_pass import transform_mlp_pass +CONFIG = dict(parallel=dict(tensor=dict(size=2, mode='1d'))) + +class MLP(torch.nn.Module): + + def __init__(self, dim: int): + super().__init__() + self.linear1 = torch.nn.Linear(dim, dim) + self.linear2 = torch.nn.Linear(dim, dim) + self.linear3 = torch.nn.Linear(dim, dim) + self.linear4 = torch.nn.Linear(dim, dim) + self.dropout = torch.nn.Dropout() + self.relu = torch.nn.ReLU() + + def forward(self, x): + x = self.relu(self.linear1(x)) + x = self.dropout(self.relu(self.linear2(x))) + x = self.linear3(x) + x = torch.nn.functional.relu(self.linear4(x)) + return x + +def test_out_acc(): + model = MLP(16).cuda() + model.eval() + input_tensor = torch.rand(2, 16).cuda() + output = model(input_tensor) + tracer = ColoTracer() + graph = tracer.trace(model, meta_args={'x': torch.randn((2, 16), device="meta")}) + gm = torch.fx.GraphModule(model, graph, model.__class__.__name__) + splitted_gm = transform_mlp_pass(gm) + new_output = splitted_gm(input_tensor) + assert output.equal(new_output) + +def test_linear_acc(): + input_tensor = torch.rand(2, 16).cuda() + model = MLP(16).cuda() + tracer = ColoTracer() + graph = tracer.trace(model, meta_args={'x': torch.randn((2, 16), device="meta")}) + gm = torch.fx.GraphModule(model, graph, model.__class__.__name__) + splitted_gm = transform_mlp_pass(gm) + col_shard = True + for node in splitted_gm.graph.nodes: + if node.op == "call_module" and isinstance(node.graph.owning_module.get_submodule(node.target), torch.nn.Linear): + target_module = node.graph.owning_module.get_submodule(node.target) + dim = 0 if col_shard else -1 + assert target_module.weight.fx_attr == (dim, "SHARD", "TP", "col_needs_many_outputs") + col_shard = not col_shard + +if __name__ == "__main__": + torch.manual_seed(1) + torch.cuda.manual_seed(1) + # colossalai.launch_from_torch(config=CONFIG) + test_out_acc() + test_linear_acc()