mirror of https://github.com/hpcaitech/ColossalAI
code style
parent
5a916c0adb
commit
7a23deb584
|
@ -34,15 +34,23 @@ def _benchmark_evoformer(model: torch.nn.Module, node, pair, title):
|
|||
|
||||
|
||||
def benchmark_evoformer():
|
||||
# data
|
||||
# init data and model
|
||||
msa_len = 300
|
||||
pair_len = 800
|
||||
node = torch.randn(1, msa_len, pair_len, 256).cuda()
|
||||
pair = torch.randn(1, pair_len, pair_len, 128).cuda()
|
||||
|
||||
# build gm model
|
||||
max_memory = 3000 # MB
|
||||
model = evoformer_base().cuda()
|
||||
|
||||
# build autochunk model
|
||||
max_memory = 3000 # MB
|
||||
autochunk = _build_autochunk(model, max_memory, node, pair)
|
||||
|
||||
# benchmark
|
||||
_benchmark_evoformer(model, node, pair, "openfold")
|
||||
_benchmark_evoformer(autochunk, node, pair, "autochunk")
|
||||
|
||||
|
||||
def _build_autochunk(model, max_memory, node, pair):
|
||||
# trace the module and replace codegen
|
||||
graph = ColoTracer().trace(
|
||||
model,
|
||||
|
@ -70,9 +78,7 @@ def benchmark_evoformer():
|
|||
# print
|
||||
code = graph.python_code("self").src
|
||||
print(code)
|
||||
|
||||
_benchmark_evoformer(gm, node, pair, "autochunk")
|
||||
_benchmark_evoformer(model, node, pair, "openfold")
|
||||
return gm
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
Loading…
Reference in New Issue