diff --git a/internlm/solver/optimizer/hybrid_zero_optim.py b/internlm/solver/optimizer/hybrid_zero_optim.py index 2fa1364..0e4343a 100644 --- a/internlm/solver/optimizer/hybrid_zero_optim.py +++ b/internlm/solver/optimizer/hybrid_zero_optim.py @@ -784,4 +784,4 @@ def reload_zero_fp32_buff(optimizer): optimizer._zero_local_rank, group_id ) # param_group["params"] is fp32 flatten optimizer states of this zero rank. - param_group["params"][0].copy_(fp16_flat_current_rank.float()) + param_group["params"][0].data.copy_(fp16_flat_current_rank.float()) diff --git a/internlm/utils/model_checkpoint.py b/internlm/utils/model_checkpoint.py index 87a1fb4..21d76d1 100644 --- a/internlm/utils/model_checkpoint.py +++ b/internlm/utils/model_checkpoint.py @@ -123,7 +123,7 @@ class CheckpointLoadMask: return content in self.load_set and len(self.load_set) > 1 def only_load(self, content: CheckpointLoadContent): - return set(content) == self.load_set + return set((content,)) == self.load_set def __str__(self) -> str: return f"{self.load_set}."