You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
ColossalAI/tests/test_analyzer/test_fx/test_bias_addition.py

111 lines
3.6 KiB

import pytest
import torch
from packaging import version
from torch.utils.checkpoint import checkpoint
from colossalai.testing.utils import clear_cache_before_run, parameterize
try:
from colossalai._analyzer.fx import symbolic_trace
except:
pass
class LinearModel(torch.nn.Module):
def __init__(self, in_features, out_features, bias):
super().__init__()
self.linear = torch.nn.Linear(in_features, out_features, bias=bias)
def forward(self, x):
x = self.linear(x)
return x
class ConvModel(torch.nn.Module):
def __init__(self, in_channel, out_channels, kernel_size, bias) -> None:
super().__init__()
self.conv = torch.nn.Conv2d(
in_channel, out_channels, kernel_size, bias=bias, padding=1, stride=2, dilation=2, groups=3
)
self.conv_transpose = torch.nn.ConvTranspose2d(
in_channel, out_channels, kernel_size, bias=bias, padding=1, stride=2, dilation=2, groups=3
)
def forward(self, x, select=0):
if select == 0:
x = self.conv(x)
else:
x = self.conv_transpose(x)
return x
class SiuModel(torch.nn.Module):
def __init__(self, bias) -> None:
super().__init__()
self.linear = LinearModel(3, 3, bias)
self.conv = ConvModel(3, 6, 3, bias)
def forward(self, x, select=torch.Tensor([0])):
x = self.linear(x)
if select:
x = checkpoint(self.conv, x, 0)
else:
x = checkpoint(self.conv, x, 1)
return x
class AddmmModel(torch.nn.Module):
def __init__(self, alpha, beta) -> None:
super().__init__()
self.alpha = alpha
self.beta = beta
def forward(self, x):
x = torch.addmm(x, x, x, alpha=self.alpha, beta=self.beta)
return x
@pytest.mark.skipif(version.parse(torch.__version__) < version.parse("1.12.0"), reason="torch version < 12")
@clear_cache_before_run()
@parameterize("bias", [True, False])
@parameterize("bias_addition_split", [True, False])
@parameterize("shape", [(3, 3, 3), (3, 3, 3, 3)])
@parameterize("select", [torch.Tensor([0]), torch.Tensor([1])])
def test_siu_model(bias, bias_addition_split, shape, select):
model = SiuModel(bias=bias)
x = torch.rand(shape)
gm = symbolic_trace(
model,
meta_args={"x": x},
concrete_args={"select": select},
trace_act_ckpt=True,
bias_addition_split=bias_addition_split,
)
assert torch.allclose(model(x, select), gm(x)), "original model and traced model should be the same!"
if bias and bias_addition_split:
assert "+" in gm.code, "bias addition should be split!"
else:
assert "+" not in gm.code, "bias addition should not be split!"
@pytest.mark.skipif(version.parse(torch.__version__) < version.parse("1.12.0"), reason="torch version < 12")
@parameterize("alpha", [1, 2])
@parameterize("beta", [1, 2])
@parameterize("bias_addition_split", [True, False])
@parameterize("shape", [(3, 3), (5, 5)])
def test_addmm_model(alpha, beta, bias_addition_split, shape):
model = AddmmModel(alpha=alpha, beta=beta)
x = torch.rand(shape)
gm = symbolic_trace(model, meta_args={"x": x}, trace_act_ckpt=True, bias_addition_split=bias_addition_split)
assert torch.allclose(model(x), gm(x)), "original model and traced model should be the same!"
if (alpha == 1 and beta == 1) or not bias_addition_split:
assert "*" not in gm.code, "bias addition should not be split!"
elif bias_addition_split:
assert "+" in gm.code, "bias addition should be split!"
if __name__ == "__main__":
test_siu_model()
test_addmm_model()