[chatgpt] fix rm eval (#2829)

* [chatgpt]fix train_rm bug with lora

* [chatgpt]support colossalai strategy to train rm

* fix pre-commit

* fix pre-commit 2

* [chatgpt]fix rm eval typo

* fix rm eval

* fix pre commit
pull/2849/head
BlueRum 2023-02-21 11:35:45 +08:00 committed by GitHub
parent 918bc94b6b
commit 3eebc4dff7
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 34 additions and 18 deletions

View File

@ -1,3 +1,4 @@
from .reward_dataset import RewardDataset from .reward_dataset import RewardDataset
from .utils import is_rank_0
__all__ = ['RewardDataset'] __all__ = ['RewardDataset', 'is_rank_0']

View File

@ -3,6 +3,8 @@ from typing import Callable
from torch.utils.data import Dataset from torch.utils.data import Dataset
from tqdm import tqdm from tqdm import tqdm
from .utils import is_rank_0
class RewardDataset(Dataset): class RewardDataset(Dataset):
""" """
@ -18,7 +20,7 @@ class RewardDataset(Dataset):
super().__init__() super().__init__()
self.chosen = [] self.chosen = []
self.reject = [] self.reject = []
for data in tqdm(dataset): for data in tqdm(dataset, disable=not is_rank_0()):
prompt = data['prompt'] prompt = data['prompt']
chosen = prompt + data['chosen'] + "<|endoftext|>" chosen = prompt + data['chosen'] + "<|endoftext|>"

View File

@ -0,0 +1,5 @@
import torch.distributed as dist
def is_rank_0() -> bool:
return not dist.is_initialized() or dist.get_rank() == 0

View File

@ -23,7 +23,7 @@ class RewardModel(LoRAModule):
lora_rank: int = 0, lora_rank: int = 0,
lora_train_bias: str = 'none') -> None: lora_train_bias: str = 'none') -> None:
super().__init__(lora_rank=lora_rank, lora_train_bias=lora_train_bias) super().__init__(lora_rank=lora_rank, lora_train_bias=lora_train_bias)
self.model = model self.body = model
if value_head is not None: if value_head is not None:
if value_head.out_features != 1: if value_head.out_features != 1:
raise ValueError("The value head of reward model's output dim should be 1!") raise ValueError("The value head of reward model's output dim should be 1!")
@ -34,7 +34,7 @@ class RewardModel(LoRAModule):
self.convert_to_lora() self.convert_to_lora()
def forward(self, sequences: torch.LongTensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor: def forward(self, sequences: torch.LongTensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
outputs = self.model(sequences, attention_mask=attention_mask) outputs = self.body(sequences, attention_mask=attention_mask)
last_hidden_states = outputs['last_hidden_state'] last_hidden_states = outputs['last_hidden_state']
values = self.value_head(last_hidden_states)[:, :-1] values = self.value_head(last_hidden_states)[:, :-1]
value = values.mean(dim=1).squeeze(1) # ensure shape is (B) value = values.mean(dim=1).squeeze(1) # ensure shape is (B)

View File

@ -1,6 +1,7 @@
from abc import ABC from abc import ABC
import loralib as lora import loralib as lora
import torch
from chatgpt.dataset import RewardDataset from chatgpt.dataset import RewardDataset
from chatgpt.nn import PairWiseLoss from chatgpt.nn import PairWiseLoss
from torch.optim import Adam, Optimizer from torch.optim import Adam, Optimizer
@ -55,7 +56,8 @@ class RewardModelTrainer(ABC):
# train # train
if use_lora > 0: if use_lora > 0:
print("Using Lora") print("Using Lora")
lora.mark_only_lora_as_trainable(self.model.model) lora.mark_only_lora_as_trainable(self.model.body)
else: else:
self.model.train() self.model.train()
for chosen_ids, c_mask, reject_ids, r_mask in self.train_dataloader: for chosen_ids, c_mask, reject_ids, r_mask in self.train_dataloader:
@ -74,16 +76,21 @@ class RewardModelTrainer(ABC):
# eval # eval
self.model.eval() self.model.eval()
for chosen_ids, c_mask, reject_ids, r_mask in self.eval_dataloader: with torch.no_grad():
dist = 0 dist = 0
chosen_ids = chosen_ids.squeeze(1).cuda() loss_sum = 0
c_mask = c_mask.squeeze(1).cuda() for chosen_ids, c_mask, reject_ids, r_mask in self.eval_dataloader:
reject_ids = reject_ids.squeeze(1).cuda() chosen_ids = chosen_ids.squeeze(1).cuda()
r_mask = r_mask.squeeze(1).cuda() c_mask = c_mask.squeeze(1).cuda()
chosen_reward = self.model(chosen_ids, attention_mask=c_mask) reject_ids = reject_ids.squeeze(1).cuda()
reject_reward = self.model(reject_ids, attention_mask=r_mask) r_mask = r_mask.squeeze(1).cuda()
dist += (chosen_reward - reject_reward) chosen_reward = self.model(chosen_ids, attention_mask=c_mask)
dist_mean = dist / self.eval_dataloader.__len__() reject_reward = self.model(reject_ids, attention_mask=r_mask)
dist += (chosen_reward - reject_reward).mean().item()
loss = self.loss_fn(chosen_reward, reject_reward)
loss_sum += loss.item()
dist_mean = dist / self.eval_dataloader.__len__()
loss_mean = loss_sum / self.eval_dataloader.__len__()
epoch_bar.update() epoch_bar.update()
step_bar.set_postfix({'loss': loss.item(), 'dist_mean': dist_mean.item()}) step_bar.set_postfix({'loss': loss_mean, 'dist_mean': dist_mean})
step_bar.close() step_bar.close()

View File

@ -29,7 +29,8 @@ def train(args):
# configure model # configure model
tokenizer = BloomTokenizerFast.from_pretrained(args.pretrain) tokenizer = BloomTokenizerFast.from_pretrained(args.pretrain)
tokenizer.pad_token = tokenizer.eos_token tokenizer.pad_token = tokenizer.eos_token
model = BLOOMRM(pretrained=args.pretrain).cuda() with strategy.model_init_context():
model = BLOOMRM(pretrained=args.pretrain).cuda()
max_len = 1024 max_len = 1024
# configure optimizer # configure optimizer
@ -71,8 +72,8 @@ if __name__ == '__main__':
parser.add_argument('--pretrain', type=str, default=None) parser.add_argument('--pretrain', type=str, default=None)
parser.add_argument('--dataset', type=str, default='Dahoas/rm-static') parser.add_argument('--dataset', type=str, default='Dahoas/rm-static')
parser.add_argument('--save_path', type=str, default='rm_ckpt.pth') parser.add_argument('--save_path', type=str, default='rm_ckpt.pth')
parser.add_argument('--max_epochs', type=int, default=2) parser.add_argument('--max_epochs', type=int, default=10)
parser.add_argument('--batch_size', type=int, default=1) parser.add_argument('--batch_size', type=int, default=4)
parser.add_argument('--lora_rank', type=int, default=0, help="low-rank adaptation matrices rank") parser.add_argument('--lora_rank', type=int, default=0, help="low-rank adaptation matrices rank")
args = parser.parse_args() args = parser.parse_args()
train(args) train(args)