mirror of https://github.com/hpcaitech/ColossalAI
[fx] added patches for tracing swin transformer (#1228)
parent
37fcf96b7f
commit
84f2298a96
|
@ -218,3 +218,8 @@ def torch_cat(tensors, dim=None, axis=None, *, out=None):
|
|||
concatenated_dim = sum(shape[dim] for shape in shapes)
|
||||
final_shape = shape[:dim] + [concatenated_dim] + shape[dim + 1:]
|
||||
return torch.empty(final_shape, device="meta")
|
||||
|
||||
|
||||
@meta_patched_function.register(torch.roll)
|
||||
def torch_roll(input, shifts, dims=None):
|
||||
return torch.empty(input.shape, device='meta')
|
||||
|
|
|
@ -249,6 +249,34 @@ def torch_nn_maxpool3d(self, input):
|
|||
return torch.empty(result_shape, device='meta')
|
||||
|
||||
|
||||
@meta_patched_module.register(torch.nn.AdaptiveAvgPool1d)
|
||||
@meta_patched_module.register(torch.nn.AdaptiveMaxPool1d)
|
||||
def torch_nn_adapative_pooling_1d(self, input):
|
||||
result_shape = input.shape[:-1] + (self.output_size,)
|
||||
return torch.empty(result_shape, device='meta')
|
||||
|
||||
|
||||
@meta_patched_module.register(torch.nn.AdaptiveAvgPool2d)
|
||||
@meta_patched_module.register(torch.nn.AdaptiveMaxPool2d)
|
||||
def torch_nn_adapative_pooling_2d(self, input):
|
||||
result_shape = input.shape[:-2] + (
|
||||
self.output_size,
|
||||
self.output_size,
|
||||
)
|
||||
return torch.empty(result_shape, device='meta')
|
||||
|
||||
|
||||
@meta_patched_module.register(torch.nn.AdaptiveAvgPool3d)
|
||||
@meta_patched_module.register(torch.nn.AdaptiveMaxPool3d)
|
||||
def torch_nn_adapative_pooling_3d(self, input):
|
||||
result_shape = input.shape[:-3] + (
|
||||
self.output_size,
|
||||
self.output_size,
|
||||
self.output_size,
|
||||
)
|
||||
return torch.empty(result_shape, device='meta')
|
||||
|
||||
|
||||
@meta_patched_module.register(torch.nn.ReLU)
|
||||
@meta_patched_module.register(torch.nn.ReLU6)
|
||||
def torch_nn_func_relu(self, input):
|
||||
|
|
|
@ -63,14 +63,8 @@ def test_timm_models_with_control_flow():
|
|||
torch.backends.cudnn.deterministic = True
|
||||
|
||||
MODEL_LIST_WITH_CONTROL_FLOW = [
|
||||
tm.convnext.convnext_base,
|
||||
tm.vgg.vgg11,
|
||||
tm.dpn.dpn68,
|
||||
tm.densenet.densenet121,
|
||||
tm.rexnet.rexnet_100,
|
||||
|
||||
# not traceable
|
||||
# tm.swin_transformer.swin_base_patch4_window7_224
|
||||
tm.convnext.convnext_base, tm.vgg.vgg11, tm.dpn.dpn68, tm.densenet.densenet121, tm.rexnet.rexnet_100,
|
||||
tm.swin_transformer.swin_base_patch4_window7_224
|
||||
]
|
||||
|
||||
tracer = ColoTracer()
|
||||
|
|
Loading…
Reference in New Issue