mirror of https://github.com/hpcaitech/ColossalAI
[fx] Add unit test and fix bugs for transform_mlp_pass (#1299)
* add test and fix bugs * add functions back * add commentspull/1323/head
parent
1b41686461
commit
ca2d3f284f
|
@ -1,59 +1,90 @@
|
||||||
import torch
|
import torch
|
||||||
from colossalai.tensor import ColoTensorSpec, distspec, ProcessGroup, ComputeSpec, ComputePattern, ShardSpec
|
import operator
|
||||||
|
import colossalai
|
||||||
|
|
||||||
|
ELEMENTWISE_MODULE_OP = [torch.nn.Dropout, torch.nn.ReLU, torch.nn.Conv1d, torch.nn.Conv2d, torch.nn.Conv3d, torch.nn.MaxPool1d, torch.nn.MaxPool2d, torch.nn.AvgPool1d, torch.nn.AvgPool2d]
|
||||||
|
ELEMENTWISE_FUNC_OP = [torch.add, operator.add, torch.abs, torch.cos, torch.exp, torch.mul, operator.mul, operator.floordiv, operator.truediv, operator.neg, torch.multiply, torch.nn.functional.relu, torch.nn.functional.dropout, torch.nn.functional.conv1d, torch.nn.functional.conv2d, torch.nn.functional.conv3d, torch.nn.functional.avg_pool1d, torch.nn.functional.avg_pool2d, torch.nn.functional.avg_pool3d, torch.nn.functional.max_pool1d, torch.nn.functional.max_pool2d, torch.nn.functional.max_pool3d]
|
||||||
|
|
||||||
def weight_split(weight: torch.Tensor, dim: int) -> torch.nn.parameter.Parameter:
|
def weight_split(weight: torch.nn.parameter.Parameter, dim: int, col_normal: bool) -> torch.nn.parameter.Parameter:
|
||||||
"""weight_split
|
"""weight_split
|
||||||
split a nn.Parameter
|
split a nn.Parameter
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
weight (torch.nn.parameter.Parameter): a torch Parameter instance
|
weight (torch.nn.parameter.Parameter): a torch Parameter instance
|
||||||
dim (int): the dimension to be sharded along with
|
dim (int): the dimension to be sharded along with
|
||||||
|
col_normal(bool): col shard with gather or not
|
||||||
Returns:
|
Returns:
|
||||||
_type_: _description_
|
_type_: _description_
|
||||||
"""
|
"""
|
||||||
# Append a Tensor spec to target_module.weight.shard
|
if col_normal:
|
||||||
# Convert to ColoTensor: colo_tensor = ColoTensor.from_torch_tensor(tensor, spec)
|
setattr(weight, "fx_attr", (dim, "SHARD", "TP", "col_normal"))
|
||||||
assert isinstance(weight, torch.Tensor), \
|
else:
|
||||||
f'The type of the input tensor should be torch.nn.parameter' \
|
setattr(weight, "fx_attr", (dim, "SHARD", "TP", "col_needs_many_outputs"))
|
||||||
f'Your Input tensor is {type(weight)}'
|
|
||||||
|
|
||||||
# FIXME() I initialized a PG for this tensor. Only has TP comm group.
|
|
||||||
# we only consider the TP-only caes.
|
|
||||||
world_size = torch.distributed.get_world_size()
|
|
||||||
pg = ProcessGroup(tp_degree=world_size)
|
|
||||||
|
|
||||||
spec = ColoTensorSpec(pg, ShardSpec([dim], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D))
|
|
||||||
# As you has constructed a Spec, why not directly convert the tensor to ColoTensor.
|
|
||||||
setattr(weight, "fx_attr", spec)
|
|
||||||
return weight
|
return weight
|
||||||
|
|
||||||
|
|
||||||
def column_shard_linear_pass(gm: torch.fx.GraphModule):
|
def column_shard_linear_pass(gm: torch.fx.GraphModule):
|
||||||
|
# Split all the linear module with column shard. Currently for testing only.
|
||||||
mod_graph = gm.graph
|
mod_graph = gm.graph
|
||||||
for node in mod_graph.nodes:
|
for node in mod_graph.nodes:
|
||||||
if node.op == "call_module":
|
if node.op == "call_module":
|
||||||
target_module = node.graph.owning_module.get_submodule(node.target)
|
target_module = node.graph.owning_module.get_submodule(node.target)
|
||||||
if isinstance(target_module, torch.nn.Linear):
|
if isinstance(target_module, torch.nn.Linear):
|
||||||
target_module.weight = weight_split(target_module.weight, dim=0)
|
target_module.weight = weight_split(target_module.weight, dim=0, col_normal=False)
|
||||||
if target_module.bias is not None:
|
if target_module.bias is not None:
|
||||||
target_module.bias.data = weight_split(target_module.bias.data, dim=0)
|
target_module.bias.data = weight_split(target_module.bias.data, dim=0, col_normal=False)
|
||||||
|
|
||||||
gm.recompile()
|
gm.recompile()
|
||||||
return gm
|
return gm
|
||||||
|
|
||||||
|
|
||||||
def row_shard_linear_pass(gm: torch.fx.GraphModule):
|
def row_shard_linear_pass(gm: torch.fx.GraphModule):
|
||||||
|
# Split all the linear module with row shard. Currently for testing only.
|
||||||
mod_graph = gm.graph
|
mod_graph = gm.graph
|
||||||
for node in mod_graph.nodes:
|
for node in mod_graph.nodes:
|
||||||
if node.op == "call_module":
|
if node.op == "call_module":
|
||||||
target_module = node.graph.owning_module.get_submodule(node.target)
|
target_module = node.graph.owning_module.get_submodule(node.target)
|
||||||
if isinstance(target_module, torch.nn.Linear):
|
if isinstance(target_module, torch.nn.Linear):
|
||||||
target_module.weight = weight_split(target_module.weight, dim=-1)
|
target_module.weight = weight_split(target_module.weight, dim=-1, col_normal=False)
|
||||||
|
|
||||||
gm.recompile()
|
gm.recompile()
|
||||||
return gm
|
return gm
|
||||||
|
|
||||||
|
def transform_mlp_pass(gm: torch.fx.GraphModule):
|
||||||
#TODO: add elementwise op process pass, then we can try to use column and row mixed strategy.
|
#TODO: Needs to handle special cases, like x = linear(x) + linear(x)
|
||||||
|
mod_graph = gm.graph
|
||||||
|
col_shard = True
|
||||||
|
element_op = []
|
||||||
|
all_linear_name = []
|
||||||
|
linear_name = []
|
||||||
|
# Get the name of element wise module(torch.nn.ReLU)
|
||||||
|
# Get the name of all the linear modules and repeated linear modules
|
||||||
|
for name, func in gm.named_children():
|
||||||
|
if not isinstance(func, torch.nn.Linear):
|
||||||
|
for i in ELEMENTWISE_MODULE_OP:
|
||||||
|
if isinstance(func, i):
|
||||||
|
element_op.append(name)
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
if name in all_linear_name:
|
||||||
|
if name in linear_name:
|
||||||
|
linear_name.remove(name)
|
||||||
|
else:
|
||||||
|
all_linear_name.append(name)
|
||||||
|
linear_name.append(name)
|
||||||
|
# If the linear modules is called multiple times, set the dist spec as col shard
|
||||||
|
# If the module is element wise or the function/method is element wise, remains col_shard
|
||||||
|
for node in mod_graph.nodes:
|
||||||
|
if node.target in linear_name:
|
||||||
|
target_module = node.graph.owning_module.get_submodule(node.target)
|
||||||
|
dim = 0 if col_shard else -1
|
||||||
|
target_module.weight = weight_split(target_module.weight, dim=dim, col_normal=False)
|
||||||
|
col_shard = not col_shard
|
||||||
|
elif node.target in all_linear_name:
|
||||||
|
target_module = node.graph.owning_module.get_submodule(node.target)
|
||||||
|
dim = 0 if col_shard else -1
|
||||||
|
target_module.weight = weight_split(target_module.weight, dim=dim, col_normal=True)
|
||||||
|
col_shard = not col_shard
|
||||||
|
else:
|
||||||
|
if node.target not in element_op and all(node.target != i for i in ELEMENTWISE_FUNC_OP):
|
||||||
|
col_shard = True
|
||||||
|
gm.recompile()
|
||||||
|
return gm
|
|
@ -0,0 +1,59 @@
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import pytest
|
||||||
|
import colossalai
|
||||||
|
from colossalai.fx import ColoTracer
|
||||||
|
from colossalai.fx.passes.shard_1d_pass import transform_mlp_pass
|
||||||
|
CONFIG = dict(parallel=dict(tensor=dict(size=2, mode='1d')))
|
||||||
|
|
||||||
|
class MLP(torch.nn.Module):
|
||||||
|
|
||||||
|
def __init__(self, dim: int):
|
||||||
|
super().__init__()
|
||||||
|
self.linear1 = torch.nn.Linear(dim, dim)
|
||||||
|
self.linear2 = torch.nn.Linear(dim, dim)
|
||||||
|
self.linear3 = torch.nn.Linear(dim, dim)
|
||||||
|
self.linear4 = torch.nn.Linear(dim, dim)
|
||||||
|
self.dropout = torch.nn.Dropout()
|
||||||
|
self.relu = torch.nn.ReLU()
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
x = self.relu(self.linear1(x))
|
||||||
|
x = self.dropout(self.relu(self.linear2(x)))
|
||||||
|
x = self.linear3(x)
|
||||||
|
x = torch.nn.functional.relu(self.linear4(x))
|
||||||
|
return x
|
||||||
|
|
||||||
|
def test_out_acc():
|
||||||
|
model = MLP(16).cuda()
|
||||||
|
model.eval()
|
||||||
|
input_tensor = torch.rand(2, 16).cuda()
|
||||||
|
output = model(input_tensor)
|
||||||
|
tracer = ColoTracer()
|
||||||
|
graph = tracer.trace(model, meta_args={'x': torch.randn((2, 16), device="meta")})
|
||||||
|
gm = torch.fx.GraphModule(model, graph, model.__class__.__name__)
|
||||||
|
splitted_gm = transform_mlp_pass(gm)
|
||||||
|
new_output = splitted_gm(input_tensor)
|
||||||
|
assert output.equal(new_output)
|
||||||
|
|
||||||
|
def test_linear_acc():
|
||||||
|
input_tensor = torch.rand(2, 16).cuda()
|
||||||
|
model = MLP(16).cuda()
|
||||||
|
tracer = ColoTracer()
|
||||||
|
graph = tracer.trace(model, meta_args={'x': torch.randn((2, 16), device="meta")})
|
||||||
|
gm = torch.fx.GraphModule(model, graph, model.__class__.__name__)
|
||||||
|
splitted_gm = transform_mlp_pass(gm)
|
||||||
|
col_shard = True
|
||||||
|
for node in splitted_gm.graph.nodes:
|
||||||
|
if node.op == "call_module" and isinstance(node.graph.owning_module.get_submodule(node.target), torch.nn.Linear):
|
||||||
|
target_module = node.graph.owning_module.get_submodule(node.target)
|
||||||
|
dim = 0 if col_shard else -1
|
||||||
|
assert target_module.weight.fx_attr == (dim, "SHARD", "TP", "col_needs_many_outputs")
|
||||||
|
col_shard = not col_shard
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
torch.manual_seed(1)
|
||||||
|
torch.cuda.manual_seed(1)
|
||||||
|
# colossalai.launch_from_torch(config=CONFIG)
|
||||||
|
test_out_acc()
|
||||||
|
test_linear_acc()
|
Loading…
Reference in New Issue