From 2a5250ffcb358023fb362761ad13b75ea72e9c0a Mon Sep 17 00:00:00 2001 From: rainatam Date: Mon, 10 Apr 2023 18:32:40 +0800 Subject: [PATCH] Update trainer --- ptuning/trainer.py | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/ptuning/trainer.py b/ptuning/trainer.py index c49944f..bbaa9db 100644 --- a/ptuning/trainer.py +++ b/ptuning/trainer.py @@ -2825,12 +2825,11 @@ class Trainer: state_dict = self.model.state_dict() torch.save(state_dict, os.path.join(output_dir, WEIGHTS_NAME)) else: - # state_dict = self.model.state_dict() - # filtered_state_dict = {} - # for k, v in self.model.named_parameters(): - # if v.requires_grad: - # filtered_state_dict[k] = state_dict[k] - # print(filtered_state_dict.keys()) + state_dict = self.model.state_dict() + filtered_state_dict = {} + for k, v in self.model.named_parameters(): + if v.requires_grad: + filtered_state_dict[k] = state_dict[k] self.model.save_pretrained(output_dir, state_dict=state_dict) if self.tokenizer is not None: self.tokenizer.save_pretrained(output_dir)