From ff181bc5f852fe46327ed65a69d445abd432e24e Mon Sep 17 00:00:00 2001 From: Guoteng <32697156+SolenoidWGT@users.noreply.github.com> Date: Wed, 6 Sep 2023 04:05:04 +0800 Subject: [PATCH] fix(ckpt): fix checkpoint reload bug (#282) 1. fix only_load tuple convert bug. 2. fix reload_zero_fp32_buff copy bug --- internlm/solver/optimizer/hybrid_zero_optim.py | 2 +- internlm/utils/model_checkpoint.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/internlm/solver/optimizer/hybrid_zero_optim.py b/internlm/solver/optimizer/hybrid_zero_optim.py index 2fa1364..0e4343a 100644 --- a/internlm/solver/optimizer/hybrid_zero_optim.py +++ b/internlm/solver/optimizer/hybrid_zero_optim.py @@ -784,4 +784,4 @@ def reload_zero_fp32_buff(optimizer): optimizer._zero_local_rank, group_id ) # param_group["params"] is fp32 flatten optimizer states of this zero rank. - param_group["params"][0].copy_(fp16_flat_current_rank.float()) + param_group["params"][0].data.copy_(fp16_flat_current_rank.float()) diff --git a/internlm/utils/model_checkpoint.py b/internlm/utils/model_checkpoint.py index 87a1fb4..21d76d1 100644 --- a/internlm/utils/model_checkpoint.py +++ b/internlm/utils/model_checkpoint.py @@ -123,7 +123,7 @@ class CheckpointLoadMask: return content in self.load_set and len(self.load_set) > 1 def only_load(self, content: CheckpointLoadContent): - return set(content) == self.load_set + return set((content,)) == self.load_set def __str__(self) -> str: return f"{self.load_set}."