From 73afb6359489295eafa95378341b9da866fe6097 Mon Sep 17 00:00:00 2001 From: Dr-Corgi Date: Thu, 6 Apr 2023 11:19:14 +0800 Subject: [PATCH] [chat]fix save_model(#3377) The function save_model should be a part of PPOTrainer. --- applications/Chat/coati/trainer/ppo.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/applications/Chat/coati/trainer/ppo.py b/applications/Chat/coati/trainer/ppo.py index 6b99855be..5c7c71d20 100644 --- a/applications/Chat/coati/trainer/ppo.py +++ b/applications/Chat/coati/trainer/ppo.py @@ -116,6 +116,9 @@ class PPOTrainer(Trainer): self.critic_optim.zero_grad() return {'reward': experience.reward.mean().item()} + + def save_model(self, path: str, only_rank0: bool = False, tokenizer: Optional[PreTrainedTokenizerBase] = None) -> None: + self.strategy.save_model(model=self.actor, path=path, only_rank0=only_rank0, tokenizer=tokenizer) def save_model(self, path: str, only_rank0: bool = False, tokenizer: Optional[PreTrainedTokenizerBase] = None) -> None: self.strategy.save_model(model=self.actor, path=path, only_rank0=only_rank0, tokenizer=tokenizer)