diff --git a/requirements/requirements-test.txt b/requirements/requirements-test.txt index ae3ff2fe9..7fd805c14 100644 --- a/requirements/requirements-test.txt +++ b/requirements/requirements-test.txt @@ -1,3 +1,4 @@ +diffusers pytest torchvision transformers diff --git a/tests/test_fx/test_tracer/test_hf_model/test_hf_albert.py b/tests/test_fx/test_tracer/test_hf_model/test_hf_albert.py index cf809e13a..5837340fa 100644 --- a/tests/test_fx/test_tracer/test_hf_model/test_hf_albert.py +++ b/tests/test_fx/test_tracer/test_hf_model/test_hf_albert.py @@ -1,10 +1,10 @@ -import transformers -import torch import pytest +import torch +import transformers from utils import trace_model_and_compare_output BATCH_SIZE = 2 -SEQ_LENGHT = 16 +SEQ_LENGTH = 16 def test_single_sentence_albert(): @@ -23,9 +23,9 @@ def test_single_sentence_albert(): intermediate_size=256) def data_gen(): - input_ids = torch.zeros((BATCH_SIZE, SEQ_LENGHT), dtype=torch.int64) - token_type_ids = torch.zeros((BATCH_SIZE, SEQ_LENGHT), dtype=torch.int64) - attention_mask = torch.zeros((BATCH_SIZE, SEQ_LENGHT), dtype=torch.int64) + input_ids = torch.zeros((BATCH_SIZE, SEQ_LENGTH), dtype=torch.int64) + token_type_ids = torch.zeros((BATCH_SIZE, SEQ_LENGTH), dtype=torch.int64) + attention_mask = torch.zeros((BATCH_SIZE, SEQ_LENGTH), dtype=torch.int64) meta_args = dict(input_ids=input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask) return meta_args diff --git a/tests/test_fx/test_tracer/test_hf_model/test_hf_bert.py b/tests/test_fx/test_tracer/test_hf_model/test_hf_bert.py index 63ad4badc..1a66b1151 100644 --- a/tests/test_fx/test_tracer/test_hf_model/test_hf_bert.py +++ b/tests/test_fx/test_tracer/test_hf_model/test_hf_bert.py @@ -1,10 +1,10 @@ -import transformers -import torch import pytest +import torch +import transformers from utils import trace_model_and_compare_output BATCH_SIZE = 2 -SEQ_LENGHT = 16 +SEQ_LENGTH = 16 def test_single_sentence_bert(): @@ -20,9 +20,9 @@ def test_single_sentence_bert(): config = transformers.BertConfig(hidden_size=128, num_hidden_layers=2, num_attention_heads=4, intermediate_size=256) def data_gen(): - input_ids = torch.zeros((BATCH_SIZE, SEQ_LENGHT), dtype=torch.int64) - token_type_ids = torch.zeros((BATCH_SIZE, SEQ_LENGHT), dtype=torch.int64) - attention_mask = torch.zeros((BATCH_SIZE, SEQ_LENGHT), dtype=torch.int64) + input_ids = torch.zeros((BATCH_SIZE, SEQ_LENGTH), dtype=torch.int64) + token_type_ids = torch.zeros((BATCH_SIZE, SEQ_LENGTH), dtype=torch.int64) + attention_mask = torch.zeros((BATCH_SIZE, SEQ_LENGTH), dtype=torch.int64) meta_args = dict(input_ids=input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask) return meta_args diff --git a/tests/test_fx/test_tracer/test_hf_model/test_hf_diffuser.py b/tests/test_fx/test_tracer/test_hf_model/test_hf_diffuser.py new file mode 100644 index 000000000..ab6e08694 --- /dev/null +++ b/tests/test_fx/test_tracer/test_hf_model/test_hf_diffuser.py @@ -0,0 +1,116 @@ +import diffusers +import pytest +import torch +import transformers +from torch.fx import GraphModule +from utils import trace_model_and_compare_output + +from colossalai.fx import ColoTracer + +BATCH_SIZE = 2 +SEQ_LENGTH = 5 +HEIGHT = 224 +WIDTH = 224 +IN_CHANNELS = 3 +LATENTS_SHAPE = (BATCH_SIZE, IN_CHANNELS, HEIGHT // 8, WIDTH // 8) +TIME_STEP = 2 + + +def test_vae(): + MODEL_LIST = [ + diffusers.AutoencoderKL, + diffusers.VQModel, + ] + + for model_cls in MODEL_LIST: + model = model_cls() + sample = torch.zeros(LATENTS_SHAPE) + + tracer = ColoTracer() + graph = tracer.trace(root=model) + + gm = GraphModule(model, graph, model.__class__.__name__) + gm.recompile() + + model.eval() + gm.eval() + + with torch.no_grad(): + fx_out = gm(sample) + non_fx_out = model(sample) + assert torch.allclose( + fx_out['sample'], + non_fx_out['sample']), f'{model.__class__.__name__} has inconsistent outputs, {fx_out} vs {non_fx_out}' + + +def test_clip(): + MODEL_LIST = [ + transformers.CLIPModel, + transformers.CLIPTextModel, + transformers.CLIPVisionModel, + ] + + CONFIG_LIST = [ + transformers.CLIPConfig, + transformers.CLIPTextConfig, + transformers.CLIPVisionConfig, + ] + + def data_gen(): + if isinstance(model, transformers.CLIPModel): + input_ids = torch.zeros((BATCH_SIZE, SEQ_LENGTH), dtype=torch.int64) + attention_mask = torch.zeros((BATCH_SIZE, SEQ_LENGTH), dtype=torch.int64) + position_ids = torch.zeros((BATCH_SIZE, SEQ_LENGTH), dtype=torch.int64) + pixel_values = torch.zeros((BATCH_SIZE, IN_CHANNELS, HEIGHT, WIDTH), dtype=torch.float32) + kwargs = dict(input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + pixel_values=pixel_values) + elif isinstance(model, transformers.CLIPTextModel): + input_ids = torch.zeros((BATCH_SIZE, SEQ_LENGTH), dtype=torch.int64) + attention_mask = torch.zeros((BATCH_SIZE, SEQ_LENGTH), dtype=torch.int64) + kwargs = dict(input_ids=input_ids, attention_mask=attention_mask) + elif isinstance(model, transformers.CLIPVisionModel): + pixel_values = torch.zeros((BATCH_SIZE, IN_CHANNELS, HEIGHT, WIDTH), dtype=torch.float32) + kwargs = dict(pixel_values=pixel_values) + return kwargs + + for model_cls, config in zip(MODEL_LIST, CONFIG_LIST): + model = model_cls(config=config()) + trace_model_and_compare_output(model, data_gen) + + +@pytest.mark.skip(reason='cannot pass the test yet') +def test_unet(): + MODEL_LIST = [ + diffusers.UNet2DModel, + diffusers.UNet2DConditionModel, + ] + + for model_cls in MODEL_LIST: + model = model_cls() + sample = torch.zeros(LATENTS_SHAPE) + + tracer = ColoTracer() + graph = tracer.trace(root=model) + + gm = GraphModule(model, graph, model.__class__.__name__) + gm.recompile() + + model.eval() + gm.eval() + + with torch.no_grad(): + fx_out = gm(sample, TIME_STEP) + non_fx_out = model(sample, TIME_STEP) + assert torch.allclose( + fx_out['sample'], + non_fx_out['sample']), f'{model.__class__.__name__} has inconsistent outputs, {fx_out} vs {non_fx_out}' + + +if __name__ == "__main__": + test_vae() + test_clip() + + # skip because of failure + # test_unet() diff --git a/tests/test_fx/test_tracer/test_hf_model/test_hf_gpt.py b/tests/test_fx/test_tracer/test_hf_model/test_hf_gpt.py index 1c20e9bfd..ae2e752f9 100644 --- a/tests/test_fx/test_tracer/test_hf_model/test_hf_gpt.py +++ b/tests/test_fx/test_tracer/test_hf_model/test_hf_gpt.py @@ -1,10 +1,10 @@ -import transformers -import torch import pytest +import torch +import transformers from utils import trace_model_and_compare_output BATCH_SIZE = 1 -SEQ_LENGHT = 16 +SEQ_LENGTH = 16 def test_gpt(): @@ -19,9 +19,9 @@ def test_gpt(): config = transformers.GPT2Config(n_position=64, n_layer=2, n_head=4) def data_gen(): - input_ids = torch.zeros((BATCH_SIZE, SEQ_LENGHT), dtype=torch.int64) - token_type_ids = torch.zeros((BATCH_SIZE, SEQ_LENGHT), dtype=torch.int64) - attention_mask = torch.zeros((BATCH_SIZE, SEQ_LENGHT), dtype=torch.int64) + input_ids = torch.zeros((BATCH_SIZE, SEQ_LENGTH), dtype=torch.int64) + token_type_ids = torch.zeros((BATCH_SIZE, SEQ_LENGTH), dtype=torch.int64) + attention_mask = torch.zeros((BATCH_SIZE, SEQ_LENGTH), dtype=torch.int64) kwargs = dict(input_ids=input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask) return kwargs diff --git a/tests/test_fx/test_tracer/test_hf_model/test_hf_opt.py b/tests/test_fx/test_tracer/test_hf_model/test_hf_opt.py index 5ac051887..c39e97a16 100644 --- a/tests/test_fx/test_tracer/test_hf_model/test_hf_opt.py +++ b/tests/test_fx/test_tracer/test_hf_model/test_hf_opt.py @@ -1,10 +1,10 @@ import pytest -import transformers import torch +import transformers from utils import trace_model_and_compare_output BATCH_SIZE = 1 -SEQ_LENGHT = 16 +SEQ_LENGTH = 16 def test_opt(): @@ -16,8 +16,8 @@ def test_opt(): config = transformers.OPTConfig(hidden_size=128, num_hidden_layers=2, num_attention_heads=4) def data_gen(): - input_ids = torch.zeros((BATCH_SIZE, SEQ_LENGHT), dtype=torch.int64) - attention_mask = torch.zeros((BATCH_SIZE, SEQ_LENGHT), dtype=torch.int64) + input_ids = torch.zeros((BATCH_SIZE, SEQ_LENGTH), dtype=torch.int64) + attention_mask = torch.zeros((BATCH_SIZE, SEQ_LENGTH), dtype=torch.int64) kwargs = dict(input_ids=input_ids, attention_mask=attention_mask) return kwargs diff --git a/tests/test_fx/test_tracer/test_hf_model/test_hf_t5.py b/tests/test_fx/test_tracer/test_hf_model/test_hf_t5.py index 645951de9..b6749c828 100644 --- a/tests/test_fx/test_tracer/test_hf_model/test_hf_t5.py +++ b/tests/test_fx/test_tracer/test_hf_model/test_hf_t5.py @@ -1,10 +1,10 @@ import pytest -import transformers import torch +import transformers from utils import trace_model_and_compare_output BATCH_SIZE = 1 -SEQ_LENGHT = 16 +SEQ_LENGTH = 16 def test_t5(): @@ -17,13 +17,13 @@ def test_t5(): config = transformers.T5Config(d_model=128, num_layers=2) def data_gen(): - input_ids = torch.zeros((BATCH_SIZE, SEQ_LENGHT), dtype=torch.int64) - decoder_input_ids = torch.zeros((BATCH_SIZE, SEQ_LENGHT), dtype=torch.int64) + input_ids = torch.zeros((BATCH_SIZE, SEQ_LENGTH), dtype=torch.int64) + decoder_input_ids = torch.zeros((BATCH_SIZE, SEQ_LENGTH), dtype=torch.int64) kwargs = dict(input_ids=input_ids, decoder_input_ids=decoder_input_ids) return kwargs def data_gen_for_encoder_only(): - input_ids = torch.zeros((BATCH_SIZE, SEQ_LENGHT), dtype=torch.int64) + input_ids = torch.zeros((BATCH_SIZE, SEQ_LENGTH), dtype=torch.int64) kwargs = dict(input_ids=input_ids) return kwargs diff --git a/tests/test_fx/test_tracer/test_hf_model/utils.py b/tests/test_fx/test_tracer/test_hf_model/utils.py index 038548209..fb0702455 100644 --- a/tests/test_fx/test_tracer/test_hf_model/utils.py +++ b/tests/test_fx/test_tracer/test_hf_model/utils.py @@ -1,9 +1,10 @@ -from numpy import isin import torch -from colossalai.fx import ColoTracer +from numpy import isin from torch.fx import GraphModule from torch.utils._pytree import tree_flatten +from colossalai.fx import ColoTracer + def trace_model_and_compare_output(model, data_gen): # must turn on eval mode to ensure the output is consistent