mirror of https://github.com/hpcaitech/ColossalAI
[hotfix] gpt example titans bug #2493
parent
8208fd023a
commit
a4b75b78a0
|
@ -12,11 +12,11 @@ TENSOR_SHAPE = (BATCH_SIZE // NUM_MICRO_BATCHES, SEQ_LEN, HIDDEN_SIZE)
|
|||
|
||||
# if you do no want zero, just comment out this dictionary
|
||||
zero = dict(model_config=dict(tensor_placement_policy='cuda', shard_strategy=TensorShardStrategy()),
|
||||
optimizer_config=dict(initial_scale=2**16))
|
||||
optimizer_config=dict(initial_scale=2**5))
|
||||
|
||||
optimizer = dict(
|
||||
type=HybridAdam,
|
||||
lr=0.00015,
|
||||
lr=0.000015,
|
||||
weight_decay=1e-2,
|
||||
)
|
||||
|
||||
|
|
|
@ -0,0 +1,39 @@
|
|||
import json
|
||||
import os
|
||||
|
||||
import torch
|
||||
from torch.utils.data import Dataset
|
||||
from transformers import GPT2Tokenizer
|
||||
|
||||
from colossalai.registry import DATASETS
|
||||
|
||||
|
||||
@DATASETS.register_module
|
||||
class WebtextDataset(Dataset):
|
||||
|
||||
def __init__(self, path, seq_len=1024) -> None:
|
||||
super().__init__()
|
||||
root = os.path.dirname(path)
|
||||
encoded_data_cache_path = os.path.join(root, f'gpt_webtext_{seq_len}.pt')
|
||||
if os.path.isfile(encoded_data_cache_path):
|
||||
seq_len_, data, attention_mask = torch.load(encoded_data_cache_path)
|
||||
if seq_len_ == seq_len:
|
||||
self.data = data
|
||||
self.attention_mask = attention_mask
|
||||
return
|
||||
raw_data = []
|
||||
with open(path) as f:
|
||||
for line in f.readlines():
|
||||
raw_data.append(json.loads(line)['text'])
|
||||
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
|
||||
tokenizer.pad_token = tokenizer.unk_token
|
||||
encoded_data = tokenizer(raw_data, padding=True, truncation=True, max_length=seq_len, return_tensors='pt')
|
||||
self.data = encoded_data['input_ids']
|
||||
self.attention_mask = encoded_data['attention_mask']
|
||||
torch.save((seq_len, self.data, self.attention_mask), encoded_data_cache_path)
|
||||
|
||||
def __len__(self):
|
||||
return len(self.data)
|
||||
|
||||
def __getitem__(self, index):
|
||||
return {'input_ids': self.data[index], 'attention_mask': self.attention_mask[index]}, self.data[index]
|
|
@ -30,7 +30,7 @@ VOCAB_SIZE = 50257
|
|||
def main():
|
||||
parser = colossalai.get_default_parser()
|
||||
parser.add_argument('--from_torch', default=False, action='store_true')
|
||||
parser.add_argument('--use_dummy_dataset', default=True, action='store_true')
|
||||
parser.add_argument('--use_dummy_dataset', default=False, action='store_true')
|
||||
args = parser.parse_args()
|
||||
disable_existing_loggers()
|
||||
if args.from_torch:
|
||||
|
|
Loading…
Reference in New Issue