from typing import Callable from torch.utils.data import Dataset from tqdm import tqdm from .utils import is_rank_0 # Dahoas/rm-static class RmStaticDataset(Dataset): """ Dataset for reward model Args: dataset: dataset for reward model tokenizer: tokenizer for reward model max_length: max length of input special_token: special token at the end of sentence """ def __init__(self, dataset, tokenizer: Callable, max_length: int, special_token=None) -> None: super().__init__() self.end_token = tokenizer.eos_token if special_token is None else special_token chosen = [data["prompt"] + data["chosen"] + self.end_token for data in tqdm(dataset, disable=not is_rank_0())] chosen_token = tokenizer( chosen, max_length=max_length, padding="max_length", truncation=True, return_tensors="pt" ) self.chosen = {"input_ids": chosen_token["input_ids"], "attention_mask": chosen_token["attention_mask"]} reject = [data["prompt"] + data["rejected"] + self.end_token for data in tqdm(dataset, disable=not is_rank_0())] reject_token = tokenizer( reject, max_length=max_length, padding="max_length", truncation=True, return_tensors="pt" ) self.reject = {"input_ids": reject_token["input_ids"], "attention_mask": reject_token["attention_mask"]} def __len__(self): length = self.chosen["input_ids"].shape[0] return length def __getitem__(self, idx): return ( self.chosen["input_ids"][idx], self.chosen["attention_mask"][idx], self.reject["input_ids"][idx], self.reject["attention_mask"][idx], ) # Anthropic/hh-rlhf class HhRlhfDataset(Dataset): """ Dataset for reward model Args: dataset: dataset for reward model tokenizer: tokenizer for reward model max_length: max length of input special_token: special token at the end of sentence """ def __init__(self, dataset, tokenizer: Callable, max_length: int, special_token=None) -> None: super().__init__() self.end_token = tokenizer.eos_token if special_token is None else special_token chosen = [data["chosen"] + self.end_token for data in tqdm(dataset, disable=not is_rank_0())] chosen_token = tokenizer( chosen, max_length=max_length, padding="max_length", truncation=True, return_tensors="pt" ) self.chosen = {"input_ids": chosen_token["input_ids"], "attention_mask": chosen_token["attention_mask"]} reject = [data["rejected"] + self.end_token for data in tqdm(dataset, disable=not is_rank_0())] reject_token = tokenizer( reject, max_length=max_length, padding="max_length", truncation=True, return_tensors="pt" ) self.reject = {"input_ids": reject_token["input_ids"], "attention_mask": reject_token["attention_mask"]} def __len__(self): length = self.chosen["input_ids"].shape[0] return length def __getitem__(self, idx): return ( self.chosen["input_ids"][idx], self.chosen["attention_mask"][idx], self.reject["input_ids"][idx], self.reject["attention_mask"][idx], )