ColossalAI/tests/test_autochunk/test_autochunk_vit/test_autochunk_vit.py

55 lines
1.2 KiB
Python
Raw Normal View History

from typing import List, Tuple
import pytest
import torch
try:
from timm.models.vision_transformer import vit_large_patch16_384 as vit
MODELS = [vit]
HAS_REPO = True
except:
MODELS = []
HAS_REPO = False
from test_autochunk_vit_utils import run_test
from colossalai.autochunk.autochunk_codegen import AUTOCHUNK_AVAILABLE
from colossalai.testing import clear_cache_before_run, parameterize, spawn
def get_data() -> Tuple[List, List]:
data = torch.rand(1, 3, 384, 384)
meta_args = {"x": data}
return data, meta_args
@pytest.mark.skipif(
not (AUTOCHUNK_AVAILABLE and HAS_REPO),
reason="torch version is lower than 1.12.0",
)
@clear_cache_before_run()
@parameterize("model", MODELS)
@parameterize("max_memory", [None, 32, 40])
def test_evoformer_block(model, max_memory):
spawn(
run_test,
1,
max_memory=max_memory,
model=model,
data=get_data(),
)
if __name__ == "__main__":
run_test(
rank=0,
data=get_data(),
max_memory=None,
model=vit,
print_code=False,
print_mem=False,
print_est_mem=False,
print_progress=False,
)