polish code and fix dataloader bugs

pull/2494/head
jiaruifang 2023-01-18 12:00:08 +08:00
parent a4b75b78a0
commit e58cc441e2
3 changed files with 35 additions and 65 deletions

View File

@ -1,5 +1,6 @@
import json
import os
from typing import Optional
import torch
from torch.utils.data import Dataset
@ -11,26 +12,29 @@ from colossalai.registry import DATASETS
@DATASETS.register_module
class WebtextDataset(Dataset):
def __init__(self, path, seq_len=1024) -> None:
def __init__(self, path: Optional[str] = None, seq_len=1024) -> None:
super().__init__()
root = os.path.dirname(path)
encoded_data_cache_path = os.path.join(root, f'gpt_webtext_{seq_len}.pt')
if os.path.isfile(encoded_data_cache_path):
seq_len_, data, attention_mask = torch.load(encoded_data_cache_path)
if seq_len_ == seq_len:
self.data = data
self.attention_mask = attention_mask
return
raw_data = []
with open(path) as f:
for line in f.readlines():
raw_data.append(json.loads(line)['text'])
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
tokenizer.pad_token = tokenizer.unk_token
encoded_data = tokenizer(raw_data, padding=True, truncation=True, max_length=seq_len, return_tensors='pt')
self.data = encoded_data['input_ids']
self.attention_mask = encoded_data['attention_mask']
torch.save((seq_len, self.data, self.attention_mask), encoded_data_cache_path)
if path is not None:
root = os.path.dirname(path)
encoded_data_cache_path = os.path.join(root, f'gpt_webtext_{seq_len}.pt')
if os.path.isfile(encoded_data_cache_path):
seq_len_, data, attention_mask = torch.load(encoded_data_cache_path)
if seq_len_ == seq_len:
self.data = data
self.attention_mask = attention_mask
return
raw_data = []
with open(path) as f:
for line in f.readlines():
raw_data.append(json.loads(line)['text'])
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
tokenizer.pad_token = tokenizer.unk_token
encoded_data = tokenizer(raw_data, padding=True, truncation=True, max_length=seq_len, return_tensors='pt')
self.data = encoded_data['input_ids']
self.attention_mask = encoded_data['attention_mask']
else:
self.data = torch.randint(0, 50257, (10240, seq_len))
self.attention_mask = torch.ones_like(self.data)
def __len__(self):
return len(self.data)

View File

@ -1,2 +1,3 @@
export DATA=/data/scratch/gpt_data/small-gpt-dataset.json
colossalai run --nproc_per_node=4 train_gpt.py --config ./configs/gpt2_small_zero3_pp1d.py --from_torch
DUMMY_DATA=--use_dummy_dataset
colossalai run --nproc_per_node=2 train_gpt.py --config ./configs/gpt2_small_zero3_pp1d.py --from_torch $DUMMY_DATA

View File

@ -3,6 +3,7 @@ import os
import torch
import torch.nn as nn
from dataset.webtext import WebtextDataset
from titans.model.gpt import GPTLMLoss
import colossalai
@ -39,52 +40,16 @@ def main():
colossalai.launch_from_slurm(config=args.config, host=args.host, port=29500, seed=42)
logger = get_dist_logger()
if not args.use_dummy_dataset:
data_path = os.environ['DATA']
logger.info(f'Build data loader from path {data_path}', ranks=[0])
from dataset.webtext import WebtextDataset
train_ds = WebtextDataset(os.environ['DATA'], seq_len=gpc.config.SEQ_LEN)
train_dataloader = utils.get_dataloader(train_ds,
seed=42,
batch_size=gpc.config.BATCH_SIZE,
pin_memory=True,
shuffle=True,
drop_last=True)
else:
# build a dummy train_dataloader
logger.info('Build data loader using dummy data', ranks=[0])
data_path = None if args.use_dummy_dataset else os.environ['DATA']
logger.info(f'Build data loader from path {data_path}', ranks=[0])
def get_data(batch_size, seq_len, vocab_size):
input_ids = torch.randint(0, vocab_size, (batch_size, seq_len), device=torch.cuda.current_device())
attention_mask = torch.ones_like(input_ids)
return input_ids, attention_mask
# 10 iterations
input_ids, attn_mask = get_data(gpc.config.BATCH_SIZE * 10, gpc.config.SEQ_LEN, VOCAB_SIZE)
from torch.utils.data import DataLoader, Dataset
class TextSamplerDataset(Dataset):
def __init__(self, data, seq_len):
super().__init__()
self.data = data
self.seq_len = seq_len
def __getitem__(self, index):
rand_start = torch.randint(0, self.data.size(0) - self.seq_len, (1,))
full_seq = self.data[rand_start:rand_start + self.seq_len + 1].long()
return full_seq.cuda()
def __len__(self):
return self.data.size(0) // self.seq_len
def cycle(loader):
while True:
for data in loader:
yield data
train_dataset = TextSamplerDataset(input_ids, gpc.config.SEQ_LEN)
train_dataloader = DataLoader(train_dataset, batch_size=gpc.config.BATCH_SIZE)
train_ds = WebtextDataset(path=data_path, seq_len=gpc.config.SEQ_LEN)
train_dataloader = utils.get_dataloader(train_ds,
seed=42,
batch_size=gpc.config.BATCH_SIZE,
pin_memory=True,
shuffle=True,
drop_last=True)
logger.info('Build model', ranks=[0])
use_pipeline = is_using_pp()