From 11973d892d0273cc4719e30997cfaeafe4bc506c Mon Sep 17 00:00:00 2001 From: Frank Lee Date: Wed, 6 Jul 2022 21:37:56 +0800 Subject: [PATCH] [fx] added torchvision model tracing testing (#1216) * [fx] added torchvision model tracing testing * remove unused imports --- .../fx/tracer/meta_patch/patched_function.py | 38 ++++- .../fx/tracer/meta_patch/patched_module.py | 131 ++++++++++++++++++ colossalai/fx/tracer/tracer.py | 64 +++++++-- .../test_tracer/test_non_patched_module.py | 31 ----- .../test_tracer/test_patched_module.py | 107 ++++++++++---- .../test_torchvision_model.py | 46 ++++++ 6 files changed, 346 insertions(+), 71 deletions(-) delete mode 100644 tests/test_fx/test_tracer/test_non_patched_module.py create mode 100644 tests/test_fx/test_tracer/test_torchvision_model/test_torchvision_model.py diff --git a/colossalai/fx/tracer/meta_patch/patched_function.py b/colossalai/fx/tracer/meta_patch/patched_function.py index fa1c7fd12..d1457d89e 100644 --- a/colossalai/fx/tracer/meta_patch/patched_function.py +++ b/colossalai/fx/tracer/meta_patch/patched_function.py @@ -1,4 +1,3 @@ -from curses import meta import operator import torch from .registry import meta_patched_function @@ -142,3 +141,40 @@ def torch_bmm(input, mat2, *, out=None): batch_size, n, m = input.shape _, _, p = mat2.shape return torch.empty(batch_size, n, p, device="meta") + + +@meta_patched_function.register(torch.squeeze) +def torch_squeeze(input, dim=None): + shape = list(input.shape) + if dim is not None: + if dim < 0: + dim = input.dim() + dim + if shape[dim] == 1: + shape.pop(dim) + else: + new_shape = [] + for dim_value in shape: + if dim_value == 1: + continue + new_shape.append(dim_value) + shape = new_shape + return torch.empty(shape, device="meta") + + +@meta_patched_function.register(torch.Tensor.squeeze) +def torch_tensor_squeeze(self, dim=None): + return torch_squeeze(self, dim) + + +@meta_patched_function.register(torch.unsqueeze) +def torch_unsqueeze(input, dim): + shape = list(input.shape) + if dim < 0: + dim = input.dim() + 1 + dim + shape.insert(dim, 1) + return torch.empty(shape, device="meta") + + +@meta_patched_function.register(torch.Tensor.unsqueeze) +def torch_tensor_unsqueeze(self, dim): + return torch_unsqueeze(self, dim) diff --git a/colossalai/fx/tracer/meta_patch/patched_module.py b/colossalai/fx/tracer/meta_patch/patched_module.py index 61fc341be..f895e73e9 100644 --- a/colossalai/fx/tracer/meta_patch/patched_module.py +++ b/colossalai/fx/tracer/meta_patch/patched_module.py @@ -88,6 +88,137 @@ def torch_nn_conv3d(self, input): return torch.empty(result_shape, device='meta') +@meta_patched_module.register(torch.nn.AvgPool1d) +def torch_nn_avgpool1d(self, input): + num_dim = input.dim() + assert num_dim in [2, 3], f'expected the input to have 2 or 3 dimensions, but got {num_dim} dimensions' + + l_in = input.shape[-1] + + def _convert_int_to_list(item): + if isinstance(item, int): + return [item] * 1 + else: + return item + + padding = _convert_int_to_list(self.padding) + kernel_size = _convert_int_to_list(self.kernel_size) + stride = _convert_int_to_list(self.stride) + + l_out = math.floor((l_in + 2 * padding[0] - kernel_size[0]) / stride[0] + 1) + + result_shape = input.shape[:-1] + (l_out,) + return torch.empty(result_shape, device='meta') + + +@meta_patched_module.register(torch.nn.AvgPool2d) +def torch_nn_avgpool2d(self, input): + num_dim = input.dim() + assert num_dim in [3, 4], f'expected the input to have 3 or 4 dimensions, but got {num_dim} dimensions' + + h_in, w_in = input.shape[-2:] + + def _convert_int_to_list(item): + if isinstance(item, int): + return [item] * 2 + else: + return item + + padding = _convert_int_to_list(self.padding) + kernel_size = _convert_int_to_list(self.kernel_size) + stride = _convert_int_to_list(self.stride) + + h_out = math.floor((h_in + 2 * padding[0] - kernel_size[0]) / stride[0] + 1) + w_out = math.floor((w_in + 2 * padding[1] - kernel_size[1]) / stride[1] + 1) + + result_shape = input.shape[:-2] + ( + h_out, + w_out, + ) + return torch.empty(result_shape, device='meta') + + +@meta_patched_module.register(torch.nn.AvgPool3d) +def torch_nn_avgpool3d(self, input): + num_dim = input.dim() + assert num_dim in [4, 5], f'expected the input to have 4 or 5 dimensions, but got {num_dim} dimensions' + + d_in, h_in, w_in = input.shape[-3:] + + def _convert_int_to_list(item): + if isinstance(item, int): + return [item] * 3 + else: + return item + + padding = _convert_int_to_list(self.padding) + kernel_size = _convert_int_to_list(self.kernel_size) + stride = _convert_int_to_list(self.stride) + + d_out = math.floor((d_in + 2 * padding[0] - kernel_size[0]) / stride[0] + 1) + h_out = math.floor((h_in + 2 * padding[1] - kernel_size[1]) / stride[1] + 1) + w_out = math.floor((w_in + 2 * padding[2] - kernel_size[2]) / stride[2] + 1) + + result_shape = input.shape[:-3] + ( + d_out, + h_out, + w_out, + ) + return torch.empty(result_shape, device='meta') + + +@meta_patched_module.register(torch.nn.MaxPool1d) +def torch_nn_maxpool1d(self, input): + num_dim = input.dim() + assert num_dim in [2, 3], f'expected the input to have 2 or 3 dimensions, but got {num_dim} dimensions' + + l_in = input.shape[-1] + + def _convert_int_to_list(item): + if isinstance(item, int): + return [item] * 1 + else: + return item + + padding = _convert_int_to_list(self.padding) + dilation = _convert_int_to_list(self.dilation) + kernel_size = _convert_int_to_list(self.kernel_size) + stride = _convert_int_to_list(self.stride) + + l_out = math.floor((l_in + 2 * padding[0] - dilation[0] * (kernel_size[0] - 1) - 1) / stride[0] + 1) + + result_shape = input.shape[:-1] + (l_out,) + return torch.empty(result_shape, device='meta') + + +@meta_patched_module.register(torch.nn.MaxPool2d) +def torch_nn_maxpool2d(self, input): + num_dim = input.dim() + assert num_dim in [3, 4], f'expected the input to have 3 or 4 dimensions, but got {num_dim} dimensions' + + h_in, w_in = input.shape[-2:] + + def _convert_int_to_list(item): + if isinstance(item, int): + return [item] * 2 + else: + return item + + padding = _convert_int_to_list(self.padding) + dilation = _convert_int_to_list(self.dilation) + kernel_size = _convert_int_to_list(self.kernel_size) + stride = _convert_int_to_list(self.stride) + + h_out = math.floor((h_in + 2 * padding[0] - dilation[0] * (kernel_size[0] - 1) - 1) / stride[0] + 1) + w_out = math.floor((w_in + 2 * padding[1] - dilation[1] * (kernel_size[1] - 1) - 1) / stride[1] + 1) + + result_shape = input.shape[:-2] + ( + h_out, + w_out, + ) + return torch.empty(result_shape, device='meta') + + @meta_patched_module.register(torch.nn.MaxPool3d) def torch_nn_maxpool3d(self, input): num_dim = input.dim() diff --git a/colossalai/fx/tracer/tracer.py b/colossalai/fx/tracer/tracer.py index e4191f88c..0398dc89f 100644 --- a/colossalai/fx/tracer/tracer.py +++ b/colossalai/fx/tracer/tracer.py @@ -4,9 +4,8 @@ tracer.py: Implemented a tracer which supports control flow and user-defined meta arguments. The implementation is partly inspired HuggingFace's fx tracer """ - +import enum import inspect -import math import functools import torch import torch.nn as nn @@ -22,6 +21,11 @@ from .meta_patch import meta_patched_function, meta_patched_module __all__ = ['ColoTracer'] +class TracerType(enum.Enum): + DEFAULT = 1 + META = 2 + + class ColoTracer(Tracer): """ ColoTracer is a symbolic tracer designed to support dynamic control flow by using meta tensors for the `colossalai.fx` module. @@ -48,6 +52,11 @@ class ColoTracer(Tracer): graph = tracer.trace(model, concrete_args={'y': torch.rand(4, 10)}, meta_args={'x': torch.rand(4, 10, device='meta')}) """ + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.tracer_type = TracerType.META + self.proxy_cls = ColoProxy + # Feature flag for proxying accesses to buffer values proxy_buffer_attributes: bool = True @@ -58,6 +67,12 @@ class ColoTracer(Tracer): Create a proxy for different kinds of operations. """ proxy = super().create_proxy(kind, target, args, kwargs, name, type_expr, proxy_factory_fn) + + if self.tracer_type == TracerType.DEFAULT: + # since meta_args is not given + # we just fall back to the original torch.fx.Tracer + return proxy + proxy: ColoProxy if kind == "placeholder" and target in self.meta_args and self.meta_args[target].is_meta: @@ -168,11 +183,21 @@ class ColoTracer(Tracer): self.orig_forward = forward return super().call_module(m, forward, args, kwargs) - def proxy(self, node) -> ColoProxy: + def proxy(self, node) -> Proxy: """ Returns a ColoProxy object. """ - return ColoProxy(node, self) + return self.proxy_cls(node, self) + + def _configure_tracer_type(self, tracer_type: TracerType): + if tracer_type == TracerType.DEFAULT: + self.proxy_cls = Proxy + self.tracer_type = TracerType.DEFAULT + elif tracer_type == TracerType.META: + self.proxy_cls = ColoProxy + self.tracer_type = TracerType.META + else: + raise ValueError(f"Unrecognised tracer type {tracer_type}") def trace(self, root: nn.Module, @@ -193,6 +218,11 @@ class ColoTracer(Tracer): if concrete_args is None: concrete_args = {} + if len(meta_args) == 0: + self._configure_tracer_type(TracerType.DEFAULT) + else: + self._configure_tracer_type(TracerType.META) + # check concrete and meta args have valid names sig = inspect.signature(root.forward) sig_names = set(sig.parameters.keys()) @@ -235,18 +265,21 @@ class ColoTracer(Tracer): self.concrete_args = concrete_args self.meta_args = meta_args - # wrap the torch tensor constructing methods so that they are captured in the graph - self.patched_torch_tensor_methods = { - target: wrap_tensor_constructor_method(getattr(torch, target)) for target in self._TORCH_METHODS_TO_PATCH - } + self.patched_torch_tensor_methods = {} + if self.tracer_type == TracerType.META: + # wrap the torch tensor constructing methods so that they are captured in the graph + self.patched_torch_tensor_methods = { + target: wrap_tensor_constructor_method(getattr(torch, target)) + for target in self._TORCH_METHODS_TO_PATCH + } - # patch these methods to replace their original use - for name, (wrapper, orig) in self.patched_torch_tensor_methods.items(): - setattr(torch, name, wrapper) + # patch these methods to replace their original use + for name, (wrapper, orig) in self.patched_torch_tensor_methods.items(): + setattr(torch, name, wrapper) - # cache these methods so that we can detect whether a method call - # should be patched during tracing - self.orig_torch_tensor_methods = [val[1] for val in self.patched_torch_tensor_methods.values()] + # cache these methods so that we can detect whether a method call + # should be patched during tracing + self.orig_torch_tensor_methods = [val[1] for val in self.patched_torch_tensor_methods.values()] try: self.graph = super().trace(root, concrete_args=concrete_args) @@ -255,6 +288,9 @@ class ColoTracer(Tracer): for name, (_, orig) in self.patched_torch_tensor_methods.items(): setattr(torch, name, orig) + if self.tracer_type == TracerType.DEFAULT: + return self.graph + # This is necessary because concrete args are added as input to the traced module since # https://github.com/pytorch/pytorch/pull/55888. for node in self.graph.nodes: diff --git a/tests/test_fx/test_tracer/test_non_patched_module.py b/tests/test_fx/test_tracer/test_non_patched_module.py deleted file mode 100644 index 9abc964b0..000000000 --- a/tests/test_fx/test_tracer/test_non_patched_module.py +++ /dev/null @@ -1,31 +0,0 @@ -import torch -import torch.nn - - -def test_maxpool(): - layer_to_test = dict(maxpool_1d=dict(layer=torch.nn.MaxPool1d, shape=(4, 3, 4)), - maxpool_2d=dict(layer=torch.nn.MaxPool2d, shape=(4, 3, 4, 4))) - - for name, info in layer_to_test.items(): - data = torch.rand(*info['shape']) - meta_data = data.to('meta') - layer = info['layer'](kernel_size=3) - out = layer(data) - meta_out = layer(meta_data) - assert meta_out.is_meta - assert out.shape == meta_out.shape - - -def test_avgpool(): - layer_to_test = dict(maxpool_1d=dict(layer=torch.nn.AvgPool1d, shape=(4, 3, 4)), - maxpool_2d=dict(layer=torch.nn.AvgPool2d, shape=(4, 3, 4, 4)), - maxpool_3d=dict(layer=torch.nn.AvgPool3d, shape=(4, 3, 4, 4, 4))) - - for name, info in layer_to_test.items(): - data = torch.rand(*info['shape']) - meta_data = data.to('meta') - layer = info['layer'](kernel_size=3) - out = layer(data) - meta_out = layer(meta_data) - assert meta_out.is_meta - assert out.shape == meta_out.shape diff --git a/tests/test_fx/test_tracer/test_patched_module.py b/tests/test_fx/test_tracer/test_patched_module.py index d96bc04ac..d7ceba1a5 100644 --- a/tests/test_fx/test_tracer/test_patched_module.py +++ b/tests/test_fx/test_tracer/test_patched_module.py @@ -227,31 +227,88 @@ def test_conv3d(): output_shape=materialized_output.shape) -def test_maxpool3d(): - pooler = torch.nn.MaxPool3d(kernel_size=3) +def test_pool1d(): + combinations = [[torch.nn.MaxPool1d, patched_module.torch_nn_maxpool1d], + [torch.nn.AvgPool1d, patched_module.torch_nn_avgpool1d]] - # test max pool 3d - data = torch.rand(2, 3, 4, 4, 4) - materialized_output = pooler(data) - _assert_output_shape(data=data, - module=pooler, - patch_fn=patched_module.torch_nn_maxpool3d, - expect_exception=False, - output_shape=materialized_output.shape) + for (layer_cls, patch_func) in combinations: + pooler = layer_cls(kernel_size=3) - # test max pool 3d - data = torch.rand(2, 3, 4, 4) - materialized_output = pooler(data) - _assert_output_shape(data=data, - module=pooler, - patch_fn=patched_module.torch_nn_maxpool3d, - expect_exception=False, - output_shape=materialized_output.shape) + data = torch.rand(2, 3, 4) + materialized_output = pooler(data) + _assert_output_shape(data=data, + module=pooler, + patch_fn=patch_func, + expect_exception=False, + output_shape=materialized_output.shape) - # test max pool 3d - data = torch.rand(2, 3, 4) - _assert_output_shape(data=data, - module=pooler, - patch_fn=patched_module.torch_nn_maxpool3d, - expect_exception=True, - output_shape=None) + data = torch.rand(2, 4) + materialized_output = pooler(data) + _assert_output_shape(data=data, + module=pooler, + patch_fn=patch_func, + expect_exception=False, + output_shape=materialized_output.shape) + + data = torch.rand(2, 3, 4, 4) + _assert_output_shape(data=data, module=pooler, patch_fn=patch_func, expect_exception=True, output_shape=None) + + +def test_pool2d(): + combinations = [[torch.nn.MaxPool2d, patched_module.torch_nn_maxpool2d], + [torch.nn.AvgPool2d, patched_module.torch_nn_avgpool2d]] + + for (layer_cls, patch_func) in combinations: + pooler = layer_cls(kernel_size=3) + + # test max pool 3d + data = torch.rand(2, 3, 4, 4) + materialized_output = pooler(data) + _assert_output_shape(data=data, + module=pooler, + patch_fn=patch_func, + expect_exception=False, + output_shape=materialized_output.shape) + + # test max pool 3d + data = torch.rand(2, 4, 4) + materialized_output = pooler(data) + _assert_output_shape(data=data, + module=pooler, + patch_fn=patch_func, + expect_exception=False, + output_shape=materialized_output.shape) + + # test max pool 3d + data = torch.rand(2, 3, 4, 4, 4) + _assert_output_shape(data=data, module=pooler, patch_fn=patch_func, expect_exception=True, output_shape=None) + + +def test_pool3d(): + combinations = [[torch.nn.MaxPool3d, patched_module.torch_nn_maxpool3d], + [torch.nn.AvgPool3d, patched_module.torch_nn_avgpool3d]] + + for (layer_cls, patch_func) in combinations: + pooler = layer_cls(kernel_size=3) + + # test max pool 3d + data = torch.rand(2, 3, 4, 4, 4) + materialized_output = pooler(data) + _assert_output_shape(data=data, + module=pooler, + patch_fn=patch_func, + expect_exception=False, + output_shape=materialized_output.shape) + + # test max pool 3d + data = torch.rand(2, 4, 4, 4) + materialized_output = pooler(data) + _assert_output_shape(data=data, + module=pooler, + patch_fn=patch_func, + expect_exception=False, + output_shape=materialized_output.shape) + + # test max pool 3d + data = torch.rand(2, 3, 4) + _assert_output_shape(data=data, module=pooler, patch_fn=patch_func, expect_exception=True, output_shape=None) diff --git a/tests/test_fx/test_tracer/test_torchvision_model/test_torchvision_model.py b/tests/test_fx/test_tracer/test_torchvision_model/test_torchvision_model.py new file mode 100644 index 000000000..11c3d7ea5 --- /dev/null +++ b/tests/test_fx/test_tracer/test_torchvision_model/test_torchvision_model.py @@ -0,0 +1,46 @@ +import torch +import pytest +try: + import torchvision.models as tm +except: + pass +from colossalai.fx import ColoTracer +from torch.fx import GraphModule + + +@pytest.mark.skip('skip as torchvision is required') +def test_torchvision_models(): + MODEL_LIST = [ + tm.vgg11, tm.resnet18, tm.densenet121, tm.mobilenet_v3_small, tm.resnext50_32x4d, tm.wide_resnet50_2, + tm.regnet_x_16gf, tm.vit_b_16, tm.convnext_small, tm.mnasnet0_5, tm.efficientnet_b0 + ] + + torch.backends.cudnn.deterministic = True + + tracer = ColoTracer() + data = torch.rand(2, 3, 224, 224) + + for model_cls in MODEL_LIST: + if model_cls in [tm.convnext_small, tm.efficientnet_b0]: + # remove the impact of randomicity + model = model_cls(stochastic_depth_prob=0) + else: + model = model_cls() + + graph = tracer.trace(root=model) + + gm = GraphModule(model, graph, model.__class__.__name__) + gm.recompile() + + model.eval() + gm.eval() + + with torch.no_grad(): + fx_out = gm(data) + non_fx_out = model(data) + assert torch.allclose( + fx_out, non_fx_out), f'{model.__class__.__name__} has inconsistent outputs, {fx_out} vs {non_fx_out}' + + +if __name__ == '__main__': + test_torchvision_models()