From 26a73977521ad91b71a1c656b035fe84382e98ef Mon Sep 17 00:00:00 2001 From: Guoteng <32697156+SolenoidWGT@users.noreply.github.com> Date: Mon, 25 Sep 2023 15:16:25 +0800 Subject: [PATCH] fix(storage): fix try_get_storage_backend (#359) * fix(storage): fix try_get_storage_backend * fix typo and print infos only in log rank * fix typo and print infos only in log rank --------- Co-authored-by: gaoyang07 --- internlm/utils/model_checkpoint.py | 29 ++++++++++++++---------- internlm/utils/storage_manager.py | 23 +++++++++---------- tests/test_utils/test_storage_manager.py | 21 +++++++++++++++++ 3 files changed, 49 insertions(+), 24 deletions(-) diff --git a/internlm/utils/model_checkpoint.py b/internlm/utils/model_checkpoint.py index dad2fc6..64cbd3e 100644 --- a/internlm/utils/model_checkpoint.py +++ b/internlm/utils/model_checkpoint.py @@ -78,8 +78,8 @@ class CheckpointLoadMethod: @staticmethod def register_ckpt_load_type(load_type: Union[str, CheckpointLoadType], load_func: Callable): - if load_type in CheckpointLoadMethod.LOAD_TYPE_FUNC: - logger.warning(f"{load_type} has aleady been registed!") + if load_type in CheckpointLoadMethod.LOAD_TYPE_FUNC and gpc.is_rank_for_log(): + logger.warning(f"{load_type} has already been registered!") return CheckpointLoadMethod.LOAD_TYPE_FUNC.update({load_type: load_func}) @@ -87,9 +87,10 @@ class CheckpointLoadMethod: if load_type == CheckpointLoadType.INTERNLM: CheckpointLoadMethod.LOAD_FUNC_SIG = inspect.signature(load_func) else: - if inspect.signature(load_func) != CheckpointLoadMethod.LOAD_FUNC_SIG: + if inspect.signature(load_func) != CheckpointLoadMethod.LOAD_FUNC_SIG and gpc.is_rank_for_log(): logger.warning( - f"registe load model ckpt signature is not same with: {CheckpointLoadMethod.LOAD_FUNC_SIG}" + f"The registered signature {inspect.signature(load_func)} of the loaded model is not same as: " + f"{CheckpointLoadMethod.LOAD_FUNC_SIG}" ) @staticmethod @@ -370,10 +371,11 @@ def load_optimizer_checkpoint(folder, optim): zero_devide_optim_plan = llm_load(fp_meta) states.update({"zero_devide_optim_plan": zero_devide_optim_plan}) except Exception as e: - logger.warning( - f"Read zero optimzer split file '{fp_meta}', for '{e}'" - f"Please check whether loading ckpts are saved with the HybridZeroOptimizer." - ) + if gpc.is_rank_for_log(): + logger.warning( + f"Read zero optimzer split file '{fp_meta}', for '{e}'" + f"Please check whether loading ckpts are saved with the HybridZeroOptimizer." + ) optim.load_state_dict(states) del states @@ -385,8 +387,8 @@ def load_sampler(ckpt_path: str, sampler): sampler.load_state_dict(sampler_states) if gpc.is_rank_for_log(): pstate = copy.deepcopy(sampler_states) - pstate.pop("indices") - pstate.pop("rng_state") + pstate.pop("indices", None) + pstate.pop("rng_state", None) logger.info(f"reload sampler_states:{pstate}") torch.cuda.empty_cache() @@ -635,9 +637,12 @@ now step_count is {train_state.step_count}", # Here we only try to find the ckpt folder named after step, ignoring snapshot and other folders. ckpt_list = [int(fn.strip("/")) for fn in ckpt_list if fn.strip("/").isdigit()] if len(ckpt_list) == 0: - logger.warning("Not found avaliable normal checkpoint!") + if gpc.is_rank_for_log(): + logger.warning("No available normal checkpoint found. Check your checkpoint path.") else: - logger.info(f"Found avaliable normal checkpoint: {ckpt_list}!") + if gpc.is_rank_for_log(): + logger.info(f"Found available normal checkpoint: {ckpt_list}") + ckpt_list.sort(reverse=True) for ckpt in ckpt_list: fns_list = self.storage_manager.get_fns(os.path.join(self.save_ckpt_folder, str(ckpt))) diff --git a/internlm/utils/storage_manager.py b/internlm/utils/storage_manager.py index 36bd105..a3f9122 100644 --- a/internlm/utils/storage_manager.py +++ b/internlm/utils/storage_manager.py @@ -166,19 +166,18 @@ def compute_file_md5_by_chunk(file_name: str): def try_get_storage_backend(path: str): - sre = path.split(":", maxsplit=1) - if len(sre) == 1: - if path.startswith("s3:"): - backend = "boto3" - if gpc.is_rank_for_log(): - logger.warning(f"path: '{path}' not start with backend prefix, guess it is the backend of boto3.") - else: - backend = "local" + if path.startswith("s3:"): + if gpc.is_rank_for_log(): + logger.warning(f"path: '{path}' not start with backend prefix, guess it is the backend of boto3.") + return "boto3", path + else: + sre = path.split(":", maxsplit=1) + if len(sre) == 1: if gpc.is_rank_for_log(): logger.warning(f"path: '{path}' not start with backend prefix, guess it is the backend of local.") - return backend, sre - else: - return sre[0], sre[1] # (backend_prefix, splited_path) + return "local", sre[0] + else: + return sre[0], sre[1] # (backend_prefix, splited_path) class Boto3Client(StorageClient): @@ -502,7 +501,7 @@ class StorageManager(metaclass=SingletonMeta): or "HTTP_PROXY" in os.environ or "HTTPS_PROXY" in os.environ ): - if not self.has_warning: + if not self.has_warning and gpc.is_rank_for_log(): logger.warning( "HTTP/HTTPS proxy is detected when using boto3, incorrectly setting \ the proxy may make boto3 unavailable or affect performance." diff --git a/tests/test_utils/test_storage_manager.py b/tests/test_utils/test_storage_manager.py index 32f905b..e5f60c4 100644 --- a/tests/test_utils/test_storage_manager.py +++ b/tests/test_utils/test_storage_manager.py @@ -87,3 +87,24 @@ def test_storage_mm_save_load(ckpt_config, init_dist_and_model): # noqa # pylin assert get_fns(ckpt_config.save_folder)[0] == "test.pt" load_obj = llm_load(save_fn, map_location="cpu") assert 0 == ((load_obj != tobj).sum()) + + +internlm_ckpt_path = [ + ("local:/mnt/ckpt/", "local", "/mnt/ckpt/"), + ("local:./ckpt/", "local", "./ckpt/"), + ("boto3:s3://oss_bucket/", "boto3", "s3://oss_bucket/"), + ("boto3:oss_bucket/", "boto3", "oss_bucket/"), + ("/mnt/ckpt/", "local", "/mnt/ckpt/"), + ("./ckpt/", "local", "./ckpt/"), + ("s3://oss_bucket/", "boto3", "s3://oss_bucket/"), +] + + +@pytest.mark.parametrize("ckpt_path", internlm_ckpt_path) +def test_try_get_storage_backend(ckpt_path): + from internlm.utils.storage_manager import try_get_storage_backend + + ipath, a_prefix, a_cut_path = ckpt_path + b_prefix, b_cut_path = try_get_storage_backend(ipath) + assert a_prefix == b_prefix, f"{a_prefix} == {b_prefix}" + assert a_cut_path == b_cut_path, f"{a_cut_path} == {b_cut_path}"