From c1492e5013709e49093e497c3b7a6ec4bb10b9d4 Mon Sep 17 00:00:00 2001 From: oahzxl Date: Tue, 10 Jan 2023 11:20:28 +0800 Subject: [PATCH] add test in import --- tests/test_autochunk/test_autochunk_codegen.py | 4 +++- tests/test_autochunk/test_autochunk_search.py | 4 +++- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/tests/test_autochunk/test_autochunk_codegen.py b/tests/test_autochunk/test_autochunk_codegen.py index 53f62077c..28999706b 100644 --- a/tests/test_autochunk/test_autochunk_codegen.py +++ b/tests/test_autochunk/test_autochunk_codegen.py @@ -6,7 +6,6 @@ import torch.fx import torch.multiprocessing as mp import colossalai -from colossalai.autochunk.autochunk_codegen import AutoChunkCodeGen from colossalai.core import global_context as gpc from colossalai.fx import ColoTracer from colossalai.fx.codegen.activation_checkpoint_codegen import CODEGEN_AVAILABLE @@ -16,6 +15,9 @@ from colossalai.fx.profiler import MetaTensor from colossalai.utils import free_port from tests.test_autochunk.evoformer.evoformer import evoformer_base +if CODEGEN_AVAILABLE: + from colossalai.autochunk.autochunk_codegen import AutoChunkCodeGen + def _test_fwd(model: torch.nn.Module, gm: ColoGraphModule, node, pair): # for memory test diff --git a/tests/test_autochunk/test_autochunk_search.py b/tests/test_autochunk/test_autochunk_search.py index 5026c3ad3..eb2bf4560 100644 --- a/tests/test_autochunk/test_autochunk_search.py +++ b/tests/test_autochunk/test_autochunk_search.py @@ -6,7 +6,6 @@ import torch.fx import torch.multiprocessing as mp import colossalai -from colossalai.autochunk.autochunk_codegen import AutoChunkCodeGen from colossalai.core import global_context as gpc from colossalai.fx.codegen.activation_checkpoint_codegen import CODEGEN_AVAILABLE from colossalai.fx.passes.meta_info_prop import MetaInfoProp @@ -14,6 +13,9 @@ from colossalai.fx.profiler import MetaTensor from colossalai.utils import free_port from tests.test_autochunk.evoformer.evoformer import evoformer_base +if CODEGEN_AVAILABLE: + from colossalai.autochunk.autochunk_codegen import AutoChunkCodeGen + def assert_chunk_infos(chunk_infos, max_memory, msa_len, pair_len): found_regions = [i["region"] for i in chunk_infos]