From 7548ca5a54ed117f03247dcb43ec1dd962ae04e0 Mon Sep 17 00:00:00 2001 From: BlueRum <70618399+ht-zhou@users.noreply.github.com> Date: Mon, 20 Mar 2023 09:59:06 +0800 Subject: [PATCH] [chatgpt]Reward Model Training Process update (#3133) * add normalize function to value_head in bloom rm * add normalization to value_function in gpt_rm * add normalization to value_head of opt_rm * add Anthropic/hh-rlhf dataset * Update __init__.py * Add LogExpLoss in RM training * Update __init__.py * update rm trainer to use acc as target * update example/train_rm * Update train_rm.sh * code style * Update README.md * Update README.md * add rm test to ci * fix tokenier * fix typo * change batchsize to avoid oom in ci * Update test_ci.sh --- .../ChatGPT/chatgpt/dataset/__init__.py | 4 +- .../ChatGPT/chatgpt/dataset/reward_dataset.py | 65 +++++++++- .../ChatGPT/chatgpt/models/__init__.py | 4 +- .../ChatGPT/chatgpt/models/bloom/bloom_rm.py | 1 + .../ChatGPT/chatgpt/models/gpt/gpt_rm.py | 1 + applications/ChatGPT/chatgpt/models/loss.py | 14 ++- .../ChatGPT/chatgpt/models/opt/opt_rm.py | 1 + applications/ChatGPT/chatgpt/trainer/rm.py | 111 +++++++++++------- applications/ChatGPT/examples/README.md | 41 +++++-- applications/ChatGPT/examples/test_ci.sh | 20 ++++ .../ChatGPT/examples/train_reward_model.py | 93 ++++++++++----- applications/ChatGPT/examples/train_rm.sh | 26 ++-- 12 files changed, 270 insertions(+), 111 deletions(-) diff --git a/applications/ChatGPT/chatgpt/dataset/__init__.py b/applications/ChatGPT/chatgpt/dataset/__init__.py index b4599c82b..833930987 100644 --- a/applications/ChatGPT/chatgpt/dataset/__init__.py +++ b/applications/ChatGPT/chatgpt/dataset/__init__.py @@ -1,4 +1,4 @@ -from .reward_dataset import RewardDataset +from .reward_dataset import RmStaticDataset, HhRlhfDataset from .utils import is_rank_0 -__all__ = ['RewardDataset', 'is_rank_0'] +__all__ = ['RmStaticDataset', 'HhRlhfDataset','is_rank_0'] diff --git a/applications/ChatGPT/chatgpt/dataset/reward_dataset.py b/applications/ChatGPT/chatgpt/dataset/reward_dataset.py index 8bc850f2d..9ee13490b 100644 --- a/applications/ChatGPT/chatgpt/dataset/reward_dataset.py +++ b/applications/ChatGPT/chatgpt/dataset/reward_dataset.py @@ -5,8 +5,8 @@ from tqdm import tqdm from .utils import is_rank_0 - -class RewardDataset(Dataset): +# Dahaos/rm-static +class RmStaticDataset(Dataset): """ Dataset for reward model @@ -14,16 +14,21 @@ class RewardDataset(Dataset): dataset: dataset for reward model tokenizer: tokenizer for reward model max_length: max length of input + special_token: special token at the end of sentence """ - def __init__(self, dataset, tokenizer: Callable, max_length: int) -> None: + def __init__(self, dataset, tokenizer: Callable, max_length: int, special_token=None) -> None: super().__init__() self.chosen = [] self.reject = [] + if special_token is None: + self.end_token = tokenizer.eos_token + else: + self.end_token = special_token for data in tqdm(dataset, disable=not is_rank_0()): prompt = data['prompt'] - chosen = prompt + data['chosen'] + "<|endoftext|>" + chosen = prompt + data['chosen'] + self.end_token chosen_token = tokenizer(chosen, max_length=max_length, padding="max_length", @@ -34,7 +39,57 @@ class RewardDataset(Dataset): "attention_mask": chosen_token['attention_mask'] }) - reject = prompt + data['rejected'] + "<|endoftext|>" + reject = prompt + data['rejected'] + self.end_token + reject_token = tokenizer(reject, + max_length=max_length, + padding="max_length", + truncation=True, + return_tensors="pt") + self.reject.append({ + "input_ids": reject_token['input_ids'], + "attention_mask": reject_token['attention_mask'] + }) + + def __len__(self): + length = len(self.chosen) + return length + + def __getitem__(self, idx): + return self.chosen[idx]["input_ids"], self.chosen[idx]["attention_mask"], self.reject[idx][ + "input_ids"], self.reject[idx]["attention_mask"] + +# Anthropic/hh-rlhf +class HhRlhfDataset(Dataset): + """ + Dataset for reward model + + Args: + dataset: dataset for reward model + tokenizer: tokenizer for reward model + max_length: max length of input + special_token: special token at the end of sentence + """ + def __init__(self, dataset, tokenizer: Callable, max_length: int, special_token=None) -> None: + super().__init__() + self.chosen = [] + self.reject = [] + if special_token is None: + self.end_token = tokenizer.eos_token + else: + self.end_token = special_token + for data in tqdm(dataset, disable=not is_rank_0()): + chosen = data['chosen'] + self.end_token + chosen_token = tokenizer(chosen, + max_length=max_length, + padding="max_length", + truncation=True, + return_tensors="pt") + self.chosen.append({ + "input_ids": chosen_token['input_ids'], + "attention_mask": chosen_token['attention_mask'] + }) + + reject = data['rejected'] + self.end_token reject_token = tokenizer(reject, max_length=max_length, padding="max_length", diff --git a/applications/ChatGPT/chatgpt/models/__init__.py b/applications/ChatGPT/chatgpt/models/__init__.py index 376fed8de..b274188a2 100644 --- a/applications/ChatGPT/chatgpt/models/__init__.py +++ b/applications/ChatGPT/chatgpt/models/__init__.py @@ -1,4 +1,4 @@ from .base import Actor, Critic, RewardModel -from .loss import PairWiseLoss, PolicyLoss, PPOPtxActorLoss, ValueLoss +from .loss import PolicyLoss, PPOPtxActorLoss, ValueLoss, LogSigLoss, LogExpLoss -__all__ = ['Actor', 'Critic', 'RewardModel', 'PolicyLoss', 'ValueLoss', 'PPOPtxActorLoss', 'PairWiseLoss'] +__all__ = ['Actor', 'Critic', 'RewardModel', 'PolicyLoss', 'ValueLoss', 'PPOPtxActorLoss', 'LogSigLoss', 'LogExpLoss'] diff --git a/applications/ChatGPT/chatgpt/models/bloom/bloom_rm.py b/applications/ChatGPT/chatgpt/models/bloom/bloom_rm.py index 4dc2646e3..2dba227ff 100644 --- a/applications/ChatGPT/chatgpt/models/bloom/bloom_rm.py +++ b/applications/ChatGPT/chatgpt/models/bloom/bloom_rm.py @@ -33,4 +33,5 @@ class BLOOMRM(RewardModel): if checkpoint: model.gradient_checkpointing_enable() value_head = nn.Linear(model.config.hidden_size, 1) + value_head.weight.data.normal_(mean=0.0, std=1/(model.config.hidden_size + 1)) super().__init__(model, value_head, lora_rank, lora_train_bias) diff --git a/applications/ChatGPT/chatgpt/models/gpt/gpt_rm.py b/applications/ChatGPT/chatgpt/models/gpt/gpt_rm.py index 0132dbf27..19d673de6 100644 --- a/applications/ChatGPT/chatgpt/models/gpt/gpt_rm.py +++ b/applications/ChatGPT/chatgpt/models/gpt/gpt_rm.py @@ -35,4 +35,5 @@ class GPTRM(RewardModel): model.gradient_checkpointing_enable() value_head = nn.Linear(model.config.n_embd, 1) + value_head.weight.data.normal_(mean=0.0, std=1/(model.config.n_embd + 1)) super().__init__(model, value_head, lora_rank, lora_train_bias) diff --git a/applications/ChatGPT/chatgpt/models/loss.py b/applications/ChatGPT/chatgpt/models/loss.py index 0ebcfea06..c5b1ccc93 100644 --- a/applications/ChatGPT/chatgpt/models/loss.py +++ b/applications/ChatGPT/chatgpt/models/loss.py @@ -93,13 +93,23 @@ class PPOPtxActorLoss(nn.Module): return policy_loss + self.pretrain_coef * lm_loss -class PairWiseLoss(nn.Module): +class LogSigLoss(nn.Module): """ Pairwise Loss for Reward Model + Details: https://arxiv.org/abs/2203.02155 """ - def forward(self, chosen_reward: torch.Tensor, reject_reward: torch.Tensor) -> torch.Tensor: probs = torch.sigmoid(chosen_reward - reject_reward) log_probs = torch.log(probs) loss = -log_probs.mean() return loss + + +class LogExpLoss(nn.Module): + """ + Pairwise Loss for Reward Model + Details: https://arxiv.org/abs/2204.05862 + """ + def forward(self, chosen_reward: torch.Tensor, reject_reward: torch.Tensor) -> torch.Tensor: + loss = torch.log(1 + torch.exp(reject_reward - chosen_reward)).mean() + return loss diff --git a/applications/ChatGPT/chatgpt/models/opt/opt_rm.py b/applications/ChatGPT/chatgpt/models/opt/opt_rm.py index 7ad7b3887..ef7f0fb16 100644 --- a/applications/ChatGPT/chatgpt/models/opt/opt_rm.py +++ b/applications/ChatGPT/chatgpt/models/opt/opt_rm.py @@ -34,4 +34,5 @@ class OPTRM(RewardModel): model.gradient_checkpointing_enable() value_head = nn.Linear(model.config.word_embed_proj_dim, 1) + value_head.weight.data.normal_(mean=0.0, std=1/(model.config.word_embed_proj_dim + 1)) super().__init__(model, value_head, lora_rank, lora_train_bias) diff --git a/applications/ChatGPT/chatgpt/trainer/rm.py b/applications/ChatGPT/chatgpt/trainer/rm.py index c07d65f84..7fa87a649 100644 --- a/applications/ChatGPT/chatgpt/trainer/rm.py +++ b/applications/ChatGPT/chatgpt/trainer/rm.py @@ -1,13 +1,12 @@ from abc import ABC - +import pandas as pd import loralib as lora import torch -from chatgpt.dataset import RewardDataset -from chatgpt.models.loss import PairWiseLoss -from torch.optim import Adam, Optimizer -from torch.utils.data import DataLoader +from datetime import datetime +from torch.optim import Optimizer, lr_scheduler +from torch.utils.data import DataLoader, Dataset from tqdm import tqdm - + from .strategies import Strategy from .utils import is_rank_0 @@ -20,11 +19,12 @@ class RewardModelTrainer(ABC): model (torch.nn.Module): the model to train strategy (Strategy): the strategy to use for training optim(Optimizer): the optimizer to use for training - train_dataset (RewardDataset): the dataset to use for training - eval_dataset (RewardDataset): the dataset to use for evaluation + loss_fn (callable): the loss function to use for training + train_dataset (Dataset): the dataset to use for training + valid_dataset (Dataset): the dataset to use for validation + eval_dataset (Dataset): the dataset to use for evaluation batch_size (int, defaults to 1): the batch size while training max_epochs (int, defaults to 2): the number of epochs to train - optim_kwargs (dict, defaults to {'lr':1e-4}): the kwargs to use while initializing optimizer """ def __init__( @@ -32,24 +32,52 @@ class RewardModelTrainer(ABC): model, strategy: Strategy, optim: Optimizer, - train_dataset: RewardDataset, - eval_dataset: RewardDataset, + loss_fn, + train_dataset: Dataset, + valid_dataset: Dataset, + eval_dataset: Dataset, batch_size: int = 1, - max_epochs: int = 2, + max_epochs: int = 1, ) -> None: super().__init__() self.strategy = strategy self.epochs = max_epochs - self.train_dataloader = DataLoader(train_dataset, batch_size=batch_size) - self.eval_dataloader = DataLoader(eval_dataset, batch_size=batch_size) - + self.train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True) + self.valid_dataloader = DataLoader(valid_dataset, batch_size=batch_size, shuffle=True) + self.eval_dataloader = DataLoader(eval_dataset, batch_size=batch_size, shuffle=True) + self.model = strategy.setup_model(model) - if "DDP" in str(self.strategy): - self.model = self.model.module - self.loss_fn = PairWiseLoss() + self.loss_fn = loss_fn self.optimizer = strategy.setup_optimizer(optim, self.model) + self.scheduler = lr_scheduler.CosineAnnealingLR(self.optimizer, self.train_dataloader.__len__()//100) - def fit(self, use_lora): + + def eval_acc(self, dataloader): + dist = 0 + on = 0 + cnt = 0 + self.model.eval() + with torch.no_grad(): + for chosen_ids, c_mask, reject_ids, r_mask in dataloader: + chosen_ids = chosen_ids.squeeze(1).to(torch.cuda.current_device()) + c_mask = c_mask.squeeze(1).to(torch.cuda.current_device()) + reject_ids = reject_ids.squeeze(1).to(torch.cuda.current_device()) + r_mask = r_mask.squeeze(1).to(torch.cuda.current_device()) + chosen_reward = self.model(chosen_ids, attention_mask=c_mask) + reject_reward = self.model(reject_ids, attention_mask=r_mask) + for i in range(len(chosen_reward)): + cnt += 1 + if chosen_reward[i] > reject_reward[i]: + on += 1 + dist += (chosen_reward - reject_reward).mean().item() + dist_mean = dist / len(dataloader) + acc = on / cnt + self.model.train() + return dist_mean, acc + + + def fit(self): + time = datetime.now() epoch_bar = tqdm(range(self.epochs), desc='Train epoch', disable=not is_rank_0()) for epoch in range(self.epochs): step_bar = tqdm(range(self.train_dataloader.__len__()), @@ -57,37 +85,36 @@ class RewardModelTrainer(ABC): disable=not is_rank_0()) # train self.model.train() + cnt = 0 + acc = 0 + dist = 0 for chosen_ids, c_mask, reject_ids, r_mask in self.train_dataloader: - chosen_ids = chosen_ids.squeeze(1).cuda() - c_mask = c_mask.squeeze(1).cuda() - reject_ids = reject_ids.squeeze(1).cuda() - r_mask = r_mask.squeeze(1).cuda() + chosen_ids = chosen_ids.squeeze(1).to(torch.cuda.current_device()) + c_mask = c_mask.squeeze(1).to(torch.cuda.current_device()) + reject_ids = reject_ids.squeeze(1).to(torch.cuda.current_device()) + r_mask = r_mask.squeeze(1).to(torch.cuda.current_device()) chosen_reward = self.model(chosen_ids, attention_mask=c_mask) reject_reward = self.model(reject_ids, attention_mask=r_mask) loss = self.loss_fn(chosen_reward, reject_reward) self.strategy.backward(loss, self.model, self.optimizer) self.strategy.optimizer_step(self.optimizer) self.optimizer.zero_grad() + cnt += 1 + if cnt == 100: + self.scheduler.step() + dist, acc = self.eval_acc(self.valid_dataloader) + cnt = 0 + if is_rank_0(): + log = pd.DataFrame([[step_bar.n, loss.item(), dist, acc]], columns=['step', 'loss', 'dist', 'acc']) + log.to_csv('log_%s.csv' % time, mode='a', header=False, index=False) step_bar.update() - step_bar.set_postfix({'loss': loss.item()}) - + step_bar.set_postfix({'dist': dist, 'acc': acc}) + # eval - self.model.eval() - with torch.no_grad(): - dist = 0 - loss_sum = 0 - for chosen_ids, c_mask, reject_ids, r_mask in self.eval_dataloader: - chosen_ids = chosen_ids.squeeze(1).cuda() - c_mask = c_mask.squeeze(1).cuda() - reject_ids = reject_ids.squeeze(1).cuda() - r_mask = r_mask.squeeze(1).cuda() - chosen_reward = self.model(chosen_ids, attention_mask=c_mask) - reject_reward = self.model(reject_ids, attention_mask=r_mask) - dist += (chosen_reward - reject_reward).mean().item() - loss = self.loss_fn(chosen_reward, reject_reward) - loss_sum += loss.item() - dist_mean = dist / self.eval_dataloader.__len__() - loss_mean = loss_sum / self.eval_dataloader.__len__() + dist, acc = self.eval_acc(self.eval_dataloader) + if is_rank_0(): + log = pd.DataFrame([[step_bar.n, loss.item(), dist, acc]], columns=['step', 'loss', 'dist', 'acc']) + log.to_csv('log.csv', mode='a', header=False, index=False) epoch_bar.update() - step_bar.set_postfix({'loss': loss_mean, 'dist_mean': dist_mean}) + step_bar.set_postfix({'dist': dist, 'acc': acc}) step_bar.close() diff --git a/applications/ChatGPT/examples/README.md b/applications/ChatGPT/examples/README.md index 3876d20f0..ce73a5407 100644 --- a/applications/ChatGPT/examples/README.md +++ b/applications/ChatGPT/examples/README.md @@ -7,26 +7,42 @@ pip install -r requirements.txt ``` ## Train the reward model (Stage 2) -We use [rm-static](https://huggingface.co/datasets/Dahoas/rm-static) as dataset to train our reward model. It is a dataset of chosen & rejected response of the same prompt. - -You can download the dataset from huggingface automatically. - Use these code to train your reward model. - ```shell -# Naive reward model training -python train_reward_model.py --pretrain --model --strategy naive +# Take naive reward model training with opt-350m as example +python train_reward_model.py --pretrain "facebook/opt-350m" --model 'opt' --strategy naive # use colossalai_zero2 -torchrun --standalone --nproc_per_node=2 train_reward_model.py --pretrain --model --strategy colossalai_zero2 +torchrun --standalone --nproc_per_node=2 train_reward_model.py --pretrain "facebook/opt-350m" --model 'opt' --strategy colossalai_zero2 ``` +### Features and tricks in RM training +- We support [Anthropic/hh-rlhf](https://huggingface.co/datasets/Anthropic/hh-rlhf)and[rm-static](https://huggingface.co/datasets/Dahoas/rm-static) datasets. +- We support 2 kinds of loss_function named 'log_sig'(used by OpenAI) and 'log_exp'(used by Anthropic). +- We change the loss to valid_acc and pair_dist to monitor progress during training. +- We add special token to the end of the sequence to get better result. +- We use cosine-reducing lr-scheduler for RM training. +- We set value_head as 1 liner layer and initialize the weight of value_head using N(0,1/(d_model + 1)) distribution. +- We train a Bloom-560m reward model for 1 epoch and find the test acc of the model achieve the performance mentions in [Anthropics paper](https://arxiv.org/abs/2112.00861). + +### Experiment result +Model performance in [Anthropics paper](https://arxiv.org/abs/2112.00861): + +
image + +
Our training & test result of bloom-560m for 1 epoch: + +
image + +
+ ## Train with dummy prompt data (Stage 3) -This script supports 3 strategies: +This script supports 4 kinds of strategies: - naive - ddp -- colossalai +- colossalai_zero2 +- colossalai_gemini It uses random generated prompt data. @@ -53,7 +69,7 @@ We use [awesome-chatgpt-prompts](https://huggingface.co/datasets/fka/awesome-cha You should download `prompts.csv` first. -This script also supports 3 strategies. +This script also supports 4 strategies. ```shell # display cli help @@ -75,6 +91,9 @@ python inference.py --model_path --model