mirror of https://github.com/hpcaitech/ColossalAI
You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
54 lines
1.2 KiB
54 lines
1.2 KiB
from typing import List, Tuple
|
|
|
|
import pytest
|
|
import torch
|
|
|
|
try:
|
|
from timm.models.vision_transformer import vit_large_patch16_384 as vit
|
|
MODELS = [vit]
|
|
HAS_REPO = True
|
|
except:
|
|
MODELS = []
|
|
HAS_REPO = False
|
|
|
|
from test_autochunk_vit_utils import run_test
|
|
|
|
from colossalai.autochunk.autochunk_codegen import AUTOCHUNK_AVAILABLE
|
|
from colossalai.testing import clear_cache_before_run, parameterize, spawn
|
|
|
|
|
|
def get_data() -> Tuple[List, List]:
|
|
data = torch.rand(1, 3, 384, 384)
|
|
meta_args = {'x': data}
|
|
return data, meta_args
|
|
|
|
|
|
@pytest.mark.skipif(
|
|
not (AUTOCHUNK_AVAILABLE and HAS_REPO),
|
|
reason="torch version is lower than 1.12.0",
|
|
)
|
|
@clear_cache_before_run()
|
|
@parameterize("model", MODELS)
|
|
@parameterize("max_memory", [None, 32, 40])
|
|
def test_evoformer_block(model, max_memory):
|
|
spawn(
|
|
run_test,
|
|
1,
|
|
max_memory=max_memory,
|
|
model=model,
|
|
data=get_data(),
|
|
)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
run_test(
|
|
rank=0,
|
|
data=get_data(),
|
|
max_memory=None,
|
|
model=vit,
|
|
print_code=False,
|
|
print_mem=False,
|
|
print_est_mem=False,
|
|
print_progress=False,
|
|
)
|