[fx] Add colotracer compatibility test on torchrec (#1370)

pull/1372/head
Boyuan Yao 2022-07-26 17:54:39 +08:00 committed by GitHub
parent c415240db6
commit bb640ec728
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 208 additions and 0 deletions

View File

@ -3,3 +3,4 @@ torchvision
transformers
timm
titans
torchrec

View File

@ -0,0 +1,91 @@
from curses import meta
from math import dist
from xml.dom import HierarchyRequestErr
from colossalai.fx.tracer import meta_patch
from colossalai.fx.tracer.tracer import ColoTracer
from colossalai.fx.tracer.meta_patch.patched_function import python_ops
import torch
from torchrec.sparse.jagged_tensor import KeyedTensor, KeyedJaggedTensor
from torchrec.modules.embedding_modules import EmbeddingBagCollection
from torchrec.modules.embedding_configs import EmbeddingBagConfig
from torchrec.models import deepfm, dlrm
import colossalai.fx as fx
import pdb
from torch.fx import GraphModule
BATCH = 2
SHAPE = 10
def test_torchrec_deepfm_models():
MODEL_LIST = [deepfm.DenseArch, deepfm.FMInteractionArch, deepfm.OverArch, deepfm.SimpleDeepFMNN, deepfm.SparseArch]
# Data Preparation
# EmbeddingBagCollection
eb1_config = EmbeddingBagConfig(name="t1", embedding_dim=SHAPE, num_embeddings=SHAPE, feature_names=["f1"])
eb2_config = EmbeddingBagConfig(name="t2", embedding_dim=SHAPE, num_embeddings=SHAPE, feature_names=["f2"])
ebc = EmbeddingBagCollection(tables=[eb1_config, eb2_config])
keys = ["f1", "f2"]
# KeyedTensor
KT = KeyedTensor(keys=keys, length_per_key=[SHAPE, SHAPE], values=torch.rand((BATCH, 2 * SHAPE)))
# KeyedJaggedTensor
KJT = KeyedJaggedTensor.from_offsets_sync(keys=keys,
values=torch.tensor([1, 2, 3, 4, 5, 6, 7, 8]),
offsets=torch.tensor([0, 2, 4, 6, 8]))
# Dense Features
features = torch.rand((BATCH, SHAPE))
# Tracer
tracer = ColoTracer()
for model_cls in MODEL_LIST:
# Initializing model
if model_cls == deepfm.DenseArch:
model = model_cls(SHAPE, SHAPE, SHAPE)
elif model_cls == deepfm.FMInteractionArch:
model = model_cls(SHAPE * 3, keys, SHAPE)
elif model_cls == deepfm.OverArch:
model = model_cls(SHAPE)
elif model_cls == deepfm.SimpleDeepFMNN:
model = model_cls(SHAPE, ebc, SHAPE, SHAPE)
elif model_cls == deepfm.SparseArch:
model = model_cls(ebc)
# Setup GraphModule
graph = tracer.trace(model)
gm = GraphModule(model, graph, model.__class__.__name__)
gm.recompile()
model.eval()
gm.eval()
# Aligned Test
with torch.no_grad():
if model_cls == deepfm.DenseArch or model_cls == deepfm.OverArch:
fx_out = gm(features)
non_fx_out = model(features)
elif model_cls == deepfm.FMInteractionArch:
fx_out = gm(features, KT)
non_fx_out = model(features, KT)
elif model_cls == deepfm.SimpleDeepFMNN:
fx_out = gm(features, KJT)
non_fx_out = model(features, KJT)
elif model_cls == deepfm.SparseArch:
fx_out = gm(KJT)
non_fx_out = model(KJT)
if torch.is_tensor(fx_out):
assert torch.allclose(
fx_out, non_fx_out), f'{model.__class__.__name__} has inconsistent outputs, {fx_out} vs {non_fx_out}'
else:
assert torch.allclose(
fx_out.values(),
non_fx_out.values()), f'{model.__class__.__name__} has inconsistent outputs, {fx_out} vs {non_fx_out}'
if __name__ == "__main__":
test_torchrec_deepfm_models()

View File

@ -0,0 +1,116 @@
from curses import meta
from math import dist
from xml.dom import HierarchyRequestErr
from colossalai.fx.tracer import meta_patch
from colossalai.fx.tracer.tracer import ColoTracer
from colossalai.fx.tracer.meta_patch.patched_function import python_ops
import torch
from torchrec.sparse.jagged_tensor import KeyedTensor, KeyedJaggedTensor
from torchrec.modules.embedding_modules import EmbeddingBagCollection
from torchrec.modules.embedding_configs import EmbeddingBagConfig
from torchrec.models import deepfm, dlrm
import colossalai.fx as fx
import pdb
from torch.fx import GraphModule
BATCH = 2
SHAPE = 10
def test_torchrec_dlrm_models():
MODEL_LIST = [
dlrm.DLRM,
dlrm.DenseArch,
dlrm.InteractionArch,
dlrm.InteractionV2Arch,
dlrm.OverArch,
dlrm.SparseArch,
# dlrm.DLRMV2
]
# Data Preparation
# EmbeddingBagCollection
eb1_config = EmbeddingBagConfig(name="t1", embedding_dim=SHAPE, num_embeddings=SHAPE, feature_names=["f1"])
eb2_config = EmbeddingBagConfig(name="t2", embedding_dim=SHAPE, num_embeddings=SHAPE, feature_names=["f2"])
ebc = EmbeddingBagCollection(tables=[eb1_config, eb2_config])
keys = ["f1", "f2"]
# KeyedTensor
KT = KeyedTensor(keys=keys, length_per_key=[SHAPE, SHAPE], values=torch.rand((BATCH, 2 * SHAPE)))
# KeyedJaggedTensor
KJT = KeyedJaggedTensor.from_offsets_sync(keys=keys,
values=torch.tensor([1, 2, 3, 4, 5, 6, 7, 8]),
offsets=torch.tensor([0, 2, 4, 6, 8]))
# Dense Features
dense_features = torch.rand((BATCH, SHAPE))
# Sparse Features
sparse_features = torch.rand((BATCH, len(keys), SHAPE))
# Tracer
tracer = ColoTracer()
for model_cls in MODEL_LIST:
# Initializing model
if model_cls == dlrm.DLRM:
model = model_cls(ebc, SHAPE, [SHAPE, SHAPE], [5, 1])
elif model_cls == dlrm.DenseArch:
model = model_cls(SHAPE, [SHAPE, SHAPE])
elif model_cls == dlrm.InteractionArch:
model = model_cls(len(keys))
elif model_cls == dlrm.InteractionV2Arch:
I1 = dlrm.DenseArch(3 * SHAPE, [3 * SHAPE, 3 * SHAPE])
I2 = dlrm.DenseArch(3 * SHAPE, [3 * SHAPE, 3 * SHAPE])
model = model_cls(len(keys), I1, I2)
elif model_cls == dlrm.OverArch:
model = model_cls(SHAPE, [5, 1])
elif model_cls == dlrm.SparseArch:
model = model_cls(ebc)
elif model_cls == dlrm.DLRMV2:
# Currently DLRMV2 cannot be traced
model = model_cls(ebc, SHAPE, [SHAPE, SHAPE], [5, 1], [4 * SHAPE, 4 * SHAPE], [4 * SHAPE, 4 * SHAPE])
# Setup GraphModule
if model_cls == dlrm.InteractionV2Arch:
concrete_args = {"dense_features": dense_features, "sparse_features": sparse_features}
graph = tracer.trace(model, concrete_args=concrete_args)
else:
graph = tracer.trace(model)
gm = GraphModule(model, graph, model.__class__.__name__)
gm.recompile()
model.eval()
gm.eval()
# Aligned Test
with torch.no_grad():
if model_cls == dlrm.DLRM or model_cls == dlrm.DLRMV2:
fx_out = gm(dense_features, KJT)
non_fx_out = model(dense_features, KJT)
elif model_cls == dlrm.DenseArch:
fx_out = gm(dense_features)
non_fx_out = model(dense_features)
elif model_cls == dlrm.InteractionArch or model_cls == dlrm.InteractionV2Arch:
fx_out = gm(dense_features, sparse_features)
non_fx_out = model(dense_features, sparse_features)
elif model_cls == dlrm.OverArch:
fx_out = gm(dense_features)
non_fx_out = model(dense_features)
elif model_cls == dlrm.SparseArch:
fx_out = gm(KJT)
non_fx_out = model(KJT)
if torch.is_tensor(fx_out):
assert torch.allclose(
fx_out, non_fx_out), f'{model.__class__.__name__} has inconsistent outputs, {fx_out} vs {non_fx_out}'
else:
assert torch.allclose(
fx_out.values(),
non_fx_out.values()), f'{model.__class__.__name__} has inconsistent outputs, {fx_out} vs {non_fx_out}'
if __name__ == "__main__":
test_torchrec_dlrm_models()