ColossalAI/tests/test_autochunk/test_autochunk_diffuser/test_autochunk_unet.py

64 lines
1.4 KiB
Python
Raw Normal View History

from functools import partial
from typing import List, Tuple
import pytest
import torch
import torch.multiprocessing as mp
try:
from diffusers import UNet2DModel
MODELS = [UNet2DModel]
HAS_REPO = True
except:
MODELS = []
HAS_REPO = False
from test_autochunk_diffuser_utils import run_test
from colossalai.autochunk.autochunk_codegen import AUTOCHUNK_AVAILABLE
BATCH_SIZE = 1
HEIGHT = 448
WIDTH = 448
IN_CHANNELS = 3
LATENTS_SHAPE = (BATCH_SIZE, IN_CHANNELS, HEIGHT // 7, WIDTH // 7)
def get_data(shape: tuple) -> Tuple[List, List]:
sample = torch.randn(shape)
meta_args = [
("sample", sample),
]
concrete_args = [("timestep", 50)]
return meta_args, concrete_args
@pytest.mark.skipif(
not (AUTOCHUNK_AVAILABLE and HAS_REPO),
reason="torch version is lower than 1.12.0",
)
@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("shape", [LATENTS_SHAPE])
@pytest.mark.parametrize("max_memory", [None, 150, 300])
def test_evoformer_block(model, shape, max_memory):
run_func = partial(
run_test,
max_memory=max_memory,
model=model,
data=get_data(shape),
)
mp.spawn(run_func, nprocs=1)
if __name__ == "__main__":
run_test(
rank=0,
data=get_data(LATENTS_SHAPE),
max_memory=None,
model=UNet2DModel,
print_code=False,
print_mem=True,
print_est_mem=False,
print_progress=False,
)