pull/2364/head
oahzxl 2 years ago
parent d106b271f8
commit d5c4f0bf95

@ -8,8 +8,6 @@ import torch.multiprocessing as mp
import colossalai
from colossalai.autochunk.autochunk_codegen import AutoChunkCodeGen
from colossalai.core import global_context as gpc
from colossalai.fx import ColoTracer
from colossalai.fx.graph_module import ColoGraphModule
from colossalai.fx.passes.meta_info_prop import MetaInfoProp
from colossalai.fx.profiler import MetaTensor
from colossalai.utils import free_port
@ -32,12 +30,31 @@ def assert_chunk_infos(chunk_infos, max_memory, msa_len, pair_len):
raise NotImplementedError()
else:
raise NotImplementedError()
assert len(found_regions) == len(target_regions), "len of found regions %s doesn't equal len of target regions %s" % (str(found_regions), str(target_regions))
assert len(found_regions) == len(
target_regions
), "len of found regions %s doesn't equal len of target regions %s" % (
str(found_regions),
str(target_regions),
)
for region in target_regions:
assert region in found_regions, "region:%s not in found regions for msa:%d, pair:%d, maxmem:%d" % (str(region), msa_len, pair_len, max_memory)
assert (
region in found_regions
), "region:%s not in found regions for msa:%d, pair:%d, maxmem:%d" % (
str(region),
msa_len,
pair_len,
max_memory,
)
for region in found_regions:
assert region in target_regions, "region:%s should not be found for msa:%d, pair:%d, maxmem:%d" % (str(region), msa_len, pair_len, max_memory)
assert (
region in target_regions
), "region:%s should not be found for msa:%d, pair:%d, maxmem:%d" % (
str(region),
msa_len,
pair_len,
max_memory,
)
def _test_autochunk_search(rank, msa_len, pair_len, max_memory):

Loading…
Cancel
Save