|
|
|
@ -6,7 +6,6 @@ import torch.fx
|
|
|
|
|
import torch.multiprocessing as mp
|
|
|
|
|
|
|
|
|
|
import colossalai
|
|
|
|
|
from colossalai.autochunk.autochunk_codegen import AutoChunkCodeGen
|
|
|
|
|
from colossalai.core import global_context as gpc
|
|
|
|
|
from colossalai.fx.codegen.activation_checkpoint_codegen import CODEGEN_AVAILABLE
|
|
|
|
|
from colossalai.fx.passes.meta_info_prop import MetaInfoProp
|
|
|
|
@ -14,6 +13,9 @@ from colossalai.fx.profiler import MetaTensor
|
|
|
|
|
from colossalai.utils import free_port
|
|
|
|
|
from tests.test_autochunk.evoformer.evoformer import evoformer_base
|
|
|
|
|
|
|
|
|
|
if CODEGEN_AVAILABLE:
|
|
|
|
|
from colossalai.autochunk.autochunk_codegen import AutoChunkCodeGen
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def assert_chunk_infos(chunk_infos, max_memory, msa_len, pair_len):
|
|
|
|
|
found_regions = [i["region"] for i in chunk_infos]
|
|
|
|
|