diff --git a/requirements/requirements-test.txt b/requirements/requirements-test.txt index 6eba3984d..f9e8960d2 100644 --- a/requirements/requirements-test.txt +++ b/requirements/requirements-test.txt @@ -1,4 +1,3 @@ -diffusers fbgemm-gpu==0.2.0 pytest torchvision diff --git a/tests/test_fx/test_tracer/test_hf_model/test_hf_diffuser.py b/tests/test_fx/test_tracer/test_hf_model/test_hf_diffuser.py index ab6e08694..e02885e38 100644 --- a/tests/test_fx/test_tracer/test_hf_model/test_hf_diffuser.py +++ b/tests/test_fx/test_tracer/test_hf_model/test_hf_diffuser.py @@ -1,12 +1,17 @@ -import diffusers import pytest import torch -import transformers from torch.fx import GraphModule from utils import trace_model_and_compare_output +import transformers from colossalai.fx import ColoTracer +try: + import diffusers + HAS_DIFFUSERS = True +except ImportError: + HAS_DIFFUSERS = False + BATCH_SIZE = 2 SEQ_LENGTH = 5 HEIGHT = 224 @@ -16,6 +21,7 @@ LATENTS_SHAPE = (BATCH_SIZE, IN_CHANNELS, HEIGHT // 8, WIDTH // 8) TIME_STEP = 2 +@pytest.mark.skipif(not HAS_DIFFUSERS, reason="diffusers has not been installed") def test_vae(): MODEL_LIST = [ diffusers.AutoencoderKL, @@ -80,6 +86,7 @@ def test_clip(): trace_model_and_compare_output(model, data_gen) +@pytest.mark.skipif(not HAS_DIFFUSERS, reason="diffusers has not been installed") @pytest.mark.skip(reason='cannot pass the test yet') def test_unet(): MODEL_LIST = [