ColossalAI/applications/Chat/examples/download_model.py

85 lines
2.8 KiB
Python
Raw Normal View History

[chat] fix bugs and add unit tests (#4213) * style: rename replay buffer Experience replay is typically for off policy algorithms. Use this name in PPO maybe misleading. * fix: fix wrong zero2 default arg * test: update experience tests * style: rename zero_pad fn * fix: defer init in CycledDataLoader * test: add benchmark test * style: rename internal fn of generation * style: rename internal fn of lora * fix: remove unused loss fn * fix: remove unused utils fn * refactor: remove generate_with_actor fn * fix: fix type annotation * test: add models tests * fix: skip llama due to long execution time * style: modify dataset * style: apply formatter * perf: update reward dataset * fix: fix wrong IGNORE_INDEX in sft dataset * fix: remove DataCollatorForSupervisedDataset * test: add dataset tests * style: apply formatter * style: rename test_ci to test_train * feat: add llama in inference * test: add inference tests * test: change test scripts directory * fix: update ci * fix: fix typo * fix: skip llama due to oom * fix: fix file mod * style: apply formatter * refactor: remove duplicated llama_gptq * style: apply formatter * to: update rm test * feat: add tokenizer arg * feat: add download model script * test: update train tests * fix: modify gemini load and save pretrained * test: update checkpoint io test * to: modify nproc_per_node * fix: do not remove existing dir * fix: modify save path * test: add random choice * fix: fix sft path * fix: enlarge nproc_per_node to avoid oom * fix: add num_retry * fix: make lora config of rm and critic consistent * fix: add warning about lora weights * fix: skip some gpt2 tests * fix: remove grad ckpt in rm and critic due to errors * refactor: directly use Actor in train_sft * test: add more arguments * fix: disable grad ckpt when using lora * fix: fix save_pretrained and related tests * test: enable zero2 tests * revert: remove useless fn * style: polish code * test: modify test args
2023-08-02 02:17:36 +00:00
import argparse
import dataclasses
import os
import parser
from typing import List
import tqdm
from coati.models.bloom import BLOOMRM, BLOOMActor, BLOOMCritic
from coati.models.gpt import GPTRM, GPTActor, GPTCritic
from coati.models.opt import OPTRM, OPTActor, OPTCritic
from huggingface_hub import hf_hub_download, snapshot_download
from transformers import AutoConfig, AutoTokenizer, BloomConfig, BloomTokenizerFast, GPT2Config, GPT2Tokenizer
@dataclasses.dataclass
class HFRepoFiles:
repo_id: str
files: List[str]
def download(self, dir_path: str):
for file in self.files:
file_path = hf_hub_download(self.repo_id, file, local_dir=dir_path)
def download_all(self):
file_path = snapshot_download(self.repo_id)
def test_init(model: str, dir_path: str):
if model == "gpt2":
config = GPT2Config.from_pretrained(dir_path)
actor = GPTActor(config=config)
critic = GPTCritic(config=config)
reward_model = GPTRM(config=config)
tokenizer = GPT2Tokenizer.from_pretrained(dir_path)
elif model == "bloom":
config = BloomConfig.from_pretrained(dir_path)
actor = BLOOMActor(config=config)
critic = BLOOMCritic(config=config)
reward_model = BLOOMRM(config=config)
tokenizer = BloomTokenizerFast.from_pretrained(dir_path)
elif model == "opt":
config = AutoConfig.from_pretrained(dir_path)
actor = OPTActor(config=config)
critic = OPTCritic(config=config)
reward_model = OPTRM(config=config)
tokenizer = AutoTokenizer.from_pretrained(dir_path)
else:
raise NotImplementedError(f"Model {model} not implemented")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--model-dir", type=str, default="test_models")
parser.add_argument("--config-only", default=False, action="store_true")
args = parser.parse_args()
if os.path.exists(args.model_dir):
print(f"[INFO]: {args.model_dir} already exists")
exit(0)
repo_list = {
"gpt2": HFRepoFiles(
repo_id="gpt2",
files=["config.json", "tokenizer.json", "vocab.json", "merges.txt"]
),
"bloom": HFRepoFiles(
repo_id="bigscience/bloom-560m",
files=["config.json", "tokenizer.json", "tokenizer_config.json"]
),
"opt": HFRepoFiles(
repo_id="facebook/opt-350m",
files=["config.json", "tokenizer_config.json", "vocab.json", "merges.txt"]
),
}
os.mkdir(args.model_dir)
for model_name in tqdm.tqdm(repo_list):
dir_path = os.path.join(args.model_dir, model_name)
if args.config_only:
os.mkdir(dir_path)
repo_list[model_name].download(dir_path)
else:
repo_list[model_name].download_all()
test_init(model_name, dir_path)