ColossalAI/applications/Chat/coati/dataset/reward_dataset.py

115 lines
3.9 KiB
Python

from typing import Callable
from torch.utils.data import Dataset
from tqdm import tqdm
from .utils import is_rank_0
# Dahoas/rm-static
class RmStaticDataset(Dataset):
"""
Dataset for reward model
Args:
dataset: dataset for reward model
tokenizer: tokenizer for reward model
max_length: max length of input
special_token: special token at the end of sentence
"""
def __init__(self, dataset, tokenizer: Callable, max_length: int, special_token=None) -> None:
super().__init__()
self.end_token = tokenizer.eos_token \
if special_token is None else special_token
chosen = [
data["prompt"] + data["chosen"] + self.end_token
for data in tqdm(dataset, disable=not is_rank_0())
]
chosen_token = tokenizer(chosen,
max_length=max_length,
padding="max_length",
truncation=True,
return_tensors="pt")
self.chosen = {
"input_ids": chosen_token["input_ids"],
"attention_mask": chosen_token["attention_mask"]
}
reject = [
data["prompt"] + data["rejected"] + self.end_token
for data in tqdm(dataset, disable=not is_rank_0())
]
reject_token = tokenizer(reject,
max_length=max_length,
padding="max_length",
truncation=True,
return_tensors="pt")
self.reject = {
"input_ids": reject_token["input_ids"],
"attention_mask": reject_token["attention_mask"]
}
def __len__(self):
length = self.chosen["input_ids"].shape[0]
return length
def __getitem__(self, idx):
return self.chosen["input_ids"][idx], self.chosen["attention_mask"][idx], \
self.reject["input_ids"][idx], self.reject["attention_mask"][idx]
# Anthropic/hh-rlhf
class HhRlhfDataset(Dataset):
"""
Dataset for reward model
Args:
dataset: dataset for reward model
tokenizer: tokenizer for reward model
max_length: max length of input
special_token: special token at the end of sentence
"""
def __init__(self, dataset, tokenizer: Callable, max_length: int, special_token=None) -> None:
super().__init__()
self.end_token = tokenizer.eos_token \
if special_token is None else special_token
chosen = [
data["chosen"] + self.end_token
for data in tqdm(dataset, disable=not is_rank_0())
]
chosen_token = tokenizer(chosen,
max_length=max_length,
padding="max_length",
truncation=True,
return_tensors="pt")
self.chosen = {
"input_ids": chosen_token["input_ids"],
"attention_mask": chosen_token["attention_mask"]
}
reject = [
data["rejected"] + self.end_token
for data in tqdm(dataset, disable=not is_rank_0())
]
reject_token = tokenizer(reject,
max_length=max_length,
padding="max_length",
truncation=True,
return_tensors="pt")
self.reject = {
"input_ids": reject_token["input_ids"],
"attention_mask": reject_token["attention_mask"]
}
def __len__(self):
length = self.chosen["input_ids"].shape[0]
return length
def __getitem__(self, idx):
return self.chosen["input_ids"][idx], self.chosen["attention_mask"][idx], \
self.reject["input_ids"][idx], self.reject["attention_mask"][idx]