mirror of https://github.com/InternLM/InternLM
add assert
parent
6ad1afd2c4
commit
7d67795f7e
|
@ -912,6 +912,11 @@ class CheckpointManager:
|
||||||
and "ckpt_type" in self.load_ckpt_info
|
and "ckpt_type" in self.load_ckpt_info
|
||||||
), "please set content in ckpt setting, eg: ckpt = dict(path='', content=['model'], ckpt_type='internlm')"
|
), "please set content in ckpt setting, eg: ckpt = dict(path='', content=['model'], ckpt_type='internlm')"
|
||||||
|
|
||||||
|
if self.load_ckpt_info["content"] != ("model",):
|
||||||
|
assert (
|
||||||
|
self.load_ckpt_info["ckpt_type"] == "internlm"
|
||||||
|
), "Only 'internlm' ckpt supports loading states other than 'model' !"
|
||||||
|
|
||||||
# replace load_ckpt
|
# replace load_ckpt
|
||||||
self.load_ckpt_info["content"] = CheckpointLoadMask(self.load_ckpt_info["content"])
|
self.load_ckpt_info["content"] = CheckpointLoadMask(self.load_ckpt_info["content"])
|
||||||
self.load_ckpt_info["ckpt_type"] = CheckpointLoadMethod.convert_load_type(self.load_ckpt_info["ckpt_type"])
|
self.load_ckpt_info["ckpt_type"] = CheckpointLoadMethod.convert_load_type(self.load_ckpt_info["ckpt_type"])
|
||||||
|
|
Loading…
Reference in New Issue