From 81f51fd0ff3a016ea1d3a4e75e0a7ba4cbac9b82 Mon Sep 17 00:00:00 2001 From: lijiaxing Date: Fri, 15 Dec 2023 18:21:17 +0800 Subject: [PATCH] auto save --- configs/7B_sft.py | 3 +-- internlm/core/trainer.py | 7 +++++ internlm/initialize/launch.py | 15 ++++++++--- internlm/train/training_internlm.py | 40 +++++++++++++++++++++++++++++ internlm/utils/model_checkpoint.py | 15 ++++++++++- 5 files changed, 73 insertions(+), 7 deletions(-) diff --git a/configs/7B_sft.py b/configs/7B_sft.py index c0a9bc8..05a2f22 100644 --- a/configs/7B_sft.py +++ b/configs/7B_sft.py @@ -38,10 +38,9 @@ ckpt = dict( # If you want to initialize your model weights from another model, you must set `auto_resume` to False. # If you want to train from scratch, please set `auto_resume` to False and 'load_ckpt_info' to None. auto_resume=True, - checkpoint_every=CHECKPOINT_EVERY, + checkpoint_every=CHECKPOINT_EVERY, # if checkpoint_every is not set, auto checkpoint_every will be executed async_upload=True, # async ckpt upload. (only work for boto3 ckpt) async_upload_tmp_folder="/dev/shm/internlm_tmp_ckpt/", # path for temporarily files during asynchronous upload. - oss_snapshot_freq=int(CHECKPOINT_EVERY / 2), # snapshot ckpt save frequency. ) TRAIN_FOLDER = None # "/path/to/dataset" diff --git a/internlm/core/trainer.py b/internlm/core/trainer.py index b189031..1350f7c 100644 --- a/internlm/core/trainer.py +++ b/internlm/core/trainer.py @@ -78,6 +78,13 @@ class TrainState: "last_tgs_50": 0, } + # ckpt statistic + self.ckpt_statistic = { + "total_step": 0, + "sum_step": 0, + "sum_time": 0, + } + def init_batch_sampler(self, batch_sampler): """ Args: diff --git a/internlm/initialize/launch.py b/internlm/initialize/launch.py index 491e2b0..425013e 100644 --- a/internlm/initialize/launch.py +++ b/internlm/initialize/launch.py @@ -169,8 +169,10 @@ def args_sanity_check(): # Saving checkpoint args. if ckpt.enable_save_ckpt: - assert "checkpoint_every" in ckpt, "If enable save checkpoint, must give checkpoint_every in config.data!" - assert ckpt.checkpoint_every > 0 + # assert "checkpoint_every" in ckpt, "If enable save checkpoint, must give checkpoint_every in config.data!" + # assert ckpt.checkpoint_every > 0 + if "checkpoint_every" not in ckpt or ckpt.checkpoint_every <= 0: + ckpt.checkpoint_every = "auto" assert "save_ckpt_folder" in ckpt, "If enable save checkpoint, must give save_ckpt_folder in config.data!" if "async_upload" not in ckpt: @@ -192,8 +194,13 @@ def args_sanity_check(): if not ckpt.async_upload: ckpt._add_item("async_upload_tmp_folder", None) - if "oss_snapshot_freq" not in ckpt: - ckpt._add_item("oss_snapshot_freq", float("inf")) # if oss_snapshot_freq not given, we disable. + if ckpt.checkpoint_every != "auto": + if "oss_snapshot_freq" not in ckpt: + ckpt._add_item("oss_snapshot_freq", "auto") + else: + ckpt.oss_snapshot_freq = "auto" + else: + ckpt.oss_snapshot_freq = float("inf") else: ckpt._add_item("checkpoint_every", float("inf")) ckpt._add_item("oss_snapshot_freq", float("inf")) diff --git a/internlm/train/training_internlm.py b/internlm/train/training_internlm.py index 474bfd2..c190fa6 100644 --- a/internlm/train/training_internlm.py +++ b/internlm/train/training_internlm.py @@ -427,6 +427,44 @@ def record_current_batch_training_metrics( if gpc.is_no_pp_or_last_stage(): acc_perplex = metric.get_metric() + # compute auto save_frequency + if success_update and (gpc.config.ckpt.checkpoint_every == "auto" or gpc.config.ckpt.oss_snapshot_freq == "auto"): + ckpt_statistic = train_state.ckpt_statistic + ckpt_statistic["total_step"] += 1 + + # only global rank 0 need to compute save_frequency + if gpc.get_global_rank() == 0 and ckpt_statistic["total_step"] < 10 and batch_count >= 5: + ckpt_statistic["sum_step"] += 1 + ckpt_statistic["sum_time"] += time.time() - start_time + + # broadcast and assign save_frequency at the 10th step from start + elif ckpt_statistic["total_step"] == 10: + + # compute save_frequency + if gpc.get_global_rank() == 0: + avg_step_time = ckpt_statistic["sum_time"] / ckpt_statistic["sum_step"] + check_time = int(os.getenv("LLM_CKPT_SAVE_TIME", "1200")) + save_frequency = torch.tensor( + [int(10 * -(-check_time // (avg_step_time * 10)))], device=torch.device("cuda") + ) + else: + save_frequency = torch.tensor([-1], device=torch.device("cuda")) + + ranks = gpc.get_ranks_in_group(ParallelMode.GLOBAL) + dist.broadcast(save_frequency, src=ranks[0], group=gpc.get_group(ParallelMode.GLOBAL)) + save_frequency = int(save_frequency[0]) + + # assign save_frequency + # when the "checkpoint_every" is "auto", no snapshot will be implemented + # when the "save_frequency" is less than the "checkpoint_every" passed in, no snapshot will be implemented + if gpc.config.ckpt.checkpoint_every == "auto": + gpc.config.ckpt.checkpoint_every = save_frequency + else: + if save_frequency < gpc.config.ckpt.checkpoint_every: + gpc.config.ckpt.oss_snapshot_freq = save_frequency + else: + gpc.config.ckpt.oss_snapshot_freq = float("inf") + if success_update and gpc.is_rank_for_log(): lr = optimizer.param_groups[0]["lr"] if hasattr(trainer.engine.optimizer, "grad_scaler"): @@ -440,6 +478,7 @@ def record_current_batch_training_metrics( max_samples_in_batch = max([len(b) - 1 for b in batch[0]["cu_seqlens"]]) min_samples_in_batch = min([len(b) - 1 for b in batch[0]["cu_seqlens"]]) time_cost = time.time() - start_time + tk_per_gpu = round( num_tokens_in_batch * gpc.get_world_size(ParallelMode.DATA) / gpc.get_world_size(ParallelMode.GLOBAL), 4, @@ -507,6 +546,7 @@ def record_current_batch_training_metrics( "loss_scale": scaler, "grad_norm": grad_norm, } + if moe_loss is not None: infos["moe_loss"] = moe_loss.item() diff --git a/internlm/utils/model_checkpoint.py b/internlm/utils/model_checkpoint.py index 234944c..1734340 100644 --- a/internlm/utils/model_checkpoint.py +++ b/internlm/utils/model_checkpoint.py @@ -971,7 +971,11 @@ now step_count is {train_state.step_count}", def is_now_to_save_ckpt(self, train_state) -> (bool, CheckpointSaveType, bool): save_ckpts, save_type, now_break = False, CheckpointSaveType.NORMAL_CHECKPOINT, False - if self.oss_snapshot_freq > 1 and train_state.step_count % self.oss_snapshot_freq == 0: + if ( + self.oss_snapshot_freq != "auto" + and self.oss_snapshot_freq > 1 + and train_state.step_count % self.oss_snapshot_freq == 0 + ): save_ckpts, save_type = True, CheckpointSaveType.SNAPSHOT_CHECKPOINT if train_state.step_count % self.checkpoint_every == 0 or train_state.step_count == train_state.total_steps: save_ckpts, save_type = True, CheckpointSaveType.NORMAL_CHECKPOINT @@ -986,6 +990,15 @@ now step_count is {train_state.step_count}", if not self.enable_save_ckpt: return False + if self.checkpoint_every == "auto": + if gpc.config.ckpt.checkpoint_every == "auto": + return False + else: + self.checkpoint_every = gpc.config.ckpt.checkpoint_every + elif self.oss_snapshot_freq == "auto": + if gpc.config.ckpt.oss_snapshot_freq != "auto": + self.oss_snapshot_freq = gpc.config.ckpt.oss_snapshot_freq + save_ckpts, save_type, now_break = self.is_now_to_save_ckpt(train_state) if save_ckpts: