diff --git a/examples/language/gpt/titans/configs/gpt2_small_zero3_pp1d.py b/examples/language/gpt/titans/configs/gpt2_small_zero3_pp1d.py index 8ef81cb0a..7bf533039 100644 --- a/examples/language/gpt/titans/configs/gpt2_small_zero3_pp1d.py +++ b/examples/language/gpt/titans/configs/gpt2_small_zero3_pp1d.py @@ -12,11 +12,11 @@ TENSOR_SHAPE = (BATCH_SIZE // NUM_MICRO_BATCHES, SEQ_LEN, HIDDEN_SIZE) # if you do no want zero, just comment out this dictionary zero = dict(model_config=dict(tensor_placement_policy='cuda', shard_strategy=TensorShardStrategy()), - optimizer_config=dict(initial_scale=2**16)) + optimizer_config=dict(initial_scale=2**5)) optimizer = dict( type=HybridAdam, - lr=0.00015, + lr=0.000015, weight_decay=1e-2, ) diff --git a/examples/language/gpt/titans/dataset/webtext.py b/examples/language/gpt/titans/dataset/webtext.py new file mode 100644 index 000000000..64f5944a9 --- /dev/null +++ b/examples/language/gpt/titans/dataset/webtext.py @@ -0,0 +1,43 @@ +import json +import os +from typing import Optional + +import torch +from torch.utils.data import Dataset +from transformers import GPT2Tokenizer + +from colossalai.registry import DATASETS + + +@DATASETS.register_module +class WebtextDataset(Dataset): + + def __init__(self, path: Optional[str] = None, seq_len=1024) -> None: + super().__init__() + if path is not None: + root = os.path.dirname(path) + encoded_data_cache_path = os.path.join(root, f'gpt_webtext_{seq_len}.pt') + if os.path.isfile(encoded_data_cache_path): + seq_len_, data, attention_mask = torch.load(encoded_data_cache_path) + if seq_len_ == seq_len: + self.data = data + self.attention_mask = attention_mask + return + raw_data = [] + with open(path) as f: + for line in f.readlines(): + raw_data.append(json.loads(line)['text']) + tokenizer = GPT2Tokenizer.from_pretrained('gpt2') + tokenizer.pad_token = tokenizer.unk_token + encoded_data = tokenizer(raw_data, padding=True, truncation=True, max_length=seq_len, return_tensors='pt') + self.data = encoded_data['input_ids'] + self.attention_mask = encoded_data['attention_mask'] + else: + self.data = torch.randint(0, 50257, (10240, seq_len)) + self.attention_mask = torch.ones_like(self.data) + + def __len__(self): + return len(self.data) + + def __getitem__(self, index): + return {'input_ids': self.data[index], 'attention_mask': self.attention_mask[index]}, self.data[index] diff --git a/examples/language/gpt/titans/run.sh b/examples/language/gpt/titans/run.sh index 157bd377a..a1a7fc737 100644 --- a/examples/language/gpt/titans/run.sh +++ b/examples/language/gpt/titans/run.sh @@ -1,2 +1,3 @@ export DATA=/data/scratch/gpt_data/small-gpt-dataset.json -colossalai run --nproc_per_node=4 train_gpt.py --config ./configs/gpt2_small_zero3_pp1d.py --from_torch +DUMMY_DATA=--use_dummy_dataset +colossalai run --nproc_per_node=2 train_gpt.py --config ./configs/gpt2_small_zero3_pp1d.py --from_torch $DUMMY_DATA diff --git a/examples/language/gpt/titans/train_gpt.py b/examples/language/gpt/titans/train_gpt.py index 1380b4b3a..66225d6c8 100644 --- a/examples/language/gpt/titans/train_gpt.py +++ b/examples/language/gpt/titans/train_gpt.py @@ -3,6 +3,7 @@ import os import torch import torch.nn as nn +from dataset.webtext import WebtextDataset from titans.model.gpt import GPTLMLoss import colossalai @@ -30,7 +31,7 @@ VOCAB_SIZE = 50257 def main(): parser = colossalai.get_default_parser() parser.add_argument('--from_torch', default=False, action='store_true') - parser.add_argument('--use_dummy_dataset', default=True, action='store_true') + parser.add_argument('--use_dummy_dataset', default=False, action='store_true') args = parser.parse_args() disable_existing_loggers() if args.from_torch: @@ -39,52 +40,16 @@ def main(): colossalai.launch_from_slurm(config=args.config, host=args.host, port=29500, seed=42) logger = get_dist_logger() - if not args.use_dummy_dataset: - data_path = os.environ['DATA'] - logger.info(f'Build data loader from path {data_path}', ranks=[0]) - from dataset.webtext import WebtextDataset - train_ds = WebtextDataset(os.environ['DATA'], seq_len=gpc.config.SEQ_LEN) - train_dataloader = utils.get_dataloader(train_ds, - seed=42, - batch_size=gpc.config.BATCH_SIZE, - pin_memory=True, - shuffle=True, - drop_last=True) - else: - # build a dummy train_dataloader - logger.info('Build data loader using dummy data', ranks=[0]) + data_path = None if args.use_dummy_dataset else os.environ['DATA'] + logger.info(f'Build data loader from path {data_path}', ranks=[0]) - def get_data(batch_size, seq_len, vocab_size): - input_ids = torch.randint(0, vocab_size, (batch_size, seq_len), device=torch.cuda.current_device()) - attention_mask = torch.ones_like(input_ids) - return input_ids, attention_mask - - # 10 iterations - input_ids, attn_mask = get_data(gpc.config.BATCH_SIZE * 10, gpc.config.SEQ_LEN, VOCAB_SIZE) - from torch.utils.data import DataLoader, Dataset - - class TextSamplerDataset(Dataset): - - def __init__(self, data, seq_len): - super().__init__() - self.data = data - self.seq_len = seq_len - - def __getitem__(self, index): - rand_start = torch.randint(0, self.data.size(0) - self.seq_len, (1,)) - full_seq = self.data[rand_start:rand_start + self.seq_len + 1].long() - return full_seq.cuda() - - def __len__(self): - return self.data.size(0) // self.seq_len - - def cycle(loader): - while True: - for data in loader: - yield data - - train_dataset = TextSamplerDataset(input_ids, gpc.config.SEQ_LEN) - train_dataloader = DataLoader(train_dataset, batch_size=gpc.config.BATCH_SIZE) + train_ds = WebtextDataset(path=data_path, seq_len=gpc.config.SEQ_LEN) + train_dataloader = utils.get_dataloader(train_ds, + seed=42, + batch_size=gpc.config.BATCH_SIZE, + pin_memory=True, + shuffle=True, + drop_last=True) logger.info('Build model', ranks=[0]) use_pipeline = is_using_pp()