[test] add torchrec models to test model zoo (#3139)

pull/3153/head
YuliangLiu0306 2023-03-15 13:46:04 +08:00 committed by GitHub
parent 14a115000b
commit ecd643f1e4
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 200 additions and 160 deletions

View File

@ -1,4 +1,5 @@
from . import diffusers, timm, torchaudio, torchvision, transformers
from . import diffusers, timm, torchaudio, torchrec, torchvision, transformers
from .registry import model_zoo
__all__ = ['model_zoo']

View File

@ -0,0 +1,97 @@
from collections import namedtuple
from functools import partial
import torch
try:
from torchrec.models import deepfm, dlrm
from torchrec.modules.embedding_configs import EmbeddingBagConfig
from torchrec.modules.embedding_modules import EmbeddingBagCollection
from torchrec.sparse.jagged_tensor import KeyedJaggedTensor, KeyedTensor
NO_TORCHREC = False
except ImportError:
NO_TORCHREC = True
from ..registry import ModelAttribute, model_zoo
def register_torchrec_models():
BATCH = 2
SHAPE = 10
# KeyedTensor
KT = KeyedTensor(keys=["f1", "f2"], length_per_key=[SHAPE, SHAPE], values=torch.rand((BATCH, 2 * SHAPE)))
# KeyedJaggedTensor
KJT = KeyedJaggedTensor.from_offsets_sync(keys=["f1", "f2"],
values=torch.tensor([1, 2, 3, 4, 5, 6, 7, 8]),
offsets=torch.tensor([0, 2, 4, 6, 8]))
data_gen_fn = lambda: dict(features=torch.rand((BATCH, SHAPE)))
interaction_arch_data_gen_fn = lambda: dict(dense_features=torch.rand((BATCH, SHAPE)), sparse_features=KT)
simple_dfm_data_gen_fn = lambda: dict(dense_features=torch.rand((BATCH, SHAPE)), sparse_features=KJT)
sparse_arch_data_gen_fn = lambda: dict(features=KJT)
output_transform_fn = lambda x: dict(output=x)
def get_ebc():
# EmbeddingBagCollection
eb1_config = EmbeddingBagConfig(name="t1", embedding_dim=SHAPE, num_embeddings=SHAPE, feature_names=["f1"])
eb2_config = EmbeddingBagConfig(name="t2", embedding_dim=SHAPE, num_embeddings=SHAPE, feature_names=["f2"])
return EmbeddingBagCollection(tables=[eb1_config, eb2_config])
model_zoo.register(name='deepfm_densearch',
model_fn=partial(deepfm.DenseArch, SHAPE, SHAPE, SHAPE),
data_gen_fn=data_gen_fn,
output_transform_fn=output_transform_fn)
model_zoo.register(name='deepfm_interactionarch',
model_fn=partial(deepfm.FMInteractionArch, SHAPE * 3, ["f1", "f2"], SHAPE),
data_gen_fn=interaction_arch_data_gen_fn,
output_transform_fn=output_transform_fn)
model_zoo.register(name='deepfm_overarch',
model_fn=partial(deepfm.OverArch, SHAPE),
data_gen_fn=data_gen_fn,
output_transform_fn=output_transform_fn)
model_zoo.register(name='deepfm_simpledeepfmnn',
model_fn=partial(deepfm.SimpleDeepFMNN, SHAPE, get_ebc(), SHAPE, SHAPE),
data_gen_fn=simple_dfm_data_gen_fn,
output_transform_fn=output_transform_fn)
model_zoo.register(name='deepfm_sparsearch',
model_fn=partial(deepfm.SparseArch, get_ebc()),
data_gen_fn=sparse_arch_data_gen_fn,
output_transform_fn=output_transform_fn)
model_zoo.register(name='dlrm',
model_fn=partial(dlrm.DLRM, get_ebc(), SHAPE, [SHAPE, SHAPE], [5, 1]),
data_gen_fn=simple_dfm_data_gen_fn,
output_transform_fn=output_transform_fn)
model_zoo.register(name='dlrm_densearch',
model_fn=partial(dlrm.DenseArch, SHAPE, [SHAPE, SHAPE]),
data_gen_fn=data_gen_fn,
output_transform_fn=output_transform_fn)
model_zoo.register(name='dlrm_interactionarch',
model_fn=partial(dlrm.InteractionArch, 2),
data_gen_fn=interaction_arch_data_gen_fn,
output_transform_fn=output_transform_fn)
model_zoo.register(name='dlrm_overarch',
model_fn=partial(dlrm.OverArch, SHAPE, [5, 1]),
data_gen_fn=data_gen_fn,
output_transform_fn=output_transform_fn)
model_zoo.register(name='dlrm_sparsearch',
model_fn=partial(dlrm.SparseArch, get_ebc()),
data_gen_fn=sparse_arch_data_gen_fn,
output_transform_fn=output_transform_fn)
if not NO_TORCHREC:
register_torchrec_models()

View File

@ -2,85 +2,69 @@ import pytest
import torch
from colossalai.fx import symbolic_trace
try:
from torchrec.models import deepfm
from torchrec.modules.embedding_configs import EmbeddingBagConfig
from torchrec.modules.embedding_modules import EmbeddingBagCollection
from torchrec.sparse.jagged_tensor import KeyedJaggedTensor, KeyedTensor
NOT_TORCHREC = False
except ImportError:
NOT_TORCHREC = True
from tests.kit.model_zoo import model_zoo
BATCH = 2
SHAPE = 10
deepfm_models = model_zoo.get_sub_registry('deepfm')
NOT_DFM = False
if not deepfm_models:
NOT_DFM = True
@pytest.mark.skipif(NOT_TORCHREC, reason='torchrec is not installed')
def test_torchrec_deepfm_models():
MODEL_LIST = [deepfm.DenseArch, deepfm.FMInteractionArch, deepfm.OverArch, deepfm.SimpleDeepFMNN, deepfm.SparseArch]
# Data Preparation
# EmbeddingBagCollection
eb1_config = EmbeddingBagConfig(name="t1", embedding_dim=SHAPE, num_embeddings=SHAPE, feature_names=["f1"])
eb2_config = EmbeddingBagConfig(name="t2", embedding_dim=SHAPE, num_embeddings=SHAPE, feature_names=["f2"])
def trace_and_compare(model_cls, data, output_transform_fn, meta_args=None):
# trace
model = model_cls()
ebc = EmbeddingBagCollection(tables=[eb1_config, eb2_config])
keys = ["f1", "f2"]
# convert to eval for inference
# it is important to set it to eval mode before tracing
# without this statement, the torch.nn.functional.batch_norm will always be in training mode
model.eval()
# KeyedTensor
KT = KeyedTensor(keys=keys, length_per_key=[SHAPE, SHAPE], values=torch.rand((BATCH, 2 * SHAPE)))
gm = symbolic_trace(model, meta_args=meta_args)
gm.eval()
# run forward
with torch.no_grad():
fx_out = gm(**data)
non_fx_out = model(**data)
# KeyedJaggedTensor
KJT = KeyedJaggedTensor.from_offsets_sync(keys=keys,
values=torch.tensor([1, 2, 3, 4, 5, 6, 7, 8]),
offsets=torch.tensor([0, 2, 4, 6, 8]))
# compare output
transformed_fx_out = output_transform_fn(fx_out)
transformed_non_fx_out = output_transform_fn(non_fx_out)
# Dense Features
features = torch.rand((BATCH, SHAPE))
for model_cls in MODEL_LIST:
# Initializing model
if model_cls == deepfm.DenseArch:
model = model_cls(SHAPE, SHAPE, SHAPE)
elif model_cls == deepfm.FMInteractionArch:
model = model_cls(SHAPE * 3, keys, SHAPE)
elif model_cls == deepfm.OverArch:
model = model_cls(SHAPE)
elif model_cls == deepfm.SimpleDeepFMNN:
model = model_cls(SHAPE, ebc, SHAPE, SHAPE)
elif model_cls == deepfm.SparseArch:
model = model_cls(ebc)
# Setup GraphModule
gm = symbolic_trace(model)
model.eval()
gm.eval()
# Aligned Test
with torch.no_grad():
if model_cls == deepfm.DenseArch or model_cls == deepfm.OverArch:
fx_out = gm(features)
non_fx_out = model(features)
elif model_cls == deepfm.FMInteractionArch:
fx_out = gm(features, KT)
non_fx_out = model(features, KT)
elif model_cls == deepfm.SimpleDeepFMNN:
fx_out = gm(features, KJT)
non_fx_out = model(features, KJT)
elif model_cls == deepfm.SparseArch:
fx_out = gm(KJT)
non_fx_out = model(KJT)
if torch.is_tensor(fx_out):
assert torch.allclose(
fx_out, non_fx_out), f'{model.__class__.__name__} has inconsistent outputs, {fx_out} vs {non_fx_out}'
assert len(transformed_fx_out) == len(transformed_non_fx_out)
if torch.is_tensor(fx_out):
assert torch.allclose(
fx_out, non_fx_out), f'{model.__class__.__name__} has inconsistent outputs, {fx_out} vs {non_fx_out}'
else:
assert torch.allclose(
fx_out.values(),
non_fx_out.values()), f'{model.__class__.__name__} has inconsistent outputs, {fx_out} vs {non_fx_out}'
for key in transformed_fx_out.keys():
fx_output_val = transformed_fx_out[key]
non_fx_output_val = transformed_non_fx_out[key]
if torch.is_tensor(fx_output_val):
assert torch.allclose(fx_output_val, non_fx_output_val, atol=1e-5), \
f'{model.__class__.__name__} has inconsistent outputs, {fx_output_val} vs {non_fx_output_val}'
else:
assert torch.allclose(
fx_out.values(),
non_fx_out.values()), f'{model.__class__.__name__} has inconsistent outputs, {fx_out} vs {non_fx_out}'
assert torch.allclose(fx_output_val.values(), non_fx_output_val.values()
), f'{model.__class__.__name__} has inconsistent outputs, {fx_out} vs {non_fx_out}'
@pytest.mark.skipif(NOT_DFM, reason='torchrec is not installed')
def test_torchrec_deepfm_models(deepfm_models):
torch.backends.cudnn.deterministic = True
for name, (model_fn, data_gen_fn, output_transform_fn, attribute) in deepfm_models.items():
data = data_gen_fn()
if attribute is not None and attribute.has_control_flow:
meta_args = {k: v.to('meta') for k, v in data.items()}
else:
meta_args = None
trace_and_compare(model_fn, data, output_transform_fn, meta_args)
if __name__ == "__main__":
test_torchrec_deepfm_models()
test_torchrec_deepfm_models(deepfm_models)

View File

@ -1,112 +1,70 @@
import pytest
import torch
from colossalai.fx import symbolic_trace
try:
from torchrec.models import dlrm
from torchrec.modules.embedding_configs import EmbeddingBagConfig
from torchrec.modules.embedding_modules import EmbeddingBagCollection
from torchrec.sparse.jagged_tensor import KeyedJaggedTensor, KeyedTensor
NOT_TORCHREC = False
except ImportError:
NOT_TORCHREC = True
import pytest
from tests.kit.model_zoo import model_zoo
BATCH = 2
SHAPE = 10
dlrm_models = model_zoo.get_sub_registry('dlrm')
NOT_DLRM = False
if not dlrm_models:
NOT_DLRM = True
@pytest.mark.skipif(NOT_TORCHREC, reason='torchrec is not installed')
def test_torchrec_dlrm_models():
MODEL_LIST = [
dlrm.DLRM,
dlrm.DenseArch,
dlrm.InteractionArch,
dlrm.InteractionV2Arch,
dlrm.OverArch,
dlrm.SparseArch,
# dlrm.DLRMV2
]
# Data Preparation
# EmbeddingBagCollection
eb1_config = EmbeddingBagConfig(name="t1", embedding_dim=SHAPE, num_embeddings=SHAPE, feature_names=["f1"])
eb2_config = EmbeddingBagConfig(name="t2", embedding_dim=SHAPE, num_embeddings=SHAPE, feature_names=["f2"])
def trace_and_compare(model_cls, data, output_transform_fn, meta_args=None):
# trace
model = model_cls()
ebc = EmbeddingBagCollection(tables=[eb1_config, eb2_config])
keys = ["f1", "f2"]
# convert to eval for inference
# it is important to set it to eval mode before tracing
# without this statement, the torch.nn.functional.batch_norm will always be in training mode
model.eval()
# KeyedTensor
KT = KeyedTensor(keys=keys, length_per_key=[SHAPE, SHAPE], values=torch.rand((BATCH, 2 * SHAPE)))
gm = symbolic_trace(model, meta_args=meta_args)
gm.eval()
# run forward
with torch.no_grad():
fx_out = gm(**data)
non_fx_out = model(**data)
# KeyedJaggedTensor
KJT = KeyedJaggedTensor.from_offsets_sync(keys=keys,
values=torch.tensor([1, 2, 3, 4, 5, 6, 7, 8]),
offsets=torch.tensor([0, 2, 4, 6, 8]))
# compare output
transformed_fx_out = output_transform_fn(fx_out)
transformed_non_fx_out = output_transform_fn(non_fx_out)
# Dense Features
dense_features = torch.rand((BATCH, SHAPE))
# Sparse Features
sparse_features = torch.rand((BATCH, len(keys), SHAPE))
for model_cls in MODEL_LIST:
# Initializing model
if model_cls == dlrm.DLRM:
model = model_cls(ebc, SHAPE, [SHAPE, SHAPE], [5, 1])
elif model_cls == dlrm.DenseArch:
model = model_cls(SHAPE, [SHAPE, SHAPE])
elif model_cls == dlrm.InteractionArch:
model = model_cls(len(keys))
elif model_cls == dlrm.InteractionV2Arch:
I1 = dlrm.DenseArch(3 * SHAPE, [3 * SHAPE, 3 * SHAPE])
I2 = dlrm.DenseArch(3 * SHAPE, [3 * SHAPE, 3 * SHAPE])
model = model_cls(len(keys), I1, I2)
elif model_cls == dlrm.OverArch:
model = model_cls(SHAPE, [5, 1])
elif model_cls == dlrm.SparseArch:
model = model_cls(ebc)
elif model_cls == dlrm.DLRMV2:
# Currently DLRMV2 cannot be traced
model = model_cls(ebc, SHAPE, [SHAPE, SHAPE], [5, 1], [4 * SHAPE, 4 * SHAPE], [4 * SHAPE, 4 * SHAPE])
# Setup GraphModule
if model_cls == dlrm.InteractionV2Arch:
concrete_args = {"dense_features": dense_features, "sparse_features": sparse_features}
gm = symbolic_trace(model, concrete_args=concrete_args)
assert len(transformed_fx_out) == len(transformed_non_fx_out)
if torch.is_tensor(fx_out):
assert torch.allclose(
fx_out, non_fx_out), f'{model.__class__.__name__} has inconsistent outputs, {fx_out} vs {non_fx_out}'
else:
assert torch.allclose(
fx_out.values(),
non_fx_out.values()), f'{model.__class__.__name__} has inconsistent outputs, {fx_out} vs {non_fx_out}'
for key in transformed_fx_out.keys():
fx_output_val = transformed_fx_out[key]
non_fx_output_val = transformed_non_fx_out[key]
if torch.is_tensor(fx_output_val):
assert torch.allclose(fx_output_val, non_fx_output_val, atol=1e-5), \
f'{model.__class__.__name__} has inconsistent outputs, {fx_output_val} vs {non_fx_output_val}'
else:
gm = symbolic_trace(model)
assert torch.allclose(fx_output_val.values(), non_fx_output_val.values()
), f'{model.__class__.__name__} has inconsistent outputs, {fx_out} vs {non_fx_out}'
model.eval()
gm.eval()
# Aligned Test
with torch.no_grad():
if model_cls == dlrm.DLRM or model_cls == dlrm.DLRMV2:
fx_out = gm(dense_features, KJT)
non_fx_out = model(dense_features, KJT)
elif model_cls == dlrm.DenseArch:
fx_out = gm(dense_features)
non_fx_out = model(dense_features)
elif model_cls == dlrm.InteractionArch or model_cls == dlrm.InteractionV2Arch:
fx_out = gm(dense_features, sparse_features)
non_fx_out = model(dense_features, sparse_features)
elif model_cls == dlrm.OverArch:
fx_out = gm(dense_features)
non_fx_out = model(dense_features)
elif model_cls == dlrm.SparseArch:
fx_out = gm(KJT)
non_fx_out = model(KJT)
@pytest.mark.skipif(NOT_DLRM, reason='torchrec is not installed')
def test_torchrec_dlrm_models(dlrm_models):
torch.backends.cudnn.deterministic = True
if torch.is_tensor(fx_out):
assert torch.allclose(
fx_out, non_fx_out), f'{model.__class__.__name__} has inconsistent outputs, {fx_out} vs {non_fx_out}'
for name, (model_fn, data_gen_fn, output_transform_fn, attribute) in dlrm_models.items():
data = data_gen_fn()
if attribute is not None and attribute.has_control_flow:
meta_args = {k: v.to('meta') for k, v in data.items()}
else:
assert torch.allclose(
fx_out.values(),
non_fx_out.values()), f'{model.__class__.__name__} has inconsistent outputs, {fx_out} vs {non_fx_out}'
meta_args = None
trace_and_compare(model_fn, data, output_transform_fn, meta_args)
if __name__ == "__main__":
test_torchrec_dlrm_models()
test_torchrec_dlrm_models(dlrm_models)