mirror of https://github.com/hpcaitech/ColossalAI
[chatgpt]Reward Model Training Process update (#3133)
* add normalize function to value_head in bloom rm * add normalization to value_function in gpt_rm * add normalization to value_head of opt_rm * add Anthropic/hh-rlhf dataset * Update __init__.py * Add LogExpLoss in RM training * Update __init__.py * update rm trainer to use acc as target * update example/train_rm * Update train_rm.sh * code style * Update README.md * Update README.md * add rm test to ci * fix tokenier * fix typo * change batchsize to avoid oom in ci * Update test_ci.shpull/3159/head
parent
1e58d31bb7
commit
7548ca5a54
|
@ -1,4 +1,4 @@
|
||||||
from .reward_dataset import RewardDataset
|
from .reward_dataset import RmStaticDataset, HhRlhfDataset
|
||||||
from .utils import is_rank_0
|
from .utils import is_rank_0
|
||||||
|
|
||||||
__all__ = ['RewardDataset', 'is_rank_0']
|
__all__ = ['RmStaticDataset', 'HhRlhfDataset','is_rank_0']
|
||||||
|
|
|
@ -5,8 +5,8 @@ from tqdm import tqdm
|
||||||
|
|
||||||
from .utils import is_rank_0
|
from .utils import is_rank_0
|
||||||
|
|
||||||
|
# Dahaos/rm-static
|
||||||
class RewardDataset(Dataset):
|
class RmStaticDataset(Dataset):
|
||||||
"""
|
"""
|
||||||
Dataset for reward model
|
Dataset for reward model
|
||||||
|
|
||||||
|
@ -14,16 +14,21 @@ class RewardDataset(Dataset):
|
||||||
dataset: dataset for reward model
|
dataset: dataset for reward model
|
||||||
tokenizer: tokenizer for reward model
|
tokenizer: tokenizer for reward model
|
||||||
max_length: max length of input
|
max_length: max length of input
|
||||||
|
special_token: special token at the end of sentence
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, dataset, tokenizer: Callable, max_length: int) -> None:
|
def __init__(self, dataset, tokenizer: Callable, max_length: int, special_token=None) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.chosen = []
|
self.chosen = []
|
||||||
self.reject = []
|
self.reject = []
|
||||||
|
if special_token is None:
|
||||||
|
self.end_token = tokenizer.eos_token
|
||||||
|
else:
|
||||||
|
self.end_token = special_token
|
||||||
for data in tqdm(dataset, disable=not is_rank_0()):
|
for data in tqdm(dataset, disable=not is_rank_0()):
|
||||||
prompt = data['prompt']
|
prompt = data['prompt']
|
||||||
|
|
||||||
chosen = prompt + data['chosen'] + "<|endoftext|>"
|
chosen = prompt + data['chosen'] + self.end_token
|
||||||
chosen_token = tokenizer(chosen,
|
chosen_token = tokenizer(chosen,
|
||||||
max_length=max_length,
|
max_length=max_length,
|
||||||
padding="max_length",
|
padding="max_length",
|
||||||
|
@ -34,7 +39,57 @@ class RewardDataset(Dataset):
|
||||||
"attention_mask": chosen_token['attention_mask']
|
"attention_mask": chosen_token['attention_mask']
|
||||||
})
|
})
|
||||||
|
|
||||||
reject = prompt + data['rejected'] + "<|endoftext|>"
|
reject = prompt + data['rejected'] + self.end_token
|
||||||
|
reject_token = tokenizer(reject,
|
||||||
|
max_length=max_length,
|
||||||
|
padding="max_length",
|
||||||
|
truncation=True,
|
||||||
|
return_tensors="pt")
|
||||||
|
self.reject.append({
|
||||||
|
"input_ids": reject_token['input_ids'],
|
||||||
|
"attention_mask": reject_token['attention_mask']
|
||||||
|
})
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
length = len(self.chosen)
|
||||||
|
return length
|
||||||
|
|
||||||
|
def __getitem__(self, idx):
|
||||||
|
return self.chosen[idx]["input_ids"], self.chosen[idx]["attention_mask"], self.reject[idx][
|
||||||
|
"input_ids"], self.reject[idx]["attention_mask"]
|
||||||
|
|
||||||
|
# Anthropic/hh-rlhf
|
||||||
|
class HhRlhfDataset(Dataset):
|
||||||
|
"""
|
||||||
|
Dataset for reward model
|
||||||
|
|
||||||
|
Args:
|
||||||
|
dataset: dataset for reward model
|
||||||
|
tokenizer: tokenizer for reward model
|
||||||
|
max_length: max length of input
|
||||||
|
special_token: special token at the end of sentence
|
||||||
|
"""
|
||||||
|
def __init__(self, dataset, tokenizer: Callable, max_length: int, special_token=None) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self.chosen = []
|
||||||
|
self.reject = []
|
||||||
|
if special_token is None:
|
||||||
|
self.end_token = tokenizer.eos_token
|
||||||
|
else:
|
||||||
|
self.end_token = special_token
|
||||||
|
for data in tqdm(dataset, disable=not is_rank_0()):
|
||||||
|
chosen = data['chosen'] + self.end_token
|
||||||
|
chosen_token = tokenizer(chosen,
|
||||||
|
max_length=max_length,
|
||||||
|
padding="max_length",
|
||||||
|
truncation=True,
|
||||||
|
return_tensors="pt")
|
||||||
|
self.chosen.append({
|
||||||
|
"input_ids": chosen_token['input_ids'],
|
||||||
|
"attention_mask": chosen_token['attention_mask']
|
||||||
|
})
|
||||||
|
|
||||||
|
reject = data['rejected'] + self.end_token
|
||||||
reject_token = tokenizer(reject,
|
reject_token = tokenizer(reject,
|
||||||
max_length=max_length,
|
max_length=max_length,
|
||||||
padding="max_length",
|
padding="max_length",
|
||||||
|
|
|
@ -1,4 +1,4 @@
|
||||||
from .base import Actor, Critic, RewardModel
|
from .base import Actor, Critic, RewardModel
|
||||||
from .loss import PairWiseLoss, PolicyLoss, PPOPtxActorLoss, ValueLoss
|
from .loss import PolicyLoss, PPOPtxActorLoss, ValueLoss, LogSigLoss, LogExpLoss
|
||||||
|
|
||||||
__all__ = ['Actor', 'Critic', 'RewardModel', 'PolicyLoss', 'ValueLoss', 'PPOPtxActorLoss', 'PairWiseLoss']
|
__all__ = ['Actor', 'Critic', 'RewardModel', 'PolicyLoss', 'ValueLoss', 'PPOPtxActorLoss', 'LogSigLoss', 'LogExpLoss']
|
||||||
|
|
|
@ -33,4 +33,5 @@ class BLOOMRM(RewardModel):
|
||||||
if checkpoint:
|
if checkpoint:
|
||||||
model.gradient_checkpointing_enable()
|
model.gradient_checkpointing_enable()
|
||||||
value_head = nn.Linear(model.config.hidden_size, 1)
|
value_head = nn.Linear(model.config.hidden_size, 1)
|
||||||
|
value_head.weight.data.normal_(mean=0.0, std=1/(model.config.hidden_size + 1))
|
||||||
super().__init__(model, value_head, lora_rank, lora_train_bias)
|
super().__init__(model, value_head, lora_rank, lora_train_bias)
|
||||||
|
|
|
@ -35,4 +35,5 @@ class GPTRM(RewardModel):
|
||||||
model.gradient_checkpointing_enable()
|
model.gradient_checkpointing_enable()
|
||||||
|
|
||||||
value_head = nn.Linear(model.config.n_embd, 1)
|
value_head = nn.Linear(model.config.n_embd, 1)
|
||||||
|
value_head.weight.data.normal_(mean=0.0, std=1/(model.config.n_embd + 1))
|
||||||
super().__init__(model, value_head, lora_rank, lora_train_bias)
|
super().__init__(model, value_head, lora_rank, lora_train_bias)
|
||||||
|
|
|
@ -93,13 +93,23 @@ class PPOPtxActorLoss(nn.Module):
|
||||||
return policy_loss + self.pretrain_coef * lm_loss
|
return policy_loss + self.pretrain_coef * lm_loss
|
||||||
|
|
||||||
|
|
||||||
class PairWiseLoss(nn.Module):
|
class LogSigLoss(nn.Module):
|
||||||
"""
|
"""
|
||||||
Pairwise Loss for Reward Model
|
Pairwise Loss for Reward Model
|
||||||
|
Details: https://arxiv.org/abs/2203.02155
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def forward(self, chosen_reward: torch.Tensor, reject_reward: torch.Tensor) -> torch.Tensor:
|
def forward(self, chosen_reward: torch.Tensor, reject_reward: torch.Tensor) -> torch.Tensor:
|
||||||
probs = torch.sigmoid(chosen_reward - reject_reward)
|
probs = torch.sigmoid(chosen_reward - reject_reward)
|
||||||
log_probs = torch.log(probs)
|
log_probs = torch.log(probs)
|
||||||
loss = -log_probs.mean()
|
loss = -log_probs.mean()
|
||||||
return loss
|
return loss
|
||||||
|
|
||||||
|
|
||||||
|
class LogExpLoss(nn.Module):
|
||||||
|
"""
|
||||||
|
Pairwise Loss for Reward Model
|
||||||
|
Details: https://arxiv.org/abs/2204.05862
|
||||||
|
"""
|
||||||
|
def forward(self, chosen_reward: torch.Tensor, reject_reward: torch.Tensor) -> torch.Tensor:
|
||||||
|
loss = torch.log(1 + torch.exp(reject_reward - chosen_reward)).mean()
|
||||||
|
return loss
|
||||||
|
|
|
@ -34,4 +34,5 @@ class OPTRM(RewardModel):
|
||||||
model.gradient_checkpointing_enable()
|
model.gradient_checkpointing_enable()
|
||||||
|
|
||||||
value_head = nn.Linear(model.config.word_embed_proj_dim, 1)
|
value_head = nn.Linear(model.config.word_embed_proj_dim, 1)
|
||||||
|
value_head.weight.data.normal_(mean=0.0, std=1/(model.config.word_embed_proj_dim + 1))
|
||||||
super().__init__(model, value_head, lora_rank, lora_train_bias)
|
super().__init__(model, value_head, lora_rank, lora_train_bias)
|
||||||
|
|
|
@ -1,13 +1,12 @@
|
||||||
from abc import ABC
|
from abc import ABC
|
||||||
|
import pandas as pd
|
||||||
import loralib as lora
|
import loralib as lora
|
||||||
import torch
|
import torch
|
||||||
from chatgpt.dataset import RewardDataset
|
from datetime import datetime
|
||||||
from chatgpt.models.loss import PairWiseLoss
|
from torch.optim import Optimizer, lr_scheduler
|
||||||
from torch.optim import Adam, Optimizer
|
from torch.utils.data import DataLoader, Dataset
|
||||||
from torch.utils.data import DataLoader
|
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
from .strategies import Strategy
|
from .strategies import Strategy
|
||||||
from .utils import is_rank_0
|
from .utils import is_rank_0
|
||||||
|
|
||||||
|
@ -20,11 +19,12 @@ class RewardModelTrainer(ABC):
|
||||||
model (torch.nn.Module): the model to train
|
model (torch.nn.Module): the model to train
|
||||||
strategy (Strategy): the strategy to use for training
|
strategy (Strategy): the strategy to use for training
|
||||||
optim(Optimizer): the optimizer to use for training
|
optim(Optimizer): the optimizer to use for training
|
||||||
train_dataset (RewardDataset): the dataset to use for training
|
loss_fn (callable): the loss function to use for training
|
||||||
eval_dataset (RewardDataset): the dataset to use for evaluation
|
train_dataset (Dataset): the dataset to use for training
|
||||||
|
valid_dataset (Dataset): the dataset to use for validation
|
||||||
|
eval_dataset (Dataset): the dataset to use for evaluation
|
||||||
batch_size (int, defaults to 1): the batch size while training
|
batch_size (int, defaults to 1): the batch size while training
|
||||||
max_epochs (int, defaults to 2): the number of epochs to train
|
max_epochs (int, defaults to 2): the number of epochs to train
|
||||||
optim_kwargs (dict, defaults to {'lr':1e-4}): the kwargs to use while initializing optimizer
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
|
@ -32,24 +32,52 @@ class RewardModelTrainer(ABC):
|
||||||
model,
|
model,
|
||||||
strategy: Strategy,
|
strategy: Strategy,
|
||||||
optim: Optimizer,
|
optim: Optimizer,
|
||||||
train_dataset: RewardDataset,
|
loss_fn,
|
||||||
eval_dataset: RewardDataset,
|
train_dataset: Dataset,
|
||||||
|
valid_dataset: Dataset,
|
||||||
|
eval_dataset: Dataset,
|
||||||
batch_size: int = 1,
|
batch_size: int = 1,
|
||||||
max_epochs: int = 2,
|
max_epochs: int = 1,
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.strategy = strategy
|
self.strategy = strategy
|
||||||
self.epochs = max_epochs
|
self.epochs = max_epochs
|
||||||
self.train_dataloader = DataLoader(train_dataset, batch_size=batch_size)
|
self.train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
|
||||||
self.eval_dataloader = DataLoader(eval_dataset, batch_size=batch_size)
|
self.valid_dataloader = DataLoader(valid_dataset, batch_size=batch_size, shuffle=True)
|
||||||
|
self.eval_dataloader = DataLoader(eval_dataset, batch_size=batch_size, shuffle=True)
|
||||||
|
|
||||||
self.model = strategy.setup_model(model)
|
self.model = strategy.setup_model(model)
|
||||||
if "DDP" in str(self.strategy):
|
self.loss_fn = loss_fn
|
||||||
self.model = self.model.module
|
|
||||||
self.loss_fn = PairWiseLoss()
|
|
||||||
self.optimizer = strategy.setup_optimizer(optim, self.model)
|
self.optimizer = strategy.setup_optimizer(optim, self.model)
|
||||||
|
self.scheduler = lr_scheduler.CosineAnnealingLR(self.optimizer, self.train_dataloader.__len__()//100)
|
||||||
|
|
||||||
def fit(self, use_lora):
|
|
||||||
|
def eval_acc(self, dataloader):
|
||||||
|
dist = 0
|
||||||
|
on = 0
|
||||||
|
cnt = 0
|
||||||
|
self.model.eval()
|
||||||
|
with torch.no_grad():
|
||||||
|
for chosen_ids, c_mask, reject_ids, r_mask in dataloader:
|
||||||
|
chosen_ids = chosen_ids.squeeze(1).to(torch.cuda.current_device())
|
||||||
|
c_mask = c_mask.squeeze(1).to(torch.cuda.current_device())
|
||||||
|
reject_ids = reject_ids.squeeze(1).to(torch.cuda.current_device())
|
||||||
|
r_mask = r_mask.squeeze(1).to(torch.cuda.current_device())
|
||||||
|
chosen_reward = self.model(chosen_ids, attention_mask=c_mask)
|
||||||
|
reject_reward = self.model(reject_ids, attention_mask=r_mask)
|
||||||
|
for i in range(len(chosen_reward)):
|
||||||
|
cnt += 1
|
||||||
|
if chosen_reward[i] > reject_reward[i]:
|
||||||
|
on += 1
|
||||||
|
dist += (chosen_reward - reject_reward).mean().item()
|
||||||
|
dist_mean = dist / len(dataloader)
|
||||||
|
acc = on / cnt
|
||||||
|
self.model.train()
|
||||||
|
return dist_mean, acc
|
||||||
|
|
||||||
|
|
||||||
|
def fit(self):
|
||||||
|
time = datetime.now()
|
||||||
epoch_bar = tqdm(range(self.epochs), desc='Train epoch', disable=not is_rank_0())
|
epoch_bar = tqdm(range(self.epochs), desc='Train epoch', disable=not is_rank_0())
|
||||||
for epoch in range(self.epochs):
|
for epoch in range(self.epochs):
|
||||||
step_bar = tqdm(range(self.train_dataloader.__len__()),
|
step_bar = tqdm(range(self.train_dataloader.__len__()),
|
||||||
|
@ -57,37 +85,36 @@ class RewardModelTrainer(ABC):
|
||||||
disable=not is_rank_0())
|
disable=not is_rank_0())
|
||||||
# train
|
# train
|
||||||
self.model.train()
|
self.model.train()
|
||||||
|
cnt = 0
|
||||||
|
acc = 0
|
||||||
|
dist = 0
|
||||||
for chosen_ids, c_mask, reject_ids, r_mask in self.train_dataloader:
|
for chosen_ids, c_mask, reject_ids, r_mask in self.train_dataloader:
|
||||||
chosen_ids = chosen_ids.squeeze(1).cuda()
|
chosen_ids = chosen_ids.squeeze(1).to(torch.cuda.current_device())
|
||||||
c_mask = c_mask.squeeze(1).cuda()
|
c_mask = c_mask.squeeze(1).to(torch.cuda.current_device())
|
||||||
reject_ids = reject_ids.squeeze(1).cuda()
|
reject_ids = reject_ids.squeeze(1).to(torch.cuda.current_device())
|
||||||
r_mask = r_mask.squeeze(1).cuda()
|
r_mask = r_mask.squeeze(1).to(torch.cuda.current_device())
|
||||||
chosen_reward = self.model(chosen_ids, attention_mask=c_mask)
|
chosen_reward = self.model(chosen_ids, attention_mask=c_mask)
|
||||||
reject_reward = self.model(reject_ids, attention_mask=r_mask)
|
reject_reward = self.model(reject_ids, attention_mask=r_mask)
|
||||||
loss = self.loss_fn(chosen_reward, reject_reward)
|
loss = self.loss_fn(chosen_reward, reject_reward)
|
||||||
self.strategy.backward(loss, self.model, self.optimizer)
|
self.strategy.backward(loss, self.model, self.optimizer)
|
||||||
self.strategy.optimizer_step(self.optimizer)
|
self.strategy.optimizer_step(self.optimizer)
|
||||||
self.optimizer.zero_grad()
|
self.optimizer.zero_grad()
|
||||||
|
cnt += 1
|
||||||
|
if cnt == 100:
|
||||||
|
self.scheduler.step()
|
||||||
|
dist, acc = self.eval_acc(self.valid_dataloader)
|
||||||
|
cnt = 0
|
||||||
|
if is_rank_0():
|
||||||
|
log = pd.DataFrame([[step_bar.n, loss.item(), dist, acc]], columns=['step', 'loss', 'dist', 'acc'])
|
||||||
|
log.to_csv('log_%s.csv' % time, mode='a', header=False, index=False)
|
||||||
step_bar.update()
|
step_bar.update()
|
||||||
step_bar.set_postfix({'loss': loss.item()})
|
step_bar.set_postfix({'dist': dist, 'acc': acc})
|
||||||
|
|
||||||
# eval
|
# eval
|
||||||
self.model.eval()
|
dist, acc = self.eval_acc(self.eval_dataloader)
|
||||||
with torch.no_grad():
|
if is_rank_0():
|
||||||
dist = 0
|
log = pd.DataFrame([[step_bar.n, loss.item(), dist, acc]], columns=['step', 'loss', 'dist', 'acc'])
|
||||||
loss_sum = 0
|
log.to_csv('log.csv', mode='a', header=False, index=False)
|
||||||
for chosen_ids, c_mask, reject_ids, r_mask in self.eval_dataloader:
|
|
||||||
chosen_ids = chosen_ids.squeeze(1).cuda()
|
|
||||||
c_mask = c_mask.squeeze(1).cuda()
|
|
||||||
reject_ids = reject_ids.squeeze(1).cuda()
|
|
||||||
r_mask = r_mask.squeeze(1).cuda()
|
|
||||||
chosen_reward = self.model(chosen_ids, attention_mask=c_mask)
|
|
||||||
reject_reward = self.model(reject_ids, attention_mask=r_mask)
|
|
||||||
dist += (chosen_reward - reject_reward).mean().item()
|
|
||||||
loss = self.loss_fn(chosen_reward, reject_reward)
|
|
||||||
loss_sum += loss.item()
|
|
||||||
dist_mean = dist / self.eval_dataloader.__len__()
|
|
||||||
loss_mean = loss_sum / self.eval_dataloader.__len__()
|
|
||||||
epoch_bar.update()
|
epoch_bar.update()
|
||||||
step_bar.set_postfix({'loss': loss_mean, 'dist_mean': dist_mean})
|
step_bar.set_postfix({'dist': dist, 'acc': acc})
|
||||||
step_bar.close()
|
step_bar.close()
|
||||||
|
|
|
@ -7,26 +7,42 @@ pip install -r requirements.txt
|
||||||
```
|
```
|
||||||
|
|
||||||
## Train the reward model (Stage 2)
|
## Train the reward model (Stage 2)
|
||||||
We use [rm-static](https://huggingface.co/datasets/Dahoas/rm-static) as dataset to train our reward model. It is a dataset of chosen & rejected response of the same prompt.
|
|
||||||
|
|
||||||
You can download the dataset from huggingface automatically.
|
|
||||||
|
|
||||||
Use these code to train your reward model.
|
Use these code to train your reward model.
|
||||||
|
|
||||||
```shell
|
```shell
|
||||||
# Naive reward model training
|
# Take naive reward model training with opt-350m as example
|
||||||
python train_reward_model.py --pretrain <your model path> --model <your model type> --strategy naive
|
python train_reward_model.py --pretrain "facebook/opt-350m" --model 'opt' --strategy naive
|
||||||
# use colossalai_zero2
|
# use colossalai_zero2
|
||||||
torchrun --standalone --nproc_per_node=2 train_reward_model.py --pretrain <your model path> --model <your model type> --strategy colossalai_zero2
|
torchrun --standalone --nproc_per_node=2 train_reward_model.py --pretrain "facebook/opt-350m" --model 'opt' --strategy colossalai_zero2
|
||||||
```
|
```
|
||||||
|
|
||||||
|
### Features and tricks in RM training
|
||||||
|
- We support [Anthropic/hh-rlhf](https://huggingface.co/datasets/Anthropic/hh-rlhf)and[rm-static](https://huggingface.co/datasets/Dahoas/rm-static) datasets.
|
||||||
|
- We support 2 kinds of loss_function named 'log_sig'(used by OpenAI) and 'log_exp'(used by Anthropic).
|
||||||
|
- We change the loss to valid_acc and pair_dist to monitor progress during training.
|
||||||
|
- We add special token to the end of the sequence to get better result.
|
||||||
|
- We use cosine-reducing lr-scheduler for RM training.
|
||||||
|
- We set value_head as 1 liner layer and initialize the weight of value_head using N(0,1/(d_model + 1)) distribution.
|
||||||
|
- We train a Bloom-560m reward model for 1 epoch and find the test acc of the model achieve the performance mentions in [Anthropics paper](https://arxiv.org/abs/2112.00861).
|
||||||
|
|
||||||
|
### Experiment result
|
||||||
|
Model performance in [Anthropics paper](https://arxiv.org/abs/2112.00861):
|
||||||
|
|
||||||
|
<div align=center> <img width="512" alt="image" src="https://user-images.githubusercontent.com/70618399/225263321-8d64c3a8-6877-4cc8-9b61-0e1c52d3d94f.png">
|
||||||
|
|
||||||
|
<div align=left>Our training & test result of bloom-560m for 1 epoch:
|
||||||
|
|
||||||
|
<div align=center> <img width="512" alt="image" src="https://user-images.githubusercontent.com/70618399/225262950-a7f0a686-25de-44ec-98f2-11b83ea86674.png">
|
||||||
|
|
||||||
|
<div align=left>
|
||||||
|
|
||||||
## Train with dummy prompt data (Stage 3)
|
## Train with dummy prompt data (Stage 3)
|
||||||
|
|
||||||
This script supports 3 strategies:
|
This script supports 4 kinds of strategies:
|
||||||
|
|
||||||
- naive
|
- naive
|
||||||
- ddp
|
- ddp
|
||||||
- colossalai
|
- colossalai_zero2
|
||||||
|
- colossalai_gemini
|
||||||
|
|
||||||
It uses random generated prompt data.
|
It uses random generated prompt data.
|
||||||
|
|
||||||
|
@ -53,7 +69,7 @@ We use [awesome-chatgpt-prompts](https://huggingface.co/datasets/fka/awesome-cha
|
||||||
|
|
||||||
You should download `prompts.csv` first.
|
You should download `prompts.csv` first.
|
||||||
|
|
||||||
This script also supports 3 strategies.
|
This script also supports 4 strategies.
|
||||||
|
|
||||||
```shell
|
```shell
|
||||||
# display cli help
|
# display cli help
|
||||||
|
@ -75,6 +91,9 @@ python inference.py --model_path <your actor model path> --model <your model typ
|
||||||
python inference.py --model_path ./actor_checkpoint_prompts.pt --pretrain bigscience/bloom-560m --model bloom
|
python inference.py --model_path ./actor_checkpoint_prompts.pt --pretrain bigscience/bloom-560m --model bloom
|
||||||
```
|
```
|
||||||
|
|
||||||
|
## Attention
|
||||||
|
The examples is just a demo for testing our progress of RM and PPO training.
|
||||||
|
|
||||||
|
|
||||||
#### data
|
#### data
|
||||||
- [x] [rm-static](https://huggingface.co/datasets/Dahoas/rm-static)
|
- [x] [rm-static](https://huggingface.co/datasets/Dahoas/rm-static)
|
||||||
|
|
|
@ -69,3 +69,23 @@ torchrun --standalone --nproc_per_node=2 ${BASE}/train_prompts.py $PROMPT_PATH \
|
||||||
python ${BASE}/inference.py --model_path ${BASE}/actor_checkpoint_prompts.pt --pretrain 'gpt2' --model gpt2
|
python ${BASE}/inference.py --model_path ${BASE}/actor_checkpoint_prompts.pt --pretrain 'gpt2' --model gpt2
|
||||||
|
|
||||||
rm -rf ${BASE}/actor_checkpoint_prompts.pt
|
rm -rf ${BASE}/actor_checkpoint_prompts.pt
|
||||||
|
|
||||||
|
# train rm
|
||||||
|
torchrun --standalone --nproc_per_node=2 ${BASE}/train_reward_model.py \
|
||||||
|
--pretrain 'facebook/opt-350m' --model 'opt' \
|
||||||
|
--strategy colossalai_zero2 --loss_fn 'log_sig'\
|
||||||
|
--dataset 'Anthropic/hh-rlhf' --subset 'harmless-base'\
|
||||||
|
--test True --lora_rank 4
|
||||||
|
|
||||||
|
torchrun --standalone --nproc_per_node=2 ${BASE}/train_reward_model.py \
|
||||||
|
--pretrain 'gpt2' --model 'gpt2' \
|
||||||
|
--strategy colossalai_gemini --loss_fn 'log_exp'\
|
||||||
|
--dataset 'Dahoas/rm-static' --test True --lora_rank 4
|
||||||
|
|
||||||
|
torchrun --standalone --nproc_per_node=2 ${BASE}/train_reward_model.py \
|
||||||
|
--pretrain 'bigscience/bloom-560m' --model 'bloom' \
|
||||||
|
--strategy colossalai_zero2 --loss_fn 'log_sig'\
|
||||||
|
--dataset 'Anthropic/hh-rlhf' --subset 'harmless-base'\
|
||||||
|
--test True --lora_rank 4
|
||||||
|
|
||||||
|
rm -rf ${BASE}/rm_ckpt.pt
|
||||||
|
|
|
@ -2,7 +2,8 @@ import argparse
|
||||||
|
|
||||||
import loralib as lora
|
import loralib as lora
|
||||||
import torch
|
import torch
|
||||||
from chatgpt.dataset import RewardDataset
|
from chatgpt.dataset import HhRlhfDataset, RmStaticDataset
|
||||||
|
from chatgpt.models import LogSigLoss, LogExpLoss
|
||||||
from chatgpt.models.base import RewardModel
|
from chatgpt.models.base import RewardModel
|
||||||
from chatgpt.models.bloom import BLOOMRM
|
from chatgpt.models.bloom import BLOOMRM
|
||||||
from chatgpt.models.gpt import GPTRM
|
from chatgpt.models.gpt import GPTRM
|
||||||
|
@ -10,13 +11,13 @@ from chatgpt.models.opt import OPTRM
|
||||||
from chatgpt.trainer import RewardModelTrainer
|
from chatgpt.trainer import RewardModelTrainer
|
||||||
from chatgpt.trainer.strategies import ColossalAIStrategy, DDPStrategy, NaiveStrategy
|
from chatgpt.trainer.strategies import ColossalAIStrategy, DDPStrategy, NaiveStrategy
|
||||||
from datasets import load_dataset
|
from datasets import load_dataset
|
||||||
|
from random import randint
|
||||||
from torch.optim import Adam
|
from torch.optim import Adam
|
||||||
from transformers import AutoTokenizer, BloomTokenizerFast
|
from transformers import AutoTokenizer, BloomTokenizerFast
|
||||||
from transformers.models.gpt2.tokenization_gpt2 import GPT2Tokenizer
|
from transformers.models.gpt2.tokenization_gpt2 import GPT2Tokenizer
|
||||||
|
|
||||||
from colossalai.nn.optimizer import HybridAdam
|
from colossalai.nn.optimizer import HybridAdam
|
||||||
|
|
||||||
|
|
||||||
def train(args):
|
def train(args):
|
||||||
# configure strategy
|
# configure strategy
|
||||||
if args.strategy == 'naive':
|
if args.strategy == 'naive':
|
||||||
|
@ -33,57 +34,85 @@ def train(args):
|
||||||
# configure model
|
# configure model
|
||||||
with strategy.model_init_context():
|
with strategy.model_init_context():
|
||||||
if args.model == 'bloom':
|
if args.model == 'bloom':
|
||||||
model = BLOOMRM(pretrained=args.pretrain, lora_rank=args.lora_rank).cuda()
|
model = BLOOMRM(pretrained=args.pretrain, lora_rank=args.lora_rank).to(torch.cuda.current_device())
|
||||||
elif args.model == 'opt':
|
elif args.model == 'opt':
|
||||||
model = OPTRM(pretrained=args.pretrain, lora_rank=args.lora_rank).cuda()
|
model = OPTRM(pretrained=args.pretrain, lora_rank=args.lora_rank).to(torch.cuda.current_device())
|
||||||
elif args.model == 'gpt2':
|
elif args.model == 'gpt2':
|
||||||
model = GPTRM(pretrained=args.pretrain, lora_rank=args.lora_rank).cuda()
|
model = GPTRM(pretrained=args.pretrain, lora_rank=args.lora_rank).to(torch.cuda.current_device())
|
||||||
else:
|
else:
|
||||||
raise ValueError(f'Unsupported model "{args.model}"')
|
raise ValueError(f'Unsupported model "{args.model}"')
|
||||||
|
|
||||||
|
if args.model_path is not None:
|
||||||
|
state_dict = torch.load(args.model_path)
|
||||||
|
model.load_state_dict(state_dict)
|
||||||
|
|
||||||
# configure tokenizer
|
# configure tokenizer
|
||||||
if args.model == 'gpt2':
|
if args.model == 'gpt2':
|
||||||
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
|
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
|
||||||
tokenizer.pad_token = tokenizer.eos_token
|
tokenizer.pad_token = tokenizer.eos_token
|
||||||
elif args.model == 'bloom':
|
elif args.model == 'bloom':
|
||||||
tokenizer = BloomTokenizerFast.from_pretrained(args.pretrain)
|
tokenizer = BloomTokenizerFast.from_pretrained('bigscience/bloom-560m')
|
||||||
tokenizer.pad_token = tokenizer.eos_token
|
|
||||||
elif args.model == 'opt':
|
elif args.model == 'opt':
|
||||||
tokenizer = AutoTokenizer.from_pretrained("facebook/opt-350m")
|
tokenizer = AutoTokenizer.from_pretrained("facebook/opt-350m")
|
||||||
else:
|
else:
|
||||||
raise ValueError(f'Unsupported model "{args.model}"')
|
raise ValueError(f'Unsupported model "{args.model}"')
|
||||||
tokenizer.pad_token = tokenizer.eos_token
|
max_len = args.max_len
|
||||||
|
|
||||||
max_len = 512
|
|
||||||
|
|
||||||
# configure optimizer
|
# configure optimizer
|
||||||
if args.strategy.startswith('colossalai'):
|
if args.strategy.startswith('colossalai'):
|
||||||
optim = HybridAdam(model.parameters(), lr=5e-5)
|
optim = HybridAdam(model.parameters(), lr=1.5e-5)
|
||||||
else:
|
else:
|
||||||
optim = Adam(model.parameters(), lr=5e-5)
|
optim = Adam(model.parameters(), lr=1.5e-5)
|
||||||
|
|
||||||
|
# configure loss function
|
||||||
|
if args.loss_fn == 'log_sig':
|
||||||
|
loss_fn = LogSigLoss()
|
||||||
|
elif args.loss_fn == 'log_exp':
|
||||||
|
loss_fn = LogExpLoss()
|
||||||
|
else:
|
||||||
|
raise ValueError(f'Unsupported loss function "{args.loss_fn}"')
|
||||||
|
|
||||||
# prepare for data and dataset
|
# prepare for data and dataset
|
||||||
data = load_dataset(args.dataset)
|
if args.subset is not None:
|
||||||
train_data = data["train"]
|
data = load_dataset(args.dataset, data_dir=args.subset)
|
||||||
eval_data = data['test']
|
else:
|
||||||
train_dataset = RewardDataset(train_data, tokenizer, max_len)
|
data = load_dataset(args.dataset)
|
||||||
eval_dataset = RewardDataset(eval_data, tokenizer, max_len)
|
|
||||||
|
if args.test:
|
||||||
|
train_data = data['train'].select(range(100))
|
||||||
|
eval_data = data['test'].select(range(10))
|
||||||
|
else:
|
||||||
|
train_data = data['train']
|
||||||
|
eval_data = data['test']
|
||||||
|
valid_data = data['test'].select((randint(0, len(eval_data) - 1) for _ in range(len(eval_data)//10)))
|
||||||
|
|
||||||
|
if args.dataset == 'Dahoas/rm-static':
|
||||||
|
train_dataset = RmStaticDataset(train_data, tokenizer, max_len)
|
||||||
|
valid_dataset = RmStaticDataset(valid_data, tokenizer, max_len)
|
||||||
|
eval_dataset = RmStaticDataset(eval_data, tokenizer, max_len)
|
||||||
|
elif args.dataset == 'Anthropic/hh-rlhf':
|
||||||
|
train_dataset = HhRlhfDataset(train_data, tokenizer, max_len)
|
||||||
|
valid_dataset = HhRlhfDataset(valid_data, tokenizer, max_len)
|
||||||
|
eval_dataset = HhRlhfDataset(eval_data, tokenizer, max_len)
|
||||||
|
else:
|
||||||
|
raise ValueError(f'Unsupported dataset "{args.dataset}"')
|
||||||
|
|
||||||
trainer = RewardModelTrainer(model=model,
|
trainer = RewardModelTrainer(model=model,
|
||||||
strategy=strategy,
|
strategy=strategy,
|
||||||
optim=optim,
|
optim=optim,
|
||||||
|
loss_fn = loss_fn,
|
||||||
train_dataset=train_dataset,
|
train_dataset=train_dataset,
|
||||||
|
valid_dataset=valid_dataset,
|
||||||
eval_dataset=eval_dataset,
|
eval_dataset=eval_dataset,
|
||||||
batch_size=args.batch_size,
|
batch_size=args.batch_size,
|
||||||
max_epochs=args.max_epochs)
|
max_epochs=args.max_epochs)
|
||||||
|
|
||||||
trainer.fit(use_lora=args.lora_rank)
|
trainer.fit()
|
||||||
|
|
||||||
# save model checkpoint after fitting on only rank0
|
# save model checkpoint after fitting on only rank0
|
||||||
strategy.save_model(model, 'rm_checkpoint.pt', only_rank0=True)
|
strategy.save_model(trainer.model, args.save_path, only_rank0=True)
|
||||||
# save optimizer checkpoint on all ranks
|
# save optimizer checkpoint on all ranks
|
||||||
strategy.save_optimizer(optim, 'rm_optim_checkpoint_%d.pt' % (torch.cuda.current_device()), only_rank0=False)
|
if args.need_optim_ckpt:
|
||||||
|
strategy.save_optimizer(trainer.optimizer, 'rm_optim_checkpoint_%d.pt' % (torch.cuda.current_device()), only_rank0=False)
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
|
@ -92,10 +121,18 @@ if __name__ == '__main__':
|
||||||
default='naive')
|
default='naive')
|
||||||
parser.add_argument('--model', choices=['gpt2', 'bloom', 'opt'], default='bloom')
|
parser.add_argument('--model', choices=['gpt2', 'bloom', 'opt'], default='bloom')
|
||||||
parser.add_argument('--pretrain', type=str, default=None)
|
parser.add_argument('--pretrain', type=str, default=None)
|
||||||
parser.add_argument('--dataset', type=str, default='Dahoas/rm-static')
|
parser.add_argument('--model_path', type=str, default=None)
|
||||||
parser.add_argument('--save_path', type=str, default='rm_ckpt.pth')
|
parser.add_argument('--need_optim_ckpt', type=bool, default=False)
|
||||||
|
parser.add_argument('--dataset', type=str,
|
||||||
|
choices=['Anthropic/hh-rlhf', 'Dahoas/rm-static'],
|
||||||
|
default='Dahoas/rm-static')
|
||||||
|
parser.add_argument('--subset', type=str, default=None)
|
||||||
|
parser.add_argument('--save_path', type=str, default='rm_ckpt.pt')
|
||||||
parser.add_argument('--max_epochs', type=int, default=1)
|
parser.add_argument('--max_epochs', type=int, default=1)
|
||||||
parser.add_argument('--batch_size', type=int, default=4)
|
parser.add_argument('--batch_size', type=int, default=1)
|
||||||
|
parser.add_argument('--max_len', type=int, default=512)
|
||||||
parser.add_argument('--lora_rank', type=int, default=0, help="low-rank adaptation matrices rank")
|
parser.add_argument('--lora_rank', type=int, default=0, help="low-rank adaptation matrices rank")
|
||||||
|
parser.add_argument('--loss_fn', type=str, default='log_sig', choices=['log_sig', 'log_exp'])
|
||||||
|
parser.add_argument('--test', type=bool, default=False)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
train(args)
|
train(args)
|
||||||
|
|
|
@ -1,20 +1,8 @@
|
||||||
set_n_least_used_CUDA_VISIBLE_DEVICES() {
|
set_n_least_used_CUDA_VISIBLE_DEVICES 1
|
||||||
local n=${1:-"9999"}
|
|
||||||
echo "GPU Memory Usage:"
|
|
||||||
local FIRST_N_GPU_IDS=$(nvidia-smi --query-gpu=memory.used --format=csv \
|
|
||||||
| tail -n +2 \
|
|
||||||
| nl -v 0 \
|
|
||||||
| tee /dev/tty \
|
|
||||||
| sort -g -k 2 \
|
|
||||||
| awk '{print $1}' \
|
|
||||||
| head -n $n)
|
|
||||||
export CUDA_VISIBLE_DEVICES=$(echo $FIRST_N_GPU_IDS | sed 's/ /,/g')
|
|
||||||
echo "Now CUDA_VISIBLE_DEVICES is set to:"
|
|
||||||
echo "CUDA_VISIBLE_DEVICES=$CUDA_VISIBLE_DEVICES"
|
|
||||||
}
|
|
||||||
|
|
||||||
set_n_least_used_CUDA_VISIBLE_DEVICES 2
|
python train_reward_model.py --pretrain '/home/lczht/data2/bloom-560m' \
|
||||||
|
--model 'bloom' \
|
||||||
# torchrun --standalone --nproc_per_node=2 train_reward_model.py --pretrain 'bigscience/bloomz-560m' --model 'bloom' --strategy colossalai_zero2
|
--strategy naive \
|
||||||
torchrun --standalone --nproc_per_node=2 train_reward_model.py --model 'gpt2' --strategy colossalai_zero2
|
--loss_fn 'log_exp'\
|
||||||
# torchrun --standalone --nproc_per_node=2 train_reward_model.py --pretrain "facebook/opt-350m" --model 'opt' --strategy colossalai_zero2
|
--save_path 'rmstatic.pt' \
|
||||||
|
--test True
|
||||||
|
|
Loading…
Reference in New Issue