mirror of https://github.com/hpcaitech/ColossalAI
[fx] added timm model tracing testing (#1221)
parent
280a81243d
commit
b6cb5a47ad
|
@ -1,3 +1,4 @@
|
||||||
|
from curses import meta
|
||||||
import operator
|
import operator
|
||||||
import torch
|
import torch
|
||||||
from .registry import meta_patched_function
|
from .registry import meta_patched_function
|
||||||
|
@ -99,7 +100,6 @@ def torch_abs(input, *, out=None):
|
||||||
|
|
||||||
@meta_patched_function.register(torch.nn.functional.relu)
|
@meta_patched_function.register(torch.nn.functional.relu)
|
||||||
def torch_nn_func_relu(input, inplace=False):
|
def torch_nn_func_relu(input, inplace=False):
|
||||||
assert not inplace, 'inplace is not supported yet'
|
|
||||||
return torch.empty(input.shape, device='meta')
|
return torch.empty(input.shape, device='meta')
|
||||||
|
|
||||||
|
|
||||||
|
@ -178,3 +178,43 @@ def torch_unsqueeze(input, dim):
|
||||||
@meta_patched_function.register(torch.Tensor.unsqueeze)
|
@meta_patched_function.register(torch.Tensor.unsqueeze)
|
||||||
def torch_tensor_unsqueeze(self, dim):
|
def torch_tensor_unsqueeze(self, dim):
|
||||||
return torch_unsqueeze(self, dim)
|
return torch_unsqueeze(self, dim)
|
||||||
|
|
||||||
|
|
||||||
|
@meta_patched_function.register(torch.nn.functional.layer_norm)
|
||||||
|
def torch_nn_func_layernorm(input, normalized_shape, weight=None, bias=None, eps=1e-05):
|
||||||
|
return torch.empty(input.shape, device='meta')
|
||||||
|
|
||||||
|
|
||||||
|
@meta_patched_function.register(torch.nn.functional.batch_norm)
|
||||||
|
def torch_nn_func_batchnorm(input,
|
||||||
|
running_mean,
|
||||||
|
running_var,
|
||||||
|
weight=None,
|
||||||
|
bias=None,
|
||||||
|
training=False,
|
||||||
|
momentum=0.1,
|
||||||
|
eps=1e-05):
|
||||||
|
return torch.empty(input.shape, device='meta')
|
||||||
|
|
||||||
|
|
||||||
|
@meta_patched_function.register(torch.var_mean)
|
||||||
|
def torch_var_mean(input, dim, unbiased=True, keepdim=False, *, out=None):
|
||||||
|
assert out is None, 'saving to out is not supported yet'
|
||||||
|
var = torch.empty(1).squeeze(0).to('meta')
|
||||||
|
mean = torch.empty(1).squeeze(0).to('meta')
|
||||||
|
return var, mean
|
||||||
|
|
||||||
|
|
||||||
|
@meta_patched_function.register(torch.cat)
|
||||||
|
def torch_cat(tensors, dim=None, axis=None, *, out=None):
|
||||||
|
if dim is None and axis is None:
|
||||||
|
dim = 0
|
||||||
|
if dim is None and axis is not None:
|
||||||
|
dim = axis
|
||||||
|
if dim < 0:
|
||||||
|
dim = tensors[0].dim() + dim
|
||||||
|
shapes = [t.shape for t in tensors]
|
||||||
|
shape = list(shapes[0])
|
||||||
|
concatenated_dim = sum(shape[dim] for shape in shapes)
|
||||||
|
final_shape = shape[:dim] + [concatenated_dim] + shape[dim + 1:]
|
||||||
|
return torch.empty(final_shape, device="meta")
|
||||||
|
|
|
@ -250,6 +250,6 @@ def torch_nn_maxpool3d(self, input):
|
||||||
|
|
||||||
|
|
||||||
@meta_patched_module.register(torch.nn.ReLU)
|
@meta_patched_module.register(torch.nn.ReLU)
|
||||||
|
@meta_patched_module.register(torch.nn.ReLU6)
|
||||||
def torch_nn_func_relu(self, input):
|
def torch_nn_func_relu(self, input):
|
||||||
assert not self.inplace, 'inplace is not supported yet'
|
return torch.empty(input.shape, device='meta')
|
||||||
return input.clone()
|
|
||||||
|
|
|
@ -3,7 +3,6 @@ from colossalai.fx.proxy import ColoProxy
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skip
|
|
||||||
def test_coloproxy():
|
def test_coloproxy():
|
||||||
# create a dummy node only for testing purpose
|
# create a dummy node only for testing purpose
|
||||||
model = torch.nn.Linear(10, 10)
|
model = torch.nn.Linear(10, 10)
|
||||||
|
|
|
@ -0,0 +1,82 @@
|
||||||
|
import torch
|
||||||
|
import pytest
|
||||||
|
try:
|
||||||
|
import timm.models as tm
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
from colossalai.fx import ColoTracer
|
||||||
|
from torch.fx import GraphModule
|
||||||
|
|
||||||
|
|
||||||
|
def trace_and_compare(model_cls, tracer, data, meta_args=None):
|
||||||
|
# trace
|
||||||
|
model = model_cls()
|
||||||
|
graph = tracer.trace(root=model, meta_args=meta_args)
|
||||||
|
gm = GraphModule(model, graph, model.__class__.__name__)
|
||||||
|
gm.recompile()
|
||||||
|
|
||||||
|
# convert to eval for inference
|
||||||
|
model.eval()
|
||||||
|
gm.eval()
|
||||||
|
|
||||||
|
# run forward
|
||||||
|
with torch.no_grad():
|
||||||
|
fx_out = gm(data)
|
||||||
|
non_fx_out = model(data)
|
||||||
|
|
||||||
|
# compare output
|
||||||
|
if isinstance(fx_out, tuple):
|
||||||
|
# some models produce tuple as output
|
||||||
|
for v1, v2 in zip(fx_out, non_fx_out):
|
||||||
|
assert torch.allclose(v1, v2), f'{model.__class__.__name__} has inconsistent outputs, {v1} vs {v2}'
|
||||||
|
else:
|
||||||
|
assert torch.allclose(
|
||||||
|
fx_out, non_fx_out), f'{model.__class__.__name__} has inconsistent outputs, {fx_out} vs {non_fx_out}'
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skip('skip as timm is required')
|
||||||
|
def test_timm_models_without_control_flow():
|
||||||
|
torch.backends.cudnn.deterministic = True
|
||||||
|
|
||||||
|
MODEL_LIST = [
|
||||||
|
tm.resnest.resnest50d, tm.beit.beit_base_patch16_224, tm.cait.cait_s24_224, tm.convmixer.convmixer_768_32,
|
||||||
|
tm.efficientnet.efficientnetv2_m, tm.resmlp_12_224, tm.vision_transformer.vit_base_patch16_224
|
||||||
|
|
||||||
|
# results not aligned
|
||||||
|
# tm.deit_base_distilled_patch16_224,
|
||||||
|
]
|
||||||
|
|
||||||
|
tracer = ColoTracer()
|
||||||
|
data = torch.rand(2, 3, 224, 224)
|
||||||
|
|
||||||
|
for model_cls in MODEL_LIST:
|
||||||
|
trace_and_compare(model_cls, tracer, data)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skip('skip as timm is required')
|
||||||
|
def test_timm_models_with_control_flow():
|
||||||
|
torch.backends.cudnn.deterministic = True
|
||||||
|
|
||||||
|
MODEL_LIST_WITH_CONTROL_FLOW = [
|
||||||
|
tm.convnext.convnext_base,
|
||||||
|
tm.vgg.vgg11,
|
||||||
|
|
||||||
|
# results not aligned
|
||||||
|
# tm.dpn.dpn68,
|
||||||
|
# tm.densenet.densenet121,
|
||||||
|
# tm.rexnet.rexnet_100,
|
||||||
|
# tm.swin_transformer.swin_base_patch4_window7_224
|
||||||
|
]
|
||||||
|
|
||||||
|
tracer = ColoTracer()
|
||||||
|
data = torch.rand(2, 3, 224, 224)
|
||||||
|
|
||||||
|
meta_args = {'x': data.to('meta')}
|
||||||
|
|
||||||
|
for model_cls in MODEL_LIST_WITH_CONTROL_FLOW:
|
||||||
|
trace_and_compare(model_cls, tracer, data, meta_args)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
test_timm_models_with_control_flow()
|
||||||
|
test_timm_models_without_control_flow()
|
Loading…
Reference in New Issue