mirror of https://github.com/hpcaitech/ColossalAI
You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
38 lines
1.3 KiB
38 lines
1.3 KiB
import torch
|
|
from torch.utils.data import Dataset
|
|
from datasets import load_dataset
|
|
|
|
|
|
class NetflixDataset(Dataset):
|
|
|
|
def __init__(self, tokenizer):
|
|
|
|
super().__init__()
|
|
|
|
self.tokenizer = tokenizer
|
|
self.input_ids = []
|
|
self.attn_masks = []
|
|
self.labels = []
|
|
self.txt_list = netflix_descriptions = load_dataset("hugginglearners/netflix-shows", split="train")['description']
|
|
self.max_length = max([len(self.tokenizer.encode(description)) for description in netflix_descriptions])
|
|
|
|
for txt in self.txt_list:
|
|
encodings_dict = self.tokenizer('</s>' + txt + '</s>',
|
|
truncation=True,
|
|
max_length=self.max_length,
|
|
padding="max_length")
|
|
self.input_ids.append(torch.tensor(encodings_dict['input_ids']))
|
|
self.attn_masks.append(torch.tensor(encodings_dict['attention_mask']))
|
|
|
|
def __len__(self):
|
|
return len(self.input_ids)
|
|
|
|
def __getitem__(self, idx):
|
|
return self.input_ids[idx], self.attn_masks[idx]
|
|
|
|
|
|
def netflix_collator(data):
|
|
return {'input_ids': torch.stack([x[0] for x in data]),
|
|
'attention_mask': torch.stack([x[1] for x in data]),
|
|
'labels': torch.stack([x[0] for x in data])}
|