|
|
|
@ -1,21 +1,22 @@
|
|
|
|
|
import argparse |
|
|
|
|
|
|
|
|
|
import torch |
|
|
|
|
from data import build_train_valid_test_data_iterators |
|
|
|
|
from data.bert_helper import SequenceParallelDataIterator, get_batch_for_sequence_parallel |
|
|
|
|
from data.tokenizer import get_padded_vocab_size, initialize_tokenizer |
|
|
|
|
from loss_func.bert_loss import BertLoss |
|
|
|
|
from lr_scheduler import AnnealingLR |
|
|
|
|
from model.bert import BertForPretrain, build_pipeline_bert |
|
|
|
|
|
|
|
|
|
import colossalai |
|
|
|
|
from colossalai.amp import AMP_TYPE |
|
|
|
|
from colossalai.context.parallel_mode import ParallelMode |
|
|
|
|
from colossalai.core import global_context as gpc |
|
|
|
|
from data import build_train_valid_test_data_iterators |
|
|
|
|
from data.tokenizer import initialize_tokenizer, get_padded_vocab_size |
|
|
|
|
from data.bert_helper import get_batch_for_sequence_parallel, SequenceParallelDataIterator |
|
|
|
|
from colossalai.amp import AMP_TYPE |
|
|
|
|
from colossalai.logging import get_dist_logger |
|
|
|
|
from colossalai.utils import MultiTimer, is_using_pp |
|
|
|
|
from model.bert import BertForPretrain |
|
|
|
|
from lr_scheduler import AnnealingLR |
|
|
|
|
from loss_func.bert_loss import BertLoss |
|
|
|
|
import torch |
|
|
|
|
from colossalai.engine.schedule import PipelineSchedule |
|
|
|
|
from colossalai.amp import AMP_TYPE |
|
|
|
|
from colossalai.nn.optimizer import FusedAdam |
|
|
|
|
from colossalai.kernel import LayerNorm |
|
|
|
|
from model.bert import build_pipeline_bert |
|
|
|
|
from colossalai.logging import get_dist_logger |
|
|
|
|
from colossalai.nn.optimizer import FusedAdam |
|
|
|
|
from colossalai.utils import MultiTimer, is_using_pp |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def process_batch_data(batch_data): |
|
|
|
@ -28,30 +29,49 @@ def process_batch_data(batch_data):
|
|
|
|
|
return data, label |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def parse_args(): |
|
|
|
|
parser = argparse.ArgumentParser() |
|
|
|
|
parser.add_argument('-s', '--synthetic', action="store_true", help="whether use synthetic data") |
|
|
|
|
return parser.parse_args() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def main(): |
|
|
|
|
# initialize |
|
|
|
|
args = parse_args() |
|
|
|
|
colossalai.launch_from_torch(config='./config.py', seed=1234, backend='nccl') |
|
|
|
|
|
|
|
|
|
logger = get_dist_logger() |
|
|
|
|
|
|
|
|
|
# build dataloader |
|
|
|
|
initialize_tokenizer(gpc.config.VOCAB_FILE_PATH, tokenizer_type='BertWordPieceLowerCase') |
|
|
|
|
VOCAB_SIZE = get_padded_vocab_size() |
|
|
|
|
trainloader, validloader, testloader = build_train_valid_test_data_iterators( |
|
|
|
|
train_iters=gpc.config.TRAIN_ITERS, |
|
|
|
|
global_batch_size=gpc.config.GLOBAL_BATCH_SIZE, |
|
|
|
|
eval_interval=gpc.config.EVAL_INTERVAL, |
|
|
|
|
eval_iters=gpc.config.EVAL_ITERS, |
|
|
|
|
data_prefix=[gpc.config.DATA_PATH], |
|
|
|
|
data_impl='mmap', |
|
|
|
|
splits_string='949,50,1', |
|
|
|
|
max_seq_length=gpc.config.SEQ_LENGTH, |
|
|
|
|
masked_lm_prob=0.15, |
|
|
|
|
short_seq_prob=0.1, |
|
|
|
|
seed=1234, |
|
|
|
|
skip_warmup=True, |
|
|
|
|
binary_head=False, |
|
|
|
|
) |
|
|
|
|
if not args.synthetic: |
|
|
|
|
initialize_tokenizer(gpc.config.VOCAB_FILE_PATH, tokenizer_type='BertWordPieceLowerCase') |
|
|
|
|
VOCAB_SIZE = get_padded_vocab_size() |
|
|
|
|
trainloader, validloader, testloader = build_train_valid_test_data_iterators( |
|
|
|
|
train_iters=gpc.config.TRAIN_ITERS, |
|
|
|
|
global_batch_size=gpc.config.GLOBAL_BATCH_SIZE, |
|
|
|
|
eval_interval=gpc.config.EVAL_INTERVAL, |
|
|
|
|
eval_iters=gpc.config.EVAL_ITERS, |
|
|
|
|
data_prefix=[gpc.config.DATA_PATH], |
|
|
|
|
data_impl='mmap', |
|
|
|
|
splits_string='949,50,1', |
|
|
|
|
max_seq_length=gpc.config.SEQ_LENGTH, |
|
|
|
|
masked_lm_prob=0.15, |
|
|
|
|
short_seq_prob=0.1, |
|
|
|
|
seed=1234, |
|
|
|
|
skip_warmup=True, |
|
|
|
|
binary_head=False, |
|
|
|
|
) |
|
|
|
|
else: |
|
|
|
|
from data.dummy_dataloader import DummyDataloader |
|
|
|
|
|
|
|
|
|
BATCH_SIZE_PER_GPUS = gpc.config.GLOBAL_BATCH_SIZE // gpc.get_world_size(ParallelMode.DATA) |
|
|
|
|
VOCAB_SIZE = 30528 |
|
|
|
|
trainloader = DummyDataloader(batch_size=BATCH_SIZE_PER_GPUS, |
|
|
|
|
vocab_size=VOCAB_SIZE, |
|
|
|
|
seq_length=gpc.config.SEQ_LENGTH) |
|
|
|
|
validloader = DummyDataloader(batch_size=BATCH_SIZE_PER_GPUS, |
|
|
|
|
vocab_size=VOCAB_SIZE, |
|
|
|
|
seq_length=gpc.config.SEQ_LENGTH) |
|
|
|
|
|
|
|
|
|
logger.info("Dataloaders are built", ranks=[0]) |
|
|
|
|
|
|
|
|
@ -121,11 +141,7 @@ def main():
|
|
|
|
|
logger.info(f"LR Scheduler is built with {warmup_steps} warmup steps and {gpc.config.DECAY_ITERS} decay steps") |
|
|
|
|
|
|
|
|
|
# # init |
|
|
|
|
engine, *dummy = colossalai.initialize( |
|
|
|
|
model, |
|
|
|
|
optimizer, |
|
|
|
|
criterion, |
|
|
|
|
) |
|
|
|
|
engine, *dummy = colossalai.initialize(model, optimizer, criterion, verbose=True) |
|
|
|
|
|
|
|
|
|
# build timer |
|
|
|
|
timer = MultiTimer() |
|
|
|
@ -140,6 +156,8 @@ def main():
|
|
|
|
|
train_data_iter = SequenceParallelDataIterator(trainloader) |
|
|
|
|
valid_data_iter = SequenceParallelDataIterator(validloader) |
|
|
|
|
|
|
|
|
|
logger.info("start training") |
|
|
|
|
|
|
|
|
|
for step in range(1, gpc.config.TRAIN_ITERS + 1): |
|
|
|
|
timer.start('train-iterations') |
|
|
|
|
engine.train() |
|
|
|
|