mirror of https://github.com/hpcaitech/ColossalAI
parent
d565a24849
commit
e327e95144
@ -0,0 +1,43 @@
|
|||||||
|
import json
|
||||||
|
import os
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from torch.utils.data import Dataset
|
||||||
|
from transformers import GPT2Tokenizer
|
||||||
|
|
||||||
|
from colossalai.registry import DATASETS
|
||||||
|
|
||||||
|
|
||||||
|
@DATASETS.register_module
|
||||||
|
class WebtextDataset(Dataset):
|
||||||
|
|
||||||
|
def __init__(self, path: Optional[str] = None, seq_len=1024) -> None:
|
||||||
|
super().__init__()
|
||||||
|
if path is not None:
|
||||||
|
root = os.path.dirname(path)
|
||||||
|
encoded_data_cache_path = os.path.join(root, f'gpt_webtext_{seq_len}.pt')
|
||||||
|
if os.path.isfile(encoded_data_cache_path):
|
||||||
|
seq_len_, data, attention_mask = torch.load(encoded_data_cache_path)
|
||||||
|
if seq_len_ == seq_len:
|
||||||
|
self.data = data
|
||||||
|
self.attention_mask = attention_mask
|
||||||
|
return
|
||||||
|
raw_data = []
|
||||||
|
with open(path) as f:
|
||||||
|
for line in f.readlines():
|
||||||
|
raw_data.append(json.loads(line)['text'])
|
||||||
|
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
|
||||||
|
tokenizer.pad_token = tokenizer.unk_token
|
||||||
|
encoded_data = tokenizer(raw_data, padding=True, truncation=True, max_length=seq_len, return_tensors='pt')
|
||||||
|
self.data = encoded_data['input_ids']
|
||||||
|
self.attention_mask = encoded_data['attention_mask']
|
||||||
|
else:
|
||||||
|
self.data = torch.randint(0, 50257, (10240, seq_len))
|
||||||
|
self.attention_mask = torch.ones_like(self.data)
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return len(self.data)
|
||||||
|
|
||||||
|
def __getitem__(self, index):
|
||||||
|
return {'input_ids': self.data[index], 'attention_mask': self.attention_mask[index]}, self.data[index]
|
@ -1,2 +1,3 @@
|
|||||||
export DATA=/data/scratch/gpt_data/small-gpt-dataset.json
|
export DATA=/data/scratch/gpt_data/small-gpt-dataset.json
|
||||||
colossalai run --nproc_per_node=4 train_gpt.py --config ./configs/gpt2_small_zero3_pp1d.py --from_torch
|
DUMMY_DATA=--use_dummy_dataset
|
||||||
|
colossalai run --nproc_per_node=2 train_gpt.py --config ./configs/gpt2_small_zero3_pp1d.py --from_torch $DUMMY_DATA
|
||||||
|
Loading…
Reference in new issue