From b7ecdba617f06995c9551d47c68a40f172f3520e Mon Sep 17 00:00:00 2001 From: Guoteng <32697156+SolenoidWGT@users.noreply.github.com> Date: Thu, 9 Nov 2023 21:07:16 +0800 Subject: [PATCH] feat(ckpt): save ckpt when reach total step count (#486) --- internlm/utils/model_checkpoint.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/internlm/utils/model_checkpoint.py b/internlm/utils/model_checkpoint.py index 4b3f7d5..d16db0c 100644 --- a/internlm/utils/model_checkpoint.py +++ b/internlm/utils/model_checkpoint.py @@ -776,7 +776,7 @@ now step_count is {train_state.step_count}", save_ckpts, save_type, now_break = False, CheckpointSaveType.NORMAL_CHECKPOINT, False if self.oss_snapshot_freq > 1 and train_state.step_count % self.oss_snapshot_freq == 0: save_ckpts, save_type = True, CheckpointSaveType.SNAPSHOT_CHECKPOINT - if train_state.step_count % self.checkpoint_every == 0: + if train_state.step_count % self.checkpoint_every == 0 or train_state.step_count == train_state.total_steps: save_ckpts, save_type = True, CheckpointSaveType.NORMAL_CHECKPOINT now_break, singal_save_ckpts, singal_save_type = self.quit_signal_handler(train_state) if save_ckpts is False: