ColossalAI/examples/language/roberta/pretraining/run_pretraining.py

268 lines
11 KiB
Python

import math
import os
import time
from functools import partial
import torch
from tqdm import tqdm
import os
import time
from functools import partial
from transformers import AutoTokenizer
import colossalai
from colossalai.context import ParallelMode
from colossalai.core import global_context as gpc
from colossalai.nn.parallel import GeminiDDP, zero_model_wrapper, zero_optim_wrapper
from colossalai.utils import get_current_device
from colossalai.utils.model.colo_init_context import ColoInitContext
from colossalai.zero import ZeroOptimizer
from colossalai.tensor import ColoParameter, ComputePattern, ComputeSpec, ProcessGroup, ReplicaSpec, ShardSpec
from arguments import parse_args
from evaluation import evaluate
from loss import LossForPretraining
from nvidia_bert_dataset_provider import NvidiaBertDatasetProvider
from pretrain_utils import get_lr_scheduler, get_model, get_optimizer, save_ckpt
from tqdm import tqdm
from transformers import AutoTokenizer
from utils.exp_util import get_mem_info, get_tflops, log_args, throughput_calculator
from utils.global_vars import get_tensorboard_writer, get_timers, set_global_variables
from utils.logger import Logger
def main():
args = parse_args()
launch_time = time.strftime("%Y-%m-%d-%H:%M:%S", time.localtime())
tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_path)
# os.environ['CUDA_LAUNCH_BLOCKING'] = '1'
logger = Logger(os.path.join(args.log_path, launch_time), cuda=torch.cuda.is_available(), debug=args.vscode_debug)
if args.vscode_debug:
colossalai.launch(config={},
rank=args.rank,
world_size=args.world_size,
host=args.host,
port=args.port,
backend=args.backend)
args.local_rank = -1
args.log_interval = 1
else:
colossalai.launch_from_torch(config={}) #args.colossal_config
args.local_rank = int(os.environ["LOCAL_RANK"])
logger.info(
f'launch_from_torch, world size: {torch.distributed.get_world_size()} | ' +
f'ParallelMode.MODEL: {ParallelMode.MODEL} | ParallelMode.DATA: {ParallelMode.DATA} | ParallelMode.TENSOR: {ParallelMode.TENSOR}'
)
log_args(logger, args)
args.tokenizer = tokenizer
args.logger = logger
set_global_variables(launch_time, args.tensorboard_path)
world_size = torch.distributed.get_world_size()
init_dev = get_current_device()
# build model, optimizer and criterion
if args.distplan.startswith("CAI"):
# all param must use the same process group.
world_size = torch.distributed.get_world_size()
shard_pg = ProcessGroup(tp_degree=world_size) if args.shardinit else None
default_dist_spec = ShardSpec([-1], [world_size]) if args.shardinit else None
if args.shardinit and args.distplan != "CAI_Gemini":
raise RuntimeError("You can only use shardinit with CAI_Gemini")
# build GPT model
with ColoInitContext(device=get_current_device(),
dtype=torch.half,
default_dist_spec=default_dist_spec,
default_pg=shard_pg):
config, model, numel = get_model(args, logger)
# asign running configurations
gemini_config = None
if args.distplan.startswith("CAI_ZeRO"):
optim_config = dict(reduce_bucket_size=12 * 1024 * 1024, overlap_communication=True, verbose=True)
elif args.distplan == "CAI_Gemini":
gemini_config = dict(strict_ddp_mode=args.tp_degree == 1,
device=get_current_device(),
placement_policy=args.placement,
pin_memory=True,
hidden_dim=model.config.hidden_size,
search_range_mb=128)
optim_config = dict(gpu_margin_mem_ratio=0.)
else:
raise RuntimeError
# build a highly optimized gpu/cpu optimizer
optimizer = get_optimizer(model, lr=args.lr)
if args.distplan == "CAI_ZeRO1":
zero_stage = 1
elif args.distplan == "CAI_ZeRO2":
zero_stage = 2
elif args.distplan == "CAI_Gemini":
zero_stage = 3
else:
raise RuntimeError
# wrap your model and optimizer
model = zero_model_wrapper(model, zero_stage, gemini_config)
optimizer = zero_optim_wrapper(model, optimizer, optim_config=optim_config)
logger.info(get_mem_info(prefix='After init optim, '))
else:
config, model, numel = get_model(args, logger)
logger.info("no_zero")
if torch.distributed.get_rank() == 0:
os.mkdir(os.path.join(args.ckpt_path, launch_time))
logger.info(f'Model numel: {numel}')
get_tflops_func = partial(get_tflops, numel, args.train_micro_batch_size_per_gpu, args.max_seq_length)
# 144003367 is is the length of the entire dataset
steps_per_epoch = 144003367 // world_size // args.train_micro_batch_size_per_gpu // args.gradient_accumulation_steps // args.refresh_bucket_size #len(dataloader)
total_steps = steps_per_epoch * args.epoch
lr_scheduler = get_lr_scheduler(optimizer, total_steps=total_steps, last_epoch=-1)
start_epoch = 0
start_shard = 0
global_step = 0
if args.resume_train:
assert os.path.exists(args.load_optimizer_lr)
o_l_state_dict = torch.load(args.load_optimizer_lr, map_location='cpu')
o_l_state_dict['lr_scheduler']['last_epoch'] = o_l_state_dict['lr_scheduler']['last_epoch'] - 1
optimizer.load_state_dict(o_l_state_dict['optimizer'])
# o_l_state_dict['lr_scheduler']['last_epoch']
lr_scheduler = get_lr_scheduler(optimizer,
total_steps=total_steps,
last_epoch=o_l_state_dict['lr_scheduler']['last_epoch'])
for state in optimizer.state.values():
for k, v in state.items():
if isinstance(v, torch.Tensor):
state[k] = v.cuda(f"cuda:{torch.cuda.current_device()}")
# if you want delete the above three code, must move the model to gpu. Because in optimizer.step()
lr_scheduler.load_state_dict(o_l_state_dict['lr_scheduler'])
start_epoch = o_l_state_dict['epoch']
start_shard = o_l_state_dict['shard'] + 1
# global_step = o_l_state_dict['global_step'] + 1
logger.info(f'resume from epoch {start_epoch} shard {start_shard} step {lr_scheduler.last_epoch} lr {lr_scheduler.get_last_lr()[0]}')
criterion = LossForPretraining(config.vocab_size)
# build dataloader
pretrain_dataset_provider = NvidiaBertDatasetProvider(args)
logger.info(get_mem_info(prefix='After init model, '))
best_loss = None
eval_loss = 0
train_loss = 0
timers = get_timers()
timers('interval_time').start()
timers('epoch_time').start()
timers('shard_time').start()
for epoch in range(start_epoch, args.epoch):
for shard in range(start_shard, len(os.listdir(args.data_path_prefix))):
dataset_iterator, total_length = pretrain_dataset_provider.get_shard(shard)
# pretrain_dataset_provider.prefetch_shard(shard + 1) # may cause cpu memory overload
if torch.distributed.get_rank() == 0:
iterator_data = tqdm(enumerate(dataset_iterator),
total=(total_length // args.train_micro_batch_size_per_gpu // world_size),
colour='cyan',
smoothing=1)
else:
iterator_data = enumerate(dataset_iterator)
model.train()
for step, batch_data in iterator_data:
# batch_data = pretrain_dataset_provider.get_batch(batch_index)
input_ids = batch_data[0].cuda(f"cuda:{torch.cuda.current_device()}")
attention_mask = batch_data[1].cuda(f"cuda:{torch.cuda.current_device()}")
token_type_ids = batch_data[2].cuda(f"cuda:{torch.cuda.current_device()}")
mlm_label = batch_data[3].cuda(f"cuda:{torch.cuda.current_device()}")
# nsp_label = batch_data[5].cuda()
output = model(input_ids=input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask)
loss = criterion(output.logits, mlm_label)
pretrain_dataset_provider.prefetch_batch()
optimizer.backward(loss)
train_loss += loss.float().item()
# if (step + 1) % args.accumulation_step == 0:
optimizer.step()
lr_scheduler.step()
optimizer.zero_grad()
global_step += 1
if global_step % args.log_interval == 0 and global_step != 0 \
and torch.distributed.get_rank() == 0:
elapsed_time = timers('interval_time').elapsed(reset=False)
elapsed_time_per_iteration = elapsed_time / global_step
samples_per_sec, tflops, approx_parameters_in_billions = throughput_calculator(
numel, args, config, elapsed_time, global_step, world_size)
cur_loss = train_loss / args.log_interval
current_lr = lr_scheduler.get_last_lr()[0]
log_str = f'| epoch: {epoch} | shard: {shard} | step: {global_step} | lr {current_lr:.7f} | elapsed_time: {elapsed_time / 60 :.3f} minutes ' + \
f'| mins/batch: {elapsed_time_per_iteration :.3f} seconds | loss: {cur_loss:.7f} | ppl: {math.exp(cur_loss):.3f} | TFLOPS: {get_tflops_func(elapsed_time_per_iteration):.3f} or {tflops:.3f}'
logger.info(log_str, print_=False)
if args.wandb:
tensorboard_log = get_tensorboard_writer()
tensorboard_log.log_train(
{
'lr': current_lr,
'loss': cur_loss,
'ppl': math.exp(cur_loss),
'mins_batch': elapsed_time_per_iteration
}, global_step)
train_loss = 0
logger.info(f'epoch {epoch} shard {shard} has cost {timers("shard_time").elapsed() / 60 :.3f} mins')
logger.info('*' * 100)
eval_loss += evaluate(model, args, logger, global_step, criterion)
save_ckpt(model, optimizer, lr_scheduler, os.path.join(args.ckpt_path, launch_time, f'epoch-{epoch}_shard-{shard}_' + launch_time), epoch, shard, global_step)
eval_loss /= len(os.listdir(args.data_path_prefix))
logger.info(
f'epoch {epoch} | shard_length {len(os.listdir(args.data_path_prefix))} | elapsed_time: {timers("epoch_time").elapsed() / 60 :.3f} mins'
+ f'eval_loss: {eval_loss} | ppl: {math.exp(eval_loss)}')
logger.info('-' * 100)
if args.wandb and torch.distributed.get_rank() == 0:
tensorboard_log = get_tensorboard_writer()
tensorboard_log.log_eval({
'all_eval_shard_loss': eval_loss,
}, epoch)
start_shard = 0
eval_loss = 0
pretrain_dataset_provider.release_shard()
logger.info('Congratulation, training has finished!!!')
if __name__ == '__main__':
main()