add test in import

pull/2364/head
oahzxl 2 years ago
parent fd818cf144
commit c1492e5013

@ -6,7 +6,6 @@ import torch.fx
import torch.multiprocessing as mp
import colossalai
from colossalai.autochunk.autochunk_codegen import AutoChunkCodeGen
from colossalai.core import global_context as gpc
from colossalai.fx import ColoTracer
from colossalai.fx.codegen.activation_checkpoint_codegen import CODEGEN_AVAILABLE
@ -16,6 +15,9 @@ from colossalai.fx.profiler import MetaTensor
from colossalai.utils import free_port
from tests.test_autochunk.evoformer.evoformer import evoformer_base
if CODEGEN_AVAILABLE:
from colossalai.autochunk.autochunk_codegen import AutoChunkCodeGen
def _test_fwd(model: torch.nn.Module, gm: ColoGraphModule, node, pair):
# for memory test

@ -6,7 +6,6 @@ import torch.fx
import torch.multiprocessing as mp
import colossalai
from colossalai.autochunk.autochunk_codegen import AutoChunkCodeGen
from colossalai.core import global_context as gpc
from colossalai.fx.codegen.activation_checkpoint_codegen import CODEGEN_AVAILABLE
from colossalai.fx.passes.meta_info_prop import MetaInfoProp
@ -14,6 +13,9 @@ from colossalai.fx.profiler import MetaTensor
from colossalai.utils import free_port
from tests.test_autochunk.evoformer.evoformer import evoformer_base
if CODEGEN_AVAILABLE:
from colossalai.autochunk.autochunk_codegen import AutoChunkCodeGen
def assert_chunk_infos(chunk_infos, max_memory, msa_len, pair_len):
found_regions = [i["region"] for i in chunk_infos]

Loading…
Cancel
Save