2023-01-31 08:00:06 +00:00
|
|
|
from typing import List, Tuple
|
|
|
|
|
|
|
|
import pytest
|
|
|
|
import torch
|
|
|
|
|
|
|
|
try:
|
2023-06-15 09:38:42 +00:00
|
|
|
import diffusers
|
|
|
|
MODELS = [diffusers.UNet2DModel]
|
2023-01-31 08:00:06 +00:00
|
|
|
HAS_REPO = True
|
2023-06-15 09:38:42 +00:00
|
|
|
from packaging import version
|
|
|
|
SKIP_UNET_TEST = version.parse(diffusers.__version__) > version.parse("0.10.2")
|
2023-01-31 08:00:06 +00:00
|
|
|
except:
|
|
|
|
MODELS = []
|
|
|
|
HAS_REPO = False
|
2023-06-15 09:38:42 +00:00
|
|
|
SKIP_UNET_TEST = False
|
2023-01-31 08:00:06 +00:00
|
|
|
|
2023-02-02 07:06:43 +00:00
|
|
|
from test_autochunk_diffuser_utils import run_test
|
2023-01-31 08:00:06 +00:00
|
|
|
|
|
|
|
from colossalai.autochunk.autochunk_codegen import AUTOCHUNK_AVAILABLE
|
2023-04-06 06:51:35 +00:00
|
|
|
from colossalai.testing import clear_cache_before_run, parameterize, spawn
|
2023-01-31 08:00:06 +00:00
|
|
|
|
2023-02-07 08:32:45 +00:00
|
|
|
BATCH_SIZE = 1
|
|
|
|
HEIGHT = 448
|
|
|
|
WIDTH = 448
|
2023-01-31 08:00:06 +00:00
|
|
|
IN_CHANNELS = 3
|
|
|
|
LATENTS_SHAPE = (BATCH_SIZE, IN_CHANNELS, HEIGHT // 7, WIDTH // 7)
|
|
|
|
|
|
|
|
|
|
|
|
def get_data(shape: tuple) -> Tuple[List, List]:
|
|
|
|
sample = torch.randn(shape)
|
|
|
|
meta_args = [
|
|
|
|
("sample", sample),
|
|
|
|
]
|
|
|
|
concrete_args = [("timestep", 50)]
|
|
|
|
return meta_args, concrete_args
|
|
|
|
|
|
|
|
|
2023-06-15 09:38:42 +00:00
|
|
|
@pytest.mark.skipif(
|
|
|
|
SKIP_UNET_TEST,
|
|
|
|
reason="diffusers version > 0.10.2",
|
|
|
|
)
|
2023-01-31 08:00:06 +00:00
|
|
|
@pytest.mark.skipif(
|
|
|
|
not (AUTOCHUNK_AVAILABLE and HAS_REPO),
|
|
|
|
reason="torch version is lower than 1.12.0",
|
|
|
|
)
|
2023-04-06 06:51:35 +00:00
|
|
|
@clear_cache_before_run()
|
|
|
|
@parameterize("model", MODELS)
|
|
|
|
@parameterize("shape", [LATENTS_SHAPE])
|
|
|
|
@parameterize("max_memory", [None, 150, 300])
|
2023-01-31 08:00:06 +00:00
|
|
|
def test_evoformer_block(model, shape, max_memory):
|
2023-04-06 06:51:35 +00:00
|
|
|
spawn(
|
2023-01-31 08:00:06 +00:00
|
|
|
run_test,
|
2023-04-06 06:51:35 +00:00
|
|
|
1,
|
2023-01-31 08:00:06 +00:00
|
|
|
max_memory=max_memory,
|
|
|
|
model=model,
|
|
|
|
data=get_data(shape),
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
|
run_test(
|
|
|
|
rank=0,
|
|
|
|
data=get_data(LATENTS_SHAPE),
|
2023-02-07 08:32:45 +00:00
|
|
|
max_memory=None,
|
2023-01-31 08:00:06 +00:00
|
|
|
model=UNet2DModel,
|
|
|
|
print_code=False,
|
2023-03-10 02:23:26 +00:00
|
|
|
print_mem=True,
|
2023-02-07 08:32:45 +00:00
|
|
|
print_est_mem=False,
|
2023-01-31 08:00:06 +00:00
|
|
|
print_progress=False,
|
|
|
|
)
|