[fx] add test for meta tensor. (#1527)

* [fx] add test for meta tensor.

* [fx] add test for meta tensor.

* [fx] add test for meta tensor.

* [fx] add test for meta tensor.

* [fx] fix error.
pull/1535/head
Super Daniel 2022-09-01 19:30:05 +08:00 committed by GitHub
parent 4b3d6caeb3
commit 7dc53237c3
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 151 additions and 0 deletions

View File

@ -0,0 +1,88 @@
from typing import Any, Callable, Union
import torch
import torch.nn as nn
import torch.nn.functional as F
from colossalai.fx.profiler import MetaTensor
import pytest
try:
meta_lib = torch.library.Library("aten", "IMPL", "Meta")
INCOMPATIBLE = False # version > 1.12.0
except:
INCOMPATIBLE = True
aten = torch.ops.aten
registered_meta = {
('aten.convolution.default', True): [ # (aten ops, requires_backward)
(nn.Conv1d(in_channels=3, out_channels=4, kernel_size=2, padding=1, dilation=2), torch.rand(2, 3, 4)),
(nn.Conv2d(in_channels=3, out_channels=4, kernel_size=2, padding=1, dilation=2), torch.rand(2, 3, 4, 4)),
(nn.Conv3d(in_channels=3, out_channels=4, kernel_size=2, padding=1, dilation=2), torch.rand(2, 3, 4, 4, 4)),
(nn.ConvTranspose1d(in_channels=3, out_channels=4, kernel_size=2, padding=1, dilation=2), torch.rand(2, 3, 4)),
(nn.ConvTranspose2d(in_channels=3, out_channels=4, kernel_size=2, padding=1,
dilation=2), torch.rand(2, 3, 4, 4)),
(nn.ConvTranspose3d(in_channels=3, out_channels=4, kernel_size=2, padding=1,
dilation=2), torch.rand(2, 3, 4, 4, 4)),
],
('aten.native_batch_norm.default', True): [
(nn.BatchNorm1d(4), torch.rand(2, 4)),
(nn.BatchNorm2d(4), torch.rand(1, 4, 4, 4)),
(nn.BatchNorm3d(4), torch.rand(1, 4, 4, 4, 4)),
],
('aten.native_layer_norm.default', True): [(nn.LayerNorm(4), torch.rand(1, 2, 3, 4)),],
('aten.avg_pool1d.default', True): [
(nn.MaxPool1d(3, stride=2), torch.rand(4, 5, 5)),
(nn.AvgPool1d(3, stride=2), torch.rand(4, 5, 5)),
(nn.AdaptiveMaxPool1d(3), torch.rand(4, 5, 5)),
(nn.AdaptiveAvgPool1d(3), torch.rand(4, 5, 5)),
],
('aten.avg_pool2d.default', True): [
(nn.MaxPool2d((3, 2), stride=(2, 1)), torch.rand(2, 4, 5, 5)),
(nn.AvgPool2d((3, 2), stride=(2, 1)), torch.rand(2, 4, 5, 5)),
(nn.AdaptiveMaxPool2d((3, 2)), torch.rand(2, 4, 5, 5)),
(nn.AdaptiveAvgPool2d((3, 2)), torch.rand(2, 4, 5, 5)),
],
('aten.relu.default', True): [
(nn.ReLU(), torch.rand(4, 3, 1, 2)),
(nn.LeakyReLU(), torch.rand(4, 3, 1, 2)),
(nn.SiLU(), torch.rand(4, 3, 1, 2)),
(nn.GELU(), torch.rand(4, 3, 1, 2)),
(nn.ELU(), torch.rand(4, 3, 1, 2)),
(nn.Sigmoid(), torch.rand(4, 3, 1, 2)),
(nn.Tanh(), torch.rand(4, 3, 1, 2)),
(nn.Hardswish(), torch.rand(4, 3, 1, 2)),
]
}
def compare_all(tensor: torch.Tensor, meta_tensor: MetaTensor) -> Any:
assert tensor.shape == meta_tensor.shape, f'the shape of tensor ({tensor.shape}) and meta tensor ({meta_tensor.shape}) does not match.'
assert tensor.dtype == meta_tensor.dtype, f'the dtype of tensor ({tensor.dtype}) and meta tensor ({meta_tensor.dtype}) does not match.'
assert tensor.stride() == meta_tensor.stride(
), f'the stride of tensor ({tensor.stride()}) and meta tensor ({meta_tensor.stride()}) does not match.'
def run_and_compare(f: Union[nn.Module, Callable], x: torch.Tensor, requires_backward=False) -> Any:
x.requires_grad = requires_backward
meta_x = MetaTensor(x.to('meta'))
if isinstance(f, nn.Module):
x_out, meta_out = f(x), f.to('meta')(meta_x)
else:
x_out, meta_out = f(x), f(meta_x)
compare_all(x_out, meta_out)
if requires_backward:
x_out.sum().backward()
meta_out.sum().backward()
compare_all(x.grad, meta_x.grad)
@pytest.mark.skipif(INCOMPATIBLE, reason='torch version is lower than 1.12.0')
def test_meta_aten():
for (aten_op, requires_backward), v in registered_meta.items():
for f, x in v:
run_and_compare(f, x, requires_backward)
if __name__ == '__main__':
test_meta_aten()

View File

@ -0,0 +1,63 @@
import torchvision.models as tm
import timm.models as tmm
import torch
from colossalai.fx.profiler import MetaTensor
import pytest
try:
meta_lib = torch.library.Library("aten", "IMPL", "Meta")
incompatible = False # version > 1.12.0
except:
incompatible = True
tm_models = [
tm.vgg11,
tm.resnet18,
tm.densenet121,
tm.mobilenet_v3_small,
tm.resnext50_32x4d,
tm.wide_resnet50_2,
tm.regnet_x_16gf,
tm.mnasnet0_5,
tm.efficientnet_b0,
]
tmm_models = [
tmm.resnest.resnest50d,
tmm.beit.beit_base_patch16_224,
tmm.cait.cait_s24_224,
tmm.efficientnet.efficientnetv2_m,
tmm.resmlp_12_224,
tmm.vision_transformer.vit_base_patch16_224,
tmm.deit_base_distilled_patch16_224,
tmm.convnext.convnext_base,
tmm.vgg.vgg11,
tmm.dpn.dpn68,
tmm.densenet.densenet121,
tmm.rexnet.rexnet_100,
tmm.swin_transformer.swin_base_patch4_window7_224
]
@pytest.mark.skipif(incompatible, reason='torch version is lower than 1.12.0')
def test_torchvision_models():
for m in tm_models:
model = m().to('meta')
data = torch.rand(1000, 3, 224, 224, device='meta')
model(MetaTensor(data)).sum().backward()
@pytest.mark.skipif(incompatible, reason='torch version is lower than 1.12.0')
def test_timm_models():
for m in tmm_models:
model = m().to('meta')
data = torch.rand(1000, 3, 224, 224, device='meta')
model(MetaTensor(data)).sum().backward()
if __name__ == '__main__':
test_torchvision_models()
test_timm_models()