From 55c7dd513debf4d4f133dd490731a9fb5d9752ba Mon Sep 17 00:00:00 2001 From: lijiaxing Date: Wed, 20 Dec 2023 17:22:21 +0800 Subject: [PATCH] redo --- configs/7B_sft.py | 1 + internlm/initialize/launch.py | 19 +++++++++---------- internlm/train/training_internlm.py | 16 ++++++---------- internlm/utils/model_checkpoint.py | 15 +++------------ 4 files changed, 19 insertions(+), 32 deletions(-) diff --git a/configs/7B_sft.py b/configs/7B_sft.py index 05a2f22..7e81f42 100644 --- a/configs/7B_sft.py +++ b/configs/7B_sft.py @@ -41,6 +41,7 @@ ckpt = dict( checkpoint_every=CHECKPOINT_EVERY, # if checkpoint_every is not set, auto checkpoint_every will be executed async_upload=True, # async ckpt upload. (only work for boto3 ckpt) async_upload_tmp_folder="/dev/shm/internlm_tmp_ckpt/", # path for temporarily files during asynchronous upload. + auto_save_time=1200, # time to control oss_snapshot_freq. If not set, 1200 seconds by default. ) TRAIN_FOLDER = None # "/path/to/dataset" diff --git a/internlm/initialize/launch.py b/internlm/initialize/launch.py index 425013e..13ba4dc 100644 --- a/internlm/initialize/launch.py +++ b/internlm/initialize/launch.py @@ -169,10 +169,8 @@ def args_sanity_check(): # Saving checkpoint args. if ckpt.enable_save_ckpt: - # assert "checkpoint_every" in ckpt, "If enable save checkpoint, must give checkpoint_every in config.data!" - # assert ckpt.checkpoint_every > 0 - if "checkpoint_every" not in ckpt or ckpt.checkpoint_every <= 0: - ckpt.checkpoint_every = "auto" + assert "checkpoint_every" in ckpt, "If enable save checkpoint, must give checkpoint_every in config.data!" + assert ckpt.checkpoint_every > 0 assert "save_ckpt_folder" in ckpt, "If enable save checkpoint, must give save_ckpt_folder in config.data!" if "async_upload" not in ckpt: @@ -194,13 +192,14 @@ def args_sanity_check(): if not ckpt.async_upload: ckpt._add_item("async_upload_tmp_folder", None) - if ckpt.checkpoint_every != "auto": - if "oss_snapshot_freq" not in ckpt: - ckpt._add_item("oss_snapshot_freq", "auto") - else: - ckpt.oss_snapshot_freq = "auto" + if "oss_snapshot_freq" not in ckpt: + ckpt._add_item("oss_snapshot_freq", -1) else: - ckpt.oss_snapshot_freq = float("inf") + ckpt.oss_snapshot_freq = -1 + + if "auto_save_time" not in ckpt: + ckpt._add_item("auto_save_time", 1200) + else: ckpt._add_item("checkpoint_every", float("inf")) ckpt._add_item("oss_snapshot_freq", float("inf")) diff --git a/internlm/train/training_internlm.py b/internlm/train/training_internlm.py index 0fec19e..636b842 100644 --- a/internlm/train/training_internlm.py +++ b/internlm/train/training_internlm.py @@ -428,7 +428,7 @@ def record_current_batch_training_metrics( acc_perplex = metric.get_metric() # compute auto save_frequency - if success_update and (gpc.config.ckpt.checkpoint_every == "auto" or gpc.config.ckpt.oss_snapshot_freq == "auto"): + if success_update and gpc.config.ckpt.oss_snapshot_freq <= 0: ckpt_statistic = train_state.ckpt_statistic ckpt_statistic["total_step"] += 1 @@ -443,7 +443,7 @@ def record_current_batch_training_metrics( # compute save_frequency if gpc.get_global_rank() == 0: avg_step_time = ckpt_statistic["sum_time"] / ckpt_statistic["sum_step"] - check_time = int(os.getenv("LLM_CKPT_SAVE_TIME", "1200")) + check_time = gpc.config.ckpt.auto_save_time save_frequency = torch.tensor( [int(10 * -(-check_time // (avg_step_time * 10)))], device=torch.device("cuda") ) @@ -455,15 +455,11 @@ def record_current_batch_training_metrics( save_frequency = int(save_frequency[0]) # assign save_frequency - # when the "checkpoint_every" is "auto", no snapshot will be performed - # when the "save_frequency" is less than the "checkpoint_every" passed in, no snapshot will be performed - if gpc.config.ckpt.checkpoint_every == "auto": - gpc.config.ckpt.checkpoint_every = save_frequency + # when the "save_frequency" is larger than the "checkpoint_every" passed in, no snapshot will be performed + if save_frequency < gpc.config.ckpt.checkpoint_every: + gpc.config.ckpt.oss_snapshot_freq = save_frequency else: - if save_frequency < gpc.config.ckpt.checkpoint_every: - gpc.config.ckpt.oss_snapshot_freq = save_frequency - else: - gpc.config.ckpt.oss_snapshot_freq = float("inf") + gpc.config.ckpt.oss_snapshot_freq = float("inf") if success_update and gpc.is_rank_for_log(): lr = optimizer.param_groups[0]["lr"] diff --git a/internlm/utils/model_checkpoint.py b/internlm/utils/model_checkpoint.py index 1734340..26fe016 100644 --- a/internlm/utils/model_checkpoint.py +++ b/internlm/utils/model_checkpoint.py @@ -971,11 +971,7 @@ now step_count is {train_state.step_count}", def is_now_to_save_ckpt(self, train_state) -> (bool, CheckpointSaveType, bool): save_ckpts, save_type, now_break = False, CheckpointSaveType.NORMAL_CHECKPOINT, False - if ( - self.oss_snapshot_freq != "auto" - and self.oss_snapshot_freq > 1 - and train_state.step_count % self.oss_snapshot_freq == 0 - ): + if self.oss_snapshot_freq > 1 and train_state.step_count % self.oss_snapshot_freq == 0: save_ckpts, save_type = True, CheckpointSaveType.SNAPSHOT_CHECKPOINT if train_state.step_count % self.checkpoint_every == 0 or train_state.step_count == train_state.total_steps: save_ckpts, save_type = True, CheckpointSaveType.NORMAL_CHECKPOINT @@ -990,13 +986,8 @@ now step_count is {train_state.step_count}", if not self.enable_save_ckpt: return False - if self.checkpoint_every == "auto": - if gpc.config.ckpt.checkpoint_every == "auto": - return False - else: - self.checkpoint_every = gpc.config.ckpt.checkpoint_every - elif self.oss_snapshot_freq == "auto": - if gpc.config.ckpt.oss_snapshot_freq != "auto": + if self.oss_snapshot_freq <= 0: + if gpc.config.ckpt.oss_snapshot_freq > 0: self.oss_snapshot_freq = gpc.config.ckpt.oss_snapshot_freq save_ckpts, save_type, now_break = self.is_now_to_save_ckpt(train_state)