mirror of https://github.com/InternLM/InternLM
fix(storage): fix try_get_storage_backend (#359)
* fix(storage): fix try_get_storage_backend * fix typo and print infos only in log rank * fix typo and print infos only in log rank --------- Co-authored-by: gaoyang07 <Gary1546308416AL@gmail.com>pull/320/head
parent
a86c4bbbfd
commit
26a7397752
|
@ -78,8 +78,8 @@ class CheckpointLoadMethod:
|
|||
|
||||
@staticmethod
|
||||
def register_ckpt_load_type(load_type: Union[str, CheckpointLoadType], load_func: Callable):
|
||||
if load_type in CheckpointLoadMethod.LOAD_TYPE_FUNC:
|
||||
logger.warning(f"{load_type} has aleady been registed!")
|
||||
if load_type in CheckpointLoadMethod.LOAD_TYPE_FUNC and gpc.is_rank_for_log():
|
||||
logger.warning(f"{load_type} has already been registered!")
|
||||
return
|
||||
|
||||
CheckpointLoadMethod.LOAD_TYPE_FUNC.update({load_type: load_func})
|
||||
|
@ -87,9 +87,10 @@ class CheckpointLoadMethod:
|
|||
if load_type == CheckpointLoadType.INTERNLM:
|
||||
CheckpointLoadMethod.LOAD_FUNC_SIG = inspect.signature(load_func)
|
||||
else:
|
||||
if inspect.signature(load_func) != CheckpointLoadMethod.LOAD_FUNC_SIG:
|
||||
if inspect.signature(load_func) != CheckpointLoadMethod.LOAD_FUNC_SIG and gpc.is_rank_for_log():
|
||||
logger.warning(
|
||||
f"registe load model ckpt signature is not same with: {CheckpointLoadMethod.LOAD_FUNC_SIG}"
|
||||
f"The registered signature {inspect.signature(load_func)} of the loaded model is not same as: "
|
||||
f"{CheckpointLoadMethod.LOAD_FUNC_SIG}"
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
|
@ -370,10 +371,11 @@ def load_optimizer_checkpoint(folder, optim):
|
|||
zero_devide_optim_plan = llm_load(fp_meta)
|
||||
states.update({"zero_devide_optim_plan": zero_devide_optim_plan})
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"Read zero optimzer split file '{fp_meta}', for '{e}'"
|
||||
f"Please check whether loading ckpts are saved with the HybridZeroOptimizer."
|
||||
)
|
||||
if gpc.is_rank_for_log():
|
||||
logger.warning(
|
||||
f"Read zero optimzer split file '{fp_meta}', for '{e}'"
|
||||
f"Please check whether loading ckpts are saved with the HybridZeroOptimizer."
|
||||
)
|
||||
|
||||
optim.load_state_dict(states)
|
||||
del states
|
||||
|
@ -385,8 +387,8 @@ def load_sampler(ckpt_path: str, sampler):
|
|||
sampler.load_state_dict(sampler_states)
|
||||
if gpc.is_rank_for_log():
|
||||
pstate = copy.deepcopy(sampler_states)
|
||||
pstate.pop("indices")
|
||||
pstate.pop("rng_state")
|
||||
pstate.pop("indices", None)
|
||||
pstate.pop("rng_state", None)
|
||||
logger.info(f"reload sampler_states:{pstate}")
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
|
@ -635,9 +637,12 @@ now step_count is {train_state.step_count}",
|
|||
# Here we only try to find the ckpt folder named after step, ignoring snapshot and other folders.
|
||||
ckpt_list = [int(fn.strip("/")) for fn in ckpt_list if fn.strip("/").isdigit()]
|
||||
if len(ckpt_list) == 0:
|
||||
logger.warning("Not found avaliable normal checkpoint!")
|
||||
if gpc.is_rank_for_log():
|
||||
logger.warning("No available normal checkpoint found. Check your checkpoint path.")
|
||||
else:
|
||||
logger.info(f"Found avaliable normal checkpoint: {ckpt_list}!")
|
||||
if gpc.is_rank_for_log():
|
||||
logger.info(f"Found available normal checkpoint: {ckpt_list}")
|
||||
|
||||
ckpt_list.sort(reverse=True)
|
||||
for ckpt in ckpt_list:
|
||||
fns_list = self.storage_manager.get_fns(os.path.join(self.save_ckpt_folder, str(ckpt)))
|
||||
|
|
|
@ -166,19 +166,18 @@ def compute_file_md5_by_chunk(file_name: str):
|
|||
|
||||
|
||||
def try_get_storage_backend(path: str):
|
||||
sre = path.split(":", maxsplit=1)
|
||||
if len(sre) == 1:
|
||||
if path.startswith("s3:"):
|
||||
backend = "boto3"
|
||||
if gpc.is_rank_for_log():
|
||||
logger.warning(f"path: '{path}' not start with backend prefix, guess it is the backend of boto3.")
|
||||
else:
|
||||
backend = "local"
|
||||
if path.startswith("s3:"):
|
||||
if gpc.is_rank_for_log():
|
||||
logger.warning(f"path: '{path}' not start with backend prefix, guess it is the backend of boto3.")
|
||||
return "boto3", path
|
||||
else:
|
||||
sre = path.split(":", maxsplit=1)
|
||||
if len(sre) == 1:
|
||||
if gpc.is_rank_for_log():
|
||||
logger.warning(f"path: '{path}' not start with backend prefix, guess it is the backend of local.")
|
||||
return backend, sre
|
||||
else:
|
||||
return sre[0], sre[1] # (backend_prefix, splited_path)
|
||||
return "local", sre[0]
|
||||
else:
|
||||
return sre[0], sre[1] # (backend_prefix, splited_path)
|
||||
|
||||
|
||||
class Boto3Client(StorageClient):
|
||||
|
@ -502,7 +501,7 @@ class StorageManager(metaclass=SingletonMeta):
|
|||
or "HTTP_PROXY" in os.environ
|
||||
or "HTTPS_PROXY" in os.environ
|
||||
):
|
||||
if not self.has_warning:
|
||||
if not self.has_warning and gpc.is_rank_for_log():
|
||||
logger.warning(
|
||||
"HTTP/HTTPS proxy is detected when using boto3, incorrectly setting \
|
||||
the proxy may make boto3 unavailable or affect performance."
|
||||
|
|
|
@ -87,3 +87,24 @@ def test_storage_mm_save_load(ckpt_config, init_dist_and_model): # noqa # pylin
|
|||
assert get_fns(ckpt_config.save_folder)[0] == "test.pt"
|
||||
load_obj = llm_load(save_fn, map_location="cpu")
|
||||
assert 0 == ((load_obj != tobj).sum())
|
||||
|
||||
|
||||
internlm_ckpt_path = [
|
||||
("local:/mnt/ckpt/", "local", "/mnt/ckpt/"),
|
||||
("local:./ckpt/", "local", "./ckpt/"),
|
||||
("boto3:s3://oss_bucket/", "boto3", "s3://oss_bucket/"),
|
||||
("boto3:oss_bucket/", "boto3", "oss_bucket/"),
|
||||
("/mnt/ckpt/", "local", "/mnt/ckpt/"),
|
||||
("./ckpt/", "local", "./ckpt/"),
|
||||
("s3://oss_bucket/", "boto3", "s3://oss_bucket/"),
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("ckpt_path", internlm_ckpt_path)
|
||||
def test_try_get_storage_backend(ckpt_path):
|
||||
from internlm.utils.storage_manager import try_get_storage_backend
|
||||
|
||||
ipath, a_prefix, a_cut_path = ckpt_path
|
||||
b_prefix, b_cut_path = try_get_storage_backend(ipath)
|
||||
assert a_prefix == b_prefix, f"{a_prefix} == {b_prefix}"
|
||||
assert a_cut_path == b_cut_path, f"{a_cut_path} == {b_cut_path}"
|
||||
|
|
Loading…
Reference in New Issue