mirror of https://github.com/hpcaitech/ColossalAI
[chatgpt] fix trainer generate kwargs (#3166)
parent
c474fda282
commit
1e58d31bb7
|
@ -63,6 +63,7 @@ class PPOTrainer(Trainer):
|
||||||
**generate_kwargs) -> None:
|
**generate_kwargs) -> None:
|
||||||
experience_maker = NaiveExperienceMaker(actor, critic, reward_model, initial_model, kl_coef)
|
experience_maker = NaiveExperienceMaker(actor, critic, reward_model, initial_model, kl_coef)
|
||||||
replay_buffer = NaiveReplayBuffer(train_batch_size, buffer_limit, buffer_cpu_offload)
|
replay_buffer = NaiveReplayBuffer(train_batch_size, buffer_limit, buffer_cpu_offload)
|
||||||
|
generate_kwargs = _set_default_generate_kwargs(strategy, generate_kwargs, actor)
|
||||||
super().__init__(strategy, experience_maker, replay_buffer, experience_batch_size, max_epochs, tokenizer,
|
super().__init__(strategy, experience_maker, replay_buffer, experience_batch_size, max_epochs, tokenizer,
|
||||||
sample_replay_buffer, dataloader_pin_memory, callbacks, **generate_kwargs)
|
sample_replay_buffer, dataloader_pin_memory, callbacks, **generate_kwargs)
|
||||||
self.actor = actor
|
self.actor = actor
|
||||||
|
@ -73,7 +74,6 @@ class PPOTrainer(Trainer):
|
||||||
|
|
||||||
self.actor_optim = actor_optim
|
self.actor_optim = actor_optim
|
||||||
self.critic_optim = critic_optim
|
self.critic_optim = critic_optim
|
||||||
self._set_default_generate_kwargs(generate_kwargs, actor)
|
|
||||||
|
|
||||||
def training_step(self, experience: Experience) -> Dict[str, float]:
|
def training_step(self, experience: Experience) -> Dict[str, float]:
|
||||||
self.actor.train()
|
self.actor.train()
|
||||||
|
@ -102,11 +102,15 @@ class PPOTrainer(Trainer):
|
||||||
|
|
||||||
return {'actor_loss': actor_loss.item(), 'critic_loss': critic_loss.item()}
|
return {'actor_loss': actor_loss.item(), 'critic_loss': critic_loss.item()}
|
||||||
|
|
||||||
def _set_default_generate_kwargs(self, generate_kwargs: dict, actor: Actor) -> None:
|
|
||||||
origin_model = self.strategy._unwrap_actor(actor)
|
|
||||||
# use huggingface models method directly
|
|
||||||
if 'prepare_inputs_fn' not in generate_kwargs and hasattr(origin_model, 'prepare_inputs_for_generation'):
|
|
||||||
generate_kwargs['prepare_inputs_fn'] = origin_model.prepare_inputs_for_generation
|
|
||||||
|
|
||||||
if 'update_model_kwargs_fn' not in generate_kwargs:
|
def _set_default_generate_kwargs(strategy: Strategy, generate_kwargs: dict, actor: Actor) -> None:
|
||||||
generate_kwargs['update_model_kwargs_fn'] = update_model_kwargs_fn
|
origin_model = strategy._unwrap_actor(actor)
|
||||||
|
new_kwargs = {**generate_kwargs}
|
||||||
|
# use huggingface models method directly
|
||||||
|
if 'prepare_inputs_fn' not in generate_kwargs and hasattr(origin_model, 'prepare_inputs_for_generation'):
|
||||||
|
new_kwargs['prepare_inputs_fn'] = origin_model.prepare_inputs_for_generation
|
||||||
|
|
||||||
|
if 'update_model_kwargs_fn' not in generate_kwargs:
|
||||||
|
new_kwargs['update_model_kwargs_fn'] = update_model_kwargs_fn
|
||||||
|
|
||||||
|
return new_kwargs
|
||||||
|
|
Loading…
Reference in New Issue