diff --git a/requirements/requirements-test.txt b/requirements/requirements-test.txt index 221c82ef7..8e4e5268d 100644 --- a/requirements/requirements-test.txt +++ b/requirements/requirements-test.txt @@ -3,3 +3,4 @@ torchvision transformers timm titans +torchrec diff --git a/tests/test_fx/test_tracer/test_torchrec_model/test_deepfm_model.py b/tests/test_fx/test_tracer/test_torchrec_model/test_deepfm_model.py new file mode 100644 index 000000000..eb4761af8 --- /dev/null +++ b/tests/test_fx/test_tracer/test_torchrec_model/test_deepfm_model.py @@ -0,0 +1,91 @@ +from curses import meta +from math import dist +from xml.dom import HierarchyRequestErr +from colossalai.fx.tracer import meta_patch +from colossalai.fx.tracer.tracer import ColoTracer +from colossalai.fx.tracer.meta_patch.patched_function import python_ops +import torch +from torchrec.sparse.jagged_tensor import KeyedTensor, KeyedJaggedTensor +from torchrec.modules.embedding_modules import EmbeddingBagCollection +from torchrec.modules.embedding_configs import EmbeddingBagConfig +from torchrec.models import deepfm, dlrm +import colossalai.fx as fx +import pdb +from torch.fx import GraphModule + +BATCH = 2 +SHAPE = 10 + + +def test_torchrec_deepfm_models(): + MODEL_LIST = [deepfm.DenseArch, deepfm.FMInteractionArch, deepfm.OverArch, deepfm.SimpleDeepFMNN, deepfm.SparseArch] + + # Data Preparation + # EmbeddingBagCollection + eb1_config = EmbeddingBagConfig(name="t1", embedding_dim=SHAPE, num_embeddings=SHAPE, feature_names=["f1"]) + eb2_config = EmbeddingBagConfig(name="t2", embedding_dim=SHAPE, num_embeddings=SHAPE, feature_names=["f2"]) + + ebc = EmbeddingBagCollection(tables=[eb1_config, eb2_config]) + keys = ["f1", "f2"] + + # KeyedTensor + KT = KeyedTensor(keys=keys, length_per_key=[SHAPE, SHAPE], values=torch.rand((BATCH, 2 * SHAPE))) + + # KeyedJaggedTensor + KJT = KeyedJaggedTensor.from_offsets_sync(keys=keys, + values=torch.tensor([1, 2, 3, 4, 5, 6, 7, 8]), + offsets=torch.tensor([0, 2, 4, 6, 8])) + + # Dense Features + features = torch.rand((BATCH, SHAPE)) + + # Tracer + tracer = ColoTracer() + + for model_cls in MODEL_LIST: + # Initializing model + if model_cls == deepfm.DenseArch: + model = model_cls(SHAPE, SHAPE, SHAPE) + elif model_cls == deepfm.FMInteractionArch: + model = model_cls(SHAPE * 3, keys, SHAPE) + elif model_cls == deepfm.OverArch: + model = model_cls(SHAPE) + elif model_cls == deepfm.SimpleDeepFMNN: + model = model_cls(SHAPE, ebc, SHAPE, SHAPE) + elif model_cls == deepfm.SparseArch: + model = model_cls(ebc) + + # Setup GraphModule + graph = tracer.trace(model) + gm = GraphModule(model, graph, model.__class__.__name__) + gm.recompile() + + model.eval() + gm.eval() + + # Aligned Test + with torch.no_grad(): + if model_cls == deepfm.DenseArch or model_cls == deepfm.OverArch: + fx_out = gm(features) + non_fx_out = model(features) + elif model_cls == deepfm.FMInteractionArch: + fx_out = gm(features, KT) + non_fx_out = model(features, KT) + elif model_cls == deepfm.SimpleDeepFMNN: + fx_out = gm(features, KJT) + non_fx_out = model(features, KJT) + elif model_cls == deepfm.SparseArch: + fx_out = gm(KJT) + non_fx_out = model(KJT) + + if torch.is_tensor(fx_out): + assert torch.allclose( + fx_out, non_fx_out), f'{model.__class__.__name__} has inconsistent outputs, {fx_out} vs {non_fx_out}' + else: + assert torch.allclose( + fx_out.values(), + non_fx_out.values()), f'{model.__class__.__name__} has inconsistent outputs, {fx_out} vs {non_fx_out}' + + +if __name__ == "__main__": + test_torchrec_deepfm_models() diff --git a/tests/test_fx/test_tracer/test_torchrec_model/test_dlrm_model.py b/tests/test_fx/test_tracer/test_torchrec_model/test_dlrm_model.py new file mode 100644 index 000000000..fdf880866 --- /dev/null +++ b/tests/test_fx/test_tracer/test_torchrec_model/test_dlrm_model.py @@ -0,0 +1,116 @@ +from curses import meta +from math import dist +from xml.dom import HierarchyRequestErr +from colossalai.fx.tracer import meta_patch +from colossalai.fx.tracer.tracer import ColoTracer +from colossalai.fx.tracer.meta_patch.patched_function import python_ops +import torch +from torchrec.sparse.jagged_tensor import KeyedTensor, KeyedJaggedTensor +from torchrec.modules.embedding_modules import EmbeddingBagCollection +from torchrec.modules.embedding_configs import EmbeddingBagConfig +from torchrec.models import deepfm, dlrm +import colossalai.fx as fx +import pdb +from torch.fx import GraphModule + +BATCH = 2 +SHAPE = 10 + + +def test_torchrec_dlrm_models(): + MODEL_LIST = [ + dlrm.DLRM, + dlrm.DenseArch, + dlrm.InteractionArch, + dlrm.InteractionV2Arch, + dlrm.OverArch, + dlrm.SparseArch, + # dlrm.DLRMV2 + ] + + # Data Preparation + # EmbeddingBagCollection + eb1_config = EmbeddingBagConfig(name="t1", embedding_dim=SHAPE, num_embeddings=SHAPE, feature_names=["f1"]) + eb2_config = EmbeddingBagConfig(name="t2", embedding_dim=SHAPE, num_embeddings=SHAPE, feature_names=["f2"]) + + ebc = EmbeddingBagCollection(tables=[eb1_config, eb2_config]) + keys = ["f1", "f2"] + + # KeyedTensor + KT = KeyedTensor(keys=keys, length_per_key=[SHAPE, SHAPE], values=torch.rand((BATCH, 2 * SHAPE))) + + # KeyedJaggedTensor + KJT = KeyedJaggedTensor.from_offsets_sync(keys=keys, + values=torch.tensor([1, 2, 3, 4, 5, 6, 7, 8]), + offsets=torch.tensor([0, 2, 4, 6, 8])) + + # Dense Features + dense_features = torch.rand((BATCH, SHAPE)) + + # Sparse Features + sparse_features = torch.rand((BATCH, len(keys), SHAPE)) + # Tracer + tracer = ColoTracer() + + for model_cls in MODEL_LIST: + # Initializing model + if model_cls == dlrm.DLRM: + model = model_cls(ebc, SHAPE, [SHAPE, SHAPE], [5, 1]) + elif model_cls == dlrm.DenseArch: + model = model_cls(SHAPE, [SHAPE, SHAPE]) + elif model_cls == dlrm.InteractionArch: + model = model_cls(len(keys)) + elif model_cls == dlrm.InteractionV2Arch: + I1 = dlrm.DenseArch(3 * SHAPE, [3 * SHAPE, 3 * SHAPE]) + I2 = dlrm.DenseArch(3 * SHAPE, [3 * SHAPE, 3 * SHAPE]) + model = model_cls(len(keys), I1, I2) + elif model_cls == dlrm.OverArch: + model = model_cls(SHAPE, [5, 1]) + elif model_cls == dlrm.SparseArch: + model = model_cls(ebc) + elif model_cls == dlrm.DLRMV2: + # Currently DLRMV2 cannot be traced + model = model_cls(ebc, SHAPE, [SHAPE, SHAPE], [5, 1], [4 * SHAPE, 4 * SHAPE], [4 * SHAPE, 4 * SHAPE]) + + # Setup GraphModule + if model_cls == dlrm.InteractionV2Arch: + concrete_args = {"dense_features": dense_features, "sparse_features": sparse_features} + graph = tracer.trace(model, concrete_args=concrete_args) + else: + graph = tracer.trace(model) + + gm = GraphModule(model, graph, model.__class__.__name__) + gm.recompile() + + model.eval() + gm.eval() + + # Aligned Test + with torch.no_grad(): + if model_cls == dlrm.DLRM or model_cls == dlrm.DLRMV2: + fx_out = gm(dense_features, KJT) + non_fx_out = model(dense_features, KJT) + elif model_cls == dlrm.DenseArch: + fx_out = gm(dense_features) + non_fx_out = model(dense_features) + elif model_cls == dlrm.InteractionArch or model_cls == dlrm.InteractionV2Arch: + fx_out = gm(dense_features, sparse_features) + non_fx_out = model(dense_features, sparse_features) + elif model_cls == dlrm.OverArch: + fx_out = gm(dense_features) + non_fx_out = model(dense_features) + elif model_cls == dlrm.SparseArch: + fx_out = gm(KJT) + non_fx_out = model(KJT) + + if torch.is_tensor(fx_out): + assert torch.allclose( + fx_out, non_fx_out), f'{model.__class__.__name__} has inconsistent outputs, {fx_out} vs {non_fx_out}' + else: + assert torch.allclose( + fx_out.values(), + non_fx_out.values()), f'{model.__class__.__name__} has inconsistent outputs, {fx_out} vs {non_fx_out}' + + +if __name__ == "__main__": + test_torchrec_dlrm_models()