import pytest import torch from packaging import version from torch.utils.checkpoint import checkpoint from colossalai.testing.utils import parameterize try: from colossalai._analyzer.fx import symbolic_trace except: pass class LinearModel(torch.nn.Module): def __init__(self, in_features, out_features, bias): super().__init__() self.linear = torch.nn.Linear(in_features, out_features, bias=bias) def forward(self, x): x = self.linear(x) return x class ConvModel(torch.nn.Module): def __init__(self, in_channel, out_channels, kernel_size, bias) -> None: super().__init__() self.conv = torch.nn.Conv2d(in_channel, out_channels, kernel_size, bias=bias, padding=1, stride=2, dilation=2, groups=3) self.conv_transpose = torch.nn.ConvTranspose2d(in_channel, out_channels, kernel_size, bias=bias, padding=1, stride=2, dilation=2, groups=3) def forward(self, x, select=0): if select == 0: x = self.conv(x) else: x = self.conv_transpose(x) return x class SiuModel(torch.nn.Module): def __init__(self, bias) -> None: super().__init__() self.linear = LinearModel(3, 3, bias) self.conv = ConvModel(3, 6, 3, bias) def forward(self, x, select=torch.Tensor([0])): x = self.linear(x) if select: x = checkpoint(self.conv, x, 0) else: x = checkpoint(self.conv, x, 1) return x class AddmmModel(torch.nn.Module): def __init__(self, alpha, beta) -> None: super().__init__() self.alpha = alpha self.beta = beta def forward(self, x): x = torch.addmm(x, x, x, alpha=self.alpha, beta=self.beta) return x @pytest.mark.skipif(version.parse(torch.__version__) < version.parse('1.12.0'), reason='torch version < 12') @parameterize("bias", [True, False]) @parameterize("bias_addition_split", [True, False]) @parameterize("shape", [(3, 3, 3), (3, 3, 3, 3)]) @parameterize("select", [torch.Tensor([0]), torch.Tensor([1])]) def test_siu_model(bias, bias_addition_split, shape, select): model = SiuModel(bias=bias) x = torch.rand(shape) gm = symbolic_trace(model, meta_args={'x': x}, concrete_args={'select': select}, trace_act_ckpt=True, bias_addition_split=bias_addition_split) assert torch.allclose(model(x, select), gm(x)), 'original model and traced model should be the same!' if bias and bias_addition_split: assert '+' in gm.code, 'bias addition should be split!' else: assert '+' not in gm.code, 'bias addition should not be split!' @pytest.mark.skipif(version.parse(torch.__version__) < version.parse('1.12.0'), reason='torch version < 12') @parameterize("alpha", [1, 2]) @parameterize("beta", [1, 2]) @parameterize("bias_addition_split", [True, False]) @parameterize("shape", [(3, 3), (5, 5)]) def test_addmm_model(alpha, beta, bias_addition_split, shape): model = AddmmModel(alpha=alpha, beta=beta) x = torch.rand(shape) gm = symbolic_trace(model, meta_args={'x': x}, trace_act_ckpt=True, bias_addition_split=bias_addition_split) assert torch.allclose(model(x), gm(x)), 'original model and traced model should be the same!' if (alpha == 1 and beta == 1) or not bias_addition_split: assert '*' not in gm.code, 'bias addition should not be split!' elif bias_addition_split: assert '+' in gm.code, 'bias addition should be split!' if __name__ == '__main__': test_siu_model() test_addmm_model()