ckpt_api
wangbluo 2024-11-25 11:58:39 +08:00
parent b83143ee72
commit 82c88c1e0d
3 changed files with 2 additions and 8 deletions

View File

@ -135,9 +135,7 @@ def exam_state_dict(
with shared_tempdir() as tempdir:
model_ckpt_path = f"{tempdir}/model"
optimizer_ckpt_path = f"{tempdir}/optimizer"
if not use_async:
model_ckpt_path = f"{model_ckpt_path}.pt"
if use_async:
if not shard and use_async:
model_ckpt_path = f"{model_ckpt_path}.safetensors"
booster.save_model(model, model_ckpt_path, shard=shard, size_per_shard=size_per_shard, use_async=use_async)

View File

@ -87,9 +87,7 @@ def exam_state_dict(shard: bool, model_name: str, size_per_shard: int, test_conf
with shared_tempdir() as tempdir:
model_ckpt_path = f"{tempdir}/model"
optimizer_ckpt_path = f"{tempdir}/optimizer"
if not use_async:
model_ckpt_path = f"{model_ckpt_path}.pt"
if use_async:
if not shard and use_async:
model_ckpt_path = f"{model_ckpt_path}.safetensors"
booster.save_model(model, model_ckpt_path, shard=shard, size_per_shard=size_per_shard, use_async=use_async)
booster.save_optimizer(optimizer, optimizer_ckpt_path, shard=shard, size_per_shard=size_per_shard)

View File

@ -47,8 +47,6 @@ def check_low_level_zero_checkpointIO(stage: int, shard: bool, offload: bool, us
model_ckpt_path = f"{tempdir}/model"
optimizer_ckpt_path = f"{tempdir}/optimizer"
if not shard and not use_async:
model_ckpt_path = f"{model_ckpt_path}.pt"
if not shard and use_async:
model_ckpt_path = f"{model_ckpt_path}.safetensors"
if not shard and use_async: