From ecd643f1e4e0100c08bd0765337fe5d2287f07dd Mon Sep 17 00:00:00 2001 From: YuliangLiu0306 <72588413+YuliangLiu0306@users.noreply.github.com> Date: Wed, 15 Mar 2023 13:46:04 +0800 Subject: [PATCH] [test] add torchrec models to test model zoo (#3139) --- tests/kit/model_zoo/__init__.py | 3 +- tests/kit/model_zoo/torchrec/torchrec.py | 97 ++++++++++++ .../test_torchrec_model/test_deepfm_model.py | 120 +++++++-------- .../test_torchrec_model/test_dlrm_model.py | 140 ++++++------------ 4 files changed, 200 insertions(+), 160 deletions(-) create mode 100644 tests/kit/model_zoo/torchrec/torchrec.py diff --git a/tests/kit/model_zoo/__init__.py b/tests/kit/model_zoo/__init__.py index 82a61626b..710038ffa 100644 --- a/tests/kit/model_zoo/__init__.py +++ b/tests/kit/model_zoo/__init__.py @@ -1,4 +1,5 @@ -from . import diffusers, timm, torchaudio, torchvision, transformers +from . import diffusers, timm, torchaudio, torchrec, torchvision, transformers + from .registry import model_zoo __all__ = ['model_zoo'] diff --git a/tests/kit/model_zoo/torchrec/torchrec.py b/tests/kit/model_zoo/torchrec/torchrec.py new file mode 100644 index 000000000..014e9218b --- /dev/null +++ b/tests/kit/model_zoo/torchrec/torchrec.py @@ -0,0 +1,97 @@ +from collections import namedtuple +from functools import partial + +import torch + +try: + from torchrec.models import deepfm, dlrm + from torchrec.modules.embedding_configs import EmbeddingBagConfig + from torchrec.modules.embedding_modules import EmbeddingBagCollection + from torchrec.sparse.jagged_tensor import KeyedJaggedTensor, KeyedTensor + NO_TORCHREC = False +except ImportError: + NO_TORCHREC = True + +from ..registry import ModelAttribute, model_zoo + + +def register_torchrec_models(): + BATCH = 2 + SHAPE = 10 + # KeyedTensor + KT = KeyedTensor(keys=["f1", "f2"], length_per_key=[SHAPE, SHAPE], values=torch.rand((BATCH, 2 * SHAPE))) + + # KeyedJaggedTensor + KJT = KeyedJaggedTensor.from_offsets_sync(keys=["f1", "f2"], + values=torch.tensor([1, 2, 3, 4, 5, 6, 7, 8]), + offsets=torch.tensor([0, 2, 4, 6, 8])) + + data_gen_fn = lambda: dict(features=torch.rand((BATCH, SHAPE))) + + interaction_arch_data_gen_fn = lambda: dict(dense_features=torch.rand((BATCH, SHAPE)), sparse_features=KT) + + simple_dfm_data_gen_fn = lambda: dict(dense_features=torch.rand((BATCH, SHAPE)), sparse_features=KJT) + + sparse_arch_data_gen_fn = lambda: dict(features=KJT) + + output_transform_fn = lambda x: dict(output=x) + + def get_ebc(): + # EmbeddingBagCollection + eb1_config = EmbeddingBagConfig(name="t1", embedding_dim=SHAPE, num_embeddings=SHAPE, feature_names=["f1"]) + eb2_config = EmbeddingBagConfig(name="t2", embedding_dim=SHAPE, num_embeddings=SHAPE, feature_names=["f2"]) + return EmbeddingBagCollection(tables=[eb1_config, eb2_config]) + + model_zoo.register(name='deepfm_densearch', + model_fn=partial(deepfm.DenseArch, SHAPE, SHAPE, SHAPE), + data_gen_fn=data_gen_fn, + output_transform_fn=output_transform_fn) + + model_zoo.register(name='deepfm_interactionarch', + model_fn=partial(deepfm.FMInteractionArch, SHAPE * 3, ["f1", "f2"], SHAPE), + data_gen_fn=interaction_arch_data_gen_fn, + output_transform_fn=output_transform_fn) + + model_zoo.register(name='deepfm_overarch', + model_fn=partial(deepfm.OverArch, SHAPE), + data_gen_fn=data_gen_fn, + output_transform_fn=output_transform_fn) + + model_zoo.register(name='deepfm_simpledeepfmnn', + model_fn=partial(deepfm.SimpleDeepFMNN, SHAPE, get_ebc(), SHAPE, SHAPE), + data_gen_fn=simple_dfm_data_gen_fn, + output_transform_fn=output_transform_fn) + + model_zoo.register(name='deepfm_sparsearch', + model_fn=partial(deepfm.SparseArch, get_ebc()), + data_gen_fn=sparse_arch_data_gen_fn, + output_transform_fn=output_transform_fn) + + model_zoo.register(name='dlrm', + model_fn=partial(dlrm.DLRM, get_ebc(), SHAPE, [SHAPE, SHAPE], [5, 1]), + data_gen_fn=simple_dfm_data_gen_fn, + output_transform_fn=output_transform_fn) + + model_zoo.register(name='dlrm_densearch', + model_fn=partial(dlrm.DenseArch, SHAPE, [SHAPE, SHAPE]), + data_gen_fn=data_gen_fn, + output_transform_fn=output_transform_fn) + + model_zoo.register(name='dlrm_interactionarch', + model_fn=partial(dlrm.InteractionArch, 2), + data_gen_fn=interaction_arch_data_gen_fn, + output_transform_fn=output_transform_fn) + + model_zoo.register(name='dlrm_overarch', + model_fn=partial(dlrm.OverArch, SHAPE, [5, 1]), + data_gen_fn=data_gen_fn, + output_transform_fn=output_transform_fn) + + model_zoo.register(name='dlrm_sparsearch', + model_fn=partial(dlrm.SparseArch, get_ebc()), + data_gen_fn=sparse_arch_data_gen_fn, + output_transform_fn=output_transform_fn) + + +if not NO_TORCHREC: + register_torchrec_models() diff --git a/tests/test_fx/test_tracer/test_torchrec_model/test_deepfm_model.py b/tests/test_fx/test_tracer/test_torchrec_model/test_deepfm_model.py index dbe8a62e7..6cbca343d 100644 --- a/tests/test_fx/test_tracer/test_torchrec_model/test_deepfm_model.py +++ b/tests/test_fx/test_tracer/test_torchrec_model/test_deepfm_model.py @@ -2,85 +2,69 @@ import pytest import torch from colossalai.fx import symbolic_trace - -try: - from torchrec.models import deepfm - from torchrec.modules.embedding_configs import EmbeddingBagConfig - from torchrec.modules.embedding_modules import EmbeddingBagCollection - from torchrec.sparse.jagged_tensor import KeyedJaggedTensor, KeyedTensor - NOT_TORCHREC = False -except ImportError: - NOT_TORCHREC = True +from tests.kit.model_zoo import model_zoo BATCH = 2 SHAPE = 10 +deepfm_models = model_zoo.get_sub_registry('deepfm') +NOT_DFM = False +if not deepfm_models: + NOT_DFM = True -@pytest.mark.skipif(NOT_TORCHREC, reason='torchrec is not installed') -def test_torchrec_deepfm_models(): - MODEL_LIST = [deepfm.DenseArch, deepfm.FMInteractionArch, deepfm.OverArch, deepfm.SimpleDeepFMNN, deepfm.SparseArch] - # Data Preparation - # EmbeddingBagCollection - eb1_config = EmbeddingBagConfig(name="t1", embedding_dim=SHAPE, num_embeddings=SHAPE, feature_names=["f1"]) - eb2_config = EmbeddingBagConfig(name="t2", embedding_dim=SHAPE, num_embeddings=SHAPE, feature_names=["f2"]) +def trace_and_compare(model_cls, data, output_transform_fn, meta_args=None): + # trace + model = model_cls() - ebc = EmbeddingBagCollection(tables=[eb1_config, eb2_config]) - keys = ["f1", "f2"] + # convert to eval for inference + # it is important to set it to eval mode before tracing + # without this statement, the torch.nn.functional.batch_norm will always be in training mode + model.eval() - # KeyedTensor - KT = KeyedTensor(keys=keys, length_per_key=[SHAPE, SHAPE], values=torch.rand((BATCH, 2 * SHAPE))) + gm = symbolic_trace(model, meta_args=meta_args) + gm.eval() + # run forward + with torch.no_grad(): + fx_out = gm(**data) + non_fx_out = model(**data) - # KeyedJaggedTensor - KJT = KeyedJaggedTensor.from_offsets_sync(keys=keys, - values=torch.tensor([1, 2, 3, 4, 5, 6, 7, 8]), - offsets=torch.tensor([0, 2, 4, 6, 8])) + # compare output + transformed_fx_out = output_transform_fn(fx_out) + transformed_non_fx_out = output_transform_fn(non_fx_out) - # Dense Features - features = torch.rand((BATCH, SHAPE)) - - for model_cls in MODEL_LIST: - # Initializing model - if model_cls == deepfm.DenseArch: - model = model_cls(SHAPE, SHAPE, SHAPE) - elif model_cls == deepfm.FMInteractionArch: - model = model_cls(SHAPE * 3, keys, SHAPE) - elif model_cls == deepfm.OverArch: - model = model_cls(SHAPE) - elif model_cls == deepfm.SimpleDeepFMNN: - model = model_cls(SHAPE, ebc, SHAPE, SHAPE) - elif model_cls == deepfm.SparseArch: - model = model_cls(ebc) - - # Setup GraphModule - gm = symbolic_trace(model) - - model.eval() - gm.eval() - - # Aligned Test - with torch.no_grad(): - if model_cls == deepfm.DenseArch or model_cls == deepfm.OverArch: - fx_out = gm(features) - non_fx_out = model(features) - elif model_cls == deepfm.FMInteractionArch: - fx_out = gm(features, KT) - non_fx_out = model(features, KT) - elif model_cls == deepfm.SimpleDeepFMNN: - fx_out = gm(features, KJT) - non_fx_out = model(features, KJT) - elif model_cls == deepfm.SparseArch: - fx_out = gm(KJT) - non_fx_out = model(KJT) - - if torch.is_tensor(fx_out): - assert torch.allclose( - fx_out, non_fx_out), f'{model.__class__.__name__} has inconsistent outputs, {fx_out} vs {non_fx_out}' + assert len(transformed_fx_out) == len(transformed_non_fx_out) + if torch.is_tensor(fx_out): + assert torch.allclose( + fx_out, non_fx_out), f'{model.__class__.__name__} has inconsistent outputs, {fx_out} vs {non_fx_out}' + else: + assert torch.allclose( + fx_out.values(), + non_fx_out.values()), f'{model.__class__.__name__} has inconsistent outputs, {fx_out} vs {non_fx_out}' + for key in transformed_fx_out.keys(): + fx_output_val = transformed_fx_out[key] + non_fx_output_val = transformed_non_fx_out[key] + if torch.is_tensor(fx_output_val): + assert torch.allclose(fx_output_val, non_fx_output_val, atol=1e-5), \ + f'{model.__class__.__name__} has inconsistent outputs, {fx_output_val} vs {non_fx_output_val}' else: - assert torch.allclose( - fx_out.values(), - non_fx_out.values()), f'{model.__class__.__name__} has inconsistent outputs, {fx_out} vs {non_fx_out}' + assert torch.allclose(fx_output_val.values(), non_fx_output_val.values() + ), f'{model.__class__.__name__} has inconsistent outputs, {fx_out} vs {non_fx_out}' + + +@pytest.mark.skipif(NOT_DFM, reason='torchrec is not installed') +def test_torchrec_deepfm_models(deepfm_models): + torch.backends.cudnn.deterministic = True + + for name, (model_fn, data_gen_fn, output_transform_fn, attribute) in deepfm_models.items(): + data = data_gen_fn() + if attribute is not None and attribute.has_control_flow: + meta_args = {k: v.to('meta') for k, v in data.items()} + else: + meta_args = None + + trace_and_compare(model_fn, data, output_transform_fn, meta_args) if __name__ == "__main__": - test_torchrec_deepfm_models() + test_torchrec_deepfm_models(deepfm_models) diff --git a/tests/test_fx/test_tracer/test_torchrec_model/test_dlrm_model.py b/tests/test_fx/test_tracer/test_torchrec_model/test_dlrm_model.py index 2f9fd8fe5..7aa868265 100644 --- a/tests/test_fx/test_tracer/test_torchrec_model/test_dlrm_model.py +++ b/tests/test_fx/test_tracer/test_torchrec_model/test_dlrm_model.py @@ -1,112 +1,70 @@ +import pytest import torch from colossalai.fx import symbolic_trace - -try: - from torchrec.models import dlrm - from torchrec.modules.embedding_configs import EmbeddingBagConfig - from torchrec.modules.embedding_modules import EmbeddingBagCollection - from torchrec.sparse.jagged_tensor import KeyedJaggedTensor, KeyedTensor - NOT_TORCHREC = False -except ImportError: - NOT_TORCHREC = True - -import pytest +from tests.kit.model_zoo import model_zoo BATCH = 2 SHAPE = 10 +dlrm_models = model_zoo.get_sub_registry('dlrm') +NOT_DLRM = False +if not dlrm_models: + NOT_DLRM = True -@pytest.mark.skipif(NOT_TORCHREC, reason='torchrec is not installed') -def test_torchrec_dlrm_models(): - MODEL_LIST = [ - dlrm.DLRM, - dlrm.DenseArch, - dlrm.InteractionArch, - dlrm.InteractionV2Arch, - dlrm.OverArch, - dlrm.SparseArch, - # dlrm.DLRMV2 - ] - # Data Preparation - # EmbeddingBagCollection - eb1_config = EmbeddingBagConfig(name="t1", embedding_dim=SHAPE, num_embeddings=SHAPE, feature_names=["f1"]) - eb2_config = EmbeddingBagConfig(name="t2", embedding_dim=SHAPE, num_embeddings=SHAPE, feature_names=["f2"]) +def trace_and_compare(model_cls, data, output_transform_fn, meta_args=None): + # trace + model = model_cls() - ebc = EmbeddingBagCollection(tables=[eb1_config, eb2_config]) - keys = ["f1", "f2"] + # convert to eval for inference + # it is important to set it to eval mode before tracing + # without this statement, the torch.nn.functional.batch_norm will always be in training mode + model.eval() - # KeyedTensor - KT = KeyedTensor(keys=keys, length_per_key=[SHAPE, SHAPE], values=torch.rand((BATCH, 2 * SHAPE))) + gm = symbolic_trace(model, meta_args=meta_args) + gm.eval() + # run forward + with torch.no_grad(): + fx_out = gm(**data) + non_fx_out = model(**data) - # KeyedJaggedTensor - KJT = KeyedJaggedTensor.from_offsets_sync(keys=keys, - values=torch.tensor([1, 2, 3, 4, 5, 6, 7, 8]), - offsets=torch.tensor([0, 2, 4, 6, 8])) + # compare output + transformed_fx_out = output_transform_fn(fx_out) + transformed_non_fx_out = output_transform_fn(non_fx_out) - # Dense Features - dense_features = torch.rand((BATCH, SHAPE)) - - # Sparse Features - sparse_features = torch.rand((BATCH, len(keys), SHAPE)) - - for model_cls in MODEL_LIST: - # Initializing model - if model_cls == dlrm.DLRM: - model = model_cls(ebc, SHAPE, [SHAPE, SHAPE], [5, 1]) - elif model_cls == dlrm.DenseArch: - model = model_cls(SHAPE, [SHAPE, SHAPE]) - elif model_cls == dlrm.InteractionArch: - model = model_cls(len(keys)) - elif model_cls == dlrm.InteractionV2Arch: - I1 = dlrm.DenseArch(3 * SHAPE, [3 * SHAPE, 3 * SHAPE]) - I2 = dlrm.DenseArch(3 * SHAPE, [3 * SHAPE, 3 * SHAPE]) - model = model_cls(len(keys), I1, I2) - elif model_cls == dlrm.OverArch: - model = model_cls(SHAPE, [5, 1]) - elif model_cls == dlrm.SparseArch: - model = model_cls(ebc) - elif model_cls == dlrm.DLRMV2: - # Currently DLRMV2 cannot be traced - model = model_cls(ebc, SHAPE, [SHAPE, SHAPE], [5, 1], [4 * SHAPE, 4 * SHAPE], [4 * SHAPE, 4 * SHAPE]) - - # Setup GraphModule - if model_cls == dlrm.InteractionV2Arch: - concrete_args = {"dense_features": dense_features, "sparse_features": sparse_features} - gm = symbolic_trace(model, concrete_args=concrete_args) + assert len(transformed_fx_out) == len(transformed_non_fx_out) + if torch.is_tensor(fx_out): + assert torch.allclose( + fx_out, non_fx_out), f'{model.__class__.__name__} has inconsistent outputs, {fx_out} vs {non_fx_out}' + else: + assert torch.allclose( + fx_out.values(), + non_fx_out.values()), f'{model.__class__.__name__} has inconsistent outputs, {fx_out} vs {non_fx_out}' + for key in transformed_fx_out.keys(): + fx_output_val = transformed_fx_out[key] + non_fx_output_val = transformed_non_fx_out[key] + if torch.is_tensor(fx_output_val): + assert torch.allclose(fx_output_val, non_fx_output_val, atol=1e-5), \ + f'{model.__class__.__name__} has inconsistent outputs, {fx_output_val} vs {non_fx_output_val}' else: - gm = symbolic_trace(model) + assert torch.allclose(fx_output_val.values(), non_fx_output_val.values() + ), f'{model.__class__.__name__} has inconsistent outputs, {fx_out} vs {non_fx_out}' - model.eval() - gm.eval() - # Aligned Test - with torch.no_grad(): - if model_cls == dlrm.DLRM or model_cls == dlrm.DLRMV2: - fx_out = gm(dense_features, KJT) - non_fx_out = model(dense_features, KJT) - elif model_cls == dlrm.DenseArch: - fx_out = gm(dense_features) - non_fx_out = model(dense_features) - elif model_cls == dlrm.InteractionArch or model_cls == dlrm.InteractionV2Arch: - fx_out = gm(dense_features, sparse_features) - non_fx_out = model(dense_features, sparse_features) - elif model_cls == dlrm.OverArch: - fx_out = gm(dense_features) - non_fx_out = model(dense_features) - elif model_cls == dlrm.SparseArch: - fx_out = gm(KJT) - non_fx_out = model(KJT) +@pytest.mark.skipif(NOT_DLRM, reason='torchrec is not installed') +def test_torchrec_dlrm_models(dlrm_models): + torch.backends.cudnn.deterministic = True - if torch.is_tensor(fx_out): - assert torch.allclose( - fx_out, non_fx_out), f'{model.__class__.__name__} has inconsistent outputs, {fx_out} vs {non_fx_out}' + for name, (model_fn, data_gen_fn, output_transform_fn, attribute) in dlrm_models.items(): + data = data_gen_fn() + if attribute is not None and attribute.has_control_flow: + meta_args = {k: v.to('meta') for k, v in data.items()} else: - assert torch.allclose( - fx_out.values(), - non_fx_out.values()), f'{model.__class__.__name__} has inconsistent outputs, {fx_out} vs {non_fx_out}' + meta_args = None + + trace_and_compare(model_fn, data, output_transform_fn, meta_args) if __name__ == "__main__": - test_torchrec_dlrm_models() + test_torchrec_dlrm_models(dlrm_models)