[fx] added torchvision model tracing testing (#1216)

* [fx] added torchvision model tracing testing

* remove unused imports
pull/1204/head
Frank Lee 2022-07-06 21:37:56 +08:00 committed by GitHub
parent 52736205d9
commit 11973d892d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 346 additions and 71 deletions

View File

@ -1,4 +1,3 @@
from curses import meta
import operator import operator
import torch import torch
from .registry import meta_patched_function from .registry import meta_patched_function
@ -142,3 +141,40 @@ def torch_bmm(input, mat2, *, out=None):
batch_size, n, m = input.shape batch_size, n, m = input.shape
_, _, p = mat2.shape _, _, p = mat2.shape
return torch.empty(batch_size, n, p, device="meta") return torch.empty(batch_size, n, p, device="meta")
@meta_patched_function.register(torch.squeeze)
def torch_squeeze(input, dim=None):
shape = list(input.shape)
if dim is not None:
if dim < 0:
dim = input.dim() + dim
if shape[dim] == 1:
shape.pop(dim)
else:
new_shape = []
for dim_value in shape:
if dim_value == 1:
continue
new_shape.append(dim_value)
shape = new_shape
return torch.empty(shape, device="meta")
@meta_patched_function.register(torch.Tensor.squeeze)
def torch_tensor_squeeze(self, dim=None):
return torch_squeeze(self, dim)
@meta_patched_function.register(torch.unsqueeze)
def torch_unsqueeze(input, dim):
shape = list(input.shape)
if dim < 0:
dim = input.dim() + 1 + dim
shape.insert(dim, 1)
return torch.empty(shape, device="meta")
@meta_patched_function.register(torch.Tensor.unsqueeze)
def torch_tensor_unsqueeze(self, dim):
return torch_unsqueeze(self, dim)

View File

@ -88,6 +88,137 @@ def torch_nn_conv3d(self, input):
return torch.empty(result_shape, device='meta') return torch.empty(result_shape, device='meta')
@meta_patched_module.register(torch.nn.AvgPool1d)
def torch_nn_avgpool1d(self, input):
num_dim = input.dim()
assert num_dim in [2, 3], f'expected the input to have 2 or 3 dimensions, but got {num_dim} dimensions'
l_in = input.shape[-1]
def _convert_int_to_list(item):
if isinstance(item, int):
return [item] * 1
else:
return item
padding = _convert_int_to_list(self.padding)
kernel_size = _convert_int_to_list(self.kernel_size)
stride = _convert_int_to_list(self.stride)
l_out = math.floor((l_in + 2 * padding[0] - kernel_size[0]) / stride[0] + 1)
result_shape = input.shape[:-1] + (l_out,)
return torch.empty(result_shape, device='meta')
@meta_patched_module.register(torch.nn.AvgPool2d)
def torch_nn_avgpool2d(self, input):
num_dim = input.dim()
assert num_dim in [3, 4], f'expected the input to have 3 or 4 dimensions, but got {num_dim} dimensions'
h_in, w_in = input.shape[-2:]
def _convert_int_to_list(item):
if isinstance(item, int):
return [item] * 2
else:
return item
padding = _convert_int_to_list(self.padding)
kernel_size = _convert_int_to_list(self.kernel_size)
stride = _convert_int_to_list(self.stride)
h_out = math.floor((h_in + 2 * padding[0] - kernel_size[0]) / stride[0] + 1)
w_out = math.floor((w_in + 2 * padding[1] - kernel_size[1]) / stride[1] + 1)
result_shape = input.shape[:-2] + (
h_out,
w_out,
)
return torch.empty(result_shape, device='meta')
@meta_patched_module.register(torch.nn.AvgPool3d)
def torch_nn_avgpool3d(self, input):
num_dim = input.dim()
assert num_dim in [4, 5], f'expected the input to have 4 or 5 dimensions, but got {num_dim} dimensions'
d_in, h_in, w_in = input.shape[-3:]
def _convert_int_to_list(item):
if isinstance(item, int):
return [item] * 3
else:
return item
padding = _convert_int_to_list(self.padding)
kernel_size = _convert_int_to_list(self.kernel_size)
stride = _convert_int_to_list(self.stride)
d_out = math.floor((d_in + 2 * padding[0] - kernel_size[0]) / stride[0] + 1)
h_out = math.floor((h_in + 2 * padding[1] - kernel_size[1]) / stride[1] + 1)
w_out = math.floor((w_in + 2 * padding[2] - kernel_size[2]) / stride[2] + 1)
result_shape = input.shape[:-3] + (
d_out,
h_out,
w_out,
)
return torch.empty(result_shape, device='meta')
@meta_patched_module.register(torch.nn.MaxPool1d)
def torch_nn_maxpool1d(self, input):
num_dim = input.dim()
assert num_dim in [2, 3], f'expected the input to have 2 or 3 dimensions, but got {num_dim} dimensions'
l_in = input.shape[-1]
def _convert_int_to_list(item):
if isinstance(item, int):
return [item] * 1
else:
return item
padding = _convert_int_to_list(self.padding)
dilation = _convert_int_to_list(self.dilation)
kernel_size = _convert_int_to_list(self.kernel_size)
stride = _convert_int_to_list(self.stride)
l_out = math.floor((l_in + 2 * padding[0] - dilation[0] * (kernel_size[0] - 1) - 1) / stride[0] + 1)
result_shape = input.shape[:-1] + (l_out,)
return torch.empty(result_shape, device='meta')
@meta_patched_module.register(torch.nn.MaxPool2d)
def torch_nn_maxpool2d(self, input):
num_dim = input.dim()
assert num_dim in [3, 4], f'expected the input to have 3 or 4 dimensions, but got {num_dim} dimensions'
h_in, w_in = input.shape[-2:]
def _convert_int_to_list(item):
if isinstance(item, int):
return [item] * 2
else:
return item
padding = _convert_int_to_list(self.padding)
dilation = _convert_int_to_list(self.dilation)
kernel_size = _convert_int_to_list(self.kernel_size)
stride = _convert_int_to_list(self.stride)
h_out = math.floor((h_in + 2 * padding[0] - dilation[0] * (kernel_size[0] - 1) - 1) / stride[0] + 1)
w_out = math.floor((w_in + 2 * padding[1] - dilation[1] * (kernel_size[1] - 1) - 1) / stride[1] + 1)
result_shape = input.shape[:-2] + (
h_out,
w_out,
)
return torch.empty(result_shape, device='meta')
@meta_patched_module.register(torch.nn.MaxPool3d) @meta_patched_module.register(torch.nn.MaxPool3d)
def torch_nn_maxpool3d(self, input): def torch_nn_maxpool3d(self, input):
num_dim = input.dim() num_dim = input.dim()

View File

@ -4,9 +4,8 @@ tracer.py:
Implemented a tracer which supports control flow and user-defined meta arguments. Implemented a tracer which supports control flow and user-defined meta arguments.
The implementation is partly inspired HuggingFace's fx tracer The implementation is partly inspired HuggingFace's fx tracer
""" """
import enum
import inspect import inspect
import math
import functools import functools
import torch import torch
import torch.nn as nn import torch.nn as nn
@ -22,6 +21,11 @@ from .meta_patch import meta_patched_function, meta_patched_module
__all__ = ['ColoTracer'] __all__ = ['ColoTracer']
class TracerType(enum.Enum):
DEFAULT = 1
META = 2
class ColoTracer(Tracer): class ColoTracer(Tracer):
""" """
ColoTracer is a symbolic tracer designed to support dynamic control flow by using meta tensors for the `colossalai.fx` module. ColoTracer is a symbolic tracer designed to support dynamic control flow by using meta tensors for the `colossalai.fx` module.
@ -48,6 +52,11 @@ class ColoTracer(Tracer):
graph = tracer.trace(model, concrete_args={'y': torch.rand(4, 10)}, meta_args={'x': torch.rand(4, 10, device='meta')}) graph = tracer.trace(model, concrete_args={'y': torch.rand(4, 10)}, meta_args={'x': torch.rand(4, 10, device='meta')})
""" """
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.tracer_type = TracerType.META
self.proxy_cls = ColoProxy
# Feature flag for proxying accesses to buffer values # Feature flag for proxying accesses to buffer values
proxy_buffer_attributes: bool = True proxy_buffer_attributes: bool = True
@ -58,6 +67,12 @@ class ColoTracer(Tracer):
Create a proxy for different kinds of operations. Create a proxy for different kinds of operations.
""" """
proxy = super().create_proxy(kind, target, args, kwargs, name, type_expr, proxy_factory_fn) proxy = super().create_proxy(kind, target, args, kwargs, name, type_expr, proxy_factory_fn)
if self.tracer_type == TracerType.DEFAULT:
# since meta_args is not given
# we just fall back to the original torch.fx.Tracer
return proxy
proxy: ColoProxy proxy: ColoProxy
if kind == "placeholder" and target in self.meta_args and self.meta_args[target].is_meta: if kind == "placeholder" and target in self.meta_args and self.meta_args[target].is_meta:
@ -168,11 +183,21 @@ class ColoTracer(Tracer):
self.orig_forward = forward self.orig_forward = forward
return super().call_module(m, forward, args, kwargs) return super().call_module(m, forward, args, kwargs)
def proxy(self, node) -> ColoProxy: def proxy(self, node) -> Proxy:
""" """
Returns a ColoProxy object. Returns a ColoProxy object.
""" """
return ColoProxy(node, self) return self.proxy_cls(node, self)
def _configure_tracer_type(self, tracer_type: TracerType):
if tracer_type == TracerType.DEFAULT:
self.proxy_cls = Proxy
self.tracer_type = TracerType.DEFAULT
elif tracer_type == TracerType.META:
self.proxy_cls = ColoProxy
self.tracer_type = TracerType.META
else:
raise ValueError(f"Unrecognised tracer type {tracer_type}")
def trace(self, def trace(self,
root: nn.Module, root: nn.Module,
@ -193,6 +218,11 @@ class ColoTracer(Tracer):
if concrete_args is None: if concrete_args is None:
concrete_args = {} concrete_args = {}
if len(meta_args) == 0:
self._configure_tracer_type(TracerType.DEFAULT)
else:
self._configure_tracer_type(TracerType.META)
# check concrete and meta args have valid names # check concrete and meta args have valid names
sig = inspect.signature(root.forward) sig = inspect.signature(root.forward)
sig_names = set(sig.parameters.keys()) sig_names = set(sig.parameters.keys())
@ -235,9 +265,12 @@ class ColoTracer(Tracer):
self.concrete_args = concrete_args self.concrete_args = concrete_args
self.meta_args = meta_args self.meta_args = meta_args
self.patched_torch_tensor_methods = {}
if self.tracer_type == TracerType.META:
# wrap the torch tensor constructing methods so that they are captured in the graph # wrap the torch tensor constructing methods so that they are captured in the graph
self.patched_torch_tensor_methods = { self.patched_torch_tensor_methods = {
target: wrap_tensor_constructor_method(getattr(torch, target)) for target in self._TORCH_METHODS_TO_PATCH target: wrap_tensor_constructor_method(getattr(torch, target))
for target in self._TORCH_METHODS_TO_PATCH
} }
# patch these methods to replace their original use # patch these methods to replace their original use
@ -255,6 +288,9 @@ class ColoTracer(Tracer):
for name, (_, orig) in self.patched_torch_tensor_methods.items(): for name, (_, orig) in self.patched_torch_tensor_methods.items():
setattr(torch, name, orig) setattr(torch, name, orig)
if self.tracer_type == TracerType.DEFAULT:
return self.graph
# This is necessary because concrete args are added as input to the traced module since # This is necessary because concrete args are added as input to the traced module since
# https://github.com/pytorch/pytorch/pull/55888. # https://github.com/pytorch/pytorch/pull/55888.
for node in self.graph.nodes: for node in self.graph.nodes:

View File

@ -1,31 +0,0 @@
import torch
import torch.nn
def test_maxpool():
layer_to_test = dict(maxpool_1d=dict(layer=torch.nn.MaxPool1d, shape=(4, 3, 4)),
maxpool_2d=dict(layer=torch.nn.MaxPool2d, shape=(4, 3, 4, 4)))
for name, info in layer_to_test.items():
data = torch.rand(*info['shape'])
meta_data = data.to('meta')
layer = info['layer'](kernel_size=3)
out = layer(data)
meta_out = layer(meta_data)
assert meta_out.is_meta
assert out.shape == meta_out.shape
def test_avgpool():
layer_to_test = dict(maxpool_1d=dict(layer=torch.nn.AvgPool1d, shape=(4, 3, 4)),
maxpool_2d=dict(layer=torch.nn.AvgPool2d, shape=(4, 3, 4, 4)),
maxpool_3d=dict(layer=torch.nn.AvgPool3d, shape=(4, 3, 4, 4, 4)))
for name, info in layer_to_test.items():
data = torch.rand(*info['shape'])
meta_data = data.to('meta')
layer = info['layer'](kernel_size=3)
out = layer(data)
meta_out = layer(meta_data)
assert meta_out.is_meta
assert out.shape == meta_out.shape

View File

@ -227,31 +227,88 @@ def test_conv3d():
output_shape=materialized_output.shape) output_shape=materialized_output.shape)
def test_maxpool3d(): def test_pool1d():
pooler = torch.nn.MaxPool3d(kernel_size=3) combinations = [[torch.nn.MaxPool1d, patched_module.torch_nn_maxpool1d],
[torch.nn.AvgPool1d, patched_module.torch_nn_avgpool1d]]
# test max pool 3d for (layer_cls, patch_func) in combinations:
data = torch.rand(2, 3, 4, 4, 4) pooler = layer_cls(kernel_size=3)
data = torch.rand(2, 3, 4)
materialized_output = pooler(data) materialized_output = pooler(data)
_assert_output_shape(data=data, _assert_output_shape(data=data,
module=pooler, module=pooler,
patch_fn=patched_module.torch_nn_maxpool3d, patch_fn=patch_func,
expect_exception=False, expect_exception=False,
output_shape=materialized_output.shape) output_shape=materialized_output.shape)
data = torch.rand(2, 4)
materialized_output = pooler(data)
_assert_output_shape(data=data,
module=pooler,
patch_fn=patch_func,
expect_exception=False,
output_shape=materialized_output.shape)
data = torch.rand(2, 3, 4, 4)
_assert_output_shape(data=data, module=pooler, patch_fn=patch_func, expect_exception=True, output_shape=None)
def test_pool2d():
combinations = [[torch.nn.MaxPool2d, patched_module.torch_nn_maxpool2d],
[torch.nn.AvgPool2d, patched_module.torch_nn_avgpool2d]]
for (layer_cls, patch_func) in combinations:
pooler = layer_cls(kernel_size=3)
# test max pool 3d # test max pool 3d
data = torch.rand(2, 3, 4, 4) data = torch.rand(2, 3, 4, 4)
materialized_output = pooler(data) materialized_output = pooler(data)
_assert_output_shape(data=data, _assert_output_shape(data=data,
module=pooler, module=pooler,
patch_fn=patched_module.torch_nn_maxpool3d, patch_fn=patch_func,
expect_exception=False,
output_shape=materialized_output.shape)
# test max pool 3d
data = torch.rand(2, 4, 4)
materialized_output = pooler(data)
_assert_output_shape(data=data,
module=pooler,
patch_fn=patch_func,
expect_exception=False,
output_shape=materialized_output.shape)
# test max pool 3d
data = torch.rand(2, 3, 4, 4, 4)
_assert_output_shape(data=data, module=pooler, patch_fn=patch_func, expect_exception=True, output_shape=None)
def test_pool3d():
combinations = [[torch.nn.MaxPool3d, patched_module.torch_nn_maxpool3d],
[torch.nn.AvgPool3d, patched_module.torch_nn_avgpool3d]]
for (layer_cls, patch_func) in combinations:
pooler = layer_cls(kernel_size=3)
# test max pool 3d
data = torch.rand(2, 3, 4, 4, 4)
materialized_output = pooler(data)
_assert_output_shape(data=data,
module=pooler,
patch_fn=patch_func,
expect_exception=False,
output_shape=materialized_output.shape)
# test max pool 3d
data = torch.rand(2, 4, 4, 4)
materialized_output = pooler(data)
_assert_output_shape(data=data,
module=pooler,
patch_fn=patch_func,
expect_exception=False, expect_exception=False,
output_shape=materialized_output.shape) output_shape=materialized_output.shape)
# test max pool 3d # test max pool 3d
data = torch.rand(2, 3, 4) data = torch.rand(2, 3, 4)
_assert_output_shape(data=data, _assert_output_shape(data=data, module=pooler, patch_fn=patch_func, expect_exception=True, output_shape=None)
module=pooler,
patch_fn=patched_module.torch_nn_maxpool3d,
expect_exception=True,
output_shape=None)

View File

@ -0,0 +1,46 @@
import torch
import pytest
try:
import torchvision.models as tm
except:
pass
from colossalai.fx import ColoTracer
from torch.fx import GraphModule
@pytest.mark.skip('skip as torchvision is required')
def test_torchvision_models():
MODEL_LIST = [
tm.vgg11, tm.resnet18, tm.densenet121, tm.mobilenet_v3_small, tm.resnext50_32x4d, tm.wide_resnet50_2,
tm.regnet_x_16gf, tm.vit_b_16, tm.convnext_small, tm.mnasnet0_5, tm.efficientnet_b0
]
torch.backends.cudnn.deterministic = True
tracer = ColoTracer()
data = torch.rand(2, 3, 224, 224)
for model_cls in MODEL_LIST:
if model_cls in [tm.convnext_small, tm.efficientnet_b0]:
# remove the impact of randomicity
model = model_cls(stochastic_depth_prob=0)
else:
model = model_cls()
graph = tracer.trace(root=model)
gm = GraphModule(model, graph, model.__class__.__name__)
gm.recompile()
model.eval()
gm.eval()
with torch.no_grad():
fx_out = gm(data)
non_fx_out = model(data)
assert torch.allclose(
fx_out, non_fx_out), f'{model.__class__.__name__} has inconsistent outputs, {fx_out} vs {non_fx_out}'
if __name__ == '__main__':
test_torchvision_models()