fix(storage): fix try_get_storage_backend (#359)

* fix(storage): fix try_get_storage_backend

* fix typo and print infos only in log rank

* fix typo and print infos only in log rank

---------

Co-authored-by: gaoyang07 <Gary1546308416AL@gmail.com>
pull/320/head
Guoteng 2023-09-25 15:16:25 +08:00 committed by GitHub
parent a86c4bbbfd
commit 26a7397752
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 49 additions and 24 deletions

View File

@ -78,8 +78,8 @@ class CheckpointLoadMethod:
@staticmethod @staticmethod
def register_ckpt_load_type(load_type: Union[str, CheckpointLoadType], load_func: Callable): def register_ckpt_load_type(load_type: Union[str, CheckpointLoadType], load_func: Callable):
if load_type in CheckpointLoadMethod.LOAD_TYPE_FUNC: if load_type in CheckpointLoadMethod.LOAD_TYPE_FUNC and gpc.is_rank_for_log():
logger.warning(f"{load_type} has aleady been registed!") logger.warning(f"{load_type} has already been registered!")
return return
CheckpointLoadMethod.LOAD_TYPE_FUNC.update({load_type: load_func}) CheckpointLoadMethod.LOAD_TYPE_FUNC.update({load_type: load_func})
@ -87,9 +87,10 @@ class CheckpointLoadMethod:
if load_type == CheckpointLoadType.INTERNLM: if load_type == CheckpointLoadType.INTERNLM:
CheckpointLoadMethod.LOAD_FUNC_SIG = inspect.signature(load_func) CheckpointLoadMethod.LOAD_FUNC_SIG = inspect.signature(load_func)
else: else:
if inspect.signature(load_func) != CheckpointLoadMethod.LOAD_FUNC_SIG: if inspect.signature(load_func) != CheckpointLoadMethod.LOAD_FUNC_SIG and gpc.is_rank_for_log():
logger.warning( logger.warning(
f"registe load model ckpt signature is not same with: {CheckpointLoadMethod.LOAD_FUNC_SIG}" f"The registered signature {inspect.signature(load_func)} of the loaded model is not same as: "
f"{CheckpointLoadMethod.LOAD_FUNC_SIG}"
) )
@staticmethod @staticmethod
@ -370,10 +371,11 @@ def load_optimizer_checkpoint(folder, optim):
zero_devide_optim_plan = llm_load(fp_meta) zero_devide_optim_plan = llm_load(fp_meta)
states.update({"zero_devide_optim_plan": zero_devide_optim_plan}) states.update({"zero_devide_optim_plan": zero_devide_optim_plan})
except Exception as e: except Exception as e:
logger.warning( if gpc.is_rank_for_log():
f"Read zero optimzer split file '{fp_meta}', for '{e}'" logger.warning(
f"Please check whether loading ckpts are saved with the HybridZeroOptimizer." f"Read zero optimzer split file '{fp_meta}', for '{e}'"
) f"Please check whether loading ckpts are saved with the HybridZeroOptimizer."
)
optim.load_state_dict(states) optim.load_state_dict(states)
del states del states
@ -385,8 +387,8 @@ def load_sampler(ckpt_path: str, sampler):
sampler.load_state_dict(sampler_states) sampler.load_state_dict(sampler_states)
if gpc.is_rank_for_log(): if gpc.is_rank_for_log():
pstate = copy.deepcopy(sampler_states) pstate = copy.deepcopy(sampler_states)
pstate.pop("indices") pstate.pop("indices", None)
pstate.pop("rng_state") pstate.pop("rng_state", None)
logger.info(f"reload sampler_states:{pstate}") logger.info(f"reload sampler_states:{pstate}")
torch.cuda.empty_cache() torch.cuda.empty_cache()
@ -635,9 +637,12 @@ now step_count is {train_state.step_count}",
# Here we only try to find the ckpt folder named after step, ignoring snapshot and other folders. # Here we only try to find the ckpt folder named after step, ignoring snapshot and other folders.
ckpt_list = [int(fn.strip("/")) for fn in ckpt_list if fn.strip("/").isdigit()] ckpt_list = [int(fn.strip("/")) for fn in ckpt_list if fn.strip("/").isdigit()]
if len(ckpt_list) == 0: if len(ckpt_list) == 0:
logger.warning("Not found avaliable normal checkpoint!") if gpc.is_rank_for_log():
logger.warning("No available normal checkpoint found. Check your checkpoint path.")
else: else:
logger.info(f"Found avaliable normal checkpoint: {ckpt_list}!") if gpc.is_rank_for_log():
logger.info(f"Found available normal checkpoint: {ckpt_list}")
ckpt_list.sort(reverse=True) ckpt_list.sort(reverse=True)
for ckpt in ckpt_list: for ckpt in ckpt_list:
fns_list = self.storage_manager.get_fns(os.path.join(self.save_ckpt_folder, str(ckpt))) fns_list = self.storage_manager.get_fns(os.path.join(self.save_ckpt_folder, str(ckpt)))

View File

@ -166,19 +166,18 @@ def compute_file_md5_by_chunk(file_name: str):
def try_get_storage_backend(path: str): def try_get_storage_backend(path: str):
sre = path.split(":", maxsplit=1) if path.startswith("s3:"):
if len(sre) == 1: if gpc.is_rank_for_log():
if path.startswith("s3:"): logger.warning(f"path: '{path}' not start with backend prefix, guess it is the backend of boto3.")
backend = "boto3" return "boto3", path
if gpc.is_rank_for_log(): else:
logger.warning(f"path: '{path}' not start with backend prefix, guess it is the backend of boto3.") sre = path.split(":", maxsplit=1)
else: if len(sre) == 1:
backend = "local"
if gpc.is_rank_for_log(): if gpc.is_rank_for_log():
logger.warning(f"path: '{path}' not start with backend prefix, guess it is the backend of local.") logger.warning(f"path: '{path}' not start with backend prefix, guess it is the backend of local.")
return backend, sre return "local", sre[0]
else: else:
return sre[0], sre[1] # (backend_prefix, splited_path) return sre[0], sre[1] # (backend_prefix, splited_path)
class Boto3Client(StorageClient): class Boto3Client(StorageClient):
@ -502,7 +501,7 @@ class StorageManager(metaclass=SingletonMeta):
or "HTTP_PROXY" in os.environ or "HTTP_PROXY" in os.environ
or "HTTPS_PROXY" in os.environ or "HTTPS_PROXY" in os.environ
): ):
if not self.has_warning: if not self.has_warning and gpc.is_rank_for_log():
logger.warning( logger.warning(
"HTTP/HTTPS proxy is detected when using boto3, incorrectly setting \ "HTTP/HTTPS proxy is detected when using boto3, incorrectly setting \
the proxy may make boto3 unavailable or affect performance." the proxy may make boto3 unavailable or affect performance."

View File

@ -87,3 +87,24 @@ def test_storage_mm_save_load(ckpt_config, init_dist_and_model): # noqa # pylin
assert get_fns(ckpt_config.save_folder)[0] == "test.pt" assert get_fns(ckpt_config.save_folder)[0] == "test.pt"
load_obj = llm_load(save_fn, map_location="cpu") load_obj = llm_load(save_fn, map_location="cpu")
assert 0 == ((load_obj != tobj).sum()) assert 0 == ((load_obj != tobj).sum())
internlm_ckpt_path = [
("local:/mnt/ckpt/", "local", "/mnt/ckpt/"),
("local:./ckpt/", "local", "./ckpt/"),
("boto3:s3://oss_bucket/", "boto3", "s3://oss_bucket/"),
("boto3:oss_bucket/", "boto3", "oss_bucket/"),
("/mnt/ckpt/", "local", "/mnt/ckpt/"),
("./ckpt/", "local", "./ckpt/"),
("s3://oss_bucket/", "boto3", "s3://oss_bucket/"),
]
@pytest.mark.parametrize("ckpt_path", internlm_ckpt_path)
def test_try_get_storage_backend(ckpt_path):
from internlm.utils.storage_manager import try_get_storage_backend
ipath, a_prefix, a_cut_path = ckpt_path
b_prefix, b_cut_path = try_get_storage_backend(ipath)
assert a_prefix == b_prefix, f"{a_prefix} == {b_prefix}"
assert a_cut_path == b_cut_path, f"{a_cut_path} == {b_cut_path}"