mirror of https://github.com/hpcaitech/ColossalAI
oahzxl
2 years ago
committed by
GitHub
5 changed files with 22 additions and 179 deletions
@ -0,0 +1 @@
|
||||
Subproject commit 19ce840650fd865bd3684684dac051ec3a7bc762 |
@ -1,153 +0,0 @@
|
||||
# Copyright 2023 HPC-AI Tech Inc. |
||||
# Copyright 2021 AlQuraishi Laboratory |
||||
# Copyright 2021 DeepMind Technologies Limited |
||||
# |
||||
# Licensed under the Apache License, Version 2.0 (the "License"); |
||||
# you may not use this file except in compliance with the License. |
||||
# You may obtain a copy of the License at |
||||
# |
||||
# http://www.apache.org/licenses/LICENSE-2.0 |
||||
# |
||||
# Unless required by applicable law or agreed to in writing, software |
||||
# distributed under the License is distributed on an "AS IS" BASIS, |
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
||||
# See the License for the specific language governing permissions and |
||||
# limitations under the License. |
||||
|
||||
import argparse |
||||
import os |
||||
import time |
||||
|
||||
import fastfold |
||||
import numpy as np |
||||
import torch |
||||
import torch.multiprocessing as mp |
||||
from fastfold.config import model_config |
||||
from fastfold.data import data_transforms |
||||
from fastfold.model.fastnn import set_chunk_size |
||||
from fastfold.model.hub import AlphaFold |
||||
from fastfold.utils.inject_fastnn import inject_fastnn |
||||
from fastfold.utils.tensor_utils import tensor_tree_map |
||||
|
||||
if int(torch.__version__.split(".")[0]) >= 1 and int(torch.__version__.split(".")[1]) > 11: |
||||
torch.backends.cuda.matmul.allow_tf32 = True |
||||
|
||||
|
||||
def random_template_feats(n_templ, n): |
||||
b = [] |
||||
batch = { |
||||
"template_mask": np.random.randint(0, 2, (*b, n_templ)), |
||||
"template_pseudo_beta_mask": np.random.randint(0, 2, (*b, n_templ, n)), |
||||
"template_pseudo_beta": np.random.rand(*b, n_templ, n, 3), |
||||
"template_aatype": np.random.randint(0, 22, (*b, n_templ, n)), |
||||
"template_all_atom_mask": np.random.randint(0, 2, (*b, n_templ, n, 37)), |
||||
"template_all_atom_positions": np.random.rand(*b, n_templ, n, 37, 3) * 10, |
||||
"template_torsion_angles_sin_cos": np.random.rand(*b, n_templ, n, 7, 2), |
||||
"template_alt_torsion_angles_sin_cos": np.random.rand(*b, n_templ, n, 7, 2), |
||||
"template_torsion_angles_mask": np.random.rand(*b, n_templ, n, 7), |
||||
} |
||||
batch = {k: v.astype(np.float32) for k, v in batch.items()} |
||||
batch["template_aatype"] = batch["template_aatype"].astype(np.int64) |
||||
return batch |
||||
|
||||
|
||||
def random_extra_msa_feats(n_extra, n): |
||||
b = [] |
||||
batch = { |
||||
"extra_msa": np.random.randint(0, 22, (*b, n_extra, n)).astype(np.int64), |
||||
"extra_has_deletion": np.random.randint(0, 2, (*b, n_extra, n)).astype(np.float32), |
||||
"extra_deletion_value": np.random.rand(*b, n_extra, n).astype(np.float32), |
||||
"extra_msa_mask": np.random.randint(0, 2, (*b, n_extra, n)).astype(np.float32), |
||||
} |
||||
return batch |
||||
|
||||
|
||||
def generate_batch(n_res): |
||||
batch = {} |
||||
tf = torch.randint(21, size=(n_res,)) |
||||
batch["target_feat"] = torch.nn.functional.one_hot(tf, 22).float() |
||||
batch["aatype"] = torch.argmax(batch["target_feat"], dim=-1) |
||||
batch["residue_index"] = torch.arange(n_res) |
||||
batch["msa_feat"] = torch.rand((128, n_res, 49)) |
||||
t_feats = random_template_feats(4, n_res) |
||||
batch.update({k: torch.tensor(v) for k, v in t_feats.items()}) |
||||
extra_feats = random_extra_msa_feats(5120, n_res) |
||||
batch.update({k: torch.tensor(v) for k, v in extra_feats.items()}) |
||||
batch["msa_mask"] = torch.randint(low=0, high=2, size=(128, n_res)).float() |
||||
batch["seq_mask"] = torch.randint(low=0, high=2, size=(n_res,)).float() |
||||
batch.update(data_transforms.make_atom14_masks(batch)) |
||||
batch["no_recycling_iters"] = torch.tensor(2.) |
||||
|
||||
add_recycling_dims = lambda t: (t.unsqueeze(-1).expand(*t.shape, 3)) |
||||
batch = tensor_tree_map(add_recycling_dims, batch) |
||||
|
||||
return batch |
||||
|
||||
|
||||
def inference_model(rank, world_size, result_q, batch, args): |
||||
os.environ['RANK'] = str(rank) |
||||
os.environ['LOCAL_RANK'] = str(rank) |
||||
os.environ['WORLD_SIZE'] = str(world_size) |
||||
# init distributed for Dynamic Axial Parallelism |
||||
fastfold.distributed.init_dap() |
||||
torch.cuda.set_device(rank) |
||||
config = model_config(args.model_name) |
||||
if args.chunk_size: |
||||
config.globals.chunk_size = args.chunk_size |
||||
|
||||
config.globals.inplace = args.inplace |
||||
config.globals.is_multimer = False |
||||
model = AlphaFold(config) |
||||
|
||||
model = inject_fastnn(model) |
||||
model = model.eval() |
||||
model = model.cuda() |
||||
|
||||
set_chunk_size(model.globals.chunk_size) |
||||
|
||||
with torch.no_grad(): |
||||
batch = {k: torch.as_tensor(v).cuda() for k, v in batch.items()} |
||||
t = time.perf_counter() |
||||
out = model(batch) |
||||
print(f"Inference time: {time.perf_counter() - t}") |
||||
out = tensor_tree_map(lambda x: np.array(x.cpu()), out) |
||||
|
||||
result_q.put(out) |
||||
|
||||
torch.distributed.barrier() |
||||
torch.cuda.synchronize() |
||||
|
||||
|
||||
def inference_monomer_model(args): |
||||
batch = generate_batch(args.n_res) |
||||
manager = mp.Manager() |
||||
result_q = manager.Queue() |
||||
torch.multiprocessing.spawn(inference_model, nprocs=args.gpus, args=(args.gpus, result_q, batch, args)) |
||||
out = result_q.get() |
||||
|
||||
# get unrelexed pdb and save |
||||
# batch = tensor_tree_map(lambda x: np.array(x[..., -1].cpu()), batch) |
||||
# plddt = out["plddt"] |
||||
# plddt_b_factors = np.repeat(plddt[..., None], residue_constants.atom_type_num, axis=-1) |
||||
# unrelaxed_protein = protein.from_prediction(features=batch, |
||||
# result=out, |
||||
# b_factors=plddt_b_factors) |
||||
# with open('demo_unrelex.pdb', 'w+') as fp: |
||||
# fp.write(unrelaxed_protein) |
||||
|
||||
|
||||
def main(args): |
||||
inference_monomer_model(args) |
||||
|
||||
|
||||
if __name__ == "__main__": |
||||
parser = argparse.ArgumentParser() |
||||
parser.add_argument("--gpus", type=int, default=1, help="""Number of GPUs with which to run inference""") |
||||
parser.add_argument("--n_res", type=int, default=50, help="virtual residue number of random data") |
||||
parser.add_argument("--model_name", type=str, default="model_1", help="model name of alphafold") |
||||
parser.add_argument('--chunk_size', type=int, default=None) |
||||
parser.add_argument('--inplace', default=False, action='store_true') |
||||
|
||||
args = parser.parse_args() |
||||
|
||||
main(args) |
Loading…
Reference in new issue