mirror of https://github.com/hpcaitech/ColossalAI
[Feature] Add a switch to control whether the model checkpoint needs to be saved after each epoch ends (#5941)
* Add a switch to control whether the model checkpoint needs to be saved after each epoch ends * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>pull/5945/head
parent
2069472e96
commit
ad35a987d3
|
@ -128,6 +128,12 @@ def main() -> None:
|
|||
parser.add_argument("--zero", type=int, default=1)
|
||||
parser.add_argument("--pad_token", choices=["eos", "unk"], default="eos")
|
||||
parser.add_argument("--padding_mode", choices=["max_length", "longest"], default="max_length")
|
||||
parser.add_argument(
|
||||
"--skip_save_each_epoch",
|
||||
action="store_true",
|
||||
default=False,
|
||||
help="skip saving the model checkpoint after each epoch is completed.",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
with open(args.config_file, "w") as f:
|
||||
|
@ -370,11 +376,17 @@ def main() -> None:
|
|||
)
|
||||
total_loss.fill_(0.0)
|
||||
pbar.update()
|
||||
|
||||
# Save modeling.
|
||||
|
||||
if (args.save_interval > 0 and (step + 1) % (args.save_interval * args.accumulation_steps) == 0) or (
|
||||
step + 1
|
||||
) == len(dataloader):
|
||||
save_model_condition = (
|
||||
args.save_interval > 0 and (step + 1) % (args.save_interval * args.accumulation_steps) == 0
|
||||
)
|
||||
|
||||
if not args.skip_save_each_epoch:
|
||||
save_model_condition = save_model_condition or (step + 1) == len(dataloader)
|
||||
|
||||
if save_model_condition:
|
||||
coordinator.print_on_master("\nStart saving model checkpoint with running states")
|
||||
|
||||
if args.use_neft:
|
||||
|
|
Loading…
Reference in New Issue