import torch from torch.utils.data import Dataset from datasets import load_dataset class NetflixDataset(Dataset): def __init__(self, tokenizer): super().__init__() self.tokenizer = tokenizer self.input_ids = [] self.attn_masks = [] self.labels = [] self.txt_list = netflix_descriptions = load_dataset("hugginglearners/netflix-shows", split="train")['description'] self.max_length = max([len(self.tokenizer.encode(description)) for description in netflix_descriptions]) for txt in self.txt_list: encodings_dict = self.tokenizer('' + txt + '', truncation=True, max_length=self.max_length, padding="max_length") self.input_ids.append(torch.tensor(encodings_dict['input_ids'])) self.attn_masks.append(torch.tensor(encodings_dict['attention_mask'])) def __len__(self): return len(self.input_ids) def __getitem__(self, idx): return self.input_ids[idx], self.attn_masks[idx] def netflix_collator(data): return {'input_ids': torch.stack([x[0] for x in data]), 'attention_mask': torch.stack([x[1] for x in data]), 'labels': torch.stack([x[0] for x in data])}