Browse Source

[chatgpt] fix trainer generate kwargs (#3166)

pull/3174/head
ver217 2 years ago committed by GitHub
parent
commit
1e58d31bb7
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
  1. 20
      applications/ChatGPT/chatgpt/trainer/ppo.py

20
applications/ChatGPT/chatgpt/trainer/ppo.py

@ -63,6 +63,7 @@ class PPOTrainer(Trainer):
**generate_kwargs) -> None:
experience_maker = NaiveExperienceMaker(actor, critic, reward_model, initial_model, kl_coef)
replay_buffer = NaiveReplayBuffer(train_batch_size, buffer_limit, buffer_cpu_offload)
generate_kwargs = _set_default_generate_kwargs(strategy, generate_kwargs, actor)
super().__init__(strategy, experience_maker, replay_buffer, experience_batch_size, max_epochs, tokenizer,
sample_replay_buffer, dataloader_pin_memory, callbacks, **generate_kwargs)
self.actor = actor
@ -73,7 +74,6 @@ class PPOTrainer(Trainer):
self.actor_optim = actor_optim
self.critic_optim = critic_optim
self._set_default_generate_kwargs(generate_kwargs, actor)
def training_step(self, experience: Experience) -> Dict[str, float]:
self.actor.train()
@ -102,11 +102,15 @@ class PPOTrainer(Trainer):
return {'actor_loss': actor_loss.item(), 'critic_loss': critic_loss.item()}
def _set_default_generate_kwargs(self, generate_kwargs: dict, actor: Actor) -> None:
origin_model = self.strategy._unwrap_actor(actor)
# use huggingface models method directly
if 'prepare_inputs_fn' not in generate_kwargs and hasattr(origin_model, 'prepare_inputs_for_generation'):
generate_kwargs['prepare_inputs_fn'] = origin_model.prepare_inputs_for_generation
if 'update_model_kwargs_fn' not in generate_kwargs:
generate_kwargs['update_model_kwargs_fn'] = update_model_kwargs_fn
def _set_default_generate_kwargs(strategy: Strategy, generate_kwargs: dict, actor: Actor) -> None:
origin_model = strategy._unwrap_actor(actor)
new_kwargs = {**generate_kwargs}
# use huggingface models method directly
if 'prepare_inputs_fn' not in generate_kwargs and hasattr(origin_model, 'prepare_inputs_for_generation'):
new_kwargs['prepare_inputs_fn'] = origin_model.prepare_inputs_for_generation
if 'update_model_kwargs_fn' not in generate_kwargs:
new_kwargs['update_model_kwargs_fn'] = update_model_kwargs_fn
return new_kwargs

Loading…
Cancel
Save