2023-03-10 05:21:05 +00:00
|
|
|
import pytest
|
|
|
|
import torch
|
2023-03-22 02:40:33 +00:00
|
|
|
from packaging import version
|
2023-03-10 05:21:05 +00:00
|
|
|
from torch.utils.checkpoint import checkpoint
|
|
|
|
|
2023-04-06 06:51:35 +00:00
|
|
|
from colossalai.testing.utils import clear_cache_before_run, parameterize
|
2023-03-22 05:38:11 +00:00
|
|
|
|
2023-03-10 05:21:05 +00:00
|
|
|
try:
|
|
|
|
from colossalai._analyzer.fx import symbolic_trace
|
|
|
|
except:
|
|
|
|
pass
|
|
|
|
|
|
|
|
|
|
|
|
class LinearModel(torch.nn.Module):
|
|
|
|
|
|
|
|
def __init__(self, in_features, out_features, bias):
|
|
|
|
super().__init__()
|
|
|
|
self.linear = torch.nn.Linear(in_features, out_features, bias=bias)
|
|
|
|
|
|
|
|
def forward(self, x):
|
|
|
|
x = self.linear(x)
|
|
|
|
return x
|
|
|
|
|
|
|
|
|
|
|
|
class ConvModel(torch.nn.Module):
|
|
|
|
|
|
|
|
def __init__(self, in_channel, out_channels, kernel_size, bias) -> None:
|
|
|
|
super().__init__()
|
|
|
|
self.conv = torch.nn.Conv2d(in_channel,
|
|
|
|
out_channels,
|
|
|
|
kernel_size,
|
|
|
|
bias=bias,
|
|
|
|
padding=1,
|
|
|
|
stride=2,
|
|
|
|
dilation=2,
|
|
|
|
groups=3)
|
|
|
|
self.conv_transpose = torch.nn.ConvTranspose2d(in_channel,
|
|
|
|
out_channels,
|
|
|
|
kernel_size,
|
|
|
|
bias=bias,
|
|
|
|
padding=1,
|
|
|
|
stride=2,
|
|
|
|
dilation=2,
|
|
|
|
groups=3)
|
|
|
|
|
|
|
|
def forward(self, x, select=0):
|
|
|
|
if select == 0:
|
|
|
|
x = self.conv(x)
|
|
|
|
else:
|
|
|
|
x = self.conv_transpose(x)
|
|
|
|
return x
|
|
|
|
|
|
|
|
|
|
|
|
class SiuModel(torch.nn.Module):
|
|
|
|
|
|
|
|
def __init__(self, bias) -> None:
|
|
|
|
super().__init__()
|
|
|
|
self.linear = LinearModel(3, 3, bias)
|
|
|
|
self.conv = ConvModel(3, 6, 3, bias)
|
|
|
|
|
2023-03-22 05:38:11 +00:00
|
|
|
def forward(self, x, select=torch.Tensor([0])):
|
2023-03-10 05:21:05 +00:00
|
|
|
x = self.linear(x)
|
2023-03-22 05:38:11 +00:00
|
|
|
if select:
|
|
|
|
x = checkpoint(self.conv, x, 0)
|
|
|
|
else:
|
|
|
|
x = checkpoint(self.conv, x, 1)
|
|
|
|
|
2023-03-10 05:21:05 +00:00
|
|
|
return x
|
|
|
|
|
|
|
|
|
|
|
|
class AddmmModel(torch.nn.Module):
|
|
|
|
|
|
|
|
def __init__(self, alpha, beta) -> None:
|
|
|
|
super().__init__()
|
|
|
|
self.alpha = alpha
|
|
|
|
self.beta = beta
|
|
|
|
|
|
|
|
def forward(self, x):
|
|
|
|
x = torch.addmm(x, x, x, alpha=self.alpha, beta=self.beta)
|
|
|
|
return x
|
|
|
|
|
|
|
|
|
2023-03-22 02:40:33 +00:00
|
|
|
@pytest.mark.skipif(version.parse(torch.__version__) < version.parse('1.12.0'), reason='torch version < 12')
|
2023-04-06 06:51:35 +00:00
|
|
|
@clear_cache_before_run()
|
2023-03-22 05:38:11 +00:00
|
|
|
@parameterize("bias", [True, False])
|
|
|
|
@parameterize("bias_addition_split", [True, False])
|
|
|
|
@parameterize("shape", [(3, 3, 3), (3, 3, 3, 3)])
|
|
|
|
@parameterize("select", [torch.Tensor([0]), torch.Tensor([1])])
|
2023-03-10 05:21:05 +00:00
|
|
|
def test_siu_model(bias, bias_addition_split, shape, select):
|
|
|
|
model = SiuModel(bias=bias)
|
|
|
|
x = torch.rand(shape)
|
|
|
|
gm = symbolic_trace(model,
|
|
|
|
meta_args={'x': x},
|
|
|
|
concrete_args={'select': select},
|
|
|
|
trace_act_ckpt=True,
|
|
|
|
bias_addition_split=bias_addition_split)
|
2023-03-22 05:38:11 +00:00
|
|
|
assert torch.allclose(model(x, select), gm(x)), 'original model and traced model should be the same!'
|
2023-03-10 05:21:05 +00:00
|
|
|
if bias and bias_addition_split:
|
|
|
|
assert '+' in gm.code, 'bias addition should be split!'
|
|
|
|
else:
|
|
|
|
assert '+' not in gm.code, 'bias addition should not be split!'
|
|
|
|
|
|
|
|
|
2023-03-22 05:38:11 +00:00
|
|
|
@pytest.mark.skipif(version.parse(torch.__version__) < version.parse('1.12.0'), reason='torch version < 12')
|
|
|
|
@parameterize("alpha", [1, 2])
|
|
|
|
@parameterize("beta", [1, 2])
|
|
|
|
@parameterize("bias_addition_split", [True, False])
|
|
|
|
@parameterize("shape", [(3, 3), (5, 5)])
|
2023-03-10 05:21:05 +00:00
|
|
|
def test_addmm_model(alpha, beta, bias_addition_split, shape):
|
|
|
|
model = AddmmModel(alpha=alpha, beta=beta)
|
|
|
|
x = torch.rand(shape)
|
|
|
|
gm = symbolic_trace(model, meta_args={'x': x}, trace_act_ckpt=True, bias_addition_split=bias_addition_split)
|
|
|
|
assert torch.allclose(model(x), gm(x)), 'original model and traced model should be the same!'
|
|
|
|
if (alpha == 1 and beta == 1) or not bias_addition_split:
|
|
|
|
assert '*' not in gm.code, 'bias addition should not be split!'
|
|
|
|
elif bias_addition_split:
|
|
|
|
assert '+' in gm.code, 'bias addition should be split!'
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == '__main__':
|
2023-03-22 05:38:11 +00:00
|
|
|
test_siu_model()
|
|
|
|
test_addmm_model()
|