mirror of https://github.com/InternLM/InternLM
41 lines
1.6 KiB
Python
41 lines
1.6 KiB
Python
#!/usr/bin/env python
|
|
# -*- encoding: utf-8 -*-
|
|
|
|
from internlm.initialize.launch import get_config_value
|
|
from internlm.utils.logger import get_logger
|
|
|
|
logger = get_logger(__file__)
|
|
|
|
|
|
def auto_resume_sanity_check(ckpt_config):
|
|
load_given_ckpt = get_config_value(ckpt_config, "load_given_ckpt", None)
|
|
if load_given_ckpt is None:
|
|
return True # default value is True
|
|
else:
|
|
return not load_given_ckpt
|
|
|
|
|
|
def ckpt_info_sanity_check(ckpt_config):
|
|
load_ckpt_folder = get_config_value(ckpt_config, "load_ckpt_folder", None)
|
|
|
|
load_model_only_folder = get_config_value(ckpt_config, "load_model_only_folder", None)
|
|
|
|
if load_model_only_folder is not None:
|
|
assert (
|
|
load_ckpt_folder is None
|
|
), "Detect 'load_ckpt_folder' and 'load_model_only_folder' set at the same time, \
|
|
# and 'load_given_ckpt' is True, so internlm will load from 'load_ckpt_folder'"
|
|
return dict(path=load_model_only_folder, content=("model",), ckpt_type="internlm")
|
|
else:
|
|
load_optimizer = get_config_value(ckpt_config, "load_optimizer", True)
|
|
|
|
if isinstance(load_ckpt_folder, str):
|
|
if load_optimizer:
|
|
return dict(path=load_ckpt_folder, content=("model", "sampler", "optimizer"), ckpt_type="internlm")
|
|
else:
|
|
return dict(path=load_ckpt_folder, content=("model", "sampler"), ckpt_type="internlm")
|
|
elif load_ckpt_folder is None:
|
|
return None
|
|
else:
|
|
assert f"Unsupport data type:'{type(load_ckpt_folder)}' for config.ckpt arg: 'load_ckpt_folder'"
|