fix save_model indent error in ppo trainer (#3450)

Co-authored-by: Yuanchen Xu <yuanchen.xu00@gmail.com>
pull/3219/head
Yuanchen 2023-04-05 09:45:42 +08:00 committed by GitHub
parent ffcdbf0f65
commit b92313903f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 3 additions and 4 deletions

View File

@ -117,6 +117,9 @@ class PPOTrainer(Trainer):
return {'reward': experience.reward.mean().item()}
def save_model(self, path: str, only_rank0: bool = False, tokenizer: Optional[PreTrainedTokenizerBase] = None) -> None:
self.strategy.save_model(model=self.actor, path=path, only_rank0=only_rank0, tokenizer=tokenizer)
def _set_default_generate_kwargs(strategy: Strategy, generate_kwargs: dict, actor: Actor) -> None:
origin_model = strategy._unwrap_actor(actor)
@ -129,7 +132,3 @@ def _set_default_generate_kwargs(strategy: Strategy, generate_kwargs: dict, acto
new_kwargs['update_model_kwargs_fn'] = update_model_kwargs_fn
return new_kwargs
def save_model(self, path: str, only_rank0: bool = False, tokenizer: Optional[PreTrainedTokenizerBase] = None) -> None:
self.strategy.save_model(model=self.actor, path=path, only_rank0=only_rank0, tokenizer=tokenizer)