|
|
|
@ -9,6 +9,7 @@ import colossalai
|
|
|
|
|
from colossalai.autochunk.autochunk_codegen import AutoChunkCodeGen |
|
|
|
|
from colossalai.core import global_context as gpc |
|
|
|
|
from colossalai.fx import ColoTracer |
|
|
|
|
from colossalai.fx.codegen.activation_checkpoint_codegen import CODEGEN_AVAILABLE |
|
|
|
|
from colossalai.fx.graph_module import ColoGraphModule |
|
|
|
|
from colossalai.fx.passes.meta_info_prop import MetaInfoProp |
|
|
|
|
from colossalai.fx.profiler import MetaTensor |
|
|
|
@ -99,6 +100,7 @@ def _test_autochunk_codegen(rank, msa_len, pair_len, max_memory):
|
|
|
|
|
gpc.destroy() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@pytest.mark.skipif(not CODEGEN_AVAILABLE, reason='torch version is lower than 1.12.0') |
|
|
|
|
@pytest.mark.parametrize("max_memory", [None, 20, 25, 30]) |
|
|
|
|
@pytest.mark.parametrize("msa_len", [32]) |
|
|
|
|
@pytest.mark.parametrize("pair_len", [64]) |
|
|
|
|