from typing import Callable from torch.utils.data import Dataset from tqdm import tqdm from .utils import is_rank_0 # Dahoas/rm-static class RmStaticDataset(Dataset): """ Dataset for reward model Args: dataset: dataset for reward model tokenizer: tokenizer for reward model max_length: max length of input special_token: special token at the end of sentence """ def __init__(self, dataset, tokenizer: Callable, max_length: int, special_token=None) -> None: super().__init__() self.chosen = [] self.reject = [] if special_token is None: self.end_token = tokenizer.eos_token else: self.end_token = special_token for data in tqdm(dataset, disable=not is_rank_0()): prompt = data['prompt'] chosen = prompt + data['chosen'] + self.end_token chosen_token = tokenizer(chosen, max_length=max_length, padding="max_length", truncation=True, return_tensors="pt") self.chosen.append({ "input_ids": chosen_token['input_ids'], "attention_mask": chosen_token['attention_mask'] }) reject = prompt + data['rejected'] + self.end_token reject_token = tokenizer(reject, max_length=max_length, padding="max_length", truncation=True, return_tensors="pt") self.reject.append({ "input_ids": reject_token['input_ids'], "attention_mask": reject_token['attention_mask'] }) def __len__(self): length = len(self.chosen) return length def __getitem__(self, idx): return self.chosen[idx]["input_ids"], self.chosen[idx]["attention_mask"], self.reject[idx][ "input_ids"], self.reject[idx]["attention_mask"] # Anthropic/hh-rlhf class HhRlhfDataset(Dataset): """ Dataset for reward model Args: dataset: dataset for reward model tokenizer: tokenizer for reward model max_length: max length of input special_token: special token at the end of sentence """ def __init__(self, dataset, tokenizer: Callable, max_length: int, special_token=None) -> None: super().__init__() self.chosen = [] self.reject = [] if special_token is None: self.end_token = tokenizer.eos_token else: self.end_token = special_token for data in tqdm(dataset, disable=not is_rank_0()): chosen = data['chosen'] + self.end_token chosen_token = tokenizer(chosen, max_length=max_length, padding="max_length", truncation=True, return_tensors="pt") self.chosen.append({ "input_ids": chosen_token['input_ids'], "attention_mask": chosen_token['attention_mask'] }) reject = data['rejected'] + self.end_token reject_token = tokenizer(reject, max_length=max_length, padding="max_length", truncation=True, return_tensors="pt") self.reject.append({ "input_ids": reject_token['input_ids'], "attention_mask": reject_token['attention_mask'] }) def __len__(self): length = len(self.chosen) return length def __getitem__(self, idx): return self.chosen[idx]["input_ids"], self.chosen[idx]["attention_mask"], self.reject[idx][ "input_ids"], self.reject[idx]["attention_mask"]