From 2e16f842a9e5b1fb54e7e41070e9d2bb5cd64d7c Mon Sep 17 00:00:00 2001 From: BlueRum <70618399+ht-zhou@users.noreply.github.com> Date: Wed, 22 Feb 2023 16:58:11 +0800 Subject: [PATCH] [chatgpt]support opt & gpt for rm training (#2876) --- applications/ChatGPT/chatgpt/nn/bloom_rm.py | 1 - applications/ChatGPT/chatgpt/nn/gpt_rm.py | 9 +++- applications/ChatGPT/chatgpt/nn/opt_rm.py | 10 +++-- .../ChatGPT/examples/train_reward_model.py | 41 ++++++++++++++----- applications/ChatGPT/examples/train_rm.sh | 4 +- 5 files changed, 48 insertions(+), 17 deletions(-) diff --git a/applications/ChatGPT/chatgpt/nn/bloom_rm.py b/applications/ChatGPT/chatgpt/nn/bloom_rm.py index 0d4dd43fa..12c37957d 100644 --- a/applications/ChatGPT/chatgpt/nn/bloom_rm.py +++ b/applications/ChatGPT/chatgpt/nn/bloom_rm.py @@ -1,6 +1,5 @@ from typing import Optional -import torch import torch.nn as nn from transformers import BloomConfig, BloomForCausalLM, BloomModel diff --git a/applications/ChatGPT/chatgpt/nn/gpt_rm.py b/applications/ChatGPT/chatgpt/nn/gpt_rm.py index c6c41a45a..fcfb61cd4 100644 --- a/applications/ChatGPT/chatgpt/nn/gpt_rm.py +++ b/applications/ChatGPT/chatgpt/nn/gpt_rm.py @@ -15,12 +15,16 @@ class GPTRM(RewardModel): pretrained (str): Pretrained model name or path. config (GPT2Config): Model config. checkpoint (bool): Enable gradient checkpointing. + lora_rank (int): Rank of the low-rank approximation. + lora_train_bias (str): LoRA bias training mode. """ def __init__(self, pretrained: Optional[str] = None, config: Optional[GPT2Config] = None, - checkpoint: bool = False) -> None: + checkpoint: bool = False, + lora_rank: int = 0, + lora_train_bias: str = 'none') -> None: if pretrained is not None: model = GPT2Model.from_pretrained(pretrained) elif config is not None: @@ -29,5 +33,6 @@ class GPTRM(RewardModel): model = GPT2Model(GPT2Config()) if checkpoint: model.gradient_checkpointing_enable() + value_head = nn.Linear(model.config.n_embd, 1) - super().__init__(model, value_head) + super().__init__(model, value_head, lora_rank, lora_train_bias) diff --git a/applications/ChatGPT/chatgpt/nn/opt_rm.py b/applications/ChatGPT/chatgpt/nn/opt_rm.py index 150f832e0..5f518a3cc 100644 --- a/applications/ChatGPT/chatgpt/nn/opt_rm.py +++ b/applications/ChatGPT/chatgpt/nn/opt_rm.py @@ -1,8 +1,7 @@ from typing import Optional import torch.nn as nn -from transformers.models.opt.configuration_opt import OPTConfig -from transformers.models.opt.modeling_opt import OPTModel +from transformers import OPTConfig, OPTModel from .reward_model import RewardModel @@ -14,6 +13,7 @@ class OPTRM(RewardModel): Args: pretrained (str): Pretrained model name or path. config (OPTConfig): Model config. + checkpoint (bool): Enable gradient checkpointing. lora_rank (int): Rank of the low-rank approximation. lora_train_bias (str): LoRA bias training mode. """ @@ -21,6 +21,7 @@ class OPTRM(RewardModel): def __init__(self, pretrained: Optional[str] = None, config: Optional[OPTConfig] = None, + checkpoint: bool = False, lora_rank: int = 0, lora_train_bias: str = 'none') -> None: if pretrained is not None: @@ -29,5 +30,8 @@ class OPTRM(RewardModel): model = OPTModel(config) else: model = OPTModel(OPTConfig()) - value_head = nn.Linear(model.config.hidden_size, 1) + if checkpoint: + model.gradient_checkpointing_enable() + + value_head = nn.Linear(model.config.word_embed_proj_dim, 1) super().__init__(model, value_head, lora_rank, lora_train_bias) diff --git a/applications/ChatGPT/examples/train_reward_model.py b/applications/ChatGPT/examples/train_reward_model.py index 57d47b695..bf2071793 100644 --- a/applications/ChatGPT/examples/train_reward_model.py +++ b/applications/ChatGPT/examples/train_reward_model.py @@ -3,12 +3,13 @@ import argparse import loralib as lora import torch from chatgpt.dataset import RewardDataset -from chatgpt.nn import BLOOMRM +from chatgpt.nn import BLOOMRM, GPTRM, OPTRM from chatgpt.trainer import RewardModelTrainer from chatgpt.trainer.strategies import ColossalAIStrategy, DDPStrategy, NaiveStrategy from datasets import load_dataset from torch.optim import Adam -from transformers import BloomTokenizerFast +from transformers import AutoTokenizer, BloomTokenizerFast +from transformers.models.gpt2.tokenization_gpt2 import GPT2Tokenizer from colossalai.nn.optimizer import HybridAdam @@ -27,11 +28,30 @@ def train(args): raise ValueError(f'Unsupported strategy "{args.strategy}"') # configure model - tokenizer = BloomTokenizerFast.from_pretrained(args.pretrain) - tokenizer.pad_token = tokenizer.eos_token with strategy.model_init_context(): - model = BLOOMRM(pretrained=args.pretrain).cuda() - max_len = 1024 + if args.model == 'bloom': + model = BLOOMRM(pretrained=args.pretrain, lora_rank=args.lora_rank).cuda() + elif args.model == 'opt': + model = OPTRM(pretrained=args.pretrain, lora_rank=args.lora_rank).cuda() + elif args.model == 'gpt2': + model = GPTRM(pretrained=args.pretrain, lora_rank=args.lora_rank).cuda() + else: + raise ValueError(f'Unsupported model "{args.model}"') + + # configure tokenizer + if args.model == 'gpt2': + tokenizer = GPT2Tokenizer.from_pretrained('gpt2') + tokenizer.pad_token = tokenizer.eos_token + elif args.model == 'bloom': + tokenizer = BloomTokenizerFast.from_pretrained(args.pretrain) + tokenizer.pad_token = tokenizer.eos_token + elif args.model == 'opt': + tokenizer = AutoTokenizer.from_pretrained("facebook/opt-350m") + else: + raise ValueError(f'Unsupported model "{args.model}"') + tokenizer.pad_token = tokenizer.eos_token + + max_len = 512 # configure optimizer if args.strategy.startswith('colossalai'): @@ -58,10 +78,10 @@ def train(args): trainer.fit(use_lora=args.lora_rank) - if args.lora_rank > 0: - torch.save({'model_state_dict': lora.lora_state_dict(trainer.model)}, args.save_path) - else: - torch.save(trainer.model, args.save_path) + # save model checkpoint after fitting on only rank0 + strategy.save_model(model, 'rm_checkpoint.pt', only_rank0=True) + # save optimizer checkpoint on all ranks + strategy.save_optimizer(optim, 'rm_optim_checkpoint_%d.pt' % (torch.cuda.current_device()), only_rank0=False) if __name__ == '__main__': @@ -69,6 +89,7 @@ if __name__ == '__main__': parser.add_argument('--strategy', choices=['naive', 'ddp', 'colossalai_gemini', 'colossalai_zero2'], default='naive') + parser.add_argument('--model', choices=['gpt2', 'bloom', 'opt'], default='bloom') parser.add_argument('--pretrain', type=str, default=None) parser.add_argument('--dataset', type=str, default='Dahoas/rm-static') parser.add_argument('--save_path', type=str, default='rm_ckpt.pth') diff --git a/applications/ChatGPT/examples/train_rm.sh b/applications/ChatGPT/examples/train_rm.sh index ed91deee2..6e11a148b 100755 --- a/applications/ChatGPT/examples/train_rm.sh +++ b/applications/ChatGPT/examples/train_rm.sh @@ -15,4 +15,6 @@ set_n_least_used_CUDA_VISIBLE_DEVICES() { set_n_least_used_CUDA_VISIBLE_DEVICES 2 -torchrun --standalone --nproc_per_node=2 train_reward_model.py --pretrain '/data2/users/lczht/bloom-560m' --strategy colossalai_zero2 +# torchrun --standalone --nproc_per_node=2 train_reward_model.py --pretrain 'bigscience/bloomz-560m' --model 'bloom' --strategy colossalai_zero2 +torchrun --standalone --nproc_per_node=2 train_reward_model.py --model 'gpt2' --strategy colossalai_zero2 +# torchrun --standalone --nproc_per_node=2 train_reward_model.py --pretrain "facebook/opt-350m" --model 'opt' --strategy colossalai_zero2