diff --git a/internlm/utils/model_checkpoint.py b/internlm/utils/model_checkpoint.py index 4b3f7d5..d16db0c 100644 --- a/internlm/utils/model_checkpoint.py +++ b/internlm/utils/model_checkpoint.py @@ -776,7 +776,7 @@ now step_count is {train_state.step_count}", save_ckpts, save_type, now_break = False, CheckpointSaveType.NORMAL_CHECKPOINT, False if self.oss_snapshot_freq > 1 and train_state.step_count % self.oss_snapshot_freq == 0: save_ckpts, save_type = True, CheckpointSaveType.SNAPSHOT_CHECKPOINT - if train_state.step_count % self.checkpoint_every == 0: + if train_state.step_count % self.checkpoint_every == 0 or train_state.step_count == train_state.total_steps: save_ckpts, save_type = True, CheckpointSaveType.NORMAL_CHECKPOINT now_break, singal_save_ckpts, singal_save_type = self.quit_signal_handler(train_state) if save_ckpts is False: