ColossalAI/tests/test_analyzer/test_fx/test_bias_addition.py

123 lines
4.2 KiB
Python
Raw Normal View History

import pytest
import torch
from packaging import version
from torch.utils.checkpoint import checkpoint
from colossalai.testing.utils import clear_cache_before_run, parameterize
2023-03-22 05:38:11 +00:00
try:
from colossalai._analyzer.fx import symbolic_trace
except:
pass
class LinearModel(torch.nn.Module):
def __init__(self, in_features, out_features, bias):
super().__init__()
self.linear = torch.nn.Linear(in_features, out_features, bias=bias)
def forward(self, x):
x = self.linear(x)
return x
class ConvModel(torch.nn.Module):
def __init__(self, in_channel, out_channels, kernel_size, bias) -> None:
super().__init__()
self.conv = torch.nn.Conv2d(in_channel,
out_channels,
kernel_size,
bias=bias,
padding=1,
stride=2,
dilation=2,
groups=3)
self.conv_transpose = torch.nn.ConvTranspose2d(in_channel,
out_channels,
kernel_size,
bias=bias,
padding=1,
stride=2,
dilation=2,
groups=3)
def forward(self, x, select=0):
if select == 0:
x = self.conv(x)
else:
x = self.conv_transpose(x)
return x
class SiuModel(torch.nn.Module):
def __init__(self, bias) -> None:
super().__init__()
self.linear = LinearModel(3, 3, bias)
self.conv = ConvModel(3, 6, 3, bias)
2023-03-22 05:38:11 +00:00
def forward(self, x, select=torch.Tensor([0])):
x = self.linear(x)
2023-03-22 05:38:11 +00:00
if select:
x = checkpoint(self.conv, x, 0)
else:
x = checkpoint(self.conv, x, 1)
return x
class AddmmModel(torch.nn.Module):
def __init__(self, alpha, beta) -> None:
super().__init__()
self.alpha = alpha
self.beta = beta
def forward(self, x):
x = torch.addmm(x, x, x, alpha=self.alpha, beta=self.beta)
return x
@pytest.mark.skipif(version.parse(torch.__version__) < version.parse('1.12.0'), reason='torch version < 12')
@clear_cache_before_run()
2023-03-22 05:38:11 +00:00
@parameterize("bias", [True, False])
@parameterize("bias_addition_split", [True, False])
@parameterize("shape", [(3, 3, 3), (3, 3, 3, 3)])
@parameterize("select", [torch.Tensor([0]), torch.Tensor([1])])
def test_siu_model(bias, bias_addition_split, shape, select):
model = SiuModel(bias=bias)
x = torch.rand(shape)
gm = symbolic_trace(model,
meta_args={'x': x},
concrete_args={'select': select},
trace_act_ckpt=True,
bias_addition_split=bias_addition_split)
2023-03-22 05:38:11 +00:00
assert torch.allclose(model(x, select), gm(x)), 'original model and traced model should be the same!'
if bias and bias_addition_split:
assert '+' in gm.code, 'bias addition should be split!'
else:
assert '+' not in gm.code, 'bias addition should not be split!'
2023-03-22 05:38:11 +00:00
@pytest.mark.skipif(version.parse(torch.__version__) < version.parse('1.12.0'), reason='torch version < 12')
@parameterize("alpha", [1, 2])
@parameterize("beta", [1, 2])
@parameterize("bias_addition_split", [True, False])
@parameterize("shape", [(3, 3), (5, 5)])
def test_addmm_model(alpha, beta, bias_addition_split, shape):
model = AddmmModel(alpha=alpha, beta=beta)
x = torch.rand(shape)
gm = symbolic_trace(model, meta_args={'x': x}, trace_act_ckpt=True, bias_addition_split=bias_addition_split)
assert torch.allclose(model(x), gm(x)), 'original model and traced model should be the same!'
if (alpha == 1 and beta == 1) or not bias_addition_split:
assert '*' not in gm.code, 'bias addition should not be split!'
elif bias_addition_split:
assert '+' in gm.code, 'bias addition should be split!'
if __name__ == '__main__':
2023-03-22 05:38:11 +00:00
test_siu_model()
test_addmm_model()