mirror of https://github.com/hpcaitech/ColossalAI
add test in import
parent
fd818cf144
commit
c1492e5013
|
@ -6,7 +6,6 @@ import torch.fx
|
||||||
import torch.multiprocessing as mp
|
import torch.multiprocessing as mp
|
||||||
|
|
||||||
import colossalai
|
import colossalai
|
||||||
from colossalai.autochunk.autochunk_codegen import AutoChunkCodeGen
|
|
||||||
from colossalai.core import global_context as gpc
|
from colossalai.core import global_context as gpc
|
||||||
from colossalai.fx import ColoTracer
|
from colossalai.fx import ColoTracer
|
||||||
from colossalai.fx.codegen.activation_checkpoint_codegen import CODEGEN_AVAILABLE
|
from colossalai.fx.codegen.activation_checkpoint_codegen import CODEGEN_AVAILABLE
|
||||||
|
@ -16,6 +15,9 @@ from colossalai.fx.profiler import MetaTensor
|
||||||
from colossalai.utils import free_port
|
from colossalai.utils import free_port
|
||||||
from tests.test_autochunk.evoformer.evoformer import evoformer_base
|
from tests.test_autochunk.evoformer.evoformer import evoformer_base
|
||||||
|
|
||||||
|
if CODEGEN_AVAILABLE:
|
||||||
|
from colossalai.autochunk.autochunk_codegen import AutoChunkCodeGen
|
||||||
|
|
||||||
|
|
||||||
def _test_fwd(model: torch.nn.Module, gm: ColoGraphModule, node, pair):
|
def _test_fwd(model: torch.nn.Module, gm: ColoGraphModule, node, pair):
|
||||||
# for memory test
|
# for memory test
|
||||||
|
|
|
@ -6,7 +6,6 @@ import torch.fx
|
||||||
import torch.multiprocessing as mp
|
import torch.multiprocessing as mp
|
||||||
|
|
||||||
import colossalai
|
import colossalai
|
||||||
from colossalai.autochunk.autochunk_codegen import AutoChunkCodeGen
|
|
||||||
from colossalai.core import global_context as gpc
|
from colossalai.core import global_context as gpc
|
||||||
from colossalai.fx.codegen.activation_checkpoint_codegen import CODEGEN_AVAILABLE
|
from colossalai.fx.codegen.activation_checkpoint_codegen import CODEGEN_AVAILABLE
|
||||||
from colossalai.fx.passes.meta_info_prop import MetaInfoProp
|
from colossalai.fx.passes.meta_info_prop import MetaInfoProp
|
||||||
|
@ -14,6 +13,9 @@ from colossalai.fx.profiler import MetaTensor
|
||||||
from colossalai.utils import free_port
|
from colossalai.utils import free_port
|
||||||
from tests.test_autochunk.evoformer.evoformer import evoformer_base
|
from tests.test_autochunk.evoformer.evoformer import evoformer_base
|
||||||
|
|
||||||
|
if CODEGEN_AVAILABLE:
|
||||||
|
from colossalai.autochunk.autochunk_codegen import AutoChunkCodeGen
|
||||||
|
|
||||||
|
|
||||||
def assert_chunk_infos(chunk_infos, max_memory, msa_len, pair_len):
|
def assert_chunk_infos(chunk_infos, max_memory, msa_len, pair_len):
|
||||||
found_regions = [i["region"] for i in chunk_infos]
|
found_regions = [i["region"] for i in chunk_infos]
|
||||||
|
|
Loading…
Reference in New Issue