add assert

pull/542/head
lijiaxing 2023-12-14 15:01:04 +08:00
parent 6ad1afd2c4
commit 7d67795f7e
1 changed files with 5 additions and 0 deletions

View File

@ -912,6 +912,11 @@ class CheckpointManager:
and "ckpt_type" in self.load_ckpt_info and "ckpt_type" in self.load_ckpt_info
), "please set content in ckpt setting, eg: ckpt = dict(path='', content=['model'], ckpt_type='internlm')" ), "please set content in ckpt setting, eg: ckpt = dict(path='', content=['model'], ckpt_type='internlm')"
if self.load_ckpt_info["content"] != ("model",):
assert (
self.load_ckpt_info["ckpt_type"] == "internlm"
), "Only 'internlm' ckpt supports loading states other than 'model' !"
# replace load_ckpt # replace load_ckpt
self.load_ckpt_info["content"] = CheckpointLoadMask(self.load_ckpt_info["content"]) self.load_ckpt_info["content"] = CheckpointLoadMask(self.load_ckpt_info["content"])
self.load_ckpt_info["ckpt_type"] = CheckpointLoadMethod.convert_load_type(self.load_ckpt_info["ckpt_type"]) self.load_ckpt_info["ckpt_type"] = CheckpointLoadMethod.convert_load_type(self.load_ckpt_info["ckpt_type"])