From 085e7f4eff832f2510d8023a9821206ab1894b2e Mon Sep 17 00:00:00 2001 From: Frank Lee Date: Mon, 20 Mar 2023 16:19:06 +0800 Subject: [PATCH] [test] fixed torchrec registration in model zoo (#3177) * [test] fixed torchrec registration in model zoo * polish code * polish code * polish code --- tests/kit/model_zoo/torchrec/torchrec.py | 72 +++++++++++++++---- .../test_torchrec_model/test_deepfm_model.py | 1 - .../test_torchrec_model/test_dlrm_model.py | 1 - 3 files changed, 59 insertions(+), 15 deletions(-) diff --git a/tests/kit/model_zoo/torchrec/torchrec.py b/tests/kit/model_zoo/torchrec/torchrec.py index 03d95a06a..dda563155 100644 --- a/tests/kit/model_zoo/torchrec/torchrec.py +++ b/tests/kit/model_zoo/torchrec/torchrec.py @@ -11,21 +11,47 @@ from ..registry import ModelAttribute, model_zoo BATCH = 2 SHAPE = 10 -# KeyedTensor -KT = KeyedTensor(keys=["f1", "f2"], length_per_key=[SHAPE, SHAPE], values=torch.rand((BATCH, 2 * SHAPE))) + + +def gen_kt(): + KT = KeyedTensor(keys=["f1", "f2"], length_per_key=[SHAPE, SHAPE], values=torch.rand((BATCH, 2 * SHAPE))) + return KT + # KeyedJaggedTensor -KJT = KeyedJaggedTensor.from_offsets_sync(keys=["f1", "f2"], - values=torch.tensor([1, 2, 3, 4, 5, 6, 7, 8]), - offsets=torch.tensor([0, 2, 4, 6, 8])) +def gen_kjt(): + KJT = KeyedJaggedTensor.from_offsets_sync(keys=["f1", "f2"], + values=torch.tensor([1, 2, 3, 4, 5, 6, 7, 8]), + offsets=torch.tensor([0, 2, 4, 6, 8])) + return KJT + data_gen_fn = lambda: dict(features=torch.rand((BATCH, SHAPE))) -interaction_arch_data_gen_fn = lambda: dict(dense_features=torch.rand((BATCH, SHAPE)), sparse_features=KT) -simple_dfm_data_gen_fn = lambda: dict(dense_features=torch.rand((BATCH, SHAPE)), sparse_features=KJT) +def interaction_arch_data_gen_fn(): + KT = gen_kt() + return dict(dense_features=torch.rand((BATCH, SHAPE)), sparse_features=KT) + + +def simple_dfm_data_gen_fn(): + KJT = gen_kjt() + return dict(dense_features=torch.rand((BATCH, SHAPE)), sparse_features=KJT) + -sparse_arch_data_gen_fn = lambda: dict(features=KJT) +def sparse_arch_data_gen_fn(): + KJT = gen_kjt() + return dict(features=KJT) + + +def output_transform_fn(x): + if isinstance(x, KeyedTensor): + output = dict() + for key in x.keys(): + output[key] = x[key] + return output + else: + return dict(output=x) def output_transform_fn(x): @@ -42,7 +68,27 @@ def get_ebc(): # EmbeddingBagCollection eb1_config = EmbeddingBagConfig(name="t1", embedding_dim=SHAPE, num_embeddings=SHAPE, feature_names=["f1"]) eb2_config = EmbeddingBagConfig(name="t2", embedding_dim=SHAPE, num_embeddings=SHAPE, feature_names=["f2"]) - return EmbeddingBagCollection(tables=[eb1_config, eb2_config]) + return EmbeddingBagCollection(tables=[eb1_config, eb2_config], device=torch.device('cpu')) + + +def sparse_arch_model_fn(): + ebc = get_ebc() + return deepfm.SparseArch(ebc) + + +def simple_deep_fmnn_model_fn(): + ebc = get_ebc() + return deepfm.SimpleDeepFMNN(SHAPE, ebc, SHAPE, SHAPE) + + +def dlrm_model_fn(): + ebc = get_ebc() + return dlrm.DLRM(ebc, SHAPE, [SHAPE, SHAPE], [5, 1]) + + +def dlrm_sparsearch_model_fn(): + ebc = get_ebc() + return dlrm.SparseArch(ebc) model_zoo.register(name='deepfm_densearch', @@ -61,17 +107,17 @@ model_zoo.register(name='deepfm_overarch', output_transform_fn=output_transform_fn) model_zoo.register(name='deepfm_simpledeepfmnn', - model_fn=partial(deepfm.SimpleDeepFMNN, SHAPE, get_ebc(), SHAPE, SHAPE), + model_fn=simple_deep_fmnn_model_fn, data_gen_fn=simple_dfm_data_gen_fn, output_transform_fn=output_transform_fn) model_zoo.register(name='deepfm_sparsearch', - model_fn=partial(deepfm.SparseArch, get_ebc()), + model_fn=sparse_arch_model_fn, data_gen_fn=sparse_arch_data_gen_fn, output_transform_fn=output_transform_fn) model_zoo.register(name='dlrm', - model_fn=partial(dlrm.DLRM, get_ebc(), SHAPE, [SHAPE, SHAPE], [5, 1]), + model_fn=dlrm_model_fn, data_gen_fn=simple_dfm_data_gen_fn, output_transform_fn=output_transform_fn) @@ -91,6 +137,6 @@ model_zoo.register(name='dlrm_overarch', output_transform_fn=output_transform_fn) model_zoo.register(name='dlrm_sparsearch', - model_fn=partial(dlrm.SparseArch, get_ebc()), + model_fn=dlrm_sparsearch_model_fn, data_gen_fn=sparse_arch_data_gen_fn, output_transform_fn=output_transform_fn) diff --git a/tests/test_fx/test_tracer/test_torchrec_model/test_deepfm_model.py b/tests/test_fx/test_tracer/test_torchrec_model/test_deepfm_model.py index a30139f26..a4e847dbc 100644 --- a/tests/test_fx/test_tracer/test_torchrec_model/test_deepfm_model.py +++ b/tests/test_fx/test_tracer/test_torchrec_model/test_deepfm_model.py @@ -47,7 +47,6 @@ def trace_and_compare(model_cls, data, output_transform_fn, meta_args=None): ), f'{model.__class__.__name__} has inconsistent outputs, {fx_out} vs {non_fx_out}' -@pytest.mark.skip('unknown error') def test_torchrec_deepfm_models(): deepfm_models = model_zoo.get_sub_registry('deepfm') torch.backends.cudnn.deterministic = True diff --git a/tests/test_fx/test_tracer/test_torchrec_model/test_dlrm_model.py b/tests/test_fx/test_tracer/test_torchrec_model/test_dlrm_model.py index 71ecf7fca..ac377ff1d 100644 --- a/tests/test_fx/test_tracer/test_torchrec_model/test_dlrm_model.py +++ b/tests/test_fx/test_tracer/test_torchrec_model/test_dlrm_model.py @@ -47,7 +47,6 @@ def trace_and_compare(model_cls, data, output_transform_fn, meta_args=None): ), f'{model.__class__.__name__} has inconsistent outputs, {fx_out} vs {non_fx_out}' -@pytest.mark.skip('unknown error') def test_torchrec_dlrm_models(): torch.backends.cudnn.deterministic = True dlrm_models = model_zoo.get_sub_registry('dlrm')