mirror of https://github.com/hpcaitech/ColossalAI
[chatgpt] fix trainer generate kwargs (#3166)
parent
c474fda282
commit
1e58d31bb7
|
@ -63,6 +63,7 @@ class PPOTrainer(Trainer):
|
|||
**generate_kwargs) -> None:
|
||||
experience_maker = NaiveExperienceMaker(actor, critic, reward_model, initial_model, kl_coef)
|
||||
replay_buffer = NaiveReplayBuffer(train_batch_size, buffer_limit, buffer_cpu_offload)
|
||||
generate_kwargs = _set_default_generate_kwargs(strategy, generate_kwargs, actor)
|
||||
super().__init__(strategy, experience_maker, replay_buffer, experience_batch_size, max_epochs, tokenizer,
|
||||
sample_replay_buffer, dataloader_pin_memory, callbacks, **generate_kwargs)
|
||||
self.actor = actor
|
||||
|
@ -73,7 +74,6 @@ class PPOTrainer(Trainer):
|
|||
|
||||
self.actor_optim = actor_optim
|
||||
self.critic_optim = critic_optim
|
||||
self._set_default_generate_kwargs(generate_kwargs, actor)
|
||||
|
||||
def training_step(self, experience: Experience) -> Dict[str, float]:
|
||||
self.actor.train()
|
||||
|
@ -102,11 +102,15 @@ class PPOTrainer(Trainer):
|
|||
|
||||
return {'actor_loss': actor_loss.item(), 'critic_loss': critic_loss.item()}
|
||||
|
||||
def _set_default_generate_kwargs(self, generate_kwargs: dict, actor: Actor) -> None:
|
||||
origin_model = self.strategy._unwrap_actor(actor)
|
||||
# use huggingface models method directly
|
||||
if 'prepare_inputs_fn' not in generate_kwargs and hasattr(origin_model, 'prepare_inputs_for_generation'):
|
||||
generate_kwargs['prepare_inputs_fn'] = origin_model.prepare_inputs_for_generation
|
||||
|
||||
if 'update_model_kwargs_fn' not in generate_kwargs:
|
||||
generate_kwargs['update_model_kwargs_fn'] = update_model_kwargs_fn
|
||||
def _set_default_generate_kwargs(strategy: Strategy, generate_kwargs: dict, actor: Actor) -> None:
|
||||
origin_model = strategy._unwrap_actor(actor)
|
||||
new_kwargs = {**generate_kwargs}
|
||||
# use huggingface models method directly
|
||||
if 'prepare_inputs_fn' not in generate_kwargs and hasattr(origin_model, 'prepare_inputs_for_generation'):
|
||||
new_kwargs['prepare_inputs_fn'] = origin_model.prepare_inputs_for_generation
|
||||
|
||||
if 'update_model_kwargs_fn' not in generate_kwargs:
|
||||
new_kwargs['update_model_kwargs_fn'] = update_model_kwargs_fn
|
||||
|
||||
return new_kwargs
|
||||
|
|
Loading…
Reference in New Issue