import pytest import torch from packaging import version from torch.utils.checkpoint import checkpoint from colossalai.testing.utils import clear_cache_before_run, parameterize try: from colossalai._analyzer.fx import symbolic_trace except: pass class LinearModel(torch.nn.Module): def __init__(self, in_features, out_features, bias): super().__init__() self.linear = torch.nn.Linear(in_features, out_features, bias=bias) def forward(self, x): x = self.linear(x) return x class ConvModel(torch.nn.Module): def __init__(self, in_channel, out_channels, kernel_size, bias) -> None: super().__init__() self.conv = torch.nn.Conv2d( in_channel, out_channels, kernel_size, bias=bias, padding=1, stride=2, dilation=2, groups=3 ) self.conv_transpose = torch.nn.ConvTranspose2d( in_channel, out_channels, kernel_size, bias=bias, padding=1, stride=2, dilation=2, groups=3 ) def forward(self, x, select=0): if select == 0: x = self.conv(x) else: x = self.conv_transpose(x) return x class SiuModel(torch.nn.Module): def __init__(self, bias) -> None: super().__init__() self.linear = LinearModel(3, 3, bias) self.conv = ConvModel(3, 6, 3, bias) def forward(self, x, select=torch.Tensor([0])): x = self.linear(x) if select: x = checkpoint(self.conv, x, 0) else: x = checkpoint(self.conv, x, 1) return x class AddmmModel(torch.nn.Module): def __init__(self, alpha, beta) -> None: super().__init__() self.alpha = alpha self.beta = beta def forward(self, x): x = torch.addmm(x, x, x, alpha=self.alpha, beta=self.beta) return x @pytest.mark.skipif(version.parse(torch.__version__) < version.parse("1.12.0"), reason="torch version < 12") @clear_cache_before_run() @parameterize("bias", [True, False]) @parameterize("bias_addition_split", [True, False]) @parameterize("shape", [(3, 3, 3), (3, 3, 3, 3)]) @parameterize("select", [torch.Tensor([0]), torch.Tensor([1])]) def test_siu_model(bias, bias_addition_split, shape, select): model = SiuModel(bias=bias) x = torch.rand(shape) gm = symbolic_trace( model, meta_args={"x": x}, concrete_args={"select": select}, trace_act_ckpt=True, bias_addition_split=bias_addition_split, ) assert torch.allclose(model(x, select), gm(x)), "original model and traced model should be the same!" if bias and bias_addition_split: assert "+" in gm.code, "bias addition should be split!" else: assert "+" not in gm.code, "bias addition should not be split!" @pytest.mark.skipif(version.parse(torch.__version__) < version.parse("1.12.0"), reason="torch version < 12") @parameterize("alpha", [1, 2]) @parameterize("beta", [1, 2]) @parameterize("bias_addition_split", [True, False]) @parameterize("shape", [(3, 3), (5, 5)]) def test_addmm_model(alpha, beta, bias_addition_split, shape): model = AddmmModel(alpha=alpha, beta=beta) x = torch.rand(shape) gm = symbolic_trace(model, meta_args={"x": x}, trace_act_ckpt=True, bias_addition_split=bias_addition_split) assert torch.allclose(model(x), gm(x)), "original model and traced model should be the same!" if (alpha == 1 and beta == 1) or not bias_addition_split: assert "*" not in gm.code, "bias addition should not be split!" elif bias_addition_split: assert "+" in gm.code, "bias addition should be split!" if __name__ == "__main__": test_siu_model() test_addmm_model()