fix(ckpt): fix checkpoint reload bug (#282)

1. fix only_load tuple convert bug.
2. fix reload_zero_fp32_buff copy bug
pull/383/head
Guoteng 2023-09-06 04:05:04 +08:00 committed by GitHub
parent 8acf823a04
commit ff181bc5f8
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 2 additions and 2 deletions

View File

@ -784,4 +784,4 @@ def reload_zero_fp32_buff(optimizer):
optimizer._zero_local_rank, group_id
)
# param_group["params"] is fp32 flatten optimizer states of this zero rank.
param_group["params"][0].copy_(fp16_flat_current_rank.float())
param_group["params"][0].data.copy_(fp16_flat_current_rank.float())

View File

@ -123,7 +123,7 @@ class CheckpointLoadMask:
return content in self.load_set and len(self.load_set) > 1
def only_load(self, content: CheckpointLoadContent):
return set(content) == self.load_set
return set((content,)) == self.load_set
def __str__(self) -> str:
return f"{self.load_set}."