|
|
|
@ -1,21 +1,22 @@
|
|
|
|
|
import argparse
|
|
|
|
|
|
|
|
|
|
import torch
|
|
|
|
|
from data import build_train_valid_test_data_iterators
|
|
|
|
|
from data.bert_helper import SequenceParallelDataIterator, get_batch_for_sequence_parallel
|
|
|
|
|
from data.tokenizer import get_padded_vocab_size, initialize_tokenizer
|
|
|
|
|
from loss_func.bert_loss import BertLoss
|
|
|
|
|
from lr_scheduler import AnnealingLR
|
|
|
|
|
from model.bert import BertForPretrain, build_pipeline_bert
|
|
|
|
|
|
|
|
|
|
import colossalai
|
|
|
|
|
from colossalai.amp import AMP_TYPE
|
|
|
|
|
from colossalai.context.parallel_mode import ParallelMode
|
|
|
|
|
from colossalai.core import global_context as gpc
|
|
|
|
|
from data import build_train_valid_test_data_iterators
|
|
|
|
|
from data.tokenizer import initialize_tokenizer, get_padded_vocab_size
|
|
|
|
|
from data.bert_helper import get_batch_for_sequence_parallel, SequenceParallelDataIterator
|
|
|
|
|
from colossalai.amp import AMP_TYPE
|
|
|
|
|
from colossalai.logging import get_dist_logger
|
|
|
|
|
from colossalai.utils import MultiTimer, is_using_pp
|
|
|
|
|
from model.bert import BertForPretrain
|
|
|
|
|
from lr_scheduler import AnnealingLR
|
|
|
|
|
from loss_func.bert_loss import BertLoss
|
|
|
|
|
import torch
|
|
|
|
|
from colossalai.engine.schedule import PipelineSchedule
|
|
|
|
|
from colossalai.amp import AMP_TYPE
|
|
|
|
|
from colossalai.nn.optimizer import FusedAdam
|
|
|
|
|
from colossalai.kernel import LayerNorm
|
|
|
|
|
from model.bert import build_pipeline_bert
|
|
|
|
|
from colossalai.logging import get_dist_logger
|
|
|
|
|
from colossalai.nn.optimizer import FusedAdam
|
|
|
|
|
from colossalai.utils import MultiTimer, is_using_pp
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def process_batch_data(batch_data):
|
|
|
|
@ -28,13 +29,21 @@ def process_batch_data(batch_data):
|
|
|
|
|
return data, label
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def parse_args():
|
|
|
|
|
parser = argparse.ArgumentParser()
|
|
|
|
|
parser.add_argument('-s', '--synthetic', action="store_true", help="whether use synthetic data")
|
|
|
|
|
return parser.parse_args()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def main():
|
|
|
|
|
# initialize
|
|
|
|
|
args = parse_args()
|
|
|
|
|
colossalai.launch_from_torch(config='./config.py', seed=1234, backend='nccl')
|
|
|
|
|
|
|
|
|
|
logger = get_dist_logger()
|
|
|
|
|
|
|
|
|
|
# build dataloader
|
|
|
|
|
if not args.synthetic:
|
|
|
|
|
initialize_tokenizer(gpc.config.VOCAB_FILE_PATH, tokenizer_type='BertWordPieceLowerCase')
|
|
|
|
|
VOCAB_SIZE = get_padded_vocab_size()
|
|
|
|
|
trainloader, validloader, testloader = build_train_valid_test_data_iterators(
|
|
|
|
@ -52,6 +61,17 @@ def main():
|
|
|
|
|
skip_warmup=True,
|
|
|
|
|
binary_head=False,
|
|
|
|
|
)
|
|
|
|
|
else:
|
|
|
|
|
from data.dummy_dataloader import DummyDataloader
|
|
|
|
|
|
|
|
|
|
BATCH_SIZE_PER_GPUS = gpc.config.GLOBAL_BATCH_SIZE // gpc.get_world_size(ParallelMode.DATA)
|
|
|
|
|
VOCAB_SIZE = 30528
|
|
|
|
|
trainloader = DummyDataloader(batch_size=BATCH_SIZE_PER_GPUS,
|
|
|
|
|
vocab_size=VOCAB_SIZE,
|
|
|
|
|
seq_length=gpc.config.SEQ_LENGTH)
|
|
|
|
|
validloader = DummyDataloader(batch_size=BATCH_SIZE_PER_GPUS,
|
|
|
|
|
vocab_size=VOCAB_SIZE,
|
|
|
|
|
seq_length=gpc.config.SEQ_LENGTH)
|
|
|
|
|
|
|
|
|
|
logger.info("Dataloaders are built", ranks=[0])
|
|
|
|
|
|
|
|
|
@ -121,11 +141,7 @@ def main():
|
|
|
|
|
logger.info(f"LR Scheduler is built with {warmup_steps} warmup steps and {gpc.config.DECAY_ITERS} decay steps")
|
|
|
|
|
|
|
|
|
|
# # init
|
|
|
|
|
engine, *dummy = colossalai.initialize(
|
|
|
|
|
model,
|
|
|
|
|
optimizer,
|
|
|
|
|
criterion,
|
|
|
|
|
)
|
|
|
|
|
engine, *dummy = colossalai.initialize(model, optimizer, criterion, verbose=True)
|
|
|
|
|
|
|
|
|
|
# build timer
|
|
|
|
|
timer = MultiTimer()
|
|
|
|
@ -140,6 +156,8 @@ def main():
|
|
|
|
|
train_data_iter = SequenceParallelDataIterator(trainloader)
|
|
|
|
|
valid_data_iter = SequenceParallelDataIterator(validloader)
|
|
|
|
|
|
|
|
|
|
logger.info("start training")
|
|
|
|
|
|
|
|
|
|
for step in range(1, gpc.config.TRAIN_ITERS + 1):
|
|
|
|
|
timer.start('train-iterations')
|
|
|
|
|
engine.train()
|
|
|
|
|