diff --git a/internlm/utils/model_checkpoint.py b/internlm/utils/model_checkpoint.py index 234944c..9b64e3b 100644 --- a/internlm/utils/model_checkpoint.py +++ b/internlm/utils/model_checkpoint.py @@ -912,6 +912,11 @@ class CheckpointManager: and "ckpt_type" in self.load_ckpt_info ), "please set content in ckpt setting, eg: ckpt = dict(path='', content=['model'], ckpt_type='internlm')" + if self.load_ckpt_info["content"] != ("model",): + assert ( + self.load_ckpt_info["ckpt_type"] == "internlm" + ), "Only 'internlm' ckpt supports loading states other than 'model' !" + # replace load_ckpt self.load_ckpt_info["content"] = CheckpointLoadMask(self.load_ckpt_info["content"]) self.load_ckpt_info["ckpt_type"] = CheckpointLoadMethod.convert_load_type(self.load_ckpt_info["ckpt_type"])