from typing import List, Tuple import pytest import torch try: from timm.models.vision_transformer import vit_large_patch16_384 as vit MODELS = [vit] HAS_REPO = True except: MODELS = [] HAS_REPO = False from test_autochunk_vit_utils import run_test from colossalai.autochunk.autochunk_codegen import AUTOCHUNK_AVAILABLE from colossalai.testing import clear_cache_before_run, parameterize, spawn def get_data() -> Tuple[List, List]: data = torch.rand(1, 3, 384, 384) meta_args = {"x": data} return data, meta_args @pytest.mark.skipif( not (AUTOCHUNK_AVAILABLE and HAS_REPO), reason="torch version is lower than 1.12.0", ) @clear_cache_before_run() @parameterize("model", MODELS) @parameterize("max_memory", [None, 32, 40]) def test_evoformer_block(model, max_memory): spawn( run_test, 1, max_memory=max_memory, model=model, data=get_data(), ) if __name__ == "__main__": run_test( rank=0, data=get_data(), max_memory=None, model=vit, print_code=False, print_mem=False, print_est_mem=False, print_progress=False, )