[fx] test tracer on diffuser modules. (#1750)

* [fx] test tracer on diffuser modules.

* [fx] shorter seq_len.

* Update requirements-test.txt
pull/1753/head
Super Daniel 2022-10-20 18:25:05 +08:00 committed by GitHub
parent b80b6eaa88
commit b893342f95
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 147 additions and 29 deletions

View File

@ -1,3 +1,4 @@
diffusers
pytest
torchvision
transformers

View File

@ -1,10 +1,10 @@
import transformers
import torch
import pytest
import torch
import transformers
from utils import trace_model_and_compare_output
BATCH_SIZE = 2
SEQ_LENGHT = 16
SEQ_LENGTH = 16
def test_single_sentence_albert():
@ -23,9 +23,9 @@ def test_single_sentence_albert():
intermediate_size=256)
def data_gen():
input_ids = torch.zeros((BATCH_SIZE, SEQ_LENGHT), dtype=torch.int64)
token_type_ids = torch.zeros((BATCH_SIZE, SEQ_LENGHT), dtype=torch.int64)
attention_mask = torch.zeros((BATCH_SIZE, SEQ_LENGHT), dtype=torch.int64)
input_ids = torch.zeros((BATCH_SIZE, SEQ_LENGTH), dtype=torch.int64)
token_type_ids = torch.zeros((BATCH_SIZE, SEQ_LENGTH), dtype=torch.int64)
attention_mask = torch.zeros((BATCH_SIZE, SEQ_LENGTH), dtype=torch.int64)
meta_args = dict(input_ids=input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask)
return meta_args

View File

@ -1,10 +1,10 @@
import transformers
import torch
import pytest
import torch
import transformers
from utils import trace_model_and_compare_output
BATCH_SIZE = 2
SEQ_LENGHT = 16
SEQ_LENGTH = 16
def test_single_sentence_bert():
@ -20,9 +20,9 @@ def test_single_sentence_bert():
config = transformers.BertConfig(hidden_size=128, num_hidden_layers=2, num_attention_heads=4, intermediate_size=256)
def data_gen():
input_ids = torch.zeros((BATCH_SIZE, SEQ_LENGHT), dtype=torch.int64)
token_type_ids = torch.zeros((BATCH_SIZE, SEQ_LENGHT), dtype=torch.int64)
attention_mask = torch.zeros((BATCH_SIZE, SEQ_LENGHT), dtype=torch.int64)
input_ids = torch.zeros((BATCH_SIZE, SEQ_LENGTH), dtype=torch.int64)
token_type_ids = torch.zeros((BATCH_SIZE, SEQ_LENGTH), dtype=torch.int64)
attention_mask = torch.zeros((BATCH_SIZE, SEQ_LENGTH), dtype=torch.int64)
meta_args = dict(input_ids=input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask)
return meta_args

View File

@ -0,0 +1,116 @@
import diffusers
import pytest
import torch
import transformers
from torch.fx import GraphModule
from utils import trace_model_and_compare_output
from colossalai.fx import ColoTracer
BATCH_SIZE = 2
SEQ_LENGTH = 5
HEIGHT = 224
WIDTH = 224
IN_CHANNELS = 3
LATENTS_SHAPE = (BATCH_SIZE, IN_CHANNELS, HEIGHT // 8, WIDTH // 8)
TIME_STEP = 2
def test_vae():
MODEL_LIST = [
diffusers.AutoencoderKL,
diffusers.VQModel,
]
for model_cls in MODEL_LIST:
model = model_cls()
sample = torch.zeros(LATENTS_SHAPE)
tracer = ColoTracer()
graph = tracer.trace(root=model)
gm = GraphModule(model, graph, model.__class__.__name__)
gm.recompile()
model.eval()
gm.eval()
with torch.no_grad():
fx_out = gm(sample)
non_fx_out = model(sample)
assert torch.allclose(
fx_out['sample'],
non_fx_out['sample']), f'{model.__class__.__name__} has inconsistent outputs, {fx_out} vs {non_fx_out}'
def test_clip():
MODEL_LIST = [
transformers.CLIPModel,
transformers.CLIPTextModel,
transformers.CLIPVisionModel,
]
CONFIG_LIST = [
transformers.CLIPConfig,
transformers.CLIPTextConfig,
transformers.CLIPVisionConfig,
]
def data_gen():
if isinstance(model, transformers.CLIPModel):
input_ids = torch.zeros((BATCH_SIZE, SEQ_LENGTH), dtype=torch.int64)
attention_mask = torch.zeros((BATCH_SIZE, SEQ_LENGTH), dtype=torch.int64)
position_ids = torch.zeros((BATCH_SIZE, SEQ_LENGTH), dtype=torch.int64)
pixel_values = torch.zeros((BATCH_SIZE, IN_CHANNELS, HEIGHT, WIDTH), dtype=torch.float32)
kwargs = dict(input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
pixel_values=pixel_values)
elif isinstance(model, transformers.CLIPTextModel):
input_ids = torch.zeros((BATCH_SIZE, SEQ_LENGTH), dtype=torch.int64)
attention_mask = torch.zeros((BATCH_SIZE, SEQ_LENGTH), dtype=torch.int64)
kwargs = dict(input_ids=input_ids, attention_mask=attention_mask)
elif isinstance(model, transformers.CLIPVisionModel):
pixel_values = torch.zeros((BATCH_SIZE, IN_CHANNELS, HEIGHT, WIDTH), dtype=torch.float32)
kwargs = dict(pixel_values=pixel_values)
return kwargs
for model_cls, config in zip(MODEL_LIST, CONFIG_LIST):
model = model_cls(config=config())
trace_model_and_compare_output(model, data_gen)
@pytest.mark.skip(reason='cannot pass the test yet')
def test_unet():
MODEL_LIST = [
diffusers.UNet2DModel,
diffusers.UNet2DConditionModel,
]
for model_cls in MODEL_LIST:
model = model_cls()
sample = torch.zeros(LATENTS_SHAPE)
tracer = ColoTracer()
graph = tracer.trace(root=model)
gm = GraphModule(model, graph, model.__class__.__name__)
gm.recompile()
model.eval()
gm.eval()
with torch.no_grad():
fx_out = gm(sample, TIME_STEP)
non_fx_out = model(sample, TIME_STEP)
assert torch.allclose(
fx_out['sample'],
non_fx_out['sample']), f'{model.__class__.__name__} has inconsistent outputs, {fx_out} vs {non_fx_out}'
if __name__ == "__main__":
test_vae()
test_clip()
# skip because of failure
# test_unet()

View File

@ -1,10 +1,10 @@
import transformers
import torch
import pytest
import torch
import transformers
from utils import trace_model_and_compare_output
BATCH_SIZE = 1
SEQ_LENGHT = 16
SEQ_LENGTH = 16
def test_gpt():
@ -19,9 +19,9 @@ def test_gpt():
config = transformers.GPT2Config(n_position=64, n_layer=2, n_head=4)
def data_gen():
input_ids = torch.zeros((BATCH_SIZE, SEQ_LENGHT), dtype=torch.int64)
token_type_ids = torch.zeros((BATCH_SIZE, SEQ_LENGHT), dtype=torch.int64)
attention_mask = torch.zeros((BATCH_SIZE, SEQ_LENGHT), dtype=torch.int64)
input_ids = torch.zeros((BATCH_SIZE, SEQ_LENGTH), dtype=torch.int64)
token_type_ids = torch.zeros((BATCH_SIZE, SEQ_LENGTH), dtype=torch.int64)
attention_mask = torch.zeros((BATCH_SIZE, SEQ_LENGTH), dtype=torch.int64)
kwargs = dict(input_ids=input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask)
return kwargs

View File

@ -1,10 +1,10 @@
import pytest
import transformers
import torch
import transformers
from utils import trace_model_and_compare_output
BATCH_SIZE = 1
SEQ_LENGHT = 16
SEQ_LENGTH = 16
def test_opt():
@ -16,8 +16,8 @@ def test_opt():
config = transformers.OPTConfig(hidden_size=128, num_hidden_layers=2, num_attention_heads=4)
def data_gen():
input_ids = torch.zeros((BATCH_SIZE, SEQ_LENGHT), dtype=torch.int64)
attention_mask = torch.zeros((BATCH_SIZE, SEQ_LENGHT), dtype=torch.int64)
input_ids = torch.zeros((BATCH_SIZE, SEQ_LENGTH), dtype=torch.int64)
attention_mask = torch.zeros((BATCH_SIZE, SEQ_LENGTH), dtype=torch.int64)
kwargs = dict(input_ids=input_ids, attention_mask=attention_mask)
return kwargs

View File

@ -1,10 +1,10 @@
import pytest
import transformers
import torch
import transformers
from utils import trace_model_and_compare_output
BATCH_SIZE = 1
SEQ_LENGHT = 16
SEQ_LENGTH = 16
def test_t5():
@ -17,13 +17,13 @@ def test_t5():
config = transformers.T5Config(d_model=128, num_layers=2)
def data_gen():
input_ids = torch.zeros((BATCH_SIZE, SEQ_LENGHT), dtype=torch.int64)
decoder_input_ids = torch.zeros((BATCH_SIZE, SEQ_LENGHT), dtype=torch.int64)
input_ids = torch.zeros((BATCH_SIZE, SEQ_LENGTH), dtype=torch.int64)
decoder_input_ids = torch.zeros((BATCH_SIZE, SEQ_LENGTH), dtype=torch.int64)
kwargs = dict(input_ids=input_ids, decoder_input_ids=decoder_input_ids)
return kwargs
def data_gen_for_encoder_only():
input_ids = torch.zeros((BATCH_SIZE, SEQ_LENGHT), dtype=torch.int64)
input_ids = torch.zeros((BATCH_SIZE, SEQ_LENGTH), dtype=torch.int64)
kwargs = dict(input_ids=input_ids)
return kwargs

View File

@ -1,9 +1,10 @@
from numpy import isin
import torch
from colossalai.fx import ColoTracer
from numpy import isin
from torch.fx import GraphModule
from torch.utils._pytree import tree_flatten
from colossalai.fx import ColoTracer
def trace_model_and_compare_output(model, data_gen):
# must turn on eval mode to ensure the output is consistent