2023-03-28 12:25:36 +00:00
|
|
|
from typing import Callable
|
|
|
|
|
|
|
|
from torch.utils.data import Dataset
|
|
|
|
from tqdm import tqdm
|
|
|
|
|
|
|
|
from .utils import is_rank_0
|
|
|
|
|
|
|
|
|
2023-05-23 07:28:20 +00:00
|
|
|
# Dahoas/rm-static
|
2023-03-28 12:25:36 +00:00
|
|
|
class RmStaticDataset(Dataset):
|
|
|
|
"""
|
|
|
|
Dataset for reward model
|
|
|
|
|
|
|
|
Args:
|
|
|
|
dataset: dataset for reward model
|
|
|
|
tokenizer: tokenizer for reward model
|
|
|
|
max_length: max length of input
|
|
|
|
special_token: special token at the end of sentence
|
|
|
|
"""
|
|
|
|
|
|
|
|
def __init__(self, dataset, tokenizer: Callable, max_length: int, special_token=None) -> None:
|
|
|
|
super().__init__()
|
2023-09-19 06:20:26 +00:00
|
|
|
self.end_token = tokenizer.eos_token if special_token is None else special_token
|
|
|
|
|
|
|
|
chosen = [data["prompt"] + data["chosen"] + self.end_token for data in tqdm(dataset, disable=not is_rank_0())]
|
|
|
|
chosen_token = tokenizer(
|
|
|
|
chosen, max_length=max_length, padding="max_length", truncation=True, return_tensors="pt"
|
|
|
|
)
|
|
|
|
self.chosen = {"input_ids": chosen_token["input_ids"], "attention_mask": chosen_token["attention_mask"]}
|
|
|
|
|
|
|
|
reject = [data["prompt"] + data["rejected"] + self.end_token for data in tqdm(dataset, disable=not is_rank_0())]
|
|
|
|
reject_token = tokenizer(
|
|
|
|
reject, max_length=max_length, padding="max_length", truncation=True, return_tensors="pt"
|
|
|
|
)
|
|
|
|
self.reject = {"input_ids": reject_token["input_ids"], "attention_mask": reject_token["attention_mask"]}
|
2023-03-28 12:25:36 +00:00
|
|
|
|
|
|
|
def __len__(self):
|
2023-08-02 02:17:36 +00:00
|
|
|
length = self.chosen["input_ids"].shape[0]
|
2023-03-28 12:25:36 +00:00
|
|
|
return length
|
|
|
|
|
|
|
|
def __getitem__(self, idx):
|
2023-09-19 06:20:26 +00:00
|
|
|
return (
|
|
|
|
self.chosen["input_ids"][idx],
|
|
|
|
self.chosen["attention_mask"][idx],
|
|
|
|
self.reject["input_ids"][idx],
|
|
|
|
self.reject["attention_mask"][idx],
|
|
|
|
)
|
2023-03-28 12:25:36 +00:00
|
|
|
|
|
|
|
|
|
|
|
# Anthropic/hh-rlhf
|
|
|
|
class HhRlhfDataset(Dataset):
|
|
|
|
"""
|
|
|
|
Dataset for reward model
|
|
|
|
|
|
|
|
Args:
|
|
|
|
dataset: dataset for reward model
|
|
|
|
tokenizer: tokenizer for reward model
|
|
|
|
max_length: max length of input
|
|
|
|
special_token: special token at the end of sentence
|
|
|
|
"""
|
|
|
|
|
|
|
|
def __init__(self, dataset, tokenizer: Callable, max_length: int, special_token=None) -> None:
|
|
|
|
super().__init__()
|
2023-09-19 06:20:26 +00:00
|
|
|
self.end_token = tokenizer.eos_token if special_token is None else special_token
|
|
|
|
|
|
|
|
chosen = [data["chosen"] + self.end_token for data in tqdm(dataset, disable=not is_rank_0())]
|
|
|
|
chosen_token = tokenizer(
|
|
|
|
chosen, max_length=max_length, padding="max_length", truncation=True, return_tensors="pt"
|
|
|
|
)
|
|
|
|
self.chosen = {"input_ids": chosen_token["input_ids"], "attention_mask": chosen_token["attention_mask"]}
|
|
|
|
|
|
|
|
reject = [data["rejected"] + self.end_token for data in tqdm(dataset, disable=not is_rank_0())]
|
|
|
|
reject_token = tokenizer(
|
|
|
|
reject, max_length=max_length, padding="max_length", truncation=True, return_tensors="pt"
|
|
|
|
)
|
|
|
|
self.reject = {"input_ids": reject_token["input_ids"], "attention_mask": reject_token["attention_mask"]}
|
2023-03-28 12:25:36 +00:00
|
|
|
|
|
|
|
def __len__(self):
|
2023-08-02 02:17:36 +00:00
|
|
|
length = self.chosen["input_ids"].shape[0]
|
2023-03-28 12:25:36 +00:00
|
|
|
return length
|
|
|
|
|
|
|
|
def __getitem__(self, idx):
|
2023-09-19 06:20:26 +00:00
|
|
|
return (
|
|
|
|
self.chosen["input_ids"][idx],
|
|
|
|
self.chosen["attention_mask"][idx],
|
|
|
|
self.reject["input_ids"][idx],
|
|
|
|
self.reject["attention_mask"][idx],
|
|
|
|
)
|