fix(storage): fix try_get_storage_backend (#359)

* fix(storage): fix try_get_storage_backend

* fix typo and print infos only in log rank

* fix typo and print infos only in log rank

---------

Co-authored-by: gaoyang07 <Gary1546308416AL@gmail.com>
pull/320/head
Guoteng 2023-09-25 15:16:25 +08:00 committed by GitHub
parent a86c4bbbfd
commit 26a7397752
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 49 additions and 24 deletions

View File

@ -78,8 +78,8 @@ class CheckpointLoadMethod:
@staticmethod
def register_ckpt_load_type(load_type: Union[str, CheckpointLoadType], load_func: Callable):
if load_type in CheckpointLoadMethod.LOAD_TYPE_FUNC:
logger.warning(f"{load_type} has aleady been registed!")
if load_type in CheckpointLoadMethod.LOAD_TYPE_FUNC and gpc.is_rank_for_log():
logger.warning(f"{load_type} has already been registered!")
return
CheckpointLoadMethod.LOAD_TYPE_FUNC.update({load_type: load_func})
@ -87,9 +87,10 @@ class CheckpointLoadMethod:
if load_type == CheckpointLoadType.INTERNLM:
CheckpointLoadMethod.LOAD_FUNC_SIG = inspect.signature(load_func)
else:
if inspect.signature(load_func) != CheckpointLoadMethod.LOAD_FUNC_SIG:
if inspect.signature(load_func) != CheckpointLoadMethod.LOAD_FUNC_SIG and gpc.is_rank_for_log():
logger.warning(
f"registe load model ckpt signature is not same with: {CheckpointLoadMethod.LOAD_FUNC_SIG}"
f"The registered signature {inspect.signature(load_func)} of the loaded model is not same as: "
f"{CheckpointLoadMethod.LOAD_FUNC_SIG}"
)
@staticmethod
@ -370,10 +371,11 @@ def load_optimizer_checkpoint(folder, optim):
zero_devide_optim_plan = llm_load(fp_meta)
states.update({"zero_devide_optim_plan": zero_devide_optim_plan})
except Exception as e:
logger.warning(
f"Read zero optimzer split file '{fp_meta}', for '{e}'"
f"Please check whether loading ckpts are saved with the HybridZeroOptimizer."
)
if gpc.is_rank_for_log():
logger.warning(
f"Read zero optimzer split file '{fp_meta}', for '{e}'"
f"Please check whether loading ckpts are saved with the HybridZeroOptimizer."
)
optim.load_state_dict(states)
del states
@ -385,8 +387,8 @@ def load_sampler(ckpt_path: str, sampler):
sampler.load_state_dict(sampler_states)
if gpc.is_rank_for_log():
pstate = copy.deepcopy(sampler_states)
pstate.pop("indices")
pstate.pop("rng_state")
pstate.pop("indices", None)
pstate.pop("rng_state", None)
logger.info(f"reload sampler_states:{pstate}")
torch.cuda.empty_cache()
@ -635,9 +637,12 @@ now step_count is {train_state.step_count}",
# Here we only try to find the ckpt folder named after step, ignoring snapshot and other folders.
ckpt_list = [int(fn.strip("/")) for fn in ckpt_list if fn.strip("/").isdigit()]
if len(ckpt_list) == 0:
logger.warning("Not found avaliable normal checkpoint!")
if gpc.is_rank_for_log():
logger.warning("No available normal checkpoint found. Check your checkpoint path.")
else:
logger.info(f"Found avaliable normal checkpoint: {ckpt_list}!")
if gpc.is_rank_for_log():
logger.info(f"Found available normal checkpoint: {ckpt_list}")
ckpt_list.sort(reverse=True)
for ckpt in ckpt_list:
fns_list = self.storage_manager.get_fns(os.path.join(self.save_ckpt_folder, str(ckpt)))

View File

@ -166,19 +166,18 @@ def compute_file_md5_by_chunk(file_name: str):
def try_get_storage_backend(path: str):
sre = path.split(":", maxsplit=1)
if len(sre) == 1:
if path.startswith("s3:"):
backend = "boto3"
if gpc.is_rank_for_log():
logger.warning(f"path: '{path}' not start with backend prefix, guess it is the backend of boto3.")
else:
backend = "local"
if path.startswith("s3:"):
if gpc.is_rank_for_log():
logger.warning(f"path: '{path}' not start with backend prefix, guess it is the backend of boto3.")
return "boto3", path
else:
sre = path.split(":", maxsplit=1)
if len(sre) == 1:
if gpc.is_rank_for_log():
logger.warning(f"path: '{path}' not start with backend prefix, guess it is the backend of local.")
return backend, sre
else:
return sre[0], sre[1] # (backend_prefix, splited_path)
return "local", sre[0]
else:
return sre[0], sre[1] # (backend_prefix, splited_path)
class Boto3Client(StorageClient):
@ -502,7 +501,7 @@ class StorageManager(metaclass=SingletonMeta):
or "HTTP_PROXY" in os.environ
or "HTTPS_PROXY" in os.environ
):
if not self.has_warning:
if not self.has_warning and gpc.is_rank_for_log():
logger.warning(
"HTTP/HTTPS proxy is detected when using boto3, incorrectly setting \
the proxy may make boto3 unavailable or affect performance."

View File

@ -87,3 +87,24 @@ def test_storage_mm_save_load(ckpt_config, init_dist_and_model): # noqa # pylin
assert get_fns(ckpt_config.save_folder)[0] == "test.pt"
load_obj = llm_load(save_fn, map_location="cpu")
assert 0 == ((load_obj != tobj).sum())
internlm_ckpt_path = [
("local:/mnt/ckpt/", "local", "/mnt/ckpt/"),
("local:./ckpt/", "local", "./ckpt/"),
("boto3:s3://oss_bucket/", "boto3", "s3://oss_bucket/"),
("boto3:oss_bucket/", "boto3", "oss_bucket/"),
("/mnt/ckpt/", "local", "/mnt/ckpt/"),
("./ckpt/", "local", "./ckpt/"),
("s3://oss_bucket/", "boto3", "s3://oss_bucket/"),
]
@pytest.mark.parametrize("ckpt_path", internlm_ckpt_path)
def test_try_get_storage_backend(ckpt_path):
from internlm.utils.storage_manager import try_get_storage_backend
ipath, a_prefix, a_cut_path = ckpt_path
b_prefix, b_cut_path = try_get_storage_backend(ipath)
assert a_prefix == b_prefix, f"{a_prefix} == {b_prefix}"
assert a_cut_path == b_cut_path, f"{a_cut_path} == {b_cut_path}"