diff --git a/train.py b/train.py index 0d6a31a..4d30b90 100644 --- a/train.py +++ b/train.py @@ -35,6 +35,7 @@ from internlm.utils.common import ( parse_args, ) from internlm.utils.evaluation import evaluate_on_val_dls +from internlm.utils.gputest import empty_cache_and_diag from internlm.utils.logger import get_logger, initialize_uniscale_logger from internlm.utils.megatron_timers import megatron_timer as timer from internlm.utils.model_checkpoint import CheckpointManager @@ -72,7 +73,6 @@ def main(args): total_steps = gpc.config.data.total_steps valid_every = gpc.config.data.valid_every label_smoothing = gpc.config.loss.label_smoothing - lr = gpc.config.adam.lr get_tflops_func = partial( get_megatron_flops, @@ -95,21 +95,11 @@ def main(args): # initialize customed llm logger uniscale_logger = initialize_llm_logger(start_time=current_time) - # initialize and resume train state - train_state = TrainState(gpc.config) - # initialize model model = initialize_model() with open(args.config, "r") as f: config_lines = f.readlines() - ckpt_manager = CheckpointManager( - ckpt_config=gpc.config.ckpt, - model=model, - model_config=gpc.config.model, - model_config_file="".join(config_lines), - feishu_address=gpc.config.alert_address, - ) # initialize loss function criterion = FlashGPTLMLoss(parallel_output=True, label_smoothing=label_smoothing) @@ -117,15 +107,25 @@ def main(args): # initialize the train and validation data loader train_dl, dataset_types = get_train_data_loader(num_worker=4) val_dls = get_validation_data_loader() - train_state.init_batch_sampler(train_dl) - # Loading model weights must be done before zero is initialized. - ckpt_manager.try_load_model(current_time) + # initialize and resume train state + train_state = TrainState(gpc.config, train_dl.batch_sampler) optimizer, beta2_scheduler, lr_scheduler = initialize_optimizer(model=model) + ckpt_manager = CheckpointManager( + ckpt_config=gpc.config.ckpt, + model=model, + optimizer=optimizer, + lr_scheduler=lr_scheduler, + train_dl=train_dl, + model_config=gpc.config.model, + model_config_file="".join(config_lines), + feishu_address=gpc.config.monitor.alert.feishu_alert_address, + ) + # Loading other persistent training states. - ckpt_manager.try_resume_training(lr_scheduler, optimizer, lr, train_state, train_dl) + ckpt_manager.try_resume_training(train_state, current_time) # initialize customed llm writer writer = Writer( @@ -194,9 +194,7 @@ def main(args): with initialize_llm_profile(profiling=args.profiling, start_time=current_time) as prof: # start iterating the train data and begin training for batch_count in range(train_state.batch_count, total_steps): - if batch_count % 50 == 0: - torch.cuda.empty_cache() - + empty_cache_and_diag(batch_count, interval=gpc.config.data.empty_cache_and_diag_interval) start_time = time.time() timer("one-batch").start() @@ -238,10 +236,10 @@ def main(args): train_state.step_count += 1 else: train_state.inf_nan_skip_batches += 1 # record the amount of updating parameters unsuccessfully. - if -1 in grad_norm_groups and gpc.is_rank_for_log(): # -1 encodes a specific failure case + if -1 in grad_norm_groups.values() and gpc.is_rank_for_log(): # -1 encodes a specific failure case logger.warning(f"Warning: skip parameter update at step {batch_count}.") send_alert_message( - address=gpc.config.alert_address, + address=gpc.config.monitor.alert.feishu_alert_address, message=f"Warning: skip parameter update at step {batch_count}.", ) @@ -302,11 +300,15 @@ if __name__ == "__main__": assert hasattr(gpc, "config") and gpc.config is not None # initialize monitor manager context - with initialize_monitor_manager(job_name=gpc.config.JOB_NAME, alert_address=gpc.config.alert_address): + with initialize_monitor_manager( + job_name=gpc.config.JOB_NAME, alert_address=gpc.config.monitor.alert.feishu_alert_address + ): try: main(args) except Exception: logger.error( f"Raise exception from {hostname} with rank id: {gpc.get_global_rank()}\n{traceback.format_exc()}", ) - mm.monitor_exception(alert_address=gpc.config.alert_address, excp_info=traceback.format_exc()) + mm.monitor_exception( + alert_address=gpc.config.monitor.alert.feishu_alert_address, excp_info=traceback.format_exc() + )