From aafc3516a5c07347f58bbc1a52410f74e51b685f Mon Sep 17 00:00:00 2001 From: oahzxl Date: Mon, 9 Jan 2023 15:32:19 +0800 Subject: [PATCH] add available --- tests/test_autochunk/test_autochunk_codegen.py | 2 ++ tests/test_autochunk/test_autochunk_search.py | 2 ++ 2 files changed, 4 insertions(+) diff --git a/tests/test_autochunk/test_autochunk_codegen.py b/tests/test_autochunk/test_autochunk_codegen.py index 62763a6d5..c4f5cda67 100644 --- a/tests/test_autochunk/test_autochunk_codegen.py +++ b/tests/test_autochunk/test_autochunk_codegen.py @@ -9,6 +9,7 @@ import colossalai from colossalai.autochunk.autochunk_codegen import AutoChunkCodeGen from colossalai.core import global_context as gpc from colossalai.fx import ColoTracer +from colossalai.fx.codegen.activation_checkpoint_codegen import CODEGEN_AVAILABLE from colossalai.fx.graph_module import ColoGraphModule from colossalai.fx.passes.meta_info_prop import MetaInfoProp from colossalai.fx.profiler import MetaTensor @@ -99,6 +100,7 @@ def _test_autochunk_codegen(rank, msa_len, pair_len, max_memory): gpc.destroy() +@pytest.mark.skipif(not CODEGEN_AVAILABLE, reason='torch version is lower than 1.12.0') @pytest.mark.parametrize("max_memory", [None, 20, 25, 30]) @pytest.mark.parametrize("msa_len", [32]) @pytest.mark.parametrize("pair_len", [64]) diff --git a/tests/test_autochunk/test_autochunk_search.py b/tests/test_autochunk/test_autochunk_search.py index 6f7214633..5026c3ad3 100644 --- a/tests/test_autochunk/test_autochunk_search.py +++ b/tests/test_autochunk/test_autochunk_search.py @@ -8,6 +8,7 @@ import torch.multiprocessing as mp import colossalai from colossalai.autochunk.autochunk_codegen import AutoChunkCodeGen from colossalai.core import global_context as gpc +from colossalai.fx.codegen.activation_checkpoint_codegen import CODEGEN_AVAILABLE from colossalai.fx.passes.meta_info_prop import MetaInfoProp from colossalai.fx.profiler import MetaTensor from colossalai.utils import free_port @@ -86,6 +87,7 @@ def _test_autochunk_search(rank, msa_len, pair_len, max_memory): gpc.destroy() +@pytest.mark.skipif(not CODEGEN_AVAILABLE, reason="torch version is lower than 1.12.0") @pytest.mark.parametrize("max_memory", [None, 20, 25, 30]) @pytest.mark.parametrize("msa_len", [32]) @pytest.mark.parametrize("pair_len", [64])