mirror of https://github.com/hpcaitech/ColossalAI
parent
d565a24849
commit
e327e95144
|
@ -12,11 +12,11 @@ TENSOR_SHAPE = (BATCH_SIZE // NUM_MICRO_BATCHES, SEQ_LEN, HIDDEN_SIZE)
|
||||||
|
|
||||||
# if you do no want zero, just comment out this dictionary
|
# if you do no want zero, just comment out this dictionary
|
||||||
zero = dict(model_config=dict(tensor_placement_policy='cuda', shard_strategy=TensorShardStrategy()),
|
zero = dict(model_config=dict(tensor_placement_policy='cuda', shard_strategy=TensorShardStrategy()),
|
||||||
optimizer_config=dict(initial_scale=2**16))
|
optimizer_config=dict(initial_scale=2**5))
|
||||||
|
|
||||||
optimizer = dict(
|
optimizer = dict(
|
||||||
type=HybridAdam,
|
type=HybridAdam,
|
||||||
lr=0.00015,
|
lr=0.000015,
|
||||||
weight_decay=1e-2,
|
weight_decay=1e-2,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
@ -0,0 +1,43 @@
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from torch.utils.data import Dataset
|
||||||
|
from transformers import GPT2Tokenizer
|
||||||
|
|
||||||
|
from colossalai.registry import DATASETS
|
||||||
|
|
||||||
|
|
||||||
|
@DATASETS.register_module
|
||||||
|
class WebtextDataset(Dataset):
|
||||||
|
|
||||||
|
def __init__(self, path: Optional[str] = None, seq_len=1024) -> None:
|
||||||
|
super().__init__()
|
||||||
|
if path is not None:
|
||||||
|
root = os.path.dirname(path)
|
||||||
|
encoded_data_cache_path = os.path.join(root, f'gpt_webtext_{seq_len}.pt')
|
||||||
|
if os.path.isfile(encoded_data_cache_path):
|
||||||
|
seq_len_, data, attention_mask = torch.load(encoded_data_cache_path)
|
||||||
|
if seq_len_ == seq_len:
|
||||||
|
self.data = data
|
||||||
|
self.attention_mask = attention_mask
|
||||||
|
return
|
||||||
|
raw_data = []
|
||||||
|
with open(path) as f:
|
||||||
|
for line in f.readlines():
|
||||||
|
raw_data.append(json.loads(line)['text'])
|
||||||
|
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
|
||||||
|
tokenizer.pad_token = tokenizer.unk_token
|
||||||
|
encoded_data = tokenizer(raw_data, padding=True, truncation=True, max_length=seq_len, return_tensors='pt')
|
||||||
|
self.data = encoded_data['input_ids']
|
||||||
|
self.attention_mask = encoded_data['attention_mask']
|
||||||
|
else:
|
||||||
|
self.data = torch.randint(0, 50257, (10240, seq_len))
|
||||||
|
self.attention_mask = torch.ones_like(self.data)
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return len(self.data)
|
||||||
|
|
||||||
|
def __getitem__(self, index):
|
||||||
|
return {'input_ids': self.data[index], 'attention_mask': self.attention_mask[index]}, self.data[index]
|
|
@ -1,2 +1,3 @@
|
||||||
export DATA=/data/scratch/gpt_data/small-gpt-dataset.json
|
export DATA=/data/scratch/gpt_data/small-gpt-dataset.json
|
||||||
colossalai run --nproc_per_node=4 train_gpt.py --config ./configs/gpt2_small_zero3_pp1d.py --from_torch
|
DUMMY_DATA=--use_dummy_dataset
|
||||||
|
colossalai run --nproc_per_node=2 train_gpt.py --config ./configs/gpt2_small_zero3_pp1d.py --from_torch $DUMMY_DATA
|
||||||
|
|
|
@ -3,6 +3,7 @@ import os
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
from dataset.webtext import WebtextDataset
|
||||||
from titans.model.gpt import GPTLMLoss
|
from titans.model.gpt import GPTLMLoss
|
||||||
|
|
||||||
import colossalai
|
import colossalai
|
||||||
|
@ -30,7 +31,7 @@ VOCAB_SIZE = 50257
|
||||||
def main():
|
def main():
|
||||||
parser = colossalai.get_default_parser()
|
parser = colossalai.get_default_parser()
|
||||||
parser.add_argument('--from_torch', default=False, action='store_true')
|
parser.add_argument('--from_torch', default=False, action='store_true')
|
||||||
parser.add_argument('--use_dummy_dataset', default=True, action='store_true')
|
parser.add_argument('--use_dummy_dataset', default=False, action='store_true')
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
disable_existing_loggers()
|
disable_existing_loggers()
|
||||||
if args.from_torch:
|
if args.from_torch:
|
||||||
|
@ -39,52 +40,16 @@ def main():
|
||||||
colossalai.launch_from_slurm(config=args.config, host=args.host, port=29500, seed=42)
|
colossalai.launch_from_slurm(config=args.config, host=args.host, port=29500, seed=42)
|
||||||
logger = get_dist_logger()
|
logger = get_dist_logger()
|
||||||
|
|
||||||
if not args.use_dummy_dataset:
|
data_path = None if args.use_dummy_dataset else os.environ['DATA']
|
||||||
data_path = os.environ['DATA']
|
logger.info(f'Build data loader from path {data_path}', ranks=[0])
|
||||||
logger.info(f'Build data loader from path {data_path}', ranks=[0])
|
|
||||||
from dataset.webtext import WebtextDataset
|
|
||||||
train_ds = WebtextDataset(os.environ['DATA'], seq_len=gpc.config.SEQ_LEN)
|
|
||||||
train_dataloader = utils.get_dataloader(train_ds,
|
|
||||||
seed=42,
|
|
||||||
batch_size=gpc.config.BATCH_SIZE,
|
|
||||||
pin_memory=True,
|
|
||||||
shuffle=True,
|
|
||||||
drop_last=True)
|
|
||||||
else:
|
|
||||||
# build a dummy train_dataloader
|
|
||||||
logger.info('Build data loader using dummy data', ranks=[0])
|
|
||||||
|
|
||||||
def get_data(batch_size, seq_len, vocab_size):
|
train_ds = WebtextDataset(path=data_path, seq_len=gpc.config.SEQ_LEN)
|
||||||
input_ids = torch.randint(0, vocab_size, (batch_size, seq_len), device=torch.cuda.current_device())
|
train_dataloader = utils.get_dataloader(train_ds,
|
||||||
attention_mask = torch.ones_like(input_ids)
|
seed=42,
|
||||||
return input_ids, attention_mask
|
batch_size=gpc.config.BATCH_SIZE,
|
||||||
|
pin_memory=True,
|
||||||
# 10 iterations
|
shuffle=True,
|
||||||
input_ids, attn_mask = get_data(gpc.config.BATCH_SIZE * 10, gpc.config.SEQ_LEN, VOCAB_SIZE)
|
drop_last=True)
|
||||||
from torch.utils.data import DataLoader, Dataset
|
|
||||||
|
|
||||||
class TextSamplerDataset(Dataset):
|
|
||||||
|
|
||||||
def __init__(self, data, seq_len):
|
|
||||||
super().__init__()
|
|
||||||
self.data = data
|
|
||||||
self.seq_len = seq_len
|
|
||||||
|
|
||||||
def __getitem__(self, index):
|
|
||||||
rand_start = torch.randint(0, self.data.size(0) - self.seq_len, (1,))
|
|
||||||
full_seq = self.data[rand_start:rand_start + self.seq_len + 1].long()
|
|
||||||
return full_seq.cuda()
|
|
||||||
|
|
||||||
def __len__(self):
|
|
||||||
return self.data.size(0) // self.seq_len
|
|
||||||
|
|
||||||
def cycle(loader):
|
|
||||||
while True:
|
|
||||||
for data in loader:
|
|
||||||
yield data
|
|
||||||
|
|
||||||
train_dataset = TextSamplerDataset(input_ids, gpc.config.SEQ_LEN)
|
|
||||||
train_dataloader = DataLoader(train_dataset, batch_size=gpc.config.BATCH_SIZE)
|
|
||||||
|
|
||||||
logger.info('Build model', ranks=[0])
|
logger.info('Build model', ranks=[0])
|
||||||
use_pipeline = is_using_pp()
|
use_pipeline = is_using_pp()
|
||||||
|
|
Loading…
Reference in New Issue