diff --git a/ptuning/main.py b/ptuning/main.py index ecce8c2..2aa5ac3 100644 --- a/ptuning/main.py +++ b/ptuning/main.py @@ -354,6 +354,7 @@ def main(): tokenizer=tokenizer, data_collator=data_collator, compute_metrics=compute_metrics if training_args.predict_with_generate else None, + save_prefixencoder=model_args.pre_seq_len is not None ) # Training diff --git a/ptuning/trainer.py b/ptuning/trainer.py index 5a9a27b..63101bc 100644 --- a/ptuning/trainer.py +++ b/ptuning/trainer.py @@ -317,7 +317,9 @@ class Trainer: callbacks: Optional[List[TrainerCallback]] = None, optimizers: Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None), preprocess_logits_for_metrics: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None, + save_prefixencoder: bool = False, ): + self.save_prefixencoder = save_prefixencoder if args is None: output_dir = "tmp_trainer" logger.info(f"No `TrainingArguments` passed, using `output_dir={output_dir}`.") @@ -2825,12 +2827,17 @@ class Trainer: state_dict = self.model.state_dict() torch.save(state_dict, os.path.join(output_dir, WEIGHTS_NAME)) else: - state_dict = self.model.state_dict() - filtered_state_dict = {} - for k, v in self.model.named_parameters(): - if v.requires_grad: - filtered_state_dict[k] = state_dict[k] - self.model.save_pretrained(output_dir, state_dict=filtered_state_dict) + if self.save_prefixencoder: + print("Saving PrefixEncoder") + state_dict = self.model.state_dict() + filtered_state_dict = {} + for k, v in self.model.named_parameters(): + if v.requires_grad: + filtered_state_dict[k] = state_dict[k] + self.model.save_pretrained(output_dir, state_dict=filtered_state_dict) + else: + print("Saving the whole model") + self.model.save_pretrained(output_dir, state_dict=state_dict) if self.tokenizer is not None: self.tokenizer.save_pretrained(output_dir)