last version of benchmark

pull/2364/head
oahzxl 2023-01-05 17:54:25 +08:00
parent 55cb713f36
commit 71e72c4890
1 changed files with 6 additions and 4 deletions

View File

@ -93,22 +93,24 @@ def _build_openfold():
def benchmark_evoformer(): def benchmark_evoformer():
# init data and model # init data and model
msa_len = 300 msa_len = 256
pair_len = 800 pair_len = 2048
node = torch.randn(1, msa_len, pair_len, 256).cuda() node = torch.randn(1, msa_len, pair_len, 256).cuda()
pair = torch.randn(1, pair_len, pair_len, 128).cuda() pair = torch.randn(1, pair_len, pair_len, 128).cuda()
model = evoformer_base().cuda() model = evoformer_base().cuda()
# build autochunk model # build autochunk model
max_memory = 3000 # MB max_memory = 10000 # MB fit memory mode
# max_memory = None # min memory mode
autochunk = _build_autochunk(evoformer_base().cuda(), max_memory, node, pair) autochunk = _build_autochunk(evoformer_base().cuda(), max_memory, node, pair)
# build openfold # build openfold
chunk_size = 64
openfold = _build_openfold() openfold = _build_openfold()
# benchmark # benchmark
_benchmark_evoformer(model, node, pair, "base") _benchmark_evoformer(model, node, pair, "base")
_benchmark_evoformer(openfold, node, pair, "openfold", chunk_size=4) _benchmark_evoformer(openfold, node, pair, "openfold", chunk_size=chunk_size)
_benchmark_evoformer(autochunk, node, pair, "autochunk") _benchmark_evoformer(autochunk, node, pair, "autochunk")