mirror of https://github.com/hpcaitech/ColossalAI
add chunk search test
parent
a005965d2d
commit
d106b271f8
|
@ -0,0 +1,86 @@
|
||||||
|
from functools import partial
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import torch
|
||||||
|
import torch.fx
|
||||||
|
import torch.multiprocessing as mp
|
||||||
|
|
||||||
|
import colossalai
|
||||||
|
from colossalai.autochunk.autochunk_codegen import AutoChunkCodeGen
|
||||||
|
from colossalai.core import global_context as gpc
|
||||||
|
from colossalai.fx import ColoTracer
|
||||||
|
from colossalai.fx.graph_module import ColoGraphModule
|
||||||
|
from colossalai.fx.passes.meta_info_prop import MetaInfoProp
|
||||||
|
from colossalai.fx.profiler import MetaTensor
|
||||||
|
from colossalai.utils import free_port
|
||||||
|
from tests.test_autochunk.evoformer.evoformer import evoformer_base
|
||||||
|
|
||||||
|
|
||||||
|
def assert_chunk_infos(chunk_infos, max_memory, msa_len, pair_len):
|
||||||
|
found_regions = [i["region"] for i in chunk_infos]
|
||||||
|
|
||||||
|
if msa_len == 32 and pair_len == 64:
|
||||||
|
if max_memory is None:
|
||||||
|
target_regions = [(142, 154), (366, 373), (233, 283), (301, 351), (127, 134), (204, 228), (167, 191), (161, 166), (198, 203), (6, 69)]
|
||||||
|
elif max_memory == 20:
|
||||||
|
target_regions = [(142, 154), (369, 373), (233, 269), (301, 351)]
|
||||||
|
elif max_memory == 25:
|
||||||
|
target_regions = [(144, 154), (369, 370)]
|
||||||
|
elif max_memory == 30:
|
||||||
|
target_regions = [(144, 154)]
|
||||||
|
else:
|
||||||
|
raise NotImplementedError()
|
||||||
|
else:
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
assert len(found_regions) == len(target_regions), "len of found regions %s doesn't equal len of target regions %s" % (str(found_regions), str(target_regions))
|
||||||
|
for region in target_regions:
|
||||||
|
assert region in found_regions, "region:%s not in found regions for msa:%d, pair:%d, maxmem:%d" % (str(region), msa_len, pair_len, max_memory)
|
||||||
|
for region in found_regions:
|
||||||
|
assert region in target_regions, "region:%s should not be found for msa:%d, pair:%d, maxmem:%d" % (str(region), msa_len, pair_len, max_memory)
|
||||||
|
|
||||||
|
|
||||||
|
def _test_autochunk_search(rank, msa_len, pair_len, max_memory):
|
||||||
|
# launch colossalai to make sure we could execute colossalai.utils.checkpoint currectly
|
||||||
|
colossalai.launch(
|
||||||
|
config={},
|
||||||
|
rank=rank,
|
||||||
|
world_size=1,
|
||||||
|
host="localhost",
|
||||||
|
port=free_port(),
|
||||||
|
backend="nccl",
|
||||||
|
)
|
||||||
|
|
||||||
|
# build model and input
|
||||||
|
model = evoformer_base().cuda()
|
||||||
|
node = torch.randn(1, msa_len, pair_len, 256).cuda()
|
||||||
|
pair = torch.randn(1, pair_len, pair_len, 128).cuda()
|
||||||
|
|
||||||
|
gm_prop = torch.fx.symbolic_trace(model) # must use symbolic_trace
|
||||||
|
interp = MetaInfoProp(gm_prop)
|
||||||
|
interp.propagate(
|
||||||
|
MetaTensor(node, fake_device="cuda:0"), MetaTensor(pair, fake_device="cuda:0")
|
||||||
|
)
|
||||||
|
|
||||||
|
codegen = AutoChunkCodeGen(gm_prop, max_memory=max_memory)
|
||||||
|
chunk_infos = codegen.chunk_infos
|
||||||
|
assert_chunk_infos(chunk_infos, max_memory, msa_len, pair_len)
|
||||||
|
|
||||||
|
gpc.destroy()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("max_memory", [None, 20, 25, 30])
|
||||||
|
@pytest.mark.parametrize("msa_len", [32])
|
||||||
|
@pytest.mark.parametrize("pair_len", [64])
|
||||||
|
def test_autochunk_search(msa_len, pair_len, max_memory):
|
||||||
|
run_func = partial(
|
||||||
|
_test_autochunk_search,
|
||||||
|
msa_len=msa_len,
|
||||||
|
pair_len=pair_len,
|
||||||
|
max_memory=max_memory,
|
||||||
|
)
|
||||||
|
mp.spawn(run_func, nprocs=1)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
_test_autochunk_search(0, 32, 64, 20)
|
Loading…
Reference in New Issue