diff --git a/internlm/utils/model_checkpoint.py b/internlm/utils/model_checkpoint.py index b9326de..4186ef6 100644 --- a/internlm/utils/model_checkpoint.py +++ b/internlm/utils/model_checkpoint.py @@ -980,12 +980,26 @@ now step_count is {train_state.step_count}", return now_break, now_save_ckpt, save_type - def is_now_to_save_ckpt(self, train_state) -> (bool, CheckpointSaveType, bool): + def is_now_to_save_ckpt(self, train_state, force=False) -> (bool, CheckpointSaveType, bool): + """The function is used to determine whether to save ckpt now.""" save_ckpts, save_type, now_break = False, CheckpointSaveType.NORMAL_CHECKPOINT, False - if self.oss_snapshot_freq > 1 and train_state.step_count % self.oss_snapshot_freq == 0: + if force: + return True, CheckpointSaveType.NORMAL_CHECKPOINT, False + + if ( + self.oss_snapshot_freq > 1 + and train_state.step_count > 0 + and train_state.step_count % self.oss_snapshot_freq == 0 + ): save_ckpts, save_type = True, CheckpointSaveType.SNAPSHOT_CHECKPOINT - if train_state.step_count % self.checkpoint_every == 0 or train_state.step_count == train_state.total_steps: + + if ( + train_state.step_count > 0 + and train_state.step_count % self.checkpoint_every == 0 + or train_state.step_count == train_state.total_steps + ): save_ckpts, save_type = True, CheckpointSaveType.NORMAL_CHECKPOINT + now_break, singal_save_ckpts, singal_save_type = self.quit_signal_handler(train_state) if save_ckpts is False: save_ckpts = singal_save_ckpts