diff --git a/tests/test_autochunk/test_autochunk_search.py b/tests/test_autochunk/test_autochunk_search.py index c824a43ab..6f7214633 100644 --- a/tests/test_autochunk/test_autochunk_search.py +++ b/tests/test_autochunk/test_autochunk_search.py @@ -8,8 +8,6 @@ import torch.multiprocessing as mp import colossalai from colossalai.autochunk.autochunk_codegen import AutoChunkCodeGen from colossalai.core import global_context as gpc -from colossalai.fx import ColoTracer -from colossalai.fx.graph_module import ColoGraphModule from colossalai.fx.passes.meta_info_prop import MetaInfoProp from colossalai.fx.profiler import MetaTensor from colossalai.utils import free_port @@ -32,12 +30,31 @@ def assert_chunk_infos(chunk_infos, max_memory, msa_len, pair_len): raise NotImplementedError() else: raise NotImplementedError() - - assert len(found_regions) == len(target_regions), "len of found regions %s doesn't equal len of target regions %s" % (str(found_regions), str(target_regions)) + + assert len(found_regions) == len( + target_regions + ), "len of found regions %s doesn't equal len of target regions %s" % ( + str(found_regions), + str(target_regions), + ) for region in target_regions: - assert region in found_regions, "region:%s not in found regions for msa:%d, pair:%d, maxmem:%d" % (str(region), msa_len, pair_len, max_memory) + assert ( + region in found_regions + ), "region:%s not in found regions for msa:%d, pair:%d, maxmem:%d" % ( + str(region), + msa_len, + pair_len, + max_memory, + ) for region in found_regions: - assert region in target_regions, "region:%s should not be found for msa:%d, pair:%d, maxmem:%d" % (str(region), msa_len, pair_len, max_memory) + assert ( + region in target_regions + ), "region:%s should not be found for msa:%d, pair:%d, maxmem:%d" % ( + str(region), + msa_len, + pair_len, + max_memory, + ) def _test_autochunk_search(rank, msa_len, pair_len, max_memory):