[chatgpt]add reward model code for deberta (#3199)

Co-authored-by: Yuanchen Xu <yuanchen.xu00@gmail.com>
pull/3209/head
Yuanchen 2023-03-22 19:09:39 +08:00 committed by GitHub
parent 1e1b9d2fea
commit 9998d5ef64
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 93 additions and 4 deletions

View File

@ -0,0 +1,4 @@
from .deberta_critic import DebertaCritic
from .deberta_rm import DebertaRM
__all__ = ['DebertaCritic', 'DebertaRM']

View File

@ -0,0 +1,36 @@
from typing import Optional
import torch.nn as nn
from transformers import DebertaV2Config, DebertaV2Model
from ..base import Critic
class DebertaCritic(Critic):
"""
Deberta Critic model.
Args:
pretrained (str): Pretrained model name or path.
config (DebertaV2Config): Model config.
checkpoint (bool): Enable gradient checkpointing.
lora_rank (int): Rank of the LO-RA decomposition.
lora_train_bias (str): LoRA bias training mode.
"""
def __init__(self,
pretrained: Optional[str] = None,
config: Optional[DebertaV2Config] = None,
checkpoint: bool = False,
lora_rank: int = 0,
lora_train_bias: str = 'none') -> None:
if pretrained is not None:
model = DebertaV2Model.from_pretrained(pretrained)
elif config is not None:
model = DebertaV2Model(config)
else:
model = DebertaV2Model(DebertaV2Config())
if checkpoint:
model.gradient_checkpointing_enable()
value_head = nn.Linear(model.config.hidden_size, 1)
super().__init__(model, value_head, lora_rank, lora_train_bias)

View File

@ -0,0 +1,37 @@
from typing import Optional
import torch.nn as nn
from transformers import DebertaV2Config, DebertaV2Model
from ..base import RewardModel
class DebertaRM(RewardModel):
"""
Deberta Reward model.
Args:
pretrained (str): Pretrained model name or path.
config (DebertaV2Config): Model config.
checkpoint (bool): Enable gradient checkpointing.
lora_rank (int): Rank of the LO-RA decomposition.
lora_train_bias (str): LoRA bias training mode.
"""
def __init__(self,
pretrained: str = None,
config: Optional[DebertaV2Config] = None,
checkpoint: bool = False,
lora_rank: int = 0,
lora_train_bias: str = 'none') -> None:
if pretrained is not None:
model = DebertaV2Model.from_pretrained(pretrained)
elif config is not None:
model = DebertaV2Model(config)
else:
model = DebertaV2Model(DebertaV2Config())
if checkpoint:
model.gradient_checkpointing_enable()
value_head = nn.Linear(model.config.hidden_size, 1)
value_head.weight.data.normal_(mean=0.0, std=1 / (model.config.hidden_size + 1))
super().__init__(model, value_head, lora_rank, lora_train_bias)

View File

@ -1 +1,2 @@
pandas>=1.4.1
sentencepiece

View File

@ -88,4 +88,10 @@ torchrun --standalone --nproc_per_node=2 ${BASE}/train_reward_model.py \
--dataset 'Anthropic/hh-rlhf' --subset 'harmless-base'\
--test True --lora_rank 4
torchrun --standalone --nproc_per_node=2 ${BASE}/train_reward_model.py \
--pretrain 'microsoft/deberta-v3-large' --model 'deberta' \
--strategy colossalai_zero2 --loss_fn 'log_sig'\
--dataset 'Anthropic/hh-rlhf' --subset 'harmless-base'\
--test True --lora_rank 4
rm -rf ${BASE}/rm_ckpt.pt

View File

@ -8,12 +8,13 @@ from chatgpt.models.base import RewardModel
from chatgpt.models.bloom import BLOOMRM
from chatgpt.models.gpt import GPTRM
from chatgpt.models.opt import OPTRM
from chatgpt.models.deberta import DebertaRM
from chatgpt.trainer import RewardModelTrainer
from chatgpt.trainer.strategies import ColossalAIStrategy, DDPStrategy, NaiveStrategy
from datasets import load_dataset
from random import randint
from torch.optim import Adam
from transformers import AutoTokenizer, BloomTokenizerFast
from transformers import AutoTokenizer, BloomTokenizerFast, DebertaV2Tokenizer
from transformers.models.gpt2.tokenization_gpt2 import GPT2Tokenizer
from colossalai.nn.optimizer import HybridAdam
@ -39,6 +40,8 @@ def train(args):
model = OPTRM(pretrained=args.pretrain, lora_rank=args.lora_rank).to(torch.cuda.current_device())
elif args.model == 'gpt2':
model = GPTRM(pretrained=args.pretrain, lora_rank=args.lora_rank).to(torch.cuda.current_device())
elif args.model == 'deberta':
model = DebertaRM(pretrained=args.pretrain, lora_rank=args.lora_rank).to(torch.cuda.current_device())
else:
raise ValueError(f'Unsupported model "{args.model}"')
@ -54,6 +57,8 @@ def train(args):
tokenizer = BloomTokenizerFast.from_pretrained('bigscience/bloom-560m')
elif args.model == 'opt':
tokenizer = AutoTokenizer.from_pretrained("facebook/opt-350m")
elif args.model == 'deberta':
tokenizer = DebertaV2Tokenizer.from_pretrained('microsoft/deberta-v3-large')
else:
raise ValueError(f'Unsupported model "{args.model}"')
max_len = args.max_len
@ -119,7 +124,7 @@ if __name__ == '__main__':
parser.add_argument('--strategy',
choices=['naive', 'ddp', 'colossalai_gemini', 'colossalai_zero2'],
default='naive')
parser.add_argument('--model', choices=['gpt2', 'bloom', 'opt'], default='bloom')
parser.add_argument('--model', choices=['gpt2', 'bloom', 'opt', 'deberta'], default='bloom')
parser.add_argument('--pretrain', type=str, default=None)
parser.add_argument('--model_path', type=str, default=None)
parser.add_argument('--need_optim_ckpt', type=bool, default=False)

View File

@ -1,7 +1,7 @@
set_n_least_used_CUDA_VISIBLE_DEVICES 1
python train_reward_model.py --pretrain '/home/lczht/data2/bloom-560m' \
--model 'bloom' \
python train_reward_model.py --pretrain 'microsoft/deberta-v3-large' \
--model 'deberta' \
--strategy naive \
--loss_fn 'log_exp'\
--save_path 'rmstatic.pt' \