mirror of https://github.com/InternLM/InternLM
				
				
				
			
		
			
				
	
	
		
			310 lines
		
	
	
		
			11 KiB
		
	
	
	
		
			Python
		
	
	
			
		
		
	
	
			310 lines
		
	
	
		
			11 KiB
		
	
	
	
		
			Python
		
	
	
| #!/usr/bin/env python
 | |
| # -*- encoding: utf-8 -*-
 | |
| 
 | |
| import socket
 | |
| import time
 | |
| import traceback
 | |
| from functools import partial
 | |
| 
 | |
| import numpy as np
 | |
| import torch
 | |
| import torch.distributed as dist
 | |
| 
 | |
| import internlm
 | |
| from internlm.core.context import ParallelMode
 | |
| from internlm.core.context import global_context as gpc
 | |
| from internlm.core.scheduler import SchedulerMetricHook
 | |
| from internlm.core.trainer import TrainState
 | |
| from internlm.initialize import initialize_distributed_env
 | |
| from internlm.model.loss import FlashGPTLMLoss
 | |
| from internlm.model.metrics import AccPerplex
 | |
| from internlm.monitor import initialize_monitor_manager, send_alert_message
 | |
| from internlm.monitor.monitor import monitor_manager as mm
 | |
| from internlm.train import (
 | |
|     get_train_data_loader,
 | |
|     get_validation_data_loader,
 | |
|     initialize_llm_profile,
 | |
|     initialize_model,
 | |
|     initialize_optimizer,
 | |
|     load_new_batch,
 | |
|     record_current_batch_training_metrics,
 | |
| )
 | |
| from internlm.utils.common import (
 | |
|     BatchSkipper,
 | |
|     get_megatron_flops,
 | |
|     launch_time,
 | |
|     parse_args,
 | |
| )
 | |
| from internlm.utils.evaluation import evaluate_on_val_dls
 | |
| from internlm.utils.logger import get_logger, initialize_uniscale_logger
 | |
| from internlm.utils.megatron_timers import megatron_timer as timer
 | |
| from internlm.utils.model_checkpoint import CheckpointManager
 | |
| from internlm.utils.parallel import get_parallel_log_file_name
 | |
| from internlm.utils.simple_memory_profiler import SimpleMemoryProfiler
 | |
| from internlm.utils.writer import Writer
 | |
| 
 | |
| # global llm logger
 | |
| logger = get_logger(__file__)
 | |
| 
 | |
| 
 | |
| def initialize_llm_logger(start_time: str):
 | |
|     """
 | |
|     Initialize customed uniscale logger.
 | |
| 
 | |
|     Args:
 | |
|         start_time (str): The launch time of current training job.
 | |
| 
 | |
|     Returns: The instance of uniscale logger.
 | |
|     """
 | |
| 
 | |
|     uniscale_logger = initialize_uniscale_logger(
 | |
|         job_name=gpc.config.JOB_NAME, launch_time=start_time, file_name=get_parallel_log_file_name()
 | |
|     )
 | |
|     if uniscale_logger is not None:
 | |
|         global logger
 | |
|         logger = uniscale_logger
 | |
| 
 | |
|     return uniscale_logger
 | |
| 
 | |
| 
 | |
| def main(args):
 | |
|     # init setting
 | |
|     skip_batches = gpc.config.data.skip_batches
 | |
|     total_steps = gpc.config.data.total_steps
 | |
|     valid_every = gpc.config.data.valid_every
 | |
|     label_smoothing = gpc.config.loss.label_smoothing
 | |
|     lr = gpc.config.adam.lr
 | |
| 
 | |
|     get_tflops_func = partial(
 | |
|         get_megatron_flops,
 | |
|         checkpoint=gpc.config.model.checkpoint,
 | |
|         seq_len=gpc.config.SEQ_LEN,
 | |
|         hidden_size=gpc.config.model.hidden_size,
 | |
|         num_layers=gpc.config.model.num_layers,
 | |
|         vocab_size=gpc.config.model.vocab_size,
 | |
|         global_batch_size=gpc.config.data.micro_bsz * gpc.config.data.micro_num * gpc.get_world_size(ParallelMode.DATA),
 | |
|         global_world_size=gpc.get_world_size(ParallelMode.GLOBAL),
 | |
|         mlp_ratio=gpc.config.MLP_RATIO,
 | |
|     )
 | |
| 
 | |
|     # get and broadcast current time
 | |
|     current_time = launch_time()
 | |
|     objs = [current_time]
 | |
|     dist.broadcast_object_list(objs, src=0)
 | |
|     current_time = objs[0]
 | |
| 
 | |
|     # initialize customed llm logger
 | |
|     uniscale_logger = initialize_llm_logger(start_time=current_time)
 | |
| 
 | |
|     # initialize and resume train state
 | |
|     train_state = TrainState(gpc.config)
 | |
| 
 | |
|     # initialize model
 | |
|     model = initialize_model()
 | |
| 
 | |
|     with open(args.config, "r") as f:
 | |
|         config_lines = f.readlines()
 | |
|     ckpt_manager = CheckpointManager(
 | |
|         ckpt_config=gpc.config.ckpt,
 | |
|         model=model,
 | |
|         model_config=gpc.config.model,
 | |
|         model_config_file="".join(config_lines),
 | |
|         feishu_address=gpc.config.alert_address,
 | |
|     )
 | |
| 
 | |
|     # initialize loss function
 | |
|     criterion = FlashGPTLMLoss(parallel_output=True, label_smoothing=label_smoothing)
 | |
| 
 | |
|     # initialize the train and validation data loader
 | |
|     train_dl, dataset_types = get_train_data_loader(num_worker=4)
 | |
|     val_dls = get_validation_data_loader()
 | |
|     train_state.init_batch_sampler(train_dl)
 | |
| 
 | |
|     # Loading model weights must be done before zero is initialized.
 | |
|     ckpt_manager.try_load_model(current_time)
 | |
| 
 | |
|     optimizer, beta2_scheduler, lr_scheduler = initialize_optimizer(model=model)
 | |
| 
 | |
|     # Loading other persistent training states.
 | |
|     ckpt_manager.try_resume_training(lr_scheduler, optimizer, lr, train_state, train_dl)
 | |
| 
 | |
|     # initialize customed llm writer
 | |
|     writer = Writer(
 | |
|         job_name=gpc.config.JOB_NAME,
 | |
|         launch_time=current_time,
 | |
|         file_name=get_parallel_log_file_name(),
 | |
|         tensorboard_folder=gpc.config.tensorboard_folder,
 | |
|         resume_tb_folder=train_state.resume_tb_folder,  # resume from ckpt.
 | |
|         step_count=train_state.step_count,  # resume from ckpt.
 | |
|         config=config_lines,
 | |
|         logger=logger,
 | |
|         enable_tb=gpc.config.enable_tb,
 | |
|     )
 | |
| 
 | |
|     # initialize metric for calculating accuracy and perplexity
 | |
|     metric = AccPerplex(
 | |
|         device=torch.cuda.current_device(),
 | |
|         tp_pg=gpc.get_group(ParallelMode.TENSOR),
 | |
|         dp_pg=gpc.get_group(ParallelMode.DATA),
 | |
|         dataset_types=dataset_types,
 | |
|     )
 | |
| 
 | |
|     # initialize trainer
 | |
|     scheduler_hooks = [
 | |
|         SchedulerMetricHook(
 | |
|             metric=metric,
 | |
|             skip=(
 | |
|                 gpc.is_using_pp()
 | |
|                 and hasattr(gpc.config.model, "num_chunks")
 | |
|                 and gpc.config.model.num_chunks > 1
 | |
|                 and gpc.config.parallel["pipeline"].get("interleaved_overlap", False)
 | |
|             ),
 | |
|         ),
 | |
|     ]
 | |
| 
 | |
|     trainer, train_dl, _, _ = internlm.initialize_trainer(
 | |
|         model=model,
 | |
|         optimizer=optimizer,
 | |
|         criterion=criterion,
 | |
|         train_dataloader=train_dl,
 | |
|         lr_scheduler=lr_scheduler,
 | |
|         beta2_scheduler=beta2_scheduler,
 | |
|         scheduler_hooks=scheduler_hooks,
 | |
|     )
 | |
| 
 | |
|     # initialize simple memory profiler
 | |
|     if args.profiling:
 | |
|         memory_profiler = SimpleMemoryProfiler(
 | |
|             model,
 | |
|             optimizer.optim,
 | |
|             log_folder=f"memory_trace/rank{gpc.get_global_rank()}_"
 | |
|             + f"dp{gpc.get_local_rank(ParallelMode.DATA)}_"
 | |
|             + f"tp{gpc.get_local_rank(ParallelMode.TENSOR)}",
 | |
|         )
 | |
|     else:
 | |
|         memory_profiler = None
 | |
| 
 | |
|     # initialize the batch skipper
 | |
|     batch_skipper = BatchSkipper(skip_batches)
 | |
| 
 | |
|     trainer.train()
 | |
| 
 | |
|     # transfer the train data loader into train data iterator
 | |
|     train_iter = iter(train_dl)
 | |
| 
 | |
|     with initialize_llm_profile(profiling=args.profiling, start_time=current_time) as prof:
 | |
|         # start iterating the train data and begin training
 | |
|         for batch_count in range(train_state.batch_count, total_steps):
 | |
|             if batch_count % 50 == 0:
 | |
|                 torch.cuda.empty_cache()
 | |
| 
 | |
|             start_time = time.time()
 | |
|             timer("one-batch").start()
 | |
| 
 | |
|             # load batch data
 | |
|             batch, train_iter = load_new_batch(train_dl=train_dl, train_iter=train_iter, train_state=train_state)
 | |
| 
 | |
|             # record the consumed samples in training
 | |
|             train_state.batch_count = batch_count
 | |
|             train_state.num_consumed_samples_in_epoch += len(batch[1])
 | |
|             if batch_skipper(batch_count):  # skip this batch
 | |
|                 if gpc.is_rank_for_log():
 | |
|                     logger.info(f"Skip batch count:`{batch_count}`...")
 | |
|                 timer("one-batch").stop()
 | |
|                 continue
 | |
| 
 | |
|             # zero the grads of parameters
 | |
|             trainer.zero_grad()
 | |
|             # process data
 | |
|             if batch[0].get("type_ids", None) is not None:
 | |
|                 metric.set_current_type_ids(type_ids=batch[0].pop("type_ids", None))
 | |
| 
 | |
|             # do forward and backward
 | |
|             timer("fwd-bwd").start()
 | |
| 
 | |
|             _, _, loss = trainer.execute_schedule(
 | |
|                 batch, forward_only=False, return_loss=True, return_output_label=False
 | |
|             )
 | |
|             timer("fwd-bwd").stop()
 | |
| 
 | |
|             # update parameters, and returns (success_update, grad_norm)
 | |
|             trainer_result = trainer.step()
 | |
|             assert trainer_result is not None
 | |
| 
 | |
|             success_update, grad_norm_groups = trainer_result
 | |
|             if success_update:  # update parameters successfully
 | |
|                 train_state.step_count += 1
 | |
|             else:
 | |
|                 train_state.inf_nan_skip_batches += 1  # record the amount of updating parameters unsuccessfully.
 | |
|                 if -1 in grad_norm_groups and gpc.is_rank_for_log():  # -1 encodes a specific failure case
 | |
|                     logger.warning(f"Warning: skip parameter update at step {batch_count}.")
 | |
|                     send_alert_message(
 | |
|                         address=gpc.config.alert_address,
 | |
|                         message=f"Warning: skip parameter update at step {batch_count}.",
 | |
|                     )
 | |
| 
 | |
|             # calculate and record the training metrics, eg. loss, accuracy and so on.
 | |
|             record_current_batch_training_metrics(
 | |
|                 get_tflops_func=get_tflops_func,
 | |
|                 logger=logger,
 | |
|                 writer=writer,
 | |
|                 success_update=success_update,
 | |
|                 batch_count=batch_count,
 | |
|                 batch=batch,
 | |
|                 train_state=train_state,
 | |
|                 optimizer=optimizer,
 | |
|                 beta2_scheduler=beta2_scheduler,
 | |
|                 trainer=trainer,
 | |
|                 start_time=start_time,
 | |
|                 loss=loss,
 | |
|                 grad_norm=np.array(grad_norm_groups),
 | |
|                 metric=metric,
 | |
|                 update_panel=uniscale_logger is not None,
 | |
|             )
 | |
| 
 | |
|             timer("one-batch").stop()
 | |
| 
 | |
|             # evaluate on validation data loaders
 | |
|             if valid_every > 0 and train_state.step_count % valid_every == 0:
 | |
|                 evaluate_on_val_dls(
 | |
|                     trainer=trainer,
 | |
|                     val_dls=val_dls,
 | |
|                     writer=writer,
 | |
|                     logger=logger,
 | |
|                     step_count=train_state.step_count,
 | |
|                     update_panel=uniscale_logger is not None,
 | |
|                 )
 | |
| 
 | |
|             # checkpoint the training states in specific steps, which is determined by the args "checkpoint_every"
 | |
|             # # save batch sampler that tracks the true consumed samples
 | |
|             now_break = ckpt_manager.try_save_checkpoint(train_state)
 | |
|             if now_break:
 | |
|                 break
 | |
| 
 | |
|             if memory_profiler is not None:
 | |
|                 memory_profiler.step()
 | |
| 
 | |
|             if batch_count % 2 == 0:
 | |
|                 prof.step()
 | |
| 
 | |
|     ckpt_manager.wait_async_upload_finish()
 | |
| 
 | |
| 
 | |
| if __name__ == "__main__":
 | |
|     args = parse_args()
 | |
|     hostname = socket.gethostname()
 | |
| 
 | |
|     # initialize distributed environment
 | |
|     initialize_distributed_env(config=args.config, launcher=args.launcher, master_port=args.port, seed=args.seed)
 | |
|     assert hasattr(gpc, "config") and gpc.config is not None
 | |
| 
 | |
|     # initialize monitor manager context
 | |
|     with initialize_monitor_manager(job_name=gpc.config.JOB_NAME, alert_address=gpc.config.alert_address):
 | |
|         try:
 | |
|             main(args)
 | |
|         except Exception:
 | |
|             logger.error(
 | |
|                 f"Raise exception from {hostname} with rank id: {gpc.get_global_rank()}\n{traceback.format_exc()}",
 | |
|             )
 | |
|             mm.monitor_exception(alert_address=gpc.config.alert_address, excp_info=traceback.format_exc())
 |