You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
ColossalAI/applications/Chat/examples/download_model.py

80 lines
2.7 KiB

import argparse
import dataclasses
import os
import parser
from typing import List
import tqdm
from coati.models.bloom import BLOOMRM, BLOOMActor, BLOOMCritic
from coati.models.gpt import GPTRM, GPTActor, GPTCritic
from coati.models.opt import OPTRM, OPTActor, OPTCritic
from huggingface_hub import hf_hub_download, snapshot_download
from transformers import AutoConfig, AutoTokenizer, BloomConfig, BloomTokenizerFast, GPT2Config, GPT2Tokenizer
@dataclasses.dataclass
class HFRepoFiles:
repo_id: str
files: List[str]
def download(self, dir_path: str):
for file in self.files:
file_path = hf_hub_download(self.repo_id, file, local_dir=dir_path)
def download_all(self):
snapshot_download(self.repo_id)
def test_init(model: str, dir_path: str):
if model == "gpt2":
config = GPT2Config.from_pretrained(dir_path)
actor = GPTActor(config=config)
critic = GPTCritic(config=config)
reward_model = GPTRM(config=config)
GPT2Tokenizer.from_pretrained(dir_path)
elif model == "bloom":
config = BloomConfig.from_pretrained(dir_path)
actor = BLOOMActor(config=config)
critic = BLOOMCritic(config=config)
reward_model = BLOOMRM(config=config)
BloomTokenizerFast.from_pretrained(dir_path)
elif model == "opt":
config = AutoConfig.from_pretrained(dir_path)
actor = OPTActor(config=config)
critic = OPTCritic(config=config)
reward_model = OPTRM(config=config)
AutoTokenizer.from_pretrained(dir_path)
else:
raise NotImplementedError(f"Model {model} not implemented")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--model-dir", type=str, default="test_models")
parser.add_argument("--config-only", default=False, action="store_true")
args = parser.parse_args()
if os.path.exists(args.model_dir):
print(f"[INFO]: {args.model_dir} already exists")
exit(0)
repo_list = {
"gpt2": HFRepoFiles(repo_id="gpt2", files=["config.json", "tokenizer.json", "vocab.json", "merges.txt"]),
"bloom": HFRepoFiles(
repo_id="bigscience/bloom-560m", files=["config.json", "tokenizer.json", "tokenizer_config.json"]
),
"opt": HFRepoFiles(
repo_id="facebook/opt-350m", files=["config.json", "tokenizer_config.json", "vocab.json", "merges.txt"]
),
}
os.mkdir(args.model_dir)
for model_name in tqdm.tqdm(repo_list):
dir_path = os.path.join(args.model_dir, model_name)
if args.config_only:
os.mkdir(dir_path)
repo_list[model_name].download(dir_path)
else:
repo_list[model_name].download_all()
test_init(model_name, dir_path)