mirror of https://github.com/InternLM/InternLM
change ckpt name
parent
83bd11f2b2
commit
c703938fb3
|
@ -292,8 +292,8 @@ def save_model_checkpoint(folder, model):
|
||||||
should_save_rank_pair.add((i, i % dp_size))
|
should_save_rank_pair.add((i, i % dp_size))
|
||||||
|
|
||||||
if (tp_rank, dp_rank) in should_save_rank_pair:
|
if (tp_rank, dp_rank) in should_save_rank_pair:
|
||||||
f_zo = f"_zo{dp_rank}" if gpc.config.parallel.use_fsdp else ""
|
f_dp = f"_dp{dp_rank}" if gpc.config.parallel.use_fsdp else ""
|
||||||
fn = f"model_tp{tp_rank}_pp{pp_rank}{f_zo}.pt"
|
fn = f"model_tp{tp_rank}_pp{pp_rank}{f_dp}.pt"
|
||||||
fp = os.path.join(folder, fn)
|
fp = os.path.join(folder, fn)
|
||||||
llm_save(fp, saved_obj=states)
|
llm_save(fp, saved_obj=states)
|
||||||
if not gpc.config.parallel.use_fsdp or dp_rank == tp_rank % dp_size:
|
if not gpc.config.parallel.use_fsdp or dp_rank == tp_rank % dp_size:
|
||||||
|
@ -328,8 +328,8 @@ def load_model_checkpoint(folder, model):
|
||||||
|
|
||||||
# avoid ckpt misuse between FSDP and no-FSDP
|
# avoid ckpt misuse between FSDP and no-FSDP
|
||||||
test_fn = list([f for f in fns if f.startswith("model_t") and not f.endswith(".md5")]).pop()
|
test_fn = list([f for f in fns if f.startswith("model_t") and not f.endswith(".md5")]).pop()
|
||||||
assert ("_zo" in test_fn and gpc.config.parallel.use_fsdp) or (
|
assert ("_dp" in test_fn and gpc.config.parallel.use_fsdp) or (
|
||||||
"_zo" not in test_fn and not gpc.config.parallel.use_fsdp
|
"_dp" not in test_fn and not gpc.config.parallel.use_fsdp
|
||||||
), "FSDP model wants to load no-FSDP ckpts or reverse"
|
), "FSDP model wants to load no-FSDP ckpts or reverse"
|
||||||
|
|
||||||
max_pp, max_tp, max_zo = 0, 0, 0
|
max_pp, max_tp, max_zo = 0, 0, 0
|
||||||
|
@ -356,7 +356,7 @@ def load_model_checkpoint(folder, model):
|
||||||
), f"The weights are save for {max_zo+1} FSDP shards , while current has {dp_size} FSDP shards"
|
), f"The weights are save for {max_zo+1} FSDP shards , while current has {dp_size} FSDP shards"
|
||||||
|
|
||||||
if gpc.config.parallel.use_fsdp:
|
if gpc.config.parallel.use_fsdp:
|
||||||
should_load_name = f"model_tp{tp_rank}_pp{pp_rank}_zo{dp_rank}.pt"
|
should_load_name = f"model_tp{tp_rank}_pp{pp_rank}_dp{dp_rank}.pt"
|
||||||
else:
|
else:
|
||||||
should_load_name = f"model_tp{tp_rank}_pp{pp_rank}.pt"
|
should_load_name = f"model_tp{tp_rank}_pp{pp_rank}.pt"
|
||||||
fp = os.path.join(folder, should_load_name)
|
fp = os.path.join(folder, should_load_name)
|
||||||
|
|
Loading…
Reference in New Issue