mirror of https://github.com/hpcaitech/ColossalAI
[test] fixed torchrec registration in model zoo (#3177)
* [test] fixed torchrec registration in model zoo * polish code * polish code * polish codepull/3181/head
parent
4e921cfbd6
commit
085e7f4eff
|
@ -11,21 +11,47 @@ from ..registry import ModelAttribute, model_zoo
|
||||||
|
|
||||||
BATCH = 2
|
BATCH = 2
|
||||||
SHAPE = 10
|
SHAPE = 10
|
||||||
# KeyedTensor
|
|
||||||
KT = KeyedTensor(keys=["f1", "f2"], length_per_key=[SHAPE, SHAPE], values=torch.rand((BATCH, 2 * SHAPE)))
|
|
||||||
|
def gen_kt():
|
||||||
|
KT = KeyedTensor(keys=["f1", "f2"], length_per_key=[SHAPE, SHAPE], values=torch.rand((BATCH, 2 * SHAPE)))
|
||||||
|
return KT
|
||||||
|
|
||||||
|
|
||||||
# KeyedJaggedTensor
|
# KeyedJaggedTensor
|
||||||
KJT = KeyedJaggedTensor.from_offsets_sync(keys=["f1", "f2"],
|
def gen_kjt():
|
||||||
values=torch.tensor([1, 2, 3, 4, 5, 6, 7, 8]),
|
KJT = KeyedJaggedTensor.from_offsets_sync(keys=["f1", "f2"],
|
||||||
offsets=torch.tensor([0, 2, 4, 6, 8]))
|
values=torch.tensor([1, 2, 3, 4, 5, 6, 7, 8]),
|
||||||
|
offsets=torch.tensor([0, 2, 4, 6, 8]))
|
||||||
|
return KJT
|
||||||
|
|
||||||
|
|
||||||
data_gen_fn = lambda: dict(features=torch.rand((BATCH, SHAPE)))
|
data_gen_fn = lambda: dict(features=torch.rand((BATCH, SHAPE)))
|
||||||
|
|
||||||
interaction_arch_data_gen_fn = lambda: dict(dense_features=torch.rand((BATCH, SHAPE)), sparse_features=KT)
|
|
||||||
|
|
||||||
simple_dfm_data_gen_fn = lambda: dict(dense_features=torch.rand((BATCH, SHAPE)), sparse_features=KJT)
|
def interaction_arch_data_gen_fn():
|
||||||
|
KT = gen_kt()
|
||||||
|
return dict(dense_features=torch.rand((BATCH, SHAPE)), sparse_features=KT)
|
||||||
|
|
||||||
sparse_arch_data_gen_fn = lambda: dict(features=KJT)
|
|
||||||
|
def simple_dfm_data_gen_fn():
|
||||||
|
KJT = gen_kjt()
|
||||||
|
return dict(dense_features=torch.rand((BATCH, SHAPE)), sparse_features=KJT)
|
||||||
|
|
||||||
|
|
||||||
|
def sparse_arch_data_gen_fn():
|
||||||
|
KJT = gen_kjt()
|
||||||
|
return dict(features=KJT)
|
||||||
|
|
||||||
|
|
||||||
|
def output_transform_fn(x):
|
||||||
|
if isinstance(x, KeyedTensor):
|
||||||
|
output = dict()
|
||||||
|
for key in x.keys():
|
||||||
|
output[key] = x[key]
|
||||||
|
return output
|
||||||
|
else:
|
||||||
|
return dict(output=x)
|
||||||
|
|
||||||
|
|
||||||
def output_transform_fn(x):
|
def output_transform_fn(x):
|
||||||
|
@ -42,7 +68,27 @@ def get_ebc():
|
||||||
# EmbeddingBagCollection
|
# EmbeddingBagCollection
|
||||||
eb1_config = EmbeddingBagConfig(name="t1", embedding_dim=SHAPE, num_embeddings=SHAPE, feature_names=["f1"])
|
eb1_config = EmbeddingBagConfig(name="t1", embedding_dim=SHAPE, num_embeddings=SHAPE, feature_names=["f1"])
|
||||||
eb2_config = EmbeddingBagConfig(name="t2", embedding_dim=SHAPE, num_embeddings=SHAPE, feature_names=["f2"])
|
eb2_config = EmbeddingBagConfig(name="t2", embedding_dim=SHAPE, num_embeddings=SHAPE, feature_names=["f2"])
|
||||||
return EmbeddingBagCollection(tables=[eb1_config, eb2_config])
|
return EmbeddingBagCollection(tables=[eb1_config, eb2_config], device=torch.device('cpu'))
|
||||||
|
|
||||||
|
|
||||||
|
def sparse_arch_model_fn():
|
||||||
|
ebc = get_ebc()
|
||||||
|
return deepfm.SparseArch(ebc)
|
||||||
|
|
||||||
|
|
||||||
|
def simple_deep_fmnn_model_fn():
|
||||||
|
ebc = get_ebc()
|
||||||
|
return deepfm.SimpleDeepFMNN(SHAPE, ebc, SHAPE, SHAPE)
|
||||||
|
|
||||||
|
|
||||||
|
def dlrm_model_fn():
|
||||||
|
ebc = get_ebc()
|
||||||
|
return dlrm.DLRM(ebc, SHAPE, [SHAPE, SHAPE], [5, 1])
|
||||||
|
|
||||||
|
|
||||||
|
def dlrm_sparsearch_model_fn():
|
||||||
|
ebc = get_ebc()
|
||||||
|
return dlrm.SparseArch(ebc)
|
||||||
|
|
||||||
|
|
||||||
model_zoo.register(name='deepfm_densearch',
|
model_zoo.register(name='deepfm_densearch',
|
||||||
|
@ -61,17 +107,17 @@ model_zoo.register(name='deepfm_overarch',
|
||||||
output_transform_fn=output_transform_fn)
|
output_transform_fn=output_transform_fn)
|
||||||
|
|
||||||
model_zoo.register(name='deepfm_simpledeepfmnn',
|
model_zoo.register(name='deepfm_simpledeepfmnn',
|
||||||
model_fn=partial(deepfm.SimpleDeepFMNN, SHAPE, get_ebc(), SHAPE, SHAPE),
|
model_fn=simple_deep_fmnn_model_fn,
|
||||||
data_gen_fn=simple_dfm_data_gen_fn,
|
data_gen_fn=simple_dfm_data_gen_fn,
|
||||||
output_transform_fn=output_transform_fn)
|
output_transform_fn=output_transform_fn)
|
||||||
|
|
||||||
model_zoo.register(name='deepfm_sparsearch',
|
model_zoo.register(name='deepfm_sparsearch',
|
||||||
model_fn=partial(deepfm.SparseArch, get_ebc()),
|
model_fn=sparse_arch_model_fn,
|
||||||
data_gen_fn=sparse_arch_data_gen_fn,
|
data_gen_fn=sparse_arch_data_gen_fn,
|
||||||
output_transform_fn=output_transform_fn)
|
output_transform_fn=output_transform_fn)
|
||||||
|
|
||||||
model_zoo.register(name='dlrm',
|
model_zoo.register(name='dlrm',
|
||||||
model_fn=partial(dlrm.DLRM, get_ebc(), SHAPE, [SHAPE, SHAPE], [5, 1]),
|
model_fn=dlrm_model_fn,
|
||||||
data_gen_fn=simple_dfm_data_gen_fn,
|
data_gen_fn=simple_dfm_data_gen_fn,
|
||||||
output_transform_fn=output_transform_fn)
|
output_transform_fn=output_transform_fn)
|
||||||
|
|
||||||
|
@ -91,6 +137,6 @@ model_zoo.register(name='dlrm_overarch',
|
||||||
output_transform_fn=output_transform_fn)
|
output_transform_fn=output_transform_fn)
|
||||||
|
|
||||||
model_zoo.register(name='dlrm_sparsearch',
|
model_zoo.register(name='dlrm_sparsearch',
|
||||||
model_fn=partial(dlrm.SparseArch, get_ebc()),
|
model_fn=dlrm_sparsearch_model_fn,
|
||||||
data_gen_fn=sparse_arch_data_gen_fn,
|
data_gen_fn=sparse_arch_data_gen_fn,
|
||||||
output_transform_fn=output_transform_fn)
|
output_transform_fn=output_transform_fn)
|
||||||
|
|
|
@ -47,7 +47,6 @@ def trace_and_compare(model_cls, data, output_transform_fn, meta_args=None):
|
||||||
), f'{model.__class__.__name__} has inconsistent outputs, {fx_out} vs {non_fx_out}'
|
), f'{model.__class__.__name__} has inconsistent outputs, {fx_out} vs {non_fx_out}'
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skip('unknown error')
|
|
||||||
def test_torchrec_deepfm_models():
|
def test_torchrec_deepfm_models():
|
||||||
deepfm_models = model_zoo.get_sub_registry('deepfm')
|
deepfm_models = model_zoo.get_sub_registry('deepfm')
|
||||||
torch.backends.cudnn.deterministic = True
|
torch.backends.cudnn.deterministic = True
|
||||||
|
|
|
@ -47,7 +47,6 @@ def trace_and_compare(model_cls, data, output_transform_fn, meta_args=None):
|
||||||
), f'{model.__class__.__name__} has inconsistent outputs, {fx_out} vs {non_fx_out}'
|
), f'{model.__class__.__name__} has inconsistent outputs, {fx_out} vs {non_fx_out}'
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skip('unknown error')
|
|
||||||
def test_torchrec_dlrm_models():
|
def test_torchrec_dlrm_models():
|
||||||
torch.backends.cudnn.deterministic = True
|
torch.backends.cudnn.deterministic = True
|
||||||
dlrm_models = model_zoo.get_sub_registry('dlrm')
|
dlrm_models = model_zoo.get_sub_registry('dlrm')
|
||||||
|
|
Loading…
Reference in New Issue