Browse Source

[fx] skip diffusers unitest if it is not installed (#1799)

pull/1807/head
Jiarui Fang 2 years ago committed by GitHub
parent
commit
6fa71d65d3
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
  1. 1
      requirements/requirements-test.txt
  2. 11
      tests/test_fx/test_tracer/test_hf_model/test_hf_diffuser.py

1
requirements/requirements-test.txt

@ -1,4 +1,3 @@
diffusers
fbgemm-gpu==0.2.0 fbgemm-gpu==0.2.0
pytest pytest
torchvision torchvision

11
tests/test_fx/test_tracer/test_hf_model/test_hf_diffuser.py

@ -1,12 +1,17 @@
import diffusers
import pytest import pytest
import torch import torch
import transformers
from torch.fx import GraphModule from torch.fx import GraphModule
from utils import trace_model_and_compare_output from utils import trace_model_and_compare_output
import transformers
from colossalai.fx import ColoTracer from colossalai.fx import ColoTracer
try:
import diffusers
HAS_DIFFUSERS = True
except ImportError:
HAS_DIFFUSERS = False
BATCH_SIZE = 2 BATCH_SIZE = 2
SEQ_LENGTH = 5 SEQ_LENGTH = 5
HEIGHT = 224 HEIGHT = 224
@ -16,6 +21,7 @@ LATENTS_SHAPE = (BATCH_SIZE, IN_CHANNELS, HEIGHT // 8, WIDTH // 8)
TIME_STEP = 2 TIME_STEP = 2
@pytest.mark.skipif(not HAS_DIFFUSERS, reason="diffusers has not been installed")
def test_vae(): def test_vae():
MODEL_LIST = [ MODEL_LIST = [
diffusers.AutoencoderKL, diffusers.AutoencoderKL,
@ -80,6 +86,7 @@ def test_clip():
trace_model_and_compare_output(model, data_gen) trace_model_and_compare_output(model, data_gen)
@pytest.mark.skipif(not HAS_DIFFUSERS, reason="diffusers has not been installed")
@pytest.mark.skip(reason='cannot pass the test yet') @pytest.mark.skip(reason='cannot pass the test yet')
def test_unet(): def test_unet():
MODEL_LIST = [ MODEL_LIST = [

Loading…
Cancel
Save