diff --git a/colossalai/fx/tracer/meta_patch/patched_module.py b/colossalai/fx/tracer/meta_patch/patched_module.py index f91436fb2..2eff50882 100644 --- a/colossalai/fx/tracer/meta_patch/patched_module.py +++ b/colossalai/fx/tracer/meta_patch/patched_module.py @@ -1,7 +1,88 @@ +import math import torch from .registry import meta_patched_module @meta_patched_module.register(torch.nn.Linear) def torch_nn_linear(self, input): + last_dim = input.shape[-1] + assert last_dim == self.in_features, f'Expected hidden size {self.in_features} but got {last_dim} for the torch.nn.Linear patch' return torch.empty(input.shape[:-1] + (self.out_features,), device="meta") + + +@meta_patched_module.register(torch.nn.LayerNorm) +@meta_patched_module.register(torch.nn.GroupNorm) +@meta_patched_module.register(torch.nn.BatchNorm1d) +@meta_patched_module.register(torch.nn.BatchNorm2d) +@meta_patched_module.register(torch.nn.BatchNorm3d) +def torch_nn_normalize(self, input): + # check shape + if isinstance(self, torch.nn.BatchNorm1d): + assert input.dim() in [2, 3] + elif isinstance(self, torch.nn.BatchNorm2d): + assert input.dim() == 4 + elif isinstance(self, torch.nn.BatchNorm3d): + assert input.dim() == 5 + + # normalization maintain the same shape as the input + return input.clone() + + +@meta_patched_module.register(torch.nn.Embedding) +def torch_nn_embedding(self, input): + result_shape = input.shape[:-1] + (self.embedding_dim,) + return torch.empty(result_shape, device='meta') + + +@meta_patched_module.register(torch.nn.Conv1d) +def torch_nn_conv1d(self, input): + # the output shape is calculated using the formula stated + # at https://pytorch.org/docs/stable/generated/torch.nn.Conv1d.html#torch.nn.Conv1d + l_in = input.shape[-1] + c_out = self.out_channels + l_out = math.floor((l_in + 2 * self.padding[0] - self.dilation[0] * + (self.kernel_size[0] - 1) - 1) / self.stride[0] + 1) + result_shape = input.shape[:-2] + ( + c_out, + l_out, + ) + return torch.empty(result_shape, device='meta') + + +@meta_patched_module.register(torch.nn.Conv2d) +def torch_nn_conv2d(self, input): + # the output shape is calculated using the formula stated + # at https://pytorch.org/docs/stable/generated/torch.nn.Conv1d.html#torch.nn.Conv2d + h_in, w_in = input.shape[-2:] + c_out = self.out_channels + h_out = math.floor((h_in + 2 * self.padding[0] - self.dilation[0] * + (self.kernel_size[0] - 1) - 1) / self.stride[0] + 1) + w_out = math.floor((w_in + 2 * self.padding[1] - self.dilation[1] * + (self.kernel_size[1] - 1) - 1) / self.stride[1] + 1) + result_shape = input.shape[:-3] + ( + c_out, + h_out, + w_out, + ) + return torch.empty(result_shape, device='meta') + + +@meta_patched_module.register(torch.nn.Conv3d) +def torch_nn_conv3d(self, input): + # the output shape is calculated using the formula stated + # at https://pytorch.org/docs/stable/generated/torch.nn.Conv1d.html#torch.nn.Conv3d + d_in, h_in, w_in = input.shape[-3:] + c_out = self.out_channels + d_out = math.floor((d_in + 2 * self.padding[0] - self.dilation[0] * + (self.kernel_size[0] - 1) - 1) / self.stride[0] + 1) + h_out = math.floor((h_in + 2 * self.padding[1] - self.dilation[1] * + (self.kernel_size[1] - 1) - 1) / self.stride[1] + 1) + w_out = math.floor((w_in + 2 * self.padding[2] - self.dilation[2] * + (self.kernel_size[2] - 1) - 1) / self.stride[2] + 1) + result_shape = input.shape[:-4] + ( + c_out, + d_out, + h_out, + w_out, + ) + return torch.empty(result_shape, device='meta') diff --git a/tests/test_fx/test_tracer/test_patched_module.py b/tests/test_fx/test_tracer/test_patched_module.py new file mode 100644 index 000000000..0cb38f436 --- /dev/null +++ b/tests/test_fx/test_tracer/test_patched_module.py @@ -0,0 +1,227 @@ +import torch +from colossalai.fx.tracer.meta_patch import patched_module + + +def _run(data, module, patch_fn): + try: + output = patch_fn(module, data) + return output + except Exception as e: + return e + + +def _assert_output_shape(data, module, patch_fn, expect_exception, output_shape): + output = _run(data, module, patch_fn) + + if expect_exception: + assert isinstance(output, AssertionError) + else: + assert not isinstance(output, Exception) + assert output.is_meta + assert output.shape == output_shape + + +def test_linear(): + # test linear patch can produce the meta output with correct shape + data = torch.rand(2, 4, device='meta') + module = torch.nn.Linear(4, 2) + _assert_output_shape(data, module, patched_module.torch_nn_linear, False, torch.Size([2, 2])) + + # Test if the linear patch can catch exception when dimension does not match + data = torch.rand(2, 2, device='meta') + _assert_output_shape(data, module, patched_module.torch_nn_linear, True, None) + + +def test_embedding(): + data = torch.rand(2, 4, device='meta') + + # test layernorm + ln = torch.nn.LayerNorm(4) + _assert_output_shape(data, ln, patched_module.torch_nn_normalize, False, data.shape) + + # test group norm + gn = torch.nn.GroupNorm(4, num_channels=2) + _assert_output_shape(data, gn, patched_module.torch_nn_normalize, False, data.shape) + + # test batch norm 1d + bn1d = torch.nn.BatchNorm1d(4) + data = torch.rand(2, 4, device='meta') + _assert_output_shape(data=data, + module=bn1d, + patch_fn=patched_module.torch_nn_normalize, + expect_exception=False, + output_shape=data.shape) + + data = torch.rand(2, 4, device='meta') + _assert_output_shape(data=data, + module=bn1d, + patch_fn=patched_module.torch_nn_normalize, + expect_exception=False, + output_shape=data.shape) + + data = torch.rand(2, 3, 4, device='meta') + _assert_output_shape(data=data, + module=bn1d, + patch_fn=patched_module.torch_nn_normalize, + expect_exception=False, + output_shape=data.shape) + + data = torch.rand(1, 2, 3, 4, device='meta') + _assert_output_shape(data=data, + module=bn1d, + patch_fn=patched_module.torch_nn_normalize, + expect_exception=True, + output_shape=None) + + # test batch norm 2d + bn2d = torch.nn.BatchNorm2d(4) + + data = torch.rand(1, 2, 3, 4, device='meta') + _assert_output_shape(data=data, + module=bn2d, + patch_fn=patched_module.torch_nn_normalize, + expect_exception=False, + output_shape=data.shape) + + data = torch.rand(2, 3, 4, device='meta') + _assert_output_shape(data=data, + module=bn2d, + patch_fn=patched_module.torch_nn_normalize, + expect_exception=True, + output_shape=None) + + # # test batch size 3d + bn3d = torch.nn.BatchNorm3d(4) + + data = torch.rand(1, 1, 2, 3, 4, device='meta') + _assert_output_shape(data=data, + module=bn3d, + patch_fn=patched_module.torch_nn_normalize, + expect_exception=False, + output_shape=data.shape) + + data = torch.rand(1, 2, 3, 4, device='meta') + _assert_output_shape(data=data, + module=bn3d, + patch_fn=patched_module.torch_nn_normalize, + expect_exception=True, + output_shape=None) + + +def test_conv1d(): + # test conv 1d + data = torch.rand(2, 3, 4) + + conv1d = torch.nn.Conv1d(in_channels=3, out_channels=4, kernel_size=2) + materialized_output = conv1d(data) + meta_data = data.to('meta') + _assert_output_shape(data=meta_data, + module=conv1d, + patch_fn=patched_module.torch_nn_conv1d, + expect_exception=False, + output_shape=materialized_output.shape) + + conv1d = torch.nn.Conv1d(in_channels=3, out_channels=4, kernel_size=2, padding=1) + materialized_output = conv1d(data) + meta_data = data.to('meta') + _assert_output_shape(data=meta_data, + module=conv1d, + patch_fn=patched_module.torch_nn_conv1d, + expect_exception=False, + output_shape=materialized_output.shape) + + conv1d = torch.nn.Conv1d(in_channels=3, + out_channels=4, + kernel_size=2, + padding=1, + dilation=2, + padding_mode='reflect') + materialized_output = conv1d(data) + meta_data = data.to('meta') + _assert_output_shape(data=meta_data, + module=conv1d, + patch_fn=patched_module.torch_nn_conv1d, + expect_exception=False, + output_shape=materialized_output.shape) + + +def test_conv2d(): + # test conv 1d + data = torch.rand(2, 3, 4, 4) + conv2d = torch.nn.Conv2d(in_channels=3, out_channels=4, kernel_size=2) + materialized_output = conv2d(data) + _assert_output_shape(data=data, + module=conv2d, + patch_fn=patched_module.torch_nn_conv2d, + expect_exception=False, + output_shape=materialized_output.shape) + + conv2d = torch.nn.Conv2d(in_channels=3, out_channels=4, kernel_size=2, padding=1) + materialized_output = conv2d(data) + _assert_output_shape(data=data, + module=conv2d, + patch_fn=patched_module.torch_nn_conv2d, + expect_exception=False, + output_shape=materialized_output.shape) + + conv2d = torch.nn.Conv2d(in_channels=3, out_channels=4, kernel_size=2, padding=1, dilation=2) + materialized_output = conv2d(data) + _assert_output_shape(data=data, + module=conv2d, + patch_fn=patched_module.torch_nn_conv2d, + expect_exception=False, + output_shape=materialized_output.shape) + + conv2d = torch.nn.Conv2d(in_channels=3, + out_channels=4, + kernel_size=2, + padding=1, + dilation=2, + padding_mode='reflect') + materialized_output = conv2d(data) + _assert_output_shape(data=data, + module=conv2d, + patch_fn=patched_module.torch_nn_conv2d, + expect_exception=False, + output_shape=materialized_output.shape) + + +def test_conv3d(): + # test conv 1d + data = torch.rand(2, 3, 4, 4, 4) + conv3d = torch.nn.Conv3d(in_channels=3, out_channels=4, kernel_size=2) + materialized_output = conv3d(data) + _assert_output_shape(data=data, + module=conv3d, + patch_fn=patched_module.torch_nn_conv3d, + expect_exception=False, + output_shape=materialized_output.shape) + + conv3d = torch.nn.Conv3d(in_channels=3, out_channels=4, kernel_size=2, padding=1) + materialized_output = conv3d(data) + _assert_output_shape(data=data, + module=conv3d, + patch_fn=patched_module.torch_nn_conv3d, + expect_exception=False, + output_shape=materialized_output.shape) + + conv3d = torch.nn.Conv3d(in_channels=3, out_channels=4, kernel_size=2, padding=1, dilation=2) + materialized_output = conv3d(data) + _assert_output_shape(data=data, + module=conv3d, + patch_fn=patched_module.torch_nn_conv3d, + expect_exception=False, + output_shape=materialized_output.shape) + + conv3d = torch.nn.Conv3d(in_channels=3, + out_channels=4, + kernel_size=2, + padding=1, + dilation=2, + padding_mode='reflect') + materialized_output = conv3d(data) + _assert_output_shape(data=data, + module=conv3d, + patch_fn=patched_module.torch_nn_conv3d, + expect_exception=False, + output_shape=materialized_output.shape)