[Feature] Add a switch to control whether the model checkpoint needs to be saved after each epoch ends (#5941)

* Add a switch to control whether the model checkpoint needs to be saved after each epoch ends

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
pull/5945/head
zhurunhua 2024-07-26 11:15:20 +08:00 committed by GitHub
parent 2069472e96
commit ad35a987d3
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
1 changed files with 15 additions and 3 deletions

View File

@ -128,6 +128,12 @@ def main() -> None:
parser.add_argument("--zero", type=int, default=1)
parser.add_argument("--pad_token", choices=["eos", "unk"], default="eos")
parser.add_argument("--padding_mode", choices=["max_length", "longest"], default="max_length")
parser.add_argument(
"--skip_save_each_epoch",
action="store_true",
default=False,
help="skip saving the model checkpoint after each epoch is completed.",
)
args = parser.parse_args()
with open(args.config_file, "w") as f:
@ -370,11 +376,17 @@ def main() -> None:
)
total_loss.fill_(0.0)
pbar.update()
# Save modeling.
if (args.save_interval > 0 and (step + 1) % (args.save_interval * args.accumulation_steps) == 0) or (
step + 1
) == len(dataloader):
save_model_condition = (
args.save_interval > 0 and (step + 1) % (args.save_interval * args.accumulation_steps) == 0
)
if not args.skip_save_each_epoch:
save_model_condition = save_model_condition or (step + 1) == len(dataloader)
if save_model_condition:
coordinator.print_on_master("\nStart saving model checkpoint with running states")
if args.use_neft: