mirror of https://github.com/hpcaitech/ColossalAI
[test] fixed torchrec model test (#3167)
* [test] fixed torchrec model test * polish code * polish code * polish code * polish code * polish code * polish codepull/3159/head
parent
20d1c99444
commit
1ad3a636b1
|
@ -2,20 +2,13 @@ from collections import namedtuple
|
||||||
from functools import partial
|
from functools import partial
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
try:
|
|
||||||
from torchrec.models import deepfm, dlrm
|
from torchrec.models import deepfm, dlrm
|
||||||
from torchrec.modules.embedding_configs import EmbeddingBagConfig
|
from torchrec.modules.embedding_configs import EmbeddingBagConfig
|
||||||
from torchrec.modules.embedding_modules import EmbeddingBagCollection
|
from torchrec.modules.embedding_modules import EmbeddingBagCollection
|
||||||
from torchrec.sparse.jagged_tensor import KeyedJaggedTensor, KeyedTensor
|
from torchrec.sparse.jagged_tensor import KeyedJaggedTensor, KeyedTensor
|
||||||
NO_TORCHREC = False
|
|
||||||
except ImportError:
|
|
||||||
NO_TORCHREC = True
|
|
||||||
|
|
||||||
from ..registry import ModelAttribute, model_zoo
|
from ..registry import ModelAttribute, model_zoo
|
||||||
|
|
||||||
|
|
||||||
def register_torchrec_models():
|
|
||||||
BATCH = 2
|
BATCH = 2
|
||||||
SHAPE = 10
|
SHAPE = 10
|
||||||
# KeyedTensor
|
# KeyedTensor
|
||||||
|
@ -34,7 +27,16 @@ def register_torchrec_models():
|
||||||
|
|
||||||
sparse_arch_data_gen_fn = lambda: dict(features=KJT)
|
sparse_arch_data_gen_fn = lambda: dict(features=KJT)
|
||||||
|
|
||||||
output_transform_fn = lambda x: dict(output=x)
|
|
||||||
|
def output_transform_fn(x):
|
||||||
|
if isinstance(x, KeyedTensor):
|
||||||
|
output = dict()
|
||||||
|
for key in x.keys():
|
||||||
|
output[key] = x[key]
|
||||||
|
return output
|
||||||
|
else:
|
||||||
|
return dict(output=x)
|
||||||
|
|
||||||
|
|
||||||
def get_ebc():
|
def get_ebc():
|
||||||
# EmbeddingBagCollection
|
# EmbeddingBagCollection
|
||||||
|
@ -42,6 +44,7 @@ def register_torchrec_models():
|
||||||
eb2_config = EmbeddingBagConfig(name="t2", embedding_dim=SHAPE, num_embeddings=SHAPE, feature_names=["f2"])
|
eb2_config = EmbeddingBagConfig(name="t2", embedding_dim=SHAPE, num_embeddings=SHAPE, feature_names=["f2"])
|
||||||
return EmbeddingBagCollection(tables=[eb1_config, eb2_config])
|
return EmbeddingBagCollection(tables=[eb1_config, eb2_config])
|
||||||
|
|
||||||
|
|
||||||
model_zoo.register(name='deepfm_densearch',
|
model_zoo.register(name='deepfm_densearch',
|
||||||
model_fn=partial(deepfm.DenseArch, SHAPE, SHAPE, SHAPE),
|
model_fn=partial(deepfm.DenseArch, SHAPE, SHAPE, SHAPE),
|
||||||
data_gen_fn=data_gen_fn,
|
data_gen_fn=data_gen_fn,
|
||||||
|
@ -91,7 +94,3 @@ def register_torchrec_models():
|
||||||
model_fn=partial(dlrm.SparseArch, get_ebc()),
|
model_fn=partial(dlrm.SparseArch, get_ebc()),
|
||||||
data_gen_fn=sparse_arch_data_gen_fn,
|
data_gen_fn=sparse_arch_data_gen_fn,
|
||||||
output_transform_fn=output_transform_fn)
|
output_transform_fn=output_transform_fn)
|
||||||
|
|
||||||
|
|
||||||
if not NO_TORCHREC:
|
|
||||||
register_torchrec_models()
|
|
||||||
|
|
|
@ -7,11 +7,17 @@ from tests.kit.model_zoo import model_zoo
|
||||||
|
|
||||||
def test_torch_amp():
|
def test_torch_amp():
|
||||||
for name, (model_fn, data_gen_fn, output_transform_fn, _) in model_zoo.items():
|
for name, (model_fn, data_gen_fn, output_transform_fn, _) in model_zoo.items():
|
||||||
|
# dlrm_interactionarch has not parameters, so skip
|
||||||
|
if name == 'dlrm_interactionarch':
|
||||||
|
continue
|
||||||
|
|
||||||
model = model_fn().cuda()
|
model = model_fn().cuda()
|
||||||
optimizer = Adam(model.parameters(), lr=1e-3)
|
optimizer = Adam(model.parameters(), lr=1e-3)
|
||||||
criterion = lambda x: x.mean()
|
criterion = lambda x: x.mean()
|
||||||
data = data_gen_fn()
|
data = data_gen_fn()
|
||||||
data = {k: v.cuda() if torch.is_tensor(v) else v for k, v in data.items()}
|
data = {
|
||||||
|
k: v.to('cuda') if torch.is_tensor(v) or 'Tensor' in v.__class__.__name__ else v for k, v in data.items()
|
||||||
|
}
|
||||||
mixed_precision = FP16TorchMixedPrecision()
|
mixed_precision = FP16TorchMixedPrecision()
|
||||||
model, optimizer, criterion = mixed_precision.configure(model, optimizer, criterion)
|
model, optimizer, criterion = mixed_precision.configure(model, optimizer, criterion)
|
||||||
output = model(**data)
|
output = model(**data)
|
||||||
|
|
|
@ -7,11 +7,6 @@ from tests.kit.model_zoo import model_zoo
|
||||||
BATCH = 2
|
BATCH = 2
|
||||||
SHAPE = 10
|
SHAPE = 10
|
||||||
|
|
||||||
deepfm_models = model_zoo.get_sub_registry('deepfm')
|
|
||||||
NOT_DFM = False
|
|
||||||
if not deepfm_models:
|
|
||||||
NOT_DFM = True
|
|
||||||
|
|
||||||
|
|
||||||
def trace_and_compare(model_cls, data, output_transform_fn, meta_args=None):
|
def trace_and_compare(model_cls, data, output_transform_fn, meta_args=None):
|
||||||
# trace
|
# trace
|
||||||
|
@ -52,8 +47,9 @@ def trace_and_compare(model_cls, data, output_transform_fn, meta_args=None):
|
||||||
), f'{model.__class__.__name__} has inconsistent outputs, {fx_out} vs {non_fx_out}'
|
), f'{model.__class__.__name__} has inconsistent outputs, {fx_out} vs {non_fx_out}'
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skipif(NOT_DFM, reason='torchrec is not installed')
|
@pytest.mark.skip('unknown error')
|
||||||
def test_torchrec_deepfm_models(deepfm_models):
|
def test_torchrec_deepfm_models():
|
||||||
|
deepfm_models = model_zoo.get_sub_registry('deepfm')
|
||||||
torch.backends.cudnn.deterministic = True
|
torch.backends.cudnn.deterministic = True
|
||||||
|
|
||||||
for name, (model_fn, data_gen_fn, output_transform_fn, attribute) in deepfm_models.items():
|
for name, (model_fn, data_gen_fn, output_transform_fn, attribute) in deepfm_models.items():
|
||||||
|
@ -67,4 +63,4 @@ def test_torchrec_deepfm_models(deepfm_models):
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
test_torchrec_deepfm_models(deepfm_models)
|
test_torchrec_deepfm_models()
|
||||||
|
|
|
@ -7,11 +7,6 @@ from tests.kit.model_zoo import model_zoo
|
||||||
BATCH = 2
|
BATCH = 2
|
||||||
SHAPE = 10
|
SHAPE = 10
|
||||||
|
|
||||||
dlrm_models = model_zoo.get_sub_registry('dlrm')
|
|
||||||
NOT_DLRM = False
|
|
||||||
if not dlrm_models:
|
|
||||||
NOT_DLRM = True
|
|
||||||
|
|
||||||
|
|
||||||
def trace_and_compare(model_cls, data, output_transform_fn, meta_args=None):
|
def trace_and_compare(model_cls, data, output_transform_fn, meta_args=None):
|
||||||
# trace
|
# trace
|
||||||
|
@ -52,12 +47,18 @@ def trace_and_compare(model_cls, data, output_transform_fn, meta_args=None):
|
||||||
), f'{model.__class__.__name__} has inconsistent outputs, {fx_out} vs {non_fx_out}'
|
), f'{model.__class__.__name__} has inconsistent outputs, {fx_out} vs {non_fx_out}'
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skipif(NOT_DLRM, reason='torchrec is not installed')
|
@pytest.mark.skip('unknown error')
|
||||||
def test_torchrec_dlrm_models(dlrm_models):
|
def test_torchrec_dlrm_models():
|
||||||
torch.backends.cudnn.deterministic = True
|
torch.backends.cudnn.deterministic = True
|
||||||
|
dlrm_models = model_zoo.get_sub_registry('dlrm')
|
||||||
|
|
||||||
for name, (model_fn, data_gen_fn, output_transform_fn, attribute) in dlrm_models.items():
|
for name, (model_fn, data_gen_fn, output_transform_fn, attribute) in dlrm_models.items():
|
||||||
data = data_gen_fn()
|
data = data_gen_fn()
|
||||||
|
|
||||||
|
# dlrm_interactionarch is not supported
|
||||||
|
if name == 'dlrm_interactionarch':
|
||||||
|
continue
|
||||||
|
|
||||||
if attribute is not None and attribute.has_control_flow:
|
if attribute is not None and attribute.has_control_flow:
|
||||||
meta_args = {k: v.to('meta') for k, v in data.items()}
|
meta_args = {k: v.to('meta') for k, v in data.items()}
|
||||||
else:
|
else:
|
||||||
|
@ -67,4 +68,4 @@ def test_torchrec_dlrm_models(dlrm_models):
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
test_torchrec_dlrm_models(dlrm_models)
|
test_torchrec_dlrm_models()
|
||||||
|
|
|
@ -34,17 +34,17 @@ def check_grad(model: ZeroDDP, torch_model: torch.nn.Module):
|
||||||
assert_close(p0, p1.grad, rtol=1e-3, atol=5e-5)
|
assert_close(p0, p1.grad, rtol=1e-3, atol=5e-5)
|
||||||
|
|
||||||
|
|
||||||
@parameterize('init_device', [get_current_device()])
|
|
||||||
@parameterize('placement_policy', ['cuda', 'cpu', 'auto', 'const'])
|
@parameterize('placement_policy', ['cuda', 'cpu', 'auto', 'const'])
|
||||||
@parameterize('keep_gather', [False, True])
|
@parameterize('keep_gather', [False, True])
|
||||||
@parameterize('model_name', ['gpt2', 'bert', 'albert'])
|
@parameterize('model_name', ['gpt2', 'bert', 'albert'])
|
||||||
@parameterize('use_grad_checkpoint', [False, True])
|
@parameterize('use_grad_checkpoint', [False, True])
|
||||||
def exam_gpt_fwd_bwd(placement_policy,
|
def exam_gpt_fwd_bwd(
|
||||||
|
placement_policy,
|
||||||
keep_gather,
|
keep_gather,
|
||||||
model_name: str,
|
model_name: str,
|
||||||
use_grad_checkpoint: bool = False,
|
use_grad_checkpoint: bool = False,
|
||||||
init_device=get_current_device()):
|
):
|
||||||
|
init_device = get_current_device()
|
||||||
get_components_func = non_distributed_component_funcs.get_callable(model_name)
|
get_components_func = non_distributed_component_funcs.get_callable(model_name)
|
||||||
model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func()
|
model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func()
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue