#!/usr/bin/env python # -*- encoding: utf-8 -*- import socket import time import traceback from functools import partial from typing import Iterable import torch import torch.distributed as dist from torch import nn from torch.utils.data import DataLoader import internlm from internlm.core.context import ParallelMode from internlm.core.context import global_context as gpc from internlm.core.naive_amp import NaiveAMPModel from internlm.core.trainer import TrainState from internlm.data.batch_sampler import StaticBatchSampler from internlm.data.collaters import packed_collate_fn from internlm.data.dummy_dataset import RandomDataset from internlm.data.packed_dataset import ( PackedDataset, PackedDatasetWithoutCuSeqlen, get_packed_dataset_without_short_length, ) from internlm.data.utils import DATASET_TYPE_IDS_MAP from internlm.model.loss import FlashGPTLMLoss from internlm.solver.beta2_scheduler import Beta2Scheduler from internlm.solver.lr_scheduler import FineTuneCosineAnnealingWarmupLR from internlm.solver.optimizer.hybrid_zero_optim import HybridZeroOptimizer from internlm.utils.common import ( BatchSkipper, get_master_node, get_megatron_flops, get_process_rank, launch_time, parse_args, ) from internlm.utils.logger import get_logger from internlm.utils.megatron_timers import megatron_timer as timer from internlm.utils.model_checkpoint import ( load_context, load_model_checkpoint, load_optimizer_checkpoint, load_sampler, load_scheduler, save_checkpoint, ) from internlm.utils.parallel import ( is_no_pp_or_last_stage, sync_model_param, sync_model_param_within_tp, ) from internlm.utils.registry import MODEL_INITIALIZER # global llm logger logger = get_logger(__file__) def initialize_distributed_env(config: str, launcher: str = "slurm", master_port: int = 8888, seed: int = 1024): """ Initialize distributed environment for distributed training. Args: config (str): Config file path. launcher (str): Launcher for launching distributed environment, can be slurm or torch. "slurm" by default. master_port (str): The master port for distributed training. 8888 by default. seed (int, optional): Specified random seed for every process. 1024 by default. """ torch.cuda.empty_cache() if launcher == "torch": internlm.launch_from_torch(config=config, seed=seed) elif launcher == "slurm": internlm.launch_from_slurm( config=config, host=get_master_node(), port=master_port, seed=seed, ) else: assert launcher in ["slurm", "torch"], "launcher only support slurm or torch" def initialize_model(): """ Initialize model. Returns: The neural network model to be trained or evaluated. """ assert ( not hasattr(gpc.config.parallel, "pipeline") or gpc.config.parallel.pipeline == 1 ), "Pipeline parallelism is not supported for now." model = MODEL_INITIALIZER.get_module(module_name=gpc.config.model_type)(**(gpc.config.model)) model = NaiveAMPModel( model=model, output_to_fp32=is_no_pp_or_last_stage(), dtype=gpc.config.model.get("dtype", torch.half), sync_buffer=False, ) # This sync is very important, cause the model weights kept in optimizer are copied # from the origin parameters in the memory, so we should make sure the dp sync # does not influence the model weights in optimizer be different with the origin parameters. sync_model_param(model, parallel_mode=ParallelMode.DATA) # This function is needed to make sure parameters that are not splitted by tensor parallelism are # the same across tensor parallelism. sync_model_param_within_tp(model) return model def get_train_data_loader(num_worker: int = 0): """ Generate and return the training data loader. Returns: A tuple of (train_dl, dataset_types). """ # Get the dataset types dataset_types = None dataset_types = list(DATASET_TYPE_IDS_MAP.keys()) data_cfg = gpc.config.data # Get the sample weight dictionary train_folder = data_cfg.train_folder if not train_folder: train_ds = RandomDataset(num_samples=1000000, max_len=data_cfg.seq_len) if data_cfg.pack_sample_into_one: train_ds = PackedDatasetWithoutCuSeqlen( train_ds, max_length_per_sample=data_cfg.seq_len, packed_length=data_cfg.packed_length ) else: train_ds = PackedDataset( train_ds, max_length_per_sample=data_cfg.seq_len, packed_length=data_cfg.packed_length ) else: train_ds = get_packed_dataset_without_short_length( folder=data_cfg.train_folder, packed_length=data_cfg.packed_length, max_length_per_sample=data_cfg.seq_len, show_progress=dist.get_rank() == 0, min_length=data_cfg.min_length, min_length_dict=data_cfg.get("min_length_dict", {}), pack_into_one_sample=data_cfg.pack_sample_into_one, ) # partition already completed # assert isinstance(train_ds, (PackedDataset, PackedDatasetWithoutCuSeqlen)) if isinstance(train_ds, (PackedDataset, PackedDatasetWithoutCuSeqlen)): datasets = [train_ds] else: datasets = train_ds.datasets # Create the training dataset sampler train_sampler = StaticBatchSampler( datasets, batch_size=data_cfg.micro_num, rampup_batch_size=data_cfg.rampup_batch_size, micro_bsz=data_cfg.micro_bsz, seed=1024, drop_last=True, data_rank=gpc.get_local_rank(ParallelMode.DATA), data_world_size=gpc.get_world_size(ParallelMode.DATA), ) train_collate_fn = partial(packed_collate_fn, packed_length=data_cfg.packed_length) # Create the training data loader train_dl = DataLoader( dataset=train_ds, batch_sampler=train_sampler, num_workers=num_worker, pin_memory=True, collate_fn=train_collate_fn, persistent_workers=True, ) return train_dl, dataset_types def load_new_batch(train_dl: DataLoader, train_iter: Iterable, train_state: TrainState): """ Load and return the new batch data based on training data loader. Args: train_dl (torch.utils.data.DataLoader): Dataloader for training. train_iter (Iterable): Data iterator from which get a batch of data, obtained by calling iter(dataloader). train_state (TrainState): Current training state. Returns: A batch data and the updated train_iter. """ timer("batch-gen").start() try: batch = next(train_iter) # structure is ({'input_ids': Tensor, 'cu_seqlens': Tensor}, Tensor) next(train_state.batch_sampler_iter) except StopIteration: train_iter = iter(train_dl) batch = next(train_iter) train_state.batch_sampler_iter = iter(train_state.batch_sampler) next(train_state.batch_sampler_iter) train_state.num_consumed_samples_in_epoch = 0 timer("batch-gen").stop() batch[0].pop("type_ids", None) return batch, train_iter def initialize_optimizer(model: nn.Module): """ Initialize optimizer. Args: model (torch.nn.Module): Your model instance to be trained or evaluated. Returns: A tuple of (optimizer, beta2_scheduler, lr_scheduler). """ adam_cfg = gpc.config.adam naive_optimizer = torch.optim.AdamW( params=[{"params": model.parameters(), "weight_decay": adam_cfg.weight_decay}], lr=adam_cfg.lr, betas=(adam_cfg.adam_beta1, adam_cfg.adam_beta2), eps=adam_cfg.adam_eps, ) optimizer = HybridZeroOptimizer( naive_optimizer, grad_scal_cfg=gpc.config.grad_scaler, zero_cfg=gpc.config.hybrid_zero_optimizer ) beta2_scheduler = Beta2Scheduler(optimizer=naive_optimizer, **gpc.config.beta2_scheduler) lr_scheduler = FineTuneCosineAnnealingWarmupLR(optimizer, **gpc.config.lr_scheduler) return optimizer, beta2_scheduler, lr_scheduler def record_current_batch_training_metrics( get_tflops_func, logger, success_update, batch_count, batch, train_state, optimizer, beta2_scheduler, trainer, start_time, loss, grad_norm, ): """ Print some training metrics of current batch. """ if success_update in (0, True): train_state.num_consumed_tokens += batch[1].nelement() * gpc.get_world_size(ParallelMode.DATA) if success_update and gpc.is_rank_for_log(): lr = optimizer.param_groups[0]["lr"] if hasattr(trainer.engine.optimizer, "grad_scaler"): scaler = trainer.engine.optimizer.grad_scaler._scale.item() elif hasattr(trainer.engine.optimizer.optim, "grad_scaler"): scaler = trainer.engine.optimizer.optim.grad_scaler._scale.item() num_tokens_in_batch = batch[1].nelement() num_samples_in_batch = sum([len(b) - 1 for b in batch[0]["cu_seqlens"]]) max_length_in_batch = max([(b[1:] - b[:-1]).max().item() for b in batch[0]["cu_seqlens"]]) max_samples_in_batch = max([len(b) - 1 for b in batch[0]["cu_seqlens"]]) min_samples_in_batch = min([len(b) - 1 for b in batch[0]["cu_seqlens"]]) tk_per_gpu = 0 tk_per_gpu = round( num_tokens_in_batch * gpc.get_world_size(ParallelMode.DATA) / gpc.get_world_size(ParallelMode.GLOBAL) / (time.time() - start_time), 2, ) tflops = get_tflops_func((time.time() - start_time)) infos = { "tflops": tflops, "step": batch_count, "loss": loss.item(), "tgs (tokens/gpu/second)": tk_per_gpu, "lr": lr, "loss_scale": scaler, "grad_norm": grad_norm, } infos["micro_num"] = len(batch[1]) infos["num_consumed_tokens"] = train_state.num_consumed_tokens infos["inf_nan_skip_batches"] = train_state.inf_nan_skip_batches infos["num_samples_in_batch"] = num_samples_in_batch # the number of batches which have the most samples infos["largest_length"] = max_length_in_batch # the longest input infos["largest_batch"] = max_samples_in_batch # the batch with the most samples infos["smallest_batch"] = min_samples_in_batch infos["adam_beta2"] = beta2_scheduler.get_beta2() line = "" for k, v in infos.items(): line += f"{k}={v}," fwd_bwd_time = round(timer("fwd-bwd").elapsed(), 2) line += f"fwd_bwd_time={fwd_bwd_time}" logger.info(line) def main(args): # initialize distributed environment initialize_distributed_env(config=args.config, launcher=args.launcher, master_port=args.port, seed=args.seed) assert hasattr(gpc, "config") and gpc.config is not None # init setting skip_batches = gpc.config.data.skip_batches total_steps = gpc.config.data.total_steps load_optimizer = gpc.config.ckpt.load_optimizer label_smoothing = gpc.config.loss.label_smoothing lr = gpc.config.adam.lr # ckpt setting save_ckpt_folder = gpc.config.ckpt.save_ckpt_folder enable_save_ckpt = gpc.config.ckpt.enable_ckpt checkpoint_every = gpc.config.ckpt.checkpoint_every load_model_only_folder = gpc.config.ckpt.get("load_model_only_folder", None) load_resume_ckpt_folder = gpc.config.ckpt.get("load_ckpt_folder", None) get_tflops_func = partial( get_megatron_flops, checkpoint=gpc.config.model.checkpoint, seq_len=gpc.config.SEQ_LEN, hidden_size=gpc.config.model.hidden_size, num_layers=gpc.config.model.num_layers, vocab_size=gpc.config.model.vocab_size, global_batch_size=gpc.config.data.micro_bsz * gpc.config.data.micro_num * gpc.get_world_size(ParallelMode.DATA), global_world_size=gpc.get_world_size(ParallelMode.GLOBAL), mlp_ratio=gpc.config.MLP_RATIO, ) # get and broadcast current time current_time = launch_time() objs = [current_time] dist.broadcast_object_list(objs, src=0) current_time = objs[0] model_load_path = None if load_resume_ckpt_folder is not None: logger.info( f"===========Resume training from `{load_resume_ckpt_folder}` {current_time} on host:" f"{socket.gethostname()}===========" ) model_load_path = load_resume_ckpt_folder elif load_model_only_folder is not None: logger.info( f"===========SFT training from `{load_model_only_folder}` {current_time} on host:" f"{socket.gethostname()}===========" ) model_load_path = load_model_only_folder else: logger.info( f"===========New Run {current_time} on host:{socket.gethostname()}," f"tp:{gpc.get_local_rank(ParallelMode.TENSOR)},pp={gpc.get_local_rank(ParallelMode.PIPELINE)}," f"dp={gpc.get_local_rank(ParallelMode.DATA)}===========" ) # initialize and resume train state train_state = TrainState(gpc.config) # initialize model model = initialize_model() # initialize loss function criterion = FlashGPTLMLoss(parallel_output=True, label_smoothing=label_smoothing) # initialize the train data loader train_dl, _ = get_train_data_loader(num_worker=4) train_state.init_batch_sampler(train_dl) # Loading model weights must be done before zero is initialized. if model_load_path is not None: load_model_checkpoint(folder=model_load_path, model=model) optimizer, beta2_scheduler, lr_scheduler = initialize_optimizer(model=model) # Loading other persistent training states. if load_resume_ckpt_folder is not None: # load lr scheduler states. load_scheduler(load_resume_ckpt_folder, lr_scheduler, optimizer, lr, train_state) # load training states. load_context(load_resume_ckpt_folder, train_dl, train_state) # load dataloader sampler states. load_sampler(load_resume_ckpt_folder, train_dl.batch_sampler) # load optimzier states. if load_optimizer: load_optimizer_checkpoint(load_resume_ckpt_folder, optimizer) # initialize trainer trainer, train_dl, _, _ = internlm.initialize_trainer( model=model, optimizer=optimizer, criterion=criterion, train_dataloader=train_dl, lr_scheduler=lr_scheduler, beta2_scheduler=beta2_scheduler, ) # initialize the batch skipper batch_skipper = BatchSkipper(skip_batches) trainer.train() # transfer the train data loader into train data iterator train_iter = iter(train_dl) # start iterating the train data and begin training for batch_count in range(train_state.batch_count, total_steps): if batch_count % 50 == 0: torch.cuda.empty_cache() start_time = time.time() timer("one-batch").start() # load batch data batch, train_iter = load_new_batch(train_dl=train_dl, train_iter=train_iter, train_state=train_state) # record the consumed samples in training train_state.batch_count = batch_count train_state.num_consumed_samples_in_epoch += len(batch[1]) if batch_skipper(batch_count): # skip this batch if gpc.is_rank_for_log(): logger.info(f"Skip batch count:`{batch_count}`...") timer("one-batch").stop() continue # zero the grads of parameters trainer.zero_grad() # do forward and backward timer("fwd-bwd").start() _, _, loss = trainer.execute_schedule(batch, forward_only=False, return_loss=True, return_output_label=False) timer("fwd-bwd").stop() assert loss is not None # update parameters, and returns (success_update, grad_norm) trainer_result = trainer.step() assert trainer_result is not None success_update, grad_norm = trainer_result if success_update: # update parameters successfully train_state.step_count += 1 else: train_state.inf_nan_skip_batches += 1 # record the amount of updating parameters unsuccessfully. if grad_norm == -99.0 and gpc.is_rank_for_log(): # -99.0 encodes a specific failure case logger.warning(f"Warning: skip parameter update at step {batch_count}.") # calculate and record the training metrics, eg. loss, accuracy and so on. record_current_batch_training_metrics( get_tflops_func=get_tflops_func, logger=logger, success_update=success_update, batch_count=batch_count, batch=batch, train_state=train_state, optimizer=optimizer, beta2_scheduler=beta2_scheduler, trainer=trainer, start_time=start_time, loss=loss, grad_norm=grad_norm, ) timer("one-batch").stop() # checkpoint the training states in specific steps, which is determined by the args "checkpoint_every" # # save batch sampler that tracks the true consumed samples if enable_save_ckpt and train_state.step_count % checkpoint_every == 0: save_checkpoint( folder=save_ckpt_folder, model=model, optimizer=optimizer, scheduler=lr_scheduler, train_state=train_state, model_config=gpc.config.model, ) # wait for all checkpoint uploads to be completed dist.barrier() if __name__ == "__main__": args = parse_args() try: main(args) except Exception: print(f"Raise exception from {socket.gethostname()} with proc id: {get_process_rank()}") traceback.print_exc()