Making large AI models cheaper, faster and more accessible
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
 
 

114 lines
4.2 KiB

import torch
from colossalai.fx import ColoGraphModule, ColoTracer
from colossalai.testing import clear_cache_before_run
class LinearModel(torch.nn.Module):
def __init__(self, in_features, out_features):
super().__init__()
self.linear = torch.nn.Linear(in_features, out_features)
def forward(self, x):
x = self.linear(x)
x = x * 2
return x
class ConvModel(torch.nn.Module):
def __init__(self, in_channels, out_channels, kernel_size, bias=True):
super().__init__()
self.conv = torch.nn.Conv2d(
in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, bias=bias
)
def forward(self, x):
x = self.conv(x)
x = x * 2
return x
@clear_cache_before_run()
def test_linear_module():
model = LinearModel(3, 6)
tracer = ColoTracer()
# graph():
# %x : torch.Tensor [#users=1] = placeholder[target=x]
# %linear_weight : [#users=1] = get_attr[target=linear.weight]
# %linear_bias : [#users=1] = get_attr[target=linear.bias]
# %linear : [#users=1] = call_function[target=torch._C._nn.linear](args = (%x, %linear_weight), kwargs = {})
# %add : [#users=1] = call_function[target=operator.add](args = (%linear, %linear_bias), kwargs = {})
# %mul : [#users=1] = call_function[target=operator.mul](args = (%add, 2), kwargs = {})
# return mul
graph = tracer.trace(root=model, meta_args={"x": torch.rand(3, 3).to("meta")})
# def forward(self, x : torch.Tensor):
# linear_weight = self.linear.weight
# linear_bias = self.linear.bias
# linear = torch._C._nn.linear(x, linear_weight); x = linear_weight = None
# add = linear + linear_bias; linear = linear_bias = None
# mul = add * 2; add = None
# return mul
gm = ColoGraphModule(model, graph)
gm.recompile()
node_list = list(graph.nodes)
for node in node_list:
if node.op == "output":
continue
assert hasattr(node, "_meta_data")
weight_node = node_list[1]
bias_node = node_list[2]
linear_node = node_list[3]
add_node = node_list[4]
assert weight_node._meta_data.shape == (6, 3)
assert bias_node._meta_data.shape == (6,)
assert linear_node._meta_data.shape == (3, 6)
assert add_node._meta_data.shape == (3, 6)
@clear_cache_before_run()
def test_conv_module():
model = ConvModel(3, 6, 2)
tracer = ColoTracer()
# graph():
# %x : torch.Tensor [#users=1] = placeholder[target=x]
# %conv_weight : [#users=1] = get_attr[target=conv.weight]
# %conv_bias : [#users=1] = get_attr[target=conv.bias]
# %conv2d : [#users=1] = call_function[target=torch.conv2d](args = (%x, %conv_weight), kwargs = {})
# %view : [#users=1] = call_method[target=view](args = (%conv_bias, [1, -1, 1, 1]), kwargs = {})
# %add : [#users=1] = call_function[target=operator.add](args = (%conv2d, %view), kwargs = {})
# %mul : [#users=1] = call_function[target=operator.mul](args = (%add, 2), kwargs = {})
# return mul
graph = tracer.trace(root=model, meta_args={"x": torch.rand(4, 3, 64, 64).to("meta")})
# def forward(self, x : torch.Tensor):
# conv_weight = self.conv.weight
# conv_bias = self.conv.bias
# conv2d = torch.conv2d(x, conv_weight); x = conv_weight = None
# view = conv_bias.view([1, -1, 1, 1]); conv_bias = None
# add = conv2d + view; conv2d = view = None
# mul = add * 2; add = None
# return mul
gm = ColoGraphModule(model, graph)
gm.recompile()
node_list = list(graph.nodes)
for node in node_list:
if node.op == "output":
continue
assert hasattr(node, "_meta_data")
weight_node = node_list[1]
bias_node = node_list[2]
conv_node = node_list[3]
view_node = node_list[4]
add_node = node_list[5]
assert weight_node._meta_data.shape == (6, 3, 2, 2)
assert bias_node._meta_data.shape == (6,)
assert conv_node._meta_data.shape == (4, 6, 63, 63)
assert view_node._meta_data.shape == (6, 1, 1)
assert add_node._meta_data.shape == (4, 6, 63, 63)
if __name__ == "__main__":
test_linear_module()
test_conv_module()