From c703938fb334b9f8ea212e3070c1ad5e2177c595 Mon Sep 17 00:00:00 2001 From: zaglc Date: Tue, 26 Sep 2023 19:16:16 +0800 Subject: [PATCH] change ckpt name --- internlm/utils/model_checkpoint.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/internlm/utils/model_checkpoint.py b/internlm/utils/model_checkpoint.py index 272fdae..20a7d49 100644 --- a/internlm/utils/model_checkpoint.py +++ b/internlm/utils/model_checkpoint.py @@ -292,8 +292,8 @@ def save_model_checkpoint(folder, model): should_save_rank_pair.add((i, i % dp_size)) if (tp_rank, dp_rank) in should_save_rank_pair: - f_zo = f"_zo{dp_rank}" if gpc.config.parallel.use_fsdp else "" - fn = f"model_tp{tp_rank}_pp{pp_rank}{f_zo}.pt" + f_dp = f"_dp{dp_rank}" if gpc.config.parallel.use_fsdp else "" + fn = f"model_tp{tp_rank}_pp{pp_rank}{f_dp}.pt" fp = os.path.join(folder, fn) llm_save(fp, saved_obj=states) if not gpc.config.parallel.use_fsdp or dp_rank == tp_rank % dp_size: @@ -328,8 +328,8 @@ def load_model_checkpoint(folder, model): # avoid ckpt misuse between FSDP and no-FSDP test_fn = list([f for f in fns if f.startswith("model_t") and not f.endswith(".md5")]).pop() - assert ("_zo" in test_fn and gpc.config.parallel.use_fsdp) or ( - "_zo" not in test_fn and not gpc.config.parallel.use_fsdp + assert ("_dp" in test_fn and gpc.config.parallel.use_fsdp) or ( + "_dp" not in test_fn and not gpc.config.parallel.use_fsdp ), "FSDP model wants to load no-FSDP ckpts or reverse" max_pp, max_tp, max_zo = 0, 0, 0 @@ -356,7 +356,7 @@ def load_model_checkpoint(folder, model): ), f"The weights are save for {max_zo+1} FSDP shards , while current has {dp_size} FSDP shards" if gpc.config.parallel.use_fsdp: - should_load_name = f"model_tp{tp_rank}_pp{pp_rank}_zo{dp_rank}.pt" + should_load_name = f"model_tp{tp_rank}_pp{pp_rank}_dp{dp_rank}.pt" else: should_load_name = f"model_tp{tp_rank}_pp{pp_rank}.pt" fp = os.path.join(folder, should_load_name)