mirror of https://github.com/hpcaitech/ColossalAI
code style
parent
d106b271f8
commit
d5c4f0bf95
|
@ -8,8 +8,6 @@ import torch.multiprocessing as mp
|
||||||
import colossalai
|
import colossalai
|
||||||
from colossalai.autochunk.autochunk_codegen import AutoChunkCodeGen
|
from colossalai.autochunk.autochunk_codegen import AutoChunkCodeGen
|
||||||
from colossalai.core import global_context as gpc
|
from colossalai.core import global_context as gpc
|
||||||
from colossalai.fx import ColoTracer
|
|
||||||
from colossalai.fx.graph_module import ColoGraphModule
|
|
||||||
from colossalai.fx.passes.meta_info_prop import MetaInfoProp
|
from colossalai.fx.passes.meta_info_prop import MetaInfoProp
|
||||||
from colossalai.fx.profiler import MetaTensor
|
from colossalai.fx.profiler import MetaTensor
|
||||||
from colossalai.utils import free_port
|
from colossalai.utils import free_port
|
||||||
|
@ -33,11 +31,30 @@ def assert_chunk_infos(chunk_infos, max_memory, msa_len, pair_len):
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
||||||
assert len(found_regions) == len(target_regions), "len of found regions %s doesn't equal len of target regions %s" % (str(found_regions), str(target_regions))
|
assert len(found_regions) == len(
|
||||||
|
target_regions
|
||||||
|
), "len of found regions %s doesn't equal len of target regions %s" % (
|
||||||
|
str(found_regions),
|
||||||
|
str(target_regions),
|
||||||
|
)
|
||||||
for region in target_regions:
|
for region in target_regions:
|
||||||
assert region in found_regions, "region:%s not in found regions for msa:%d, pair:%d, maxmem:%d" % (str(region), msa_len, pair_len, max_memory)
|
assert (
|
||||||
|
region in found_regions
|
||||||
|
), "region:%s not in found regions for msa:%d, pair:%d, maxmem:%d" % (
|
||||||
|
str(region),
|
||||||
|
msa_len,
|
||||||
|
pair_len,
|
||||||
|
max_memory,
|
||||||
|
)
|
||||||
for region in found_regions:
|
for region in found_regions:
|
||||||
assert region in target_regions, "region:%s should not be found for msa:%d, pair:%d, maxmem:%d" % (str(region), msa_len, pair_len, max_memory)
|
assert (
|
||||||
|
region in target_regions
|
||||||
|
), "region:%s should not be found for msa:%d, pair:%d, maxmem:%d" % (
|
||||||
|
str(region),
|
||||||
|
msa_len,
|
||||||
|
pair_len,
|
||||||
|
max_memory,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def _test_autochunk_search(rank, msa_len, pair_len, max_memory):
|
def _test_autochunk_search(rank, msa_len, pair_len, max_memory):
|
||||||
|
|
Loading…
Reference in New Issue