[fx] fixed adapative pooling size concatenation error (#1489)

pull/1496/head
Frank Lee 2022-08-25 09:05:07 +08:00 committed by GitHub
parent cde7b8a5b8
commit 3da68d6b1b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 98 additions and 17 deletions

View File

@ -22,7 +22,7 @@ def torch_nn_avgpool1d(self, input):
l_out = math.floor((l_in + 2 * padding[0] - kernel_size[0]) / stride[0] + 1)
result_shape = input.shape[:-1] + (l_out,)
result_shape = tuple(input.shape[:-1]) + (l_out,)
return torch.empty(result_shape, device='meta')
@ -46,7 +46,7 @@ def torch_nn_avgpool2d(self, input):
h_out = math.floor((h_in + 2 * padding[0] - kernel_size[0]) / stride[0] + 1)
w_out = math.floor((w_in + 2 * padding[1] - kernel_size[1]) / stride[1] + 1)
result_shape = input.shape[:-2] + (
result_shape = tuple(input.shape[:-2]) + (
h_out,
w_out,
)
@ -74,7 +74,7 @@ def torch_nn_avgpool3d(self, input):
h_out = math.floor((h_in + 2 * padding[1] - kernel_size[1]) / stride[1] + 1)
w_out = math.floor((w_in + 2 * padding[2] - kernel_size[2]) / stride[2] + 1)
result_shape = input.shape[:-3] + (
result_shape = tuple(input.shape[:-3]) + (
d_out,
h_out,
w_out,
@ -102,7 +102,7 @@ def torch_nn_maxpool1d(self, input):
l_out = math.floor((l_in + 2 * padding[0] - dilation[0] * (kernel_size[0] - 1) - 1) / stride[0] + 1)
result_shape = input.shape[:-1] + (l_out,)
result_shape = tuple(input.shape[:-1]) + (l_out,)
return torch.empty(result_shape, device='meta')
@ -127,7 +127,7 @@ def torch_nn_maxpool2d(self, input):
h_out = math.floor((h_in + 2 * padding[0] - dilation[0] * (kernel_size[0] - 1) - 1) / stride[0] + 1)
w_out = math.floor((w_in + 2 * padding[1] - dilation[1] * (kernel_size[1] - 1) - 1) / stride[1] + 1)
result_shape = input.shape[:-2] + (
result_shape = tuple(input.shape[:-2]) + (
h_out,
w_out,
)
@ -156,7 +156,7 @@ def torch_nn_maxpool3d(self, input):
h_out = math.floor((h_in + 2 * padding[1] - dilation[1] * (kernel_size[1] - 1) - 1) / stride[1] + 1)
w_out = math.floor((w_in + 2 * padding[2] - dilation[2] * (kernel_size[2] - 1) - 1) / stride[2] + 1)
result_shape = input.shape[:-3] + (
result_shape = tuple(input.shape[:-3]) + (
d_out,
h_out,
w_out,
@ -167,26 +167,34 @@ def torch_nn_maxpool3d(self, input):
@meta_patched_module.register(torch.nn.AdaptiveAvgPool1d)
@meta_patched_module.register(torch.nn.AdaptiveMaxPool1d)
def torch_nn_adapative_pooling_1d(self, input):
result_shape = input.shape[:-1] + (self.output_size,)
assert input.dim() in [2, 3]
if isinstance(self.output_size, int):
output_size = (self.output_size,)
else:
output_size = self.output_size
result_shape = tuple(input.shape[:-1]) + output_size
return torch.empty(result_shape, device='meta')
@meta_patched_module.register(torch.nn.AdaptiveAvgPool2d)
@meta_patched_module.register(torch.nn.AdaptiveMaxPool2d)
def torch_nn_adapative_pooling_2d(self, input):
result_shape = input.shape[:-2] + (
self.output_size,
self.output_size,
)
assert input.dim() in [3, 4]
if isinstance(self.output_size, int):
output_size = (self.output_size,) * 2
else:
output_size = self.output_size
result_shape = tuple(input.shape[:-2]) + output_size
return torch.empty(result_shape, device='meta')
@meta_patched_module.register(torch.nn.AdaptiveAvgPool3d)
@meta_patched_module.register(torch.nn.AdaptiveMaxPool3d)
def torch_nn_adapative_pooling_3d(self, input):
result_shape = input.shape[:-3] + (
self.output_size,
self.output_size,
self.output_size,
)
return torch.empty(result_shape, device='meta')
assert input.dim() in [4, 5]
if isinstance(self.output_size, int):
output_size = (self.output_size,) * 3
else:
output_size = self.output_size
result_shape = tuple(input.shape[:-3]) + output_size
return torch.empty(result_shape, device='meta')

View File

@ -407,3 +407,76 @@ def test_pool3d():
# test max pool 3d
data = torch.rand(2, 3, 4)
_assert_output_shape(data=data, module=pooler, patch_fn=patch_func, expect_exception=True, output_shape=None)
# adapative pooling is different from other pooling, so test it individually
def test_adaptive_pooling_1d():
pooler = torch.nn.AdaptiveAvgPool1d(output_size=3)
patch_func = patched_module.torch_nn_adapative_pooling_1d
data = torch.rand(3, 4)
output = pooler(data)
_assert_output_shape(data=data,
module=pooler,
patch_fn=patch_func,
expect_exception=False,
output_shape=output.shape)
data = torch.rand(2, 3, 4)
output = pooler(data)
_assert_output_shape(data=data,
module=pooler,
patch_fn=patch_func,
expect_exception=False,
output_shape=output.shape)
data = torch.rand(2, 3, 4, 5)
_assert_output_shape(data=data, module=pooler, patch_fn=patch_func, expect_exception=True, output_shape=None)
def test_adaptive_pooling_2d():
pooler = torch.nn.AdaptiveAvgPool2d(output_size=3)
patch_func = patched_module.torch_nn_adapative_pooling_2d
data = torch.rand(3, 4)
_assert_output_shape(data=data, module=pooler, patch_fn=patch_func, expect_exception=True, output_shape=None)
data = torch.rand(2, 3, 4)
output = pooler(data)
_assert_output_shape(data=data,
module=pooler,
patch_fn=patch_func,
expect_exception=False,
output_shape=output.shape)
data = torch.rand(2, 3, 4, 5)
output = pooler(data)
_assert_output_shape(data=data,
module=pooler,
patch_fn=patch_func,
expect_exception=False,
output_shape=output.shape)
def test_adaptive_pooling_3d():
pooler = torch.nn.AdaptiveAvgPool3d(output_size=3)
patch_func = patched_module.torch_nn_adapative_pooling_3d
data = torch.rand(3, 4, 5)
_assert_output_shape(data=data, module=pooler, patch_fn=patch_func, expect_exception=True, output_shape=None)
data = torch.rand(2, 3, 4, 5)
output = pooler(data)
_assert_output_shape(data=data,
module=pooler,
patch_fn=patch_func,
expect_exception=False,
output_shape=output.shape)
data = torch.rand(2, 3, 4, 5, 6)
output = pooler(data)
_assert_output_shape(data=data,
module=pooler,
patch_fn=patch_func,
expect_exception=False,
output_shape=output.shape)