import math import os import time from functools import partial import torch from tqdm import tqdm import os import time from functools import partial from transformers import AutoTokenizer import colossalai from colossalai.context import ParallelMode from colossalai.core import global_context as gpc from colossalai.nn.parallel import GeminiDDP, zero_model_wrapper, zero_optim_wrapper from colossalai.utils import get_current_device from colossalai.utils.model.colo_init_context import ColoInitContext from colossalai.zero import ZeroOptimizer from colossalai.tensor import ColoParameter, ComputePattern, ComputeSpec, ProcessGroup, ReplicaSpec, ShardSpec from arguments import parse_args from evaluation import evaluate from loss import LossForPretraining from nvidia_bert_dataset_provider import NvidiaBertDatasetProvider from pretrain_utils import get_lr_scheduler, get_model, get_optimizer, save_ckpt from tqdm import tqdm from transformers import AutoTokenizer from utils.exp_util import get_mem_info, get_tflops, log_args, throughput_calculator from utils.global_vars import get_tensorboard_writer, get_timers, set_global_variables from utils.logger import Logger def main(): args = parse_args() launch_time = time.strftime("%Y-%m-%d-%H:%M:%S", time.localtime()) tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_path) # os.environ['CUDA_LAUNCH_BLOCKING'] = '1' logger = Logger(os.path.join(args.log_path, launch_time), cuda=torch.cuda.is_available(), debug=args.vscode_debug) if args.vscode_debug: colossalai.launch(config={}, rank=args.rank, world_size=args.world_size, host=args.host, port=args.port, backend=args.backend) args.local_rank = -1 args.log_interval = 1 else: colossalai.launch_from_torch(config={}) #args.colossal_config args.local_rank = int(os.environ["LOCAL_RANK"]) logger.info( f'launch_from_torch, world size: {torch.distributed.get_world_size()} | ' + f'ParallelMode.MODEL: {ParallelMode.MODEL} | ParallelMode.DATA: {ParallelMode.DATA} | ParallelMode.TENSOR: {ParallelMode.TENSOR}' ) log_args(logger, args) args.tokenizer = tokenizer args.logger = logger set_global_variables(launch_time, args.tensorboard_path) world_size = torch.distributed.get_world_size() init_dev = get_current_device() # build model, optimizer and criterion if args.distplan.startswith("CAI"): # all param must use the same process group. world_size = torch.distributed.get_world_size() shard_pg = ProcessGroup(tp_degree=world_size) if args.shardinit else None default_dist_spec = ShardSpec([-1], [world_size]) if args.shardinit else None if args.shardinit and args.distplan != "CAI_Gemini": raise RuntimeError("You can only use shardinit with CAI_Gemini") # build GPT model with ColoInitContext(device=get_current_device(), dtype=torch.half, default_dist_spec=default_dist_spec, default_pg=shard_pg): config, model, numel = get_model(args, logger) # asign running configurations gemini_config = None if args.distplan.startswith("CAI_ZeRO"): optim_config = dict(reduce_bucket_size=12 * 1024 * 1024, overlap_communication=True, verbose=True) elif args.distplan == "CAI_Gemini": gemini_config = dict(strict_ddp_mode=args.tp_degree == 1, device=get_current_device(), placement_policy=args.placement, pin_memory=True, hidden_dim=model.config.hidden_size, search_range_mb=128) optim_config = dict(gpu_margin_mem_ratio=0.) else: raise RuntimeError # build a highly optimized gpu/cpu optimizer optimizer = get_optimizer(model, lr=args.lr) if args.distplan == "CAI_ZeRO1": zero_stage = 1 elif args.distplan == "CAI_ZeRO2": zero_stage = 2 elif args.distplan == "CAI_Gemini": zero_stage = 3 else: raise RuntimeError # wrap your model and optimizer model = zero_model_wrapper(model, zero_stage, gemini_config) optimizer = zero_optim_wrapper(model, optimizer, optim_config=optim_config) logger.info(get_mem_info(prefix='After init optim, ')) else: config, model, numel = get_model(args, logger) logger.info("no_zero") if torch.distributed.get_rank() == 0: os.mkdir(os.path.join(args.ckpt_path, launch_time)) logger.info(f'Model numel: {numel}') get_tflops_func = partial(get_tflops, numel, args.train_micro_batch_size_per_gpu, args.max_seq_length) # 144003367 is is the length of the entire dataset steps_per_epoch = 144003367 // world_size // args.train_micro_batch_size_per_gpu // args.gradient_accumulation_steps // args.refresh_bucket_size #len(dataloader) total_steps = steps_per_epoch * args.epoch lr_scheduler = get_lr_scheduler(optimizer, total_steps=total_steps, last_epoch=-1) start_epoch = 0 start_shard = 0 global_step = 0 if args.resume_train: assert os.path.exists(args.load_optimizer_lr) o_l_state_dict = torch.load(args.load_optimizer_lr, map_location='cpu') o_l_state_dict['lr_scheduler']['last_epoch'] = o_l_state_dict['lr_scheduler']['last_epoch'] - 1 optimizer.load_state_dict(o_l_state_dict['optimizer']) # o_l_state_dict['lr_scheduler']['last_epoch'] lr_scheduler = get_lr_scheduler(optimizer, total_steps=total_steps, last_epoch=o_l_state_dict['lr_scheduler']['last_epoch']) for state in optimizer.state.values(): for k, v in state.items(): if isinstance(v, torch.Tensor): state[k] = v.cuda(f"cuda:{torch.cuda.current_device()}") # if you want delete the above three code, must move the model to gpu. Because in optimizer.step() lr_scheduler.load_state_dict(o_l_state_dict['lr_scheduler']) start_epoch = o_l_state_dict['epoch'] start_shard = o_l_state_dict['shard'] + 1 # global_step = o_l_state_dict['global_step'] + 1 logger.info(f'resume from epoch {start_epoch} shard {start_shard} step {lr_scheduler.last_epoch} lr {lr_scheduler.get_last_lr()[0]}') criterion = LossForPretraining(config.vocab_size) # build dataloader pretrain_dataset_provider = NvidiaBertDatasetProvider(args) logger.info(get_mem_info(prefix='After init model, ')) best_loss = None eval_loss = 0 train_loss = 0 timers = get_timers() timers('interval_time').start() timers('epoch_time').start() timers('shard_time').start() for epoch in range(start_epoch, args.epoch): for shard in range(start_shard, len(os.listdir(args.data_path_prefix))): dataset_iterator, total_length = pretrain_dataset_provider.get_shard(shard) # pretrain_dataset_provider.prefetch_shard(shard + 1) # may cause cpu memory overload if torch.distributed.get_rank() == 0: iterator_data = tqdm(enumerate(dataset_iterator), total=(total_length // args.train_micro_batch_size_per_gpu // world_size), colour='cyan', smoothing=1) else: iterator_data = enumerate(dataset_iterator) model.train() for step, batch_data in iterator_data: # batch_data = pretrain_dataset_provider.get_batch(batch_index) input_ids = batch_data[0].cuda(f"cuda:{torch.cuda.current_device()}") attention_mask = batch_data[1].cuda(f"cuda:{torch.cuda.current_device()}") token_type_ids = batch_data[2].cuda(f"cuda:{torch.cuda.current_device()}") mlm_label = batch_data[3].cuda(f"cuda:{torch.cuda.current_device()}") # nsp_label = batch_data[5].cuda() output = model(input_ids=input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask) loss = criterion(output.logits, mlm_label) pretrain_dataset_provider.prefetch_batch() optimizer.backward(loss) train_loss += loss.float().item() # if (step + 1) % args.accumulation_step == 0: optimizer.step() lr_scheduler.step() optimizer.zero_grad() global_step += 1 if global_step % args.log_interval == 0 and global_step != 0 \ and torch.distributed.get_rank() == 0: elapsed_time = timers('interval_time').elapsed(reset=False) elapsed_time_per_iteration = elapsed_time / global_step samples_per_sec, tflops, approx_parameters_in_billions = throughput_calculator( numel, args, config, elapsed_time, global_step, world_size) cur_loss = train_loss / args.log_interval current_lr = lr_scheduler.get_last_lr()[0] log_str = f'| epoch: {epoch} | shard: {shard} | step: {global_step} | lr {current_lr:.7f} | elapsed_time: {elapsed_time / 60 :.3f} minutes ' + \ f'| mins/batch: {elapsed_time_per_iteration :.3f} seconds | loss: {cur_loss:.7f} | ppl: {math.exp(cur_loss):.3f} | TFLOPS: {get_tflops_func(elapsed_time_per_iteration):.3f} or {tflops:.3f}' logger.info(log_str, print_=False) if args.wandb: tensorboard_log = get_tensorboard_writer() tensorboard_log.log_train( { 'lr': current_lr, 'loss': cur_loss, 'ppl': math.exp(cur_loss), 'mins_batch': elapsed_time_per_iteration }, global_step) train_loss = 0 logger.info(f'epoch {epoch} shard {shard} has cost {timers("shard_time").elapsed() / 60 :.3f} mins') logger.info('*' * 100) eval_loss += evaluate(model, args, logger, global_step, criterion) save_ckpt(model, optimizer, lr_scheduler, os.path.join(args.ckpt_path, launch_time, f'epoch-{epoch}_shard-{shard}_' + launch_time), epoch, shard, global_step) eval_loss /= len(os.listdir(args.data_path_prefix)) logger.info( f'epoch {epoch} | shard_length {len(os.listdir(args.data_path_prefix))} | elapsed_time: {timers("epoch_time").elapsed() / 60 :.3f} mins' + f'eval_loss: {eval_loss} | ppl: {math.exp(eval_loss)}') logger.info('-' * 100) if args.wandb and torch.distributed.get_rank() == 0: tensorboard_log = get_tensorboard_writer() tensorboard_log.log_eval({ 'all_eval_shard_loss': eval_loss, }, epoch) start_shard = 0 eval_loss = 0 pretrain_dataset_provider.release_shard() logger.info('Congratulation, training has finished!!!') if __name__ == '__main__': main()