[fx] added patches for tracing swin transformer (#1228)

pull/1243/head
Frank Lee 2022-07-07 15:20:13 +08:00 committed by GitHub
parent 37fcf96b7f
commit 84f2298a96
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 35 additions and 8 deletions

View File

@ -218,3 +218,8 @@ def torch_cat(tensors, dim=None, axis=None, *, out=None):
concatenated_dim = sum(shape[dim] for shape in shapes)
final_shape = shape[:dim] + [concatenated_dim] + shape[dim + 1:]
return torch.empty(final_shape, device="meta")
@meta_patched_function.register(torch.roll)
def torch_roll(input, shifts, dims=None):
return torch.empty(input.shape, device='meta')

View File

@ -249,6 +249,34 @@ def torch_nn_maxpool3d(self, input):
return torch.empty(result_shape, device='meta')
@meta_patched_module.register(torch.nn.AdaptiveAvgPool1d)
@meta_patched_module.register(torch.nn.AdaptiveMaxPool1d)
def torch_nn_adapative_pooling_1d(self, input):
result_shape = input.shape[:-1] + (self.output_size,)
return torch.empty(result_shape, device='meta')
@meta_patched_module.register(torch.nn.AdaptiveAvgPool2d)
@meta_patched_module.register(torch.nn.AdaptiveMaxPool2d)
def torch_nn_adapative_pooling_2d(self, input):
result_shape = input.shape[:-2] + (
self.output_size,
self.output_size,
)
return torch.empty(result_shape, device='meta')
@meta_patched_module.register(torch.nn.AdaptiveAvgPool3d)
@meta_patched_module.register(torch.nn.AdaptiveMaxPool3d)
def torch_nn_adapative_pooling_3d(self, input):
result_shape = input.shape[:-3] + (
self.output_size,
self.output_size,
self.output_size,
)
return torch.empty(result_shape, device='meta')
@meta_patched_module.register(torch.nn.ReLU)
@meta_patched_module.register(torch.nn.ReLU6)
def torch_nn_func_relu(self, input):

View File

@ -63,14 +63,8 @@ def test_timm_models_with_control_flow():
torch.backends.cudnn.deterministic = True
MODEL_LIST_WITH_CONTROL_FLOW = [
tm.convnext.convnext_base,
tm.vgg.vgg11,
tm.dpn.dpn68,
tm.densenet.densenet121,
tm.rexnet.rexnet_100,
# not traceable
# tm.swin_transformer.swin_base_patch4_window7_224
tm.convnext.convnext_base, tm.vgg.vgg11, tm.dpn.dpn68, tm.densenet.densenet121, tm.rexnet.rexnet_100,
tm.swin_transformer.swin_base_patch4_window7_224
]
tracer = ColoTracer()