update codegen test

pull/2364/head
oahzxl 2023-01-09 14:53:04 +08:00
parent 74b81395a2
commit 3abbaf8bc6
1 changed files with 16 additions and 7 deletions

View File

@ -1,3 +1,5 @@
from functools import partial
import pytest
import torch
import torch.fx
@ -46,7 +48,7 @@ def _test_fwd(model: torch.nn.Module, gm: ColoGraphModule, node, pair):
)
def _test_autochunk_codegen(rank):
def _test_autochunk_codegen(rank, msa_len, pair_len, max_memory):
# launch colossalai to make sure we could execute colossalai.utils.checkpoint currectly
colossalai.launch(
config={},
@ -59,8 +61,6 @@ def _test_autochunk_codegen(rank):
# build model and input
model = evoformer_base().cuda()
msa_len = 32
pair_len = 64
node = torch.randn(1, msa_len, pair_len, 256).cuda()
pair = torch.randn(1, pair_len, pair_len, 128).cuda()
@ -85,7 +85,7 @@ def _test_autochunk_codegen(rank):
MetaTensor(node, fake_device="cuda:0"), MetaTensor(pair, fake_device="cuda:0")
)
codegen = AutoChunkCodeGen(gm_prop)
codegen = AutoChunkCodeGen(gm_prop, max_memory=max_memory)
graph.set_codegen(codegen)
gm = ColoGraphModule(model, graph)
gm.recompile()
@ -99,9 +99,18 @@ def _test_autochunk_codegen(rank):
gpc.destroy()
def test_autochunk_codegen():
mp.spawn(_test_autochunk_codegen, nprocs=1)
@pytest.mark.parametrize("max_memory", [None, 20, 24, 28, 32])
@pytest.mark.parametrize("msa_len", [32])
@pytest.mark.parametrize("pair_len", [64])
def test_autochunk_codegen(msa_len, pair_len, max_memory):
run_func = partial(
_test_autochunk_codegen,
msa_len=msa_len,
pair_len=pair_len,
max_memory=max_memory,
)
mp.spawn(run_func, nprocs=1)
if __name__ == "__main__":
_test_autochunk_codegen(0)
_test_autochunk_codegen(0, 32, 64, None)