Browse Source

Add option for saving checkpoint

pull/621/head
rainatam 2 years ago
parent
commit
5fb705cd5b
  1. 1
      ptuning/main.py
  2. 19
      ptuning/trainer.py

1
ptuning/main.py

@ -354,6 +354,7 @@ def main():
tokenizer=tokenizer, tokenizer=tokenizer,
data_collator=data_collator, data_collator=data_collator,
compute_metrics=compute_metrics if training_args.predict_with_generate else None, compute_metrics=compute_metrics if training_args.predict_with_generate else None,
save_prefixencoder=model_args.pre_seq_len is not None
) )
# Training # Training

19
ptuning/trainer.py

@ -317,7 +317,9 @@ class Trainer:
callbacks: Optional[List[TrainerCallback]] = None, callbacks: Optional[List[TrainerCallback]] = None,
optimizers: Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None), optimizers: Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None),
preprocess_logits_for_metrics: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None, preprocess_logits_for_metrics: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None,
save_prefixencoder: bool = False,
): ):
self.save_prefixencoder = save_prefixencoder
if args is None: if args is None:
output_dir = "tmp_trainer" output_dir = "tmp_trainer"
logger.info(f"No `TrainingArguments` passed, using `output_dir={output_dir}`.") logger.info(f"No `TrainingArguments` passed, using `output_dir={output_dir}`.")
@ -2825,12 +2827,17 @@ class Trainer:
state_dict = self.model.state_dict() state_dict = self.model.state_dict()
torch.save(state_dict, os.path.join(output_dir, WEIGHTS_NAME)) torch.save(state_dict, os.path.join(output_dir, WEIGHTS_NAME))
else: else:
state_dict = self.model.state_dict() if self.save_prefixencoder:
filtered_state_dict = {} print("Saving PrefixEncoder")
for k, v in self.model.named_parameters(): state_dict = self.model.state_dict()
if v.requires_grad: filtered_state_dict = {}
filtered_state_dict[k] = state_dict[k] for k, v in self.model.named_parameters():
self.model.save_pretrained(output_dir, state_dict=filtered_state_dict) if v.requires_grad:
filtered_state_dict[k] = state_dict[k]
self.model.save_pretrained(output_dir, state_dict=filtered_state_dict)
else:
print("Saving the whole model")
self.model.save_pretrained(output_dir, state_dict=state_dict)
if self.tokenizer is not None: if self.tokenizer is not None:
self.tokenizer.save_pretrained(output_dir) self.tokenizer.save_pretrained(output_dir)

Loading…
Cancel
Save