import torch from colossalai.fx.tracer.meta_patch import patched_module from colossalai.testing import clear_cache_before_run def _run(data, module, patch_fn): try: if isinstance(data, dict): output = patch_fn(module, **data) if isinstance(data, tuple) or isinstance(data, list): output = patch_fn(module, *data) else: output = patch_fn(module, data) return output except Exception as e: return e def _assert_output_shape(data, module, patch_fn, expect_exception, output_shape): output = _run(data, module, patch_fn) if expect_exception: assert isinstance(output, AssertionError) else: assert not isinstance(output, Exception) if isinstance(output, tuple): for item, shape in zip(output, output_shape): assert item.is_meta assert item.shape == shape else: assert output.is_meta assert output.shape == output_shape @clear_cache_before_run() def test_linear(): # test linear patch can produce the meta output with correct shape data = torch.rand(2, 4, device='meta') module = torch.nn.Linear(4, 2) _assert_output_shape(data, module, patched_module.torch_nn_linear, False, torch.Size([2, 2])) # test if the linear patch can catch exception when dimension does not match data = torch.rand(2, 2, device='meta') _assert_output_shape(data, module, patched_module.torch_nn_linear, True, None) @clear_cache_before_run() def test_rnn(): # test rnn patch can produce the meta output with correct shape data = (torch.randn(5, 3, 10), torch.randn(2, 3, 20)) module = torch.nn.RNN(10, 20, 2) output, hn = module(*data) meta_data = (torch.randn(5, 3, 10).to('meta'), torch.randn(2, 3, 20).to('meta')) _assert_output_shape(meta_data, module, patched_module.torch_nn_rnn, False, (output.shape, hn.shape)) # test if the rnn patch can catch exception when dimension does not match data = (torch.randn(5, 3, 10), torch.randn(2, 3, 20)) module = torch.nn.RNN(10, 20, 2) output, hn = module(*data) meta_data = (torch.randn(5, 3, 1).to('meta'), torch.randn(2, 3, 20).to('meta')) _assert_output_shape(meta_data, module, patched_module.torch_nn_rnn, True, None) @clear_cache_before_run() def test_embedding(): data = torch.rand(2, 4, device='meta') # test layernorm ln = torch.nn.LayerNorm(4) _assert_output_shape(data, ln, patched_module.torch_nn_normalize, False, data.shape) # test group norm gn = torch.nn.GroupNorm(4, num_channels=8) _assert_output_shape(data, gn, patched_module.torch_nn_normalize, False, data.shape) # test batch norm 1d bn1d = torch.nn.BatchNorm1d(4) data = torch.rand(2, 4, device='meta') _assert_output_shape(data=data, module=bn1d, patch_fn=patched_module.torch_nn_normalize, expect_exception=False, output_shape=data.shape) data = torch.rand(2, 4, device='meta') _assert_output_shape(data=data, module=bn1d, patch_fn=patched_module.torch_nn_normalize, expect_exception=False, output_shape=data.shape) data = torch.rand(2, 3, 4, device='meta') _assert_output_shape(data=data, module=bn1d, patch_fn=patched_module.torch_nn_normalize, expect_exception=False, output_shape=data.shape) data = torch.rand(1, 2, 3, 4, device='meta') _assert_output_shape(data=data, module=bn1d, patch_fn=patched_module.torch_nn_normalize, expect_exception=True, output_shape=None) # test batch norm 2d bn2d = torch.nn.BatchNorm2d(4) data = torch.rand(1, 2, 3, 4, device='meta') _assert_output_shape(data=data, module=bn2d, patch_fn=patched_module.torch_nn_normalize, expect_exception=False, output_shape=data.shape) data = torch.rand(2, 3, 4, device='meta') _assert_output_shape(data=data, module=bn2d, patch_fn=patched_module.torch_nn_normalize, expect_exception=True, output_shape=None) # # test batch size 3d bn3d = torch.nn.BatchNorm3d(4) data = torch.rand(1, 1, 2, 3, 4, device='meta') _assert_output_shape(data=data, module=bn3d, patch_fn=patched_module.torch_nn_normalize, expect_exception=False, output_shape=data.shape) data = torch.rand(1, 2, 3, 4, device='meta') _assert_output_shape(data=data, module=bn3d, patch_fn=patched_module.torch_nn_normalize, expect_exception=True, output_shape=None) @clear_cache_before_run() def test_conv1d(): # test conv 1d data = torch.rand(2, 3, 4) conv1d = torch.nn.Conv1d(in_channels=3, out_channels=4, kernel_size=2) materialized_output = conv1d(data) meta_data = data.to('meta') _assert_output_shape(data=meta_data, module=conv1d, patch_fn=patched_module.torch_nn_conv1d, expect_exception=False, output_shape=materialized_output.shape) conv1d = torch.nn.Conv1d(in_channels=3, out_channels=4, kernel_size=2, padding=1) materialized_output = conv1d(data) meta_data = data.to('meta') _assert_output_shape(data=meta_data, module=conv1d, patch_fn=patched_module.torch_nn_conv1d, expect_exception=False, output_shape=materialized_output.shape) conv1d = torch.nn.Conv1d(in_channels=3, out_channels=4, kernel_size=2, padding=1, dilation=2, padding_mode='reflect') materialized_output = conv1d(data) meta_data = data.to('meta') _assert_output_shape(data=meta_data, module=conv1d, patch_fn=patched_module.torch_nn_conv1d, expect_exception=False, output_shape=materialized_output.shape) def test_conv2d(): # test conv 2d data = torch.rand(2, 3, 4, 4) conv2d = torch.nn.Conv2d(in_channels=3, out_channels=4, kernel_size=2) materialized_output = conv2d(data) _assert_output_shape(data=data, module=conv2d, patch_fn=patched_module.torch_nn_conv2d, expect_exception=False, output_shape=materialized_output.shape) conv2d = torch.nn.Conv2d(in_channels=3, out_channels=4, kernel_size=2, padding=1) materialized_output = conv2d(data) _assert_output_shape(data=data, module=conv2d, patch_fn=patched_module.torch_nn_conv2d, expect_exception=False, output_shape=materialized_output.shape) conv2d = torch.nn.Conv2d(in_channels=3, out_channels=4, kernel_size=2, padding=1, dilation=2) materialized_output = conv2d(data) _assert_output_shape(data=data, module=conv2d, patch_fn=patched_module.torch_nn_conv2d, expect_exception=False, output_shape=materialized_output.shape) conv2d = torch.nn.Conv2d(in_channels=3, out_channels=4, kernel_size=2, padding=1, dilation=2, padding_mode='reflect') materialized_output = conv2d(data) _assert_output_shape(data=data, module=conv2d, patch_fn=patched_module.torch_nn_conv2d, expect_exception=False, output_shape=materialized_output.shape) @clear_cache_before_run() def test_conv3d(): # test conv 3d data = torch.rand(2, 3, 4, 4, 4) conv3d = torch.nn.Conv3d(in_channels=3, out_channels=4, kernel_size=2) materialized_output = conv3d(data) _assert_output_shape(data=data, module=conv3d, patch_fn=patched_module.torch_nn_conv3d, expect_exception=False, output_shape=materialized_output.shape) conv3d = torch.nn.Conv3d(in_channels=3, out_channels=4, kernel_size=2, padding=1) materialized_output = conv3d(data) _assert_output_shape(data=data, module=conv3d, patch_fn=patched_module.torch_nn_conv3d, expect_exception=False, output_shape=materialized_output.shape) conv3d = torch.nn.Conv3d(in_channels=3, out_channels=4, kernel_size=2, padding=1, dilation=2) materialized_output = conv3d(data) _assert_output_shape(data=data, module=conv3d, patch_fn=patched_module.torch_nn_conv3d, expect_exception=False, output_shape=materialized_output.shape) conv3d = torch.nn.Conv3d(in_channels=3, out_channels=4, kernel_size=2, padding=1, dilation=2, padding_mode='reflect') materialized_output = conv3d(data) _assert_output_shape(data=data, module=conv3d, patch_fn=patched_module.torch_nn_conv3d, expect_exception=False, output_shape=materialized_output.shape) @clear_cache_before_run() def test_conv_transpose1d(): # test conv transpose1d data = torch.rand(2, 3, 4) convtrans1d = torch.nn.ConvTranspose1d(in_channels=3, out_channels=4, kernel_size=2) materialized_output = convtrans1d(data) meta_data = data.to('meta') _assert_output_shape(data=meta_data, module=convtrans1d, patch_fn=patched_module.torch_nn_convtranspose1d, expect_exception=False, output_shape=materialized_output.shape) convtrans1d = torch.nn.ConvTranspose1d(in_channels=3, out_channels=4, kernel_size=2, padding=1) materialized_output = convtrans1d(data) meta_data = data.to('meta') _assert_output_shape(data=meta_data, module=convtrans1d, patch_fn=patched_module.torch_nn_convtranspose1d, expect_exception=False, output_shape=materialized_output.shape) @clear_cache_before_run() def test_conv_transpose2d(): # test conv transpose2d data = torch.rand(2, 3, 4, 4) convtrans2d = torch.nn.ConvTranspose2d(in_channels=3, out_channels=4, kernel_size=2) materialized_output = convtrans2d(data) meta_data = data.to('meta') _assert_output_shape(data=meta_data, module=convtrans2d, patch_fn=patched_module.torch_nn_convtranspose2d, expect_exception=False, output_shape=materialized_output.shape) convtrans2d = torch.nn.ConvTranspose2d(in_channels=3, out_channels=4, kernel_size=2, padding=1) materialized_output = convtrans2d(data) meta_data = data.to('meta') _assert_output_shape(data=meta_data, module=convtrans2d, patch_fn=patched_module.torch_nn_convtranspose2d, expect_exception=False, output_shape=materialized_output.shape) @clear_cache_before_run() def test_conv_transpose3d(): # test conv transpose2d data = torch.rand(2, 3, 4, 4, 4) convtrans3d = torch.nn.ConvTranspose3d(in_channels=3, out_channels=4, kernel_size=2) materialized_output = convtrans3d(data) meta_data = data.to('meta') _assert_output_shape(data=meta_data, module=convtrans3d, patch_fn=patched_module.torch_nn_convtranspose3d, expect_exception=False, output_shape=materialized_output.shape) convtrans3d = torch.nn.ConvTranspose3d(in_channels=3, out_channels=4, kernel_size=2, padding=1) materialized_output = convtrans3d(data) meta_data = data.to('meta') _assert_output_shape(data=meta_data, module=convtrans3d, patch_fn=patched_module.torch_nn_convtranspose3d, expect_exception=False, output_shape=materialized_output.shape) @clear_cache_before_run() def test_pool1d(): combinations = [[torch.nn.MaxPool1d, patched_module.torch_nn_maxpool1d], [torch.nn.AvgPool1d, patched_module.torch_nn_avgpool1d]] for (layer_cls, patch_func) in combinations: pooler = layer_cls(kernel_size=3) data = torch.rand(2, 3, 4) materialized_output = pooler(data) _assert_output_shape(data=data, module=pooler, patch_fn=patch_func, expect_exception=False, output_shape=materialized_output.shape) data = torch.rand(2, 4) materialized_output = pooler(data) _assert_output_shape(data=data, module=pooler, patch_fn=patch_func, expect_exception=False, output_shape=materialized_output.shape) data = torch.rand(2, 3, 4, 4) _assert_output_shape(data=data, module=pooler, patch_fn=patch_func, expect_exception=True, output_shape=None) @clear_cache_before_run() def test_pool2d(): combinations = [[torch.nn.MaxPool2d, patched_module.torch_nn_maxpool2d], [torch.nn.AvgPool2d, patched_module.torch_nn_avgpool2d]] for (layer_cls, patch_func) in combinations: pooler = layer_cls(kernel_size=3) # test max pool 3d data = torch.rand(2, 3, 4, 4) materialized_output = pooler(data) _assert_output_shape(data=data, module=pooler, patch_fn=patch_func, expect_exception=False, output_shape=materialized_output.shape) # test max pool 3d data = torch.rand(2, 4, 4) materialized_output = pooler(data) _assert_output_shape(data=data, module=pooler, patch_fn=patch_func, expect_exception=False, output_shape=materialized_output.shape) # test max pool 3d data = torch.rand(2, 3, 4, 4, 4) _assert_output_shape(data=data, module=pooler, patch_fn=patch_func, expect_exception=True, output_shape=None) @clear_cache_before_run() def test_pool3d(): combinations = [[torch.nn.MaxPool3d, patched_module.torch_nn_maxpool3d], [torch.nn.AvgPool3d, patched_module.torch_nn_avgpool3d]] for (layer_cls, patch_func) in combinations: pooler = layer_cls(kernel_size=3) # test max pool 3d data = torch.rand(2, 3, 4, 4, 4) materialized_output = pooler(data) _assert_output_shape(data=data, module=pooler, patch_fn=patch_func, expect_exception=False, output_shape=materialized_output.shape) # test max pool 3d data = torch.rand(2, 4, 4, 4) materialized_output = pooler(data) _assert_output_shape(data=data, module=pooler, patch_fn=patch_func, expect_exception=False, output_shape=materialized_output.shape) # test max pool 3d data = torch.rand(2, 3, 4) _assert_output_shape(data=data, module=pooler, patch_fn=patch_func, expect_exception=True, output_shape=None) # adapative pooling is different from other pooling, so test it individually @clear_cache_before_run() def test_adaptive_pooling_1d(): pooler = torch.nn.AdaptiveAvgPool1d(output_size=3) patch_func = patched_module.torch_nn_adapative_pooling_1d data = torch.rand(3, 4) output = pooler(data) _assert_output_shape(data=data, module=pooler, patch_fn=patch_func, expect_exception=False, output_shape=output.shape) data = torch.rand(2, 3, 4) output = pooler(data) _assert_output_shape(data=data, module=pooler, patch_fn=patch_func, expect_exception=False, output_shape=output.shape) data = torch.rand(2, 3, 4, 5) _assert_output_shape(data=data, module=pooler, patch_fn=patch_func, expect_exception=True, output_shape=None) @clear_cache_before_run() def test_adaptive_pooling_2d(): pooler = torch.nn.AdaptiveAvgPool2d(output_size=3) patch_func = patched_module.torch_nn_adapative_pooling_2d data = torch.rand(3, 4) _assert_output_shape(data=data, module=pooler, patch_fn=patch_func, expect_exception=True, output_shape=None) data = torch.rand(2, 3, 4) output = pooler(data) _assert_output_shape(data=data, module=pooler, patch_fn=patch_func, expect_exception=False, output_shape=output.shape) data = torch.rand(2, 3, 4, 5) output = pooler(data) _assert_output_shape(data=data, module=pooler, patch_fn=patch_func, expect_exception=False, output_shape=output.shape) @clear_cache_before_run() def test_adaptive_pooling_3d(): pooler = torch.nn.AdaptiveAvgPool3d(output_size=3) patch_func = patched_module.torch_nn_adapative_pooling_3d data = torch.rand(3, 4, 5) _assert_output_shape(data=data, module=pooler, patch_fn=patch_func, expect_exception=True, output_shape=None) data = torch.rand(2, 3, 4, 5) output = pooler(data) _assert_output_shape(data=data, module=pooler, patch_fn=patch_func, expect_exception=False, output_shape=output.shape) data = torch.rand(2, 3, 4, 5, 6) output = pooler(data) _assert_output_shape(data=data, module=pooler, patch_fn=patch_func, expect_exception=False, output_shape=output.shape)