mirror of https://github.com/hpcaitech/ColossalAI
[fx] fixed adapative pooling size concatenation error (#1489)
parent
cde7b8a5b8
commit
3da68d6b1b
|
@ -22,7 +22,7 @@ def torch_nn_avgpool1d(self, input):
|
|||
|
||||
l_out = math.floor((l_in + 2 * padding[0] - kernel_size[0]) / stride[0] + 1)
|
||||
|
||||
result_shape = input.shape[:-1] + (l_out,)
|
||||
result_shape = tuple(input.shape[:-1]) + (l_out,)
|
||||
return torch.empty(result_shape, device='meta')
|
||||
|
||||
|
||||
|
@ -46,7 +46,7 @@ def torch_nn_avgpool2d(self, input):
|
|||
h_out = math.floor((h_in + 2 * padding[0] - kernel_size[0]) / stride[0] + 1)
|
||||
w_out = math.floor((w_in + 2 * padding[1] - kernel_size[1]) / stride[1] + 1)
|
||||
|
||||
result_shape = input.shape[:-2] + (
|
||||
result_shape = tuple(input.shape[:-2]) + (
|
||||
h_out,
|
||||
w_out,
|
||||
)
|
||||
|
@ -74,7 +74,7 @@ def torch_nn_avgpool3d(self, input):
|
|||
h_out = math.floor((h_in + 2 * padding[1] - kernel_size[1]) / stride[1] + 1)
|
||||
w_out = math.floor((w_in + 2 * padding[2] - kernel_size[2]) / stride[2] + 1)
|
||||
|
||||
result_shape = input.shape[:-3] + (
|
||||
result_shape = tuple(input.shape[:-3]) + (
|
||||
d_out,
|
||||
h_out,
|
||||
w_out,
|
||||
|
@ -102,7 +102,7 @@ def torch_nn_maxpool1d(self, input):
|
|||
|
||||
l_out = math.floor((l_in + 2 * padding[0] - dilation[0] * (kernel_size[0] - 1) - 1) / stride[0] + 1)
|
||||
|
||||
result_shape = input.shape[:-1] + (l_out,)
|
||||
result_shape = tuple(input.shape[:-1]) + (l_out,)
|
||||
return torch.empty(result_shape, device='meta')
|
||||
|
||||
|
||||
|
@ -127,7 +127,7 @@ def torch_nn_maxpool2d(self, input):
|
|||
h_out = math.floor((h_in + 2 * padding[0] - dilation[0] * (kernel_size[0] - 1) - 1) / stride[0] + 1)
|
||||
w_out = math.floor((w_in + 2 * padding[1] - dilation[1] * (kernel_size[1] - 1) - 1) / stride[1] + 1)
|
||||
|
||||
result_shape = input.shape[:-2] + (
|
||||
result_shape = tuple(input.shape[:-2]) + (
|
||||
h_out,
|
||||
w_out,
|
||||
)
|
||||
|
@ -156,7 +156,7 @@ def torch_nn_maxpool3d(self, input):
|
|||
h_out = math.floor((h_in + 2 * padding[1] - dilation[1] * (kernel_size[1] - 1) - 1) / stride[1] + 1)
|
||||
w_out = math.floor((w_in + 2 * padding[2] - dilation[2] * (kernel_size[2] - 1) - 1) / stride[2] + 1)
|
||||
|
||||
result_shape = input.shape[:-3] + (
|
||||
result_shape = tuple(input.shape[:-3]) + (
|
||||
d_out,
|
||||
h_out,
|
||||
w_out,
|
||||
|
@ -167,26 +167,34 @@ def torch_nn_maxpool3d(self, input):
|
|||
@meta_patched_module.register(torch.nn.AdaptiveAvgPool1d)
|
||||
@meta_patched_module.register(torch.nn.AdaptiveMaxPool1d)
|
||||
def torch_nn_adapative_pooling_1d(self, input):
|
||||
result_shape = input.shape[:-1] + (self.output_size,)
|
||||
assert input.dim() in [2, 3]
|
||||
if isinstance(self.output_size, int):
|
||||
output_size = (self.output_size,)
|
||||
else:
|
||||
output_size = self.output_size
|
||||
result_shape = tuple(input.shape[:-1]) + output_size
|
||||
return torch.empty(result_shape, device='meta')
|
||||
|
||||
|
||||
@meta_patched_module.register(torch.nn.AdaptiveAvgPool2d)
|
||||
@meta_patched_module.register(torch.nn.AdaptiveMaxPool2d)
|
||||
def torch_nn_adapative_pooling_2d(self, input):
|
||||
result_shape = input.shape[:-2] + (
|
||||
self.output_size,
|
||||
self.output_size,
|
||||
)
|
||||
assert input.dim() in [3, 4]
|
||||
if isinstance(self.output_size, int):
|
||||
output_size = (self.output_size,) * 2
|
||||
else:
|
||||
output_size = self.output_size
|
||||
result_shape = tuple(input.shape[:-2]) + output_size
|
||||
return torch.empty(result_shape, device='meta')
|
||||
|
||||
|
||||
@meta_patched_module.register(torch.nn.AdaptiveAvgPool3d)
|
||||
@meta_patched_module.register(torch.nn.AdaptiveMaxPool3d)
|
||||
def torch_nn_adapative_pooling_3d(self, input):
|
||||
result_shape = input.shape[:-3] + (
|
||||
self.output_size,
|
||||
self.output_size,
|
||||
self.output_size,
|
||||
)
|
||||
return torch.empty(result_shape, device='meta')
|
||||
assert input.dim() in [4, 5]
|
||||
if isinstance(self.output_size, int):
|
||||
output_size = (self.output_size,) * 3
|
||||
else:
|
||||
output_size = self.output_size
|
||||
result_shape = tuple(input.shape[:-3]) + output_size
|
||||
return torch.empty(result_shape, device='meta')
|
||||
|
|
|
@ -407,3 +407,76 @@ def test_pool3d():
|
|||
# test max pool 3d
|
||||
data = torch.rand(2, 3, 4)
|
||||
_assert_output_shape(data=data, module=pooler, patch_fn=patch_func, expect_exception=True, output_shape=None)
|
||||
|
||||
|
||||
# adapative pooling is different from other pooling, so test it individually
|
||||
def test_adaptive_pooling_1d():
|
||||
pooler = torch.nn.AdaptiveAvgPool1d(output_size=3)
|
||||
patch_func = patched_module.torch_nn_adapative_pooling_1d
|
||||
|
||||
data = torch.rand(3, 4)
|
||||
output = pooler(data)
|
||||
_assert_output_shape(data=data,
|
||||
module=pooler,
|
||||
patch_fn=patch_func,
|
||||
expect_exception=False,
|
||||
output_shape=output.shape)
|
||||
|
||||
data = torch.rand(2, 3, 4)
|
||||
output = pooler(data)
|
||||
_assert_output_shape(data=data,
|
||||
module=pooler,
|
||||
patch_fn=patch_func,
|
||||
expect_exception=False,
|
||||
output_shape=output.shape)
|
||||
|
||||
data = torch.rand(2, 3, 4, 5)
|
||||
_assert_output_shape(data=data, module=pooler, patch_fn=patch_func, expect_exception=True, output_shape=None)
|
||||
|
||||
|
||||
def test_adaptive_pooling_2d():
|
||||
pooler = torch.nn.AdaptiveAvgPool2d(output_size=3)
|
||||
patch_func = patched_module.torch_nn_adapative_pooling_2d
|
||||
|
||||
data = torch.rand(3, 4)
|
||||
_assert_output_shape(data=data, module=pooler, patch_fn=patch_func, expect_exception=True, output_shape=None)
|
||||
|
||||
data = torch.rand(2, 3, 4)
|
||||
output = pooler(data)
|
||||
_assert_output_shape(data=data,
|
||||
module=pooler,
|
||||
patch_fn=patch_func,
|
||||
expect_exception=False,
|
||||
output_shape=output.shape)
|
||||
|
||||
data = torch.rand(2, 3, 4, 5)
|
||||
output = pooler(data)
|
||||
_assert_output_shape(data=data,
|
||||
module=pooler,
|
||||
patch_fn=patch_func,
|
||||
expect_exception=False,
|
||||
output_shape=output.shape)
|
||||
|
||||
|
||||
def test_adaptive_pooling_3d():
|
||||
pooler = torch.nn.AdaptiveAvgPool3d(output_size=3)
|
||||
patch_func = patched_module.torch_nn_adapative_pooling_3d
|
||||
|
||||
data = torch.rand(3, 4, 5)
|
||||
_assert_output_shape(data=data, module=pooler, patch_fn=patch_func, expect_exception=True, output_shape=None)
|
||||
|
||||
data = torch.rand(2, 3, 4, 5)
|
||||
output = pooler(data)
|
||||
_assert_output_shape(data=data,
|
||||
module=pooler,
|
||||
patch_fn=patch_func,
|
||||
expect_exception=False,
|
||||
output_shape=output.shape)
|
||||
|
||||
data = torch.rand(2, 3, 4, 5, 6)
|
||||
output = pooler(data)
|
||||
_assert_output_shape(data=data,
|
||||
module=pooler,
|
||||
patch_fn=patch_func,
|
||||
expect_exception=False,
|
||||
output_shape=output.shape)
|
||||
|
|
Loading…
Reference in New Issue