mirror of https://github.com/hpcaitech/ColossalAI
[fx] test tracer on diffuser modules. (#1750)
* [fx] test tracer on diffuser modules. * [fx] shorter seq_len. * Update requirements-test.txtpull/1753/head
parent
b80b6eaa88
commit
b893342f95
|
@ -1,3 +1,4 @@
|
||||||
|
diffusers
|
||||||
pytest
|
pytest
|
||||||
torchvision
|
torchvision
|
||||||
transformers
|
transformers
|
||||||
|
|
|
@ -1,10 +1,10 @@
|
||||||
import transformers
|
|
||||||
import torch
|
|
||||||
import pytest
|
import pytest
|
||||||
|
import torch
|
||||||
|
import transformers
|
||||||
from utils import trace_model_and_compare_output
|
from utils import trace_model_and_compare_output
|
||||||
|
|
||||||
BATCH_SIZE = 2
|
BATCH_SIZE = 2
|
||||||
SEQ_LENGHT = 16
|
SEQ_LENGTH = 16
|
||||||
|
|
||||||
|
|
||||||
def test_single_sentence_albert():
|
def test_single_sentence_albert():
|
||||||
|
@ -23,9 +23,9 @@ def test_single_sentence_albert():
|
||||||
intermediate_size=256)
|
intermediate_size=256)
|
||||||
|
|
||||||
def data_gen():
|
def data_gen():
|
||||||
input_ids = torch.zeros((BATCH_SIZE, SEQ_LENGHT), dtype=torch.int64)
|
input_ids = torch.zeros((BATCH_SIZE, SEQ_LENGTH), dtype=torch.int64)
|
||||||
token_type_ids = torch.zeros((BATCH_SIZE, SEQ_LENGHT), dtype=torch.int64)
|
token_type_ids = torch.zeros((BATCH_SIZE, SEQ_LENGTH), dtype=torch.int64)
|
||||||
attention_mask = torch.zeros((BATCH_SIZE, SEQ_LENGHT), dtype=torch.int64)
|
attention_mask = torch.zeros((BATCH_SIZE, SEQ_LENGTH), dtype=torch.int64)
|
||||||
meta_args = dict(input_ids=input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask)
|
meta_args = dict(input_ids=input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask)
|
||||||
return meta_args
|
return meta_args
|
||||||
|
|
||||||
|
|
|
@ -1,10 +1,10 @@
|
||||||
import transformers
|
|
||||||
import torch
|
|
||||||
import pytest
|
import pytest
|
||||||
|
import torch
|
||||||
|
import transformers
|
||||||
from utils import trace_model_and_compare_output
|
from utils import trace_model_and_compare_output
|
||||||
|
|
||||||
BATCH_SIZE = 2
|
BATCH_SIZE = 2
|
||||||
SEQ_LENGHT = 16
|
SEQ_LENGTH = 16
|
||||||
|
|
||||||
|
|
||||||
def test_single_sentence_bert():
|
def test_single_sentence_bert():
|
||||||
|
@ -20,9 +20,9 @@ def test_single_sentence_bert():
|
||||||
config = transformers.BertConfig(hidden_size=128, num_hidden_layers=2, num_attention_heads=4, intermediate_size=256)
|
config = transformers.BertConfig(hidden_size=128, num_hidden_layers=2, num_attention_heads=4, intermediate_size=256)
|
||||||
|
|
||||||
def data_gen():
|
def data_gen():
|
||||||
input_ids = torch.zeros((BATCH_SIZE, SEQ_LENGHT), dtype=torch.int64)
|
input_ids = torch.zeros((BATCH_SIZE, SEQ_LENGTH), dtype=torch.int64)
|
||||||
token_type_ids = torch.zeros((BATCH_SIZE, SEQ_LENGHT), dtype=torch.int64)
|
token_type_ids = torch.zeros((BATCH_SIZE, SEQ_LENGTH), dtype=torch.int64)
|
||||||
attention_mask = torch.zeros((BATCH_SIZE, SEQ_LENGHT), dtype=torch.int64)
|
attention_mask = torch.zeros((BATCH_SIZE, SEQ_LENGTH), dtype=torch.int64)
|
||||||
meta_args = dict(input_ids=input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask)
|
meta_args = dict(input_ids=input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask)
|
||||||
return meta_args
|
return meta_args
|
||||||
|
|
||||||
|
|
|
@ -0,0 +1,116 @@
|
||||||
|
import diffusers
|
||||||
|
import pytest
|
||||||
|
import torch
|
||||||
|
import transformers
|
||||||
|
from torch.fx import GraphModule
|
||||||
|
from utils import trace_model_and_compare_output
|
||||||
|
|
||||||
|
from colossalai.fx import ColoTracer
|
||||||
|
|
||||||
|
BATCH_SIZE = 2
|
||||||
|
SEQ_LENGTH = 5
|
||||||
|
HEIGHT = 224
|
||||||
|
WIDTH = 224
|
||||||
|
IN_CHANNELS = 3
|
||||||
|
LATENTS_SHAPE = (BATCH_SIZE, IN_CHANNELS, HEIGHT // 8, WIDTH // 8)
|
||||||
|
TIME_STEP = 2
|
||||||
|
|
||||||
|
|
||||||
|
def test_vae():
|
||||||
|
MODEL_LIST = [
|
||||||
|
diffusers.AutoencoderKL,
|
||||||
|
diffusers.VQModel,
|
||||||
|
]
|
||||||
|
|
||||||
|
for model_cls in MODEL_LIST:
|
||||||
|
model = model_cls()
|
||||||
|
sample = torch.zeros(LATENTS_SHAPE)
|
||||||
|
|
||||||
|
tracer = ColoTracer()
|
||||||
|
graph = tracer.trace(root=model)
|
||||||
|
|
||||||
|
gm = GraphModule(model, graph, model.__class__.__name__)
|
||||||
|
gm.recompile()
|
||||||
|
|
||||||
|
model.eval()
|
||||||
|
gm.eval()
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
fx_out = gm(sample)
|
||||||
|
non_fx_out = model(sample)
|
||||||
|
assert torch.allclose(
|
||||||
|
fx_out['sample'],
|
||||||
|
non_fx_out['sample']), f'{model.__class__.__name__} has inconsistent outputs, {fx_out} vs {non_fx_out}'
|
||||||
|
|
||||||
|
|
||||||
|
def test_clip():
|
||||||
|
MODEL_LIST = [
|
||||||
|
transformers.CLIPModel,
|
||||||
|
transformers.CLIPTextModel,
|
||||||
|
transformers.CLIPVisionModel,
|
||||||
|
]
|
||||||
|
|
||||||
|
CONFIG_LIST = [
|
||||||
|
transformers.CLIPConfig,
|
||||||
|
transformers.CLIPTextConfig,
|
||||||
|
transformers.CLIPVisionConfig,
|
||||||
|
]
|
||||||
|
|
||||||
|
def data_gen():
|
||||||
|
if isinstance(model, transformers.CLIPModel):
|
||||||
|
input_ids = torch.zeros((BATCH_SIZE, SEQ_LENGTH), dtype=torch.int64)
|
||||||
|
attention_mask = torch.zeros((BATCH_SIZE, SEQ_LENGTH), dtype=torch.int64)
|
||||||
|
position_ids = torch.zeros((BATCH_SIZE, SEQ_LENGTH), dtype=torch.int64)
|
||||||
|
pixel_values = torch.zeros((BATCH_SIZE, IN_CHANNELS, HEIGHT, WIDTH), dtype=torch.float32)
|
||||||
|
kwargs = dict(input_ids=input_ids,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
position_ids=position_ids,
|
||||||
|
pixel_values=pixel_values)
|
||||||
|
elif isinstance(model, transformers.CLIPTextModel):
|
||||||
|
input_ids = torch.zeros((BATCH_SIZE, SEQ_LENGTH), dtype=torch.int64)
|
||||||
|
attention_mask = torch.zeros((BATCH_SIZE, SEQ_LENGTH), dtype=torch.int64)
|
||||||
|
kwargs = dict(input_ids=input_ids, attention_mask=attention_mask)
|
||||||
|
elif isinstance(model, transformers.CLIPVisionModel):
|
||||||
|
pixel_values = torch.zeros((BATCH_SIZE, IN_CHANNELS, HEIGHT, WIDTH), dtype=torch.float32)
|
||||||
|
kwargs = dict(pixel_values=pixel_values)
|
||||||
|
return kwargs
|
||||||
|
|
||||||
|
for model_cls, config in zip(MODEL_LIST, CONFIG_LIST):
|
||||||
|
model = model_cls(config=config())
|
||||||
|
trace_model_and_compare_output(model, data_gen)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skip(reason='cannot pass the test yet')
|
||||||
|
def test_unet():
|
||||||
|
MODEL_LIST = [
|
||||||
|
diffusers.UNet2DModel,
|
||||||
|
diffusers.UNet2DConditionModel,
|
||||||
|
]
|
||||||
|
|
||||||
|
for model_cls in MODEL_LIST:
|
||||||
|
model = model_cls()
|
||||||
|
sample = torch.zeros(LATENTS_SHAPE)
|
||||||
|
|
||||||
|
tracer = ColoTracer()
|
||||||
|
graph = tracer.trace(root=model)
|
||||||
|
|
||||||
|
gm = GraphModule(model, graph, model.__class__.__name__)
|
||||||
|
gm.recompile()
|
||||||
|
|
||||||
|
model.eval()
|
||||||
|
gm.eval()
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
fx_out = gm(sample, TIME_STEP)
|
||||||
|
non_fx_out = model(sample, TIME_STEP)
|
||||||
|
assert torch.allclose(
|
||||||
|
fx_out['sample'],
|
||||||
|
non_fx_out['sample']), f'{model.__class__.__name__} has inconsistent outputs, {fx_out} vs {non_fx_out}'
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
test_vae()
|
||||||
|
test_clip()
|
||||||
|
|
||||||
|
# skip because of failure
|
||||||
|
# test_unet()
|
|
@ -1,10 +1,10 @@
|
||||||
import transformers
|
|
||||||
import torch
|
|
||||||
import pytest
|
import pytest
|
||||||
|
import torch
|
||||||
|
import transformers
|
||||||
from utils import trace_model_and_compare_output
|
from utils import trace_model_and_compare_output
|
||||||
|
|
||||||
BATCH_SIZE = 1
|
BATCH_SIZE = 1
|
||||||
SEQ_LENGHT = 16
|
SEQ_LENGTH = 16
|
||||||
|
|
||||||
|
|
||||||
def test_gpt():
|
def test_gpt():
|
||||||
|
@ -19,9 +19,9 @@ def test_gpt():
|
||||||
config = transformers.GPT2Config(n_position=64, n_layer=2, n_head=4)
|
config = transformers.GPT2Config(n_position=64, n_layer=2, n_head=4)
|
||||||
|
|
||||||
def data_gen():
|
def data_gen():
|
||||||
input_ids = torch.zeros((BATCH_SIZE, SEQ_LENGHT), dtype=torch.int64)
|
input_ids = torch.zeros((BATCH_SIZE, SEQ_LENGTH), dtype=torch.int64)
|
||||||
token_type_ids = torch.zeros((BATCH_SIZE, SEQ_LENGHT), dtype=torch.int64)
|
token_type_ids = torch.zeros((BATCH_SIZE, SEQ_LENGTH), dtype=torch.int64)
|
||||||
attention_mask = torch.zeros((BATCH_SIZE, SEQ_LENGHT), dtype=torch.int64)
|
attention_mask = torch.zeros((BATCH_SIZE, SEQ_LENGTH), dtype=torch.int64)
|
||||||
kwargs = dict(input_ids=input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask)
|
kwargs = dict(input_ids=input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask)
|
||||||
return kwargs
|
return kwargs
|
||||||
|
|
||||||
|
|
|
@ -1,10 +1,10 @@
|
||||||
import pytest
|
import pytest
|
||||||
import transformers
|
|
||||||
import torch
|
import torch
|
||||||
|
import transformers
|
||||||
from utils import trace_model_and_compare_output
|
from utils import trace_model_and_compare_output
|
||||||
|
|
||||||
BATCH_SIZE = 1
|
BATCH_SIZE = 1
|
||||||
SEQ_LENGHT = 16
|
SEQ_LENGTH = 16
|
||||||
|
|
||||||
|
|
||||||
def test_opt():
|
def test_opt():
|
||||||
|
@ -16,8 +16,8 @@ def test_opt():
|
||||||
config = transformers.OPTConfig(hidden_size=128, num_hidden_layers=2, num_attention_heads=4)
|
config = transformers.OPTConfig(hidden_size=128, num_hidden_layers=2, num_attention_heads=4)
|
||||||
|
|
||||||
def data_gen():
|
def data_gen():
|
||||||
input_ids = torch.zeros((BATCH_SIZE, SEQ_LENGHT), dtype=torch.int64)
|
input_ids = torch.zeros((BATCH_SIZE, SEQ_LENGTH), dtype=torch.int64)
|
||||||
attention_mask = torch.zeros((BATCH_SIZE, SEQ_LENGHT), dtype=torch.int64)
|
attention_mask = torch.zeros((BATCH_SIZE, SEQ_LENGTH), dtype=torch.int64)
|
||||||
kwargs = dict(input_ids=input_ids, attention_mask=attention_mask)
|
kwargs = dict(input_ids=input_ids, attention_mask=attention_mask)
|
||||||
return kwargs
|
return kwargs
|
||||||
|
|
||||||
|
|
|
@ -1,10 +1,10 @@
|
||||||
import pytest
|
import pytest
|
||||||
import transformers
|
|
||||||
import torch
|
import torch
|
||||||
|
import transformers
|
||||||
from utils import trace_model_and_compare_output
|
from utils import trace_model_and_compare_output
|
||||||
|
|
||||||
BATCH_SIZE = 1
|
BATCH_SIZE = 1
|
||||||
SEQ_LENGHT = 16
|
SEQ_LENGTH = 16
|
||||||
|
|
||||||
|
|
||||||
def test_t5():
|
def test_t5():
|
||||||
|
@ -17,13 +17,13 @@ def test_t5():
|
||||||
config = transformers.T5Config(d_model=128, num_layers=2)
|
config = transformers.T5Config(d_model=128, num_layers=2)
|
||||||
|
|
||||||
def data_gen():
|
def data_gen():
|
||||||
input_ids = torch.zeros((BATCH_SIZE, SEQ_LENGHT), dtype=torch.int64)
|
input_ids = torch.zeros((BATCH_SIZE, SEQ_LENGTH), dtype=torch.int64)
|
||||||
decoder_input_ids = torch.zeros((BATCH_SIZE, SEQ_LENGHT), dtype=torch.int64)
|
decoder_input_ids = torch.zeros((BATCH_SIZE, SEQ_LENGTH), dtype=torch.int64)
|
||||||
kwargs = dict(input_ids=input_ids, decoder_input_ids=decoder_input_ids)
|
kwargs = dict(input_ids=input_ids, decoder_input_ids=decoder_input_ids)
|
||||||
return kwargs
|
return kwargs
|
||||||
|
|
||||||
def data_gen_for_encoder_only():
|
def data_gen_for_encoder_only():
|
||||||
input_ids = torch.zeros((BATCH_SIZE, SEQ_LENGHT), dtype=torch.int64)
|
input_ids = torch.zeros((BATCH_SIZE, SEQ_LENGTH), dtype=torch.int64)
|
||||||
kwargs = dict(input_ids=input_ids)
|
kwargs = dict(input_ids=input_ids)
|
||||||
return kwargs
|
return kwargs
|
||||||
|
|
||||||
|
|
|
@ -1,9 +1,10 @@
|
||||||
from numpy import isin
|
|
||||||
import torch
|
import torch
|
||||||
from colossalai.fx import ColoTracer
|
from numpy import isin
|
||||||
from torch.fx import GraphModule
|
from torch.fx import GraphModule
|
||||||
from torch.utils._pytree import tree_flatten
|
from torch.utils._pytree import tree_flatten
|
||||||
|
|
||||||
|
from colossalai.fx import ColoTracer
|
||||||
|
|
||||||
|
|
||||||
def trace_model_and_compare_output(model, data_gen):
|
def trace_model_and_compare_output(model, data_gen):
|
||||||
# must turn on eval mode to ensure the output is consistent
|
# must turn on eval mode to ensure the output is consistent
|
||||||
|
|
Loading…
Reference in New Issue