mirror of https://github.com/hpcaitech/ColossalAI
parent
79079a9d0c
commit
4f5ef73a43
|
@ -2,3 +2,6 @@
|
||||||
path = inference
|
path = inference
|
||||||
url = https://github.com/hpcaitech/EnergonAI.git
|
url = https://github.com/hpcaitech/EnergonAI.git
|
||||||
branch = main
|
branch = main
|
||||||
|
[submodule "examples/tutorial/fastfold/FastFold"]
|
||||||
|
path = examples/tutorial/fastfold/FastFold
|
||||||
|
url = https://github.com/hpcaitech/FastFold
|
||||||
|
|
|
@ -0,0 +1 @@
|
||||||
|
Subproject commit 19ce840650fd865bd3684684dac051ec3a7bc762
|
|
@ -2,23 +2,21 @@
|
||||||
|
|
||||||
## Table of contents
|
## Table of contents
|
||||||
|
|
||||||
- [Overview](#๐-overview)
|
- [FastFold Inference](#fastfold-inference)
|
||||||
- [Quick Start](#๐-quick-start)
|
- [Table of contents](#table-of-contents)
|
||||||
- [Dive into FastFold](#๐-dive-into-fastfold)
|
- [๐ Overview](#-overview)
|
||||||
|
- [๐ Quick Start](#-quick-start)
|
||||||
|
- [๐ Dive into FastFold](#-dive-into-fastfold)
|
||||||
|
|
||||||
## ๐ Overview
|
## ๐ Overview
|
||||||
|
|
||||||
This example lets you to quickly try out the inference of FastFold.
|
This example lets you to try out the inference of FastFold.
|
||||||
|
|
||||||
**NOTE: We use random data and random parameters in this example.**
|
|
||||||
|
|
||||||
|
|
||||||
## ๐ Quick Start
|
## ๐ Quick Start
|
||||||
|
|
||||||
1. Install FastFold
|
1. Install FastFold
|
||||||
|
|
||||||
We highly recommend installing an Anaconda or Miniconda environment and install PyTorch with conda.
|
We highly recommend you to install FastFold with conda.
|
||||||
|
|
||||||
```
|
```
|
||||||
git clone https://github.com/hpcaitech/FastFold
|
git clone https://github.com/hpcaitech/FastFold
|
||||||
cd FastFold
|
cd FastFold
|
||||||
|
@ -27,15 +25,19 @@ conda activate fastfold
|
||||||
python setup.py install
|
python setup.py install
|
||||||
```
|
```
|
||||||
|
|
||||||
2. Run the inference scripts.
|
2. Download datasets.
|
||||||
|
|
||||||
```bash
|
It may take ~900GB space to keep datasets.
|
||||||
python inference.py --gpus=1 --n_res=256 --chunk_size=None --inplace
|
|
||||||
```
|
```
|
||||||
+ `gpus` means the DAP size
|
./scripts/download_all_data.sh data/
|
||||||
+ `n_res` means the length of residue sequence
|
```
|
||||||
+ `chunk_size` introduces a memory-saving technology at the cost of speed, None means not using, 16 may be a good trade off for long sequences.
|
|
||||||
+ `inplace` introduces another memory-saving technology with zero cost, drop `--inplace` if you do not want it.
|
3. Run the inference scripts.
|
||||||
|
|
||||||
|
```
|
||||||
|
bash inference.sh
|
||||||
|
```
|
||||||
|
You can find predictions under the `outputs` dir.
|
||||||
|
|
||||||
## ๐ Dive into FastFold
|
## ๐ Dive into FastFold
|
||||||
|
|
||||||
|
|
|
@ -1,153 +0,0 @@
|
||||||
# Copyright 2023 HPC-AI Tech Inc.
|
|
||||||
# Copyright 2021 AlQuraishi Laboratory
|
|
||||||
# Copyright 2021 DeepMind Technologies Limited
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
|
|
||||||
import argparse
|
|
||||||
import os
|
|
||||||
import time
|
|
||||||
|
|
||||||
import fastfold
|
|
||||||
import numpy as np
|
|
||||||
import torch
|
|
||||||
import torch.multiprocessing as mp
|
|
||||||
from fastfold.config import model_config
|
|
||||||
from fastfold.data import data_transforms
|
|
||||||
from fastfold.model.fastnn import set_chunk_size
|
|
||||||
from fastfold.model.hub import AlphaFold
|
|
||||||
from fastfold.utils.inject_fastnn import inject_fastnn
|
|
||||||
from fastfold.utils.tensor_utils import tensor_tree_map
|
|
||||||
|
|
||||||
if int(torch.__version__.split(".")[0]) >= 1 and int(torch.__version__.split(".")[1]) > 11:
|
|
||||||
torch.backends.cuda.matmul.allow_tf32 = True
|
|
||||||
|
|
||||||
|
|
||||||
def random_template_feats(n_templ, n):
|
|
||||||
b = []
|
|
||||||
batch = {
|
|
||||||
"template_mask": np.random.randint(0, 2, (*b, n_templ)),
|
|
||||||
"template_pseudo_beta_mask": np.random.randint(0, 2, (*b, n_templ, n)),
|
|
||||||
"template_pseudo_beta": np.random.rand(*b, n_templ, n, 3),
|
|
||||||
"template_aatype": np.random.randint(0, 22, (*b, n_templ, n)),
|
|
||||||
"template_all_atom_mask": np.random.randint(0, 2, (*b, n_templ, n, 37)),
|
|
||||||
"template_all_atom_positions": np.random.rand(*b, n_templ, n, 37, 3) * 10,
|
|
||||||
"template_torsion_angles_sin_cos": np.random.rand(*b, n_templ, n, 7, 2),
|
|
||||||
"template_alt_torsion_angles_sin_cos": np.random.rand(*b, n_templ, n, 7, 2),
|
|
||||||
"template_torsion_angles_mask": np.random.rand(*b, n_templ, n, 7),
|
|
||||||
}
|
|
||||||
batch = {k: v.astype(np.float32) for k, v in batch.items()}
|
|
||||||
batch["template_aatype"] = batch["template_aatype"].astype(np.int64)
|
|
||||||
return batch
|
|
||||||
|
|
||||||
|
|
||||||
def random_extra_msa_feats(n_extra, n):
|
|
||||||
b = []
|
|
||||||
batch = {
|
|
||||||
"extra_msa": np.random.randint(0, 22, (*b, n_extra, n)).astype(np.int64),
|
|
||||||
"extra_has_deletion": np.random.randint(0, 2, (*b, n_extra, n)).astype(np.float32),
|
|
||||||
"extra_deletion_value": np.random.rand(*b, n_extra, n).astype(np.float32),
|
|
||||||
"extra_msa_mask": np.random.randint(0, 2, (*b, n_extra, n)).astype(np.float32),
|
|
||||||
}
|
|
||||||
return batch
|
|
||||||
|
|
||||||
|
|
||||||
def generate_batch(n_res):
|
|
||||||
batch = {}
|
|
||||||
tf = torch.randint(21, size=(n_res,))
|
|
||||||
batch["target_feat"] = torch.nn.functional.one_hot(tf, 22).float()
|
|
||||||
batch["aatype"] = torch.argmax(batch["target_feat"], dim=-1)
|
|
||||||
batch["residue_index"] = torch.arange(n_res)
|
|
||||||
batch["msa_feat"] = torch.rand((128, n_res, 49))
|
|
||||||
t_feats = random_template_feats(4, n_res)
|
|
||||||
batch.update({k: torch.tensor(v) for k, v in t_feats.items()})
|
|
||||||
extra_feats = random_extra_msa_feats(5120, n_res)
|
|
||||||
batch.update({k: torch.tensor(v) for k, v in extra_feats.items()})
|
|
||||||
batch["msa_mask"] = torch.randint(low=0, high=2, size=(128, n_res)).float()
|
|
||||||
batch["seq_mask"] = torch.randint(low=0, high=2, size=(n_res,)).float()
|
|
||||||
batch.update(data_transforms.make_atom14_masks(batch))
|
|
||||||
batch["no_recycling_iters"] = torch.tensor(2.)
|
|
||||||
|
|
||||||
add_recycling_dims = lambda t: (t.unsqueeze(-1).expand(*t.shape, 3))
|
|
||||||
batch = tensor_tree_map(add_recycling_dims, batch)
|
|
||||||
|
|
||||||
return batch
|
|
||||||
|
|
||||||
|
|
||||||
def inference_model(rank, world_size, result_q, batch, args):
|
|
||||||
os.environ['RANK'] = str(rank)
|
|
||||||
os.environ['LOCAL_RANK'] = str(rank)
|
|
||||||
os.environ['WORLD_SIZE'] = str(world_size)
|
|
||||||
# init distributed for Dynamic Axial Parallelism
|
|
||||||
fastfold.distributed.init_dap()
|
|
||||||
torch.cuda.set_device(rank)
|
|
||||||
config = model_config(args.model_name)
|
|
||||||
if args.chunk_size:
|
|
||||||
config.globals.chunk_size = args.chunk_size
|
|
||||||
|
|
||||||
config.globals.inplace = args.inplace
|
|
||||||
config.globals.is_multimer = False
|
|
||||||
model = AlphaFold(config)
|
|
||||||
|
|
||||||
model = inject_fastnn(model)
|
|
||||||
model = model.eval()
|
|
||||||
model = model.cuda()
|
|
||||||
|
|
||||||
set_chunk_size(model.globals.chunk_size)
|
|
||||||
|
|
||||||
with torch.no_grad():
|
|
||||||
batch = {k: torch.as_tensor(v).cuda() for k, v in batch.items()}
|
|
||||||
t = time.perf_counter()
|
|
||||||
out = model(batch)
|
|
||||||
print(f"Inference time: {time.perf_counter() - t}")
|
|
||||||
out = tensor_tree_map(lambda x: np.array(x.cpu()), out)
|
|
||||||
|
|
||||||
result_q.put(out)
|
|
||||||
|
|
||||||
torch.distributed.barrier()
|
|
||||||
torch.cuda.synchronize()
|
|
||||||
|
|
||||||
|
|
||||||
def inference_monomer_model(args):
|
|
||||||
batch = generate_batch(args.n_res)
|
|
||||||
manager = mp.Manager()
|
|
||||||
result_q = manager.Queue()
|
|
||||||
torch.multiprocessing.spawn(inference_model, nprocs=args.gpus, args=(args.gpus, result_q, batch, args))
|
|
||||||
out = result_q.get()
|
|
||||||
|
|
||||||
# get unrelexed pdb and save
|
|
||||||
# batch = tensor_tree_map(lambda x: np.array(x[..., -1].cpu()), batch)
|
|
||||||
# plddt = out["plddt"]
|
|
||||||
# plddt_b_factors = np.repeat(plddt[..., None], residue_constants.atom_type_num, axis=-1)
|
|
||||||
# unrelaxed_protein = protein.from_prediction(features=batch,
|
|
||||||
# result=out,
|
|
||||||
# b_factors=plddt_b_factors)
|
|
||||||
# with open('demo_unrelex.pdb', 'w+') as fp:
|
|
||||||
# fp.write(unrelaxed_protein)
|
|
||||||
|
|
||||||
|
|
||||||
def main(args):
|
|
||||||
inference_monomer_model(args)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
parser = argparse.ArgumentParser()
|
|
||||||
parser.add_argument("--gpus", type=int, default=1, help="""Number of GPUs with which to run inference""")
|
|
||||||
parser.add_argument("--n_res", type=int, default=50, help="virtual residue number of random data")
|
|
||||||
parser.add_argument("--model_name", type=str, default="model_1", help="model name of alphafold")
|
|
||||||
parser.add_argument('--chunk_size', type=int, default=None)
|
|
||||||
parser.add_argument('--inplace', default=False, action='store_true')
|
|
||||||
|
|
||||||
args = parser.parse_args()
|
|
||||||
|
|
||||||
main(args)
|
|
|
@ -1,10 +0,0 @@
|
||||||
set -euxo pipefail
|
|
||||||
|
|
||||||
git clone https://github.com/hpcaitech/FastFold
|
|
||||||
cd FastFold
|
|
||||||
pip install -r requirements/requirements.txt
|
|
||||||
python setup.py install
|
|
||||||
pip install -r requirements/test_requirements.txt
|
|
||||||
cd ..
|
|
||||||
|
|
||||||
python inference.py
|
|
Loadingโฆ
Reference in New Issue