From 4b3d6caeb3c3f619f66491a701deed9aba9513e6 Mon Sep 17 00:00:00 2001 From: YuliangLiu0306 <72588413+YuliangLiu0306@users.noreply.github.com> Date: Thu, 1 Sep 2022 19:05:07 +0800 Subject: [PATCH] [fx]patch nn.functional convolution (#1528) --- .../meta_patch/patched_function/__init__.py | 3 +- .../patched_function/convolution.py | 178 ++++++++++++++++++ .../test_tracer/test_functional_conv.py | 48 +++++ 3 files changed, 228 insertions(+), 1 deletion(-) create mode 100644 colossalai/fx/tracer/meta_patch/patched_function/convolution.py create mode 100644 tests/test_fx/test_tracer/test_functional_conv.py diff --git a/colossalai/fx/tracer/meta_patch/patched_function/__init__.py b/colossalai/fx/tracer/meta_patch/patched_function/__init__.py index ca20ac0a9..a40ca4c39 100644 --- a/colossalai/fx/tracer/meta_patch/patched_function/__init__.py +++ b/colossalai/fx/tracer/meta_patch/patched_function/__init__.py @@ -3,4 +3,5 @@ from .arithmetic import * from .embedding import * from .normalization import * from .python_ops import * -from .torch_ops import * \ No newline at end of file +from .torch_ops import * +from .convolution import * \ No newline at end of file diff --git a/colossalai/fx/tracer/meta_patch/patched_function/convolution.py b/colossalai/fx/tracer/meta_patch/patched_function/convolution.py new file mode 100644 index 000000000..eb88f2451 --- /dev/null +++ b/colossalai/fx/tracer/meta_patch/patched_function/convolution.py @@ -0,0 +1,178 @@ +import torch +import collections +from itertools import repeat +from ..registry import meta_patched_function +import math + + +def _ntuple(n, name="parse"): + + def parse(x): + if isinstance(x, collections.abc.Iterable): + return tuple(x) + return tuple(repeat(x, n)) + + parse.__name__ = name + return parse + + +_single = _ntuple(1, "_single") +_pair = _ntuple(2, "_pair") +_triple = _ntuple(3, "_triple") + + +def _extract_kwargs(kwargs): + if 'stride' in kwargs: + stride = kwargs['stride'] + else: + stride = 1 + # TODO: process str type padding + if 'padding' in kwargs: + padding = kwargs['padding'] + else: + padding = 0 + if 'dilation' in kwargs: + dilation = kwargs['dilation'] + else: + dilation = 1 + if 'output_padding' in kwargs: + output_padding = kwargs['output_padding'] + else: + output_padding = 0 + + return stride, padding, dilation, output_padding + + +@meta_patched_function.register(torch.nn.functional.conv1d) +def torch_nn_functional_conv1d(input, weight, **kwargs): + stride, padding, dilation, _ = _extract_kwargs(kwargs) + + stride = _single(stride) + padding = _single(padding) + dilation = _single(dilation) + + kernel_size = weight.shape[2:] + l_in = input.shape[-1] + c_out = weight.shape[0] + l_out = math.floor((l_in + 2 * padding[0] - dilation[0] * (kernel_size[0] - 1) - 1) / stride[0] + 1) + result_shape = input.shape[:-2] + ( + c_out, + l_out, + ) + return torch.empty(result_shape, device='meta') + + +@meta_patched_function.register(torch.nn.functional.conv2d) +def torch_nn_functional_conv2d(input, weight, **kwargs): + stride, padding, dilation, _ = _extract_kwargs(kwargs) + + stride = _pair(stride) + padding = _pair(padding) + dilation = _pair(dilation) + + kernel_size = weight.shape[2:] + h_in, w_in = input.shape[-2:] + c_out = weight.shape[0] + h_out = math.floor((h_in + 2 * padding[0] - dilation[0] * (kernel_size[0] - 1) - 1) / stride[0] + 1) + w_out = math.floor((w_in + 2 * padding[1] - dilation[1] * (kernel_size[1] - 1) - 1) / stride[1] + 1) + result_shape = input.shape[:-3] + ( + c_out, + h_out, + w_out, + ) + return torch.empty(result_shape, device='meta') + + +@meta_patched_function.register(torch.nn.functional.conv3d) +def torch_nn_functional_conv3d(input, weight, **kwargs): + stride, padding, dilation, _ = _extract_kwargs(kwargs) + + stride = _triple(stride) + padding = _triple(padding) + dilation = _triple(dilation) + + kernel_size = weight.shape[2:] + d_in, h_in, w_in = input.shape[-3:] + c_out = weight.shape[0] + d_out = math.floor((d_in + 2 * padding[0] - dilation[0] * (kernel_size[0] - 1) - 1) / stride[0] + 1) + h_out = math.floor((h_in + 2 * padding[1] - dilation[1] * (kernel_size[1] - 1) - 1) / stride[1] + 1) + w_out = math.floor((w_in + 2 * padding[2] - dilation[2] * (kernel_size[2] - 1) - 1) / stride[2] + 1) + result_shape = input.shape[:-4] + ( + c_out, + d_out, + h_out, + w_out, + ) + return torch.empty(result_shape, device='meta') + + +@meta_patched_function.register(torch.nn.functional.conv_transpose1d) +def torch_nn_functional_convtranspose1d(input, weight, **kwargs): + stride, padding, dilation, output_padding = _extract_kwargs(kwargs) + + stride = _single(stride) + padding = _single(padding) + dilation = _single(dilation) + output_padding = _single(output_padding) + + kernel_size = weight.shape[2:] + l_in = input.shape[-1] + c_out = weight.shape[1] + l_out = math.floor((l_in - 1) * stride[0] - 2 * padding[0] + dilation[0] * (kernel_size[0] - 1) + + output_padding[0] + 1) + result_shape = input.shape[:-2] + ( + c_out, + l_out, + ) + return torch.empty(result_shape, device='meta') + + +@meta_patched_function.register(torch.nn.functional.conv_transpose2d) +def torch_nn_functional_convtranspose2d(input, weight, **kwargs): + stride, padding, dilation, output_padding = _extract_kwargs(kwargs) + + stride = _pair(stride) + padding = _pair(padding) + dilation = _pair(dilation) + output_padding = _pair(output_padding) + + kernel_size = weight.shape[2:] + h_in, w_in = input.shape[-2:] + c_out = weight.shape[1] + h_out = math.floor((h_in - 1) * stride[0] - 2 * padding[0] + dilation[0] * (kernel_size[0] - 1) + + output_padding[0] + 1) + w_out = math.floor((w_in - 1) * stride[1] - 2 * padding[1] + dilation[1] * (kernel_size[1] - 1) + + output_padding[1] + 1) + result_shape = input.shape[:-3] + ( + c_out, + h_out, + w_out, + ) + return torch.empty(result_shape, device='meta') + + +@meta_patched_function.register(torch.nn.functional.conv_transpose3d) +def torch_nn_functional_convtranspose3d(input, weight, **kwargs): + stride, padding, dilation, output_padding = _extract_kwargs(kwargs) + + stride = _triple(stride) + padding = _triple(padding) + dilation = _triple(dilation) + output_padding = _triple(output_padding) + + kernel_size = weight.shape[2:] + d_in, h_in, w_in = input.shape[-3:] + c_out = weight.shape[1] + d_out = math.floor((d_in - 1) * stride[0] - 2 * padding[0] + dilation[0] * (kernel_size[0] - 1) + + output_padding[0] + 1) + h_out = math.floor((h_in - 1) * stride[1] - 2 * padding[1] + dilation[1] * (kernel_size[1] - 1) + + output_padding[1] + 1) + w_out = math.floor((w_in - 1) * stride[2] - 2 * padding[2] + dilation[2] * (kernel_size[2] - 1) + + output_padding[2] + 1) + result_shape = input.shape[:-4] + ( + c_out, + d_out, + h_out, + w_out, + ) + return torch.empty(result_shape, device='meta') diff --git a/tests/test_fx/test_tracer/test_functional_conv.py b/tests/test_fx/test_tracer/test_functional_conv.py new file mode 100644 index 000000000..95670b85f --- /dev/null +++ b/tests/test_fx/test_tracer/test_functional_conv.py @@ -0,0 +1,48 @@ +import torch +from torch.nn import functional as F +from colossalai.fx.tracer.meta_patch import patched_function + + +def test_conv(): + # test F.conv_1d + data_1d = torch.rand(3, 16, 10) + weight_1d = torch.rand(3, 16, 3) + out_1d = F.conv1d(data_1d, weight_1d) + patched_out_1d = patched_function.torch_nn_functional_conv1d(data_1d, weight_1d) + assert out_1d.shape == patched_out_1d.shape + + # test F.conv_transpose1d + weight_1d = torch.transpose(weight_1d, 0, 1) + out_transpose_1d = F.conv_transpose1d(data_1d, weight_1d) + patched_out_transpose_1d = patched_function.torch_nn_functional_convtranspose1d(data_1d, weight_1d) + assert out_transpose_1d.shape == patched_out_transpose_1d.shape + + # test F.conv2d + data_2d = torch.rand(3, 16, 10, 10) + weight_2d = torch.rand(3, 16, 3, 3) + out_2d = F.conv2d(data_2d, weight_2d) + patched_out_2d = patched_function.torch_nn_functional_conv2d(data_2d, weight_2d) + assert out_2d.shape == patched_out_2d.shape + + # test F.conv_transpose2d + weight_2d = torch.transpose(weight_2d, 0, 1) + out_transpose_2d = F.conv_transpose2d(data_2d, weight_2d) + patched_out_transpose_2d = patched_function.torch_nn_functional_convtranspose2d(data_2d, weight_2d) + assert out_transpose_2d.shape == patched_out_transpose_2d.shape + + # test F.conv3d + data_3d = torch.rand(3, 16, 10, 10, 10) + weight_3d = torch.rand(3, 16, 3, 3, 3) + out_3d = F.conv3d(data_3d, weight_3d) + patched_out_3d = patched_function.torch_nn_functional_conv3d(data_3d, weight_3d) + assert out_3d.shape == patched_out_3d.shape + + # test F.conv_transpose3d + weight_3d = torch.transpose(weight_3d, 0, 1) + out_transpose_3d = F.conv_transpose3d(data_3d, weight_3d) + patched_out_transpose_3d = patched_function.torch_nn_functional_convtranspose3d(data_3d, weight_3d) + assert out_transpose_3d.shape == patched_out_transpose_3d.shape + + +if __name__ == '__main__': + test_conv()