fix(model): add ckpt_type constraint when loading ckpts (#542)

* support hf llama

* support hf llama

* support hf llama

* support hf llama

* importerror

* importerror

* modeling

* modeling

* fix bug

* add assert
pull/565/head
jiaxingli 2023-12-20 16:43:27 +08:00 committed by GitHub
parent a58bf853db
commit d418eba094
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 5 additions and 0 deletions

View File

@ -912,6 +912,11 @@ class CheckpointManager:
and "ckpt_type" in self.load_ckpt_info
), "please set content in ckpt setting, eg: ckpt = dict(path='', content=['model'], ckpt_type='internlm')"
if self.load_ckpt_info["content"] != ("model",):
assert (
self.load_ckpt_info["ckpt_type"] == "internlm"
), "Only 'internlm' ckpt supports loading states other than 'model' !"
# replace load_ckpt
self.load_ckpt_info["content"] = CheckpointLoadMask(self.load_ckpt_info["content"])
self.load_ckpt_info["ckpt_type"] = CheckpointLoadMethod.convert_load_type(self.load_ckpt_info["ckpt_type"])