ColossalAI/applications/Chat/coati/dataset/reward_dataset.py

89 lines
3.2 KiB
Python

from typing import Callable
from torch.utils.data import Dataset
from tqdm import tqdm
from .utils import is_rank_0
# Dahoas/rm-static
class RmStaticDataset(Dataset):
"""
Dataset for reward model
Args:
dataset: dataset for reward model
tokenizer: tokenizer for reward model
max_length: max length of input
special_token: special token at the end of sentence
"""
def __init__(self, dataset, tokenizer: Callable, max_length: int, special_token=None) -> None:
super().__init__()
self.end_token = tokenizer.eos_token if special_token is None else special_token
chosen = [data["prompt"] + data["chosen"] + self.end_token for data in tqdm(dataset, disable=not is_rank_0())]
chosen_token = tokenizer(
chosen, max_length=max_length, padding="max_length", truncation=True, return_tensors="pt"
)
self.chosen = {"input_ids": chosen_token["input_ids"], "attention_mask": chosen_token["attention_mask"]}
reject = [data["prompt"] + data["rejected"] + self.end_token for data in tqdm(dataset, disable=not is_rank_0())]
reject_token = tokenizer(
reject, max_length=max_length, padding="max_length", truncation=True, return_tensors="pt"
)
self.reject = {"input_ids": reject_token["input_ids"], "attention_mask": reject_token["attention_mask"]}
def __len__(self):
length = self.chosen["input_ids"].shape[0]
return length
def __getitem__(self, idx):
return (
self.chosen["input_ids"][idx],
self.chosen["attention_mask"][idx],
self.reject["input_ids"][idx],
self.reject["attention_mask"][idx],
)
# Anthropic/hh-rlhf
class HhRlhfDataset(Dataset):
"""
Dataset for reward model
Args:
dataset: dataset for reward model
tokenizer: tokenizer for reward model
max_length: max length of input
special_token: special token at the end of sentence
"""
def __init__(self, dataset, tokenizer: Callable, max_length: int, special_token=None) -> None:
super().__init__()
self.end_token = tokenizer.eos_token if special_token is None else special_token
chosen = [data["chosen"] + self.end_token for data in tqdm(dataset, disable=not is_rank_0())]
chosen_token = tokenizer(
chosen, max_length=max_length, padding="max_length", truncation=True, return_tensors="pt"
)
self.chosen = {"input_ids": chosen_token["input_ids"], "attention_mask": chosen_token["attention_mask"]}
reject = [data["rejected"] + self.end_token for data in tqdm(dataset, disable=not is_rank_0())]
reject_token = tokenizer(
reject, max_length=max_length, padding="max_length", truncation=True, return_tensors="pt"
)
self.reject = {"input_ids": reject_token["input_ids"], "attention_mask": reject_token["attention_mask"]}
def __len__(self):
length = self.chosen["input_ids"].shape[0]
return length
def __getitem__(self, idx):
return (
self.chosen["input_ids"][idx],
self.chosen["attention_mask"][idx],
self.reject["input_ids"][idx],
self.reject["attention_mask"][idx],
)