import json import os from typing import Optional import torch from torch.utils.data import Dataset from transformers import GPT2Tokenizer from colossalai.legacy.registry import DATASETS @DATASETS.register_module class WebtextDataset(Dataset): def __init__(self, path: Optional[str] = None, seq_len=1024) -> None: super().__init__() if path is not None: root = os.path.dirname(path) encoded_data_cache_path = os.path.join(root, f'gpt_webtext_{seq_len}.pt') if os.path.isfile(encoded_data_cache_path): seq_len_, data, attention_mask = torch.load(encoded_data_cache_path) if seq_len_ == seq_len: self.data = data self.attention_mask = attention_mask return raw_data = [] with open(path) as f: for line in f.readlines(): raw_data.append(json.loads(line)['text']) tokenizer = GPT2Tokenizer.from_pretrained('gpt2') tokenizer.pad_token = tokenizer.unk_token encoded_data = tokenizer(raw_data, padding=True, truncation=True, max_length=seq_len, return_tensors='pt') self.data = encoded_data['input_ids'] self.attention_mask = encoded_data['attention_mask'] else: self.data = torch.randint(0, 50257, (10240, seq_len)) self.attention_mask = torch.ones_like(self.data) def __len__(self): return len(self.data) def __getitem__(self, index): return {'input_ids': self.data[index], 'attention_mask': self.attention_mask[index]}, self.data[index]