diff --git a/internlm/utils/model_checkpoint.py b/internlm/utils/model_checkpoint.py index 4186ef6..2da2755 100644 --- a/internlm/utils/model_checkpoint.py +++ b/internlm/utils/model_checkpoint.py @@ -996,7 +996,7 @@ now step_count is {train_state.step_count}", if ( train_state.step_count > 0 and train_state.step_count % self.checkpoint_every == 0 - or train_state.step_count == train_state.total_steps + or train_state.batch_count == (train_state.total_steps - 1) ): save_ckpts, save_type = True, CheckpointSaveType.NORMAL_CHECKPOINT