[chatgpt]support opt & gpt for rm training (#2876)

pull/2880/head
BlueRum 2023-02-22 16:58:11 +08:00 committed by GitHub
parent c52edcf0eb
commit 2e16f842a9
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 48 additions and 17 deletions

View File

@ -1,6 +1,5 @@
from typing import Optional
import torch
import torch.nn as nn
from transformers import BloomConfig, BloomForCausalLM, BloomModel

View File

@ -15,12 +15,16 @@ class GPTRM(RewardModel):
pretrained (str): Pretrained model name or path.
config (GPT2Config): Model config.
checkpoint (bool): Enable gradient checkpointing.
lora_rank (int): Rank of the low-rank approximation.
lora_train_bias (str): LoRA bias training mode.
"""
def __init__(self,
pretrained: Optional[str] = None,
config: Optional[GPT2Config] = None,
checkpoint: bool = False) -> None:
checkpoint: bool = False,
lora_rank: int = 0,
lora_train_bias: str = 'none') -> None:
if pretrained is not None:
model = GPT2Model.from_pretrained(pretrained)
elif config is not None:
@ -29,5 +33,6 @@ class GPTRM(RewardModel):
model = GPT2Model(GPT2Config())
if checkpoint:
model.gradient_checkpointing_enable()
value_head = nn.Linear(model.config.n_embd, 1)
super().__init__(model, value_head)
super().__init__(model, value_head, lora_rank, lora_train_bias)

View File

@ -1,8 +1,7 @@
from typing import Optional
import torch.nn as nn
from transformers.models.opt.configuration_opt import OPTConfig
from transformers.models.opt.modeling_opt import OPTModel
from transformers import OPTConfig, OPTModel
from .reward_model import RewardModel
@ -14,6 +13,7 @@ class OPTRM(RewardModel):
Args:
pretrained (str): Pretrained model name or path.
config (OPTConfig): Model config.
checkpoint (bool): Enable gradient checkpointing.
lora_rank (int): Rank of the low-rank approximation.
lora_train_bias (str): LoRA bias training mode.
"""
@ -21,6 +21,7 @@ class OPTRM(RewardModel):
def __init__(self,
pretrained: Optional[str] = None,
config: Optional[OPTConfig] = None,
checkpoint: bool = False,
lora_rank: int = 0,
lora_train_bias: str = 'none') -> None:
if pretrained is not None:
@ -29,5 +30,8 @@ class OPTRM(RewardModel):
model = OPTModel(config)
else:
model = OPTModel(OPTConfig())
value_head = nn.Linear(model.config.hidden_size, 1)
if checkpoint:
model.gradient_checkpointing_enable()
value_head = nn.Linear(model.config.word_embed_proj_dim, 1)
super().__init__(model, value_head, lora_rank, lora_train_bias)

View File

@ -3,12 +3,13 @@ import argparse
import loralib as lora
import torch
from chatgpt.dataset import RewardDataset
from chatgpt.nn import BLOOMRM
from chatgpt.nn import BLOOMRM, GPTRM, OPTRM
from chatgpt.trainer import RewardModelTrainer
from chatgpt.trainer.strategies import ColossalAIStrategy, DDPStrategy, NaiveStrategy
from datasets import load_dataset
from torch.optim import Adam
from transformers import BloomTokenizerFast
from transformers import AutoTokenizer, BloomTokenizerFast
from transformers.models.gpt2.tokenization_gpt2 import GPT2Tokenizer
from colossalai.nn.optimizer import HybridAdam
@ -27,11 +28,30 @@ def train(args):
raise ValueError(f'Unsupported strategy "{args.strategy}"')
# configure model
tokenizer = BloomTokenizerFast.from_pretrained(args.pretrain)
tokenizer.pad_token = tokenizer.eos_token
with strategy.model_init_context():
model = BLOOMRM(pretrained=args.pretrain).cuda()
max_len = 1024
if args.model == 'bloom':
model = BLOOMRM(pretrained=args.pretrain, lora_rank=args.lora_rank).cuda()
elif args.model == 'opt':
model = OPTRM(pretrained=args.pretrain, lora_rank=args.lora_rank).cuda()
elif args.model == 'gpt2':
model = GPTRM(pretrained=args.pretrain, lora_rank=args.lora_rank).cuda()
else:
raise ValueError(f'Unsupported model "{args.model}"')
# configure tokenizer
if args.model == 'gpt2':
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
tokenizer.pad_token = tokenizer.eos_token
elif args.model == 'bloom':
tokenizer = BloomTokenizerFast.from_pretrained(args.pretrain)
tokenizer.pad_token = tokenizer.eos_token
elif args.model == 'opt':
tokenizer = AutoTokenizer.from_pretrained("facebook/opt-350m")
else:
raise ValueError(f'Unsupported model "{args.model}"')
tokenizer.pad_token = tokenizer.eos_token
max_len = 512
# configure optimizer
if args.strategy.startswith('colossalai'):
@ -58,10 +78,10 @@ def train(args):
trainer.fit(use_lora=args.lora_rank)
if args.lora_rank > 0:
torch.save({'model_state_dict': lora.lora_state_dict(trainer.model)}, args.save_path)
else:
torch.save(trainer.model, args.save_path)
# save model checkpoint after fitting on only rank0
strategy.save_model(model, 'rm_checkpoint.pt', only_rank0=True)
# save optimizer checkpoint on all ranks
strategy.save_optimizer(optim, 'rm_optim_checkpoint_%d.pt' % (torch.cuda.current_device()), only_rank0=False)
if __name__ == '__main__':
@ -69,6 +89,7 @@ if __name__ == '__main__':
parser.add_argument('--strategy',
choices=['naive', 'ddp', 'colossalai_gemini', 'colossalai_zero2'],
default='naive')
parser.add_argument('--model', choices=['gpt2', 'bloom', 'opt'], default='bloom')
parser.add_argument('--pretrain', type=str, default=None)
parser.add_argument('--dataset', type=str, default='Dahoas/rm-static')
parser.add_argument('--save_path', type=str, default='rm_ckpt.pth')

View File

@ -15,4 +15,6 @@ set_n_least_used_CUDA_VISIBLE_DEVICES() {
set_n_least_used_CUDA_VISIBLE_DEVICES 2
torchrun --standalone --nproc_per_node=2 train_reward_model.py --pretrain '/data2/users/lczht/bloom-560m' --strategy colossalai_zero2
# torchrun --standalone --nproc_per_node=2 train_reward_model.py --pretrain 'bigscience/bloomz-560m' --model 'bloom' --strategy colossalai_zero2
torchrun --standalone --nproc_per_node=2 train_reward_model.py --model 'gpt2' --strategy colossalai_zero2
# torchrun --standalone --nproc_per_node=2 train_reward_model.py --pretrain "facebook/opt-350m" --model 'opt' --strategy colossalai_zero2