|
|
|
@ -63,6 +63,7 @@ class PPOTrainer(Trainer):
|
|
|
|
|
**generate_kwargs) -> None: |
|
|
|
|
experience_maker = NaiveExperienceMaker(actor, critic, reward_model, initial_model, kl_coef) |
|
|
|
|
replay_buffer = NaiveReplayBuffer(train_batch_size, buffer_limit, buffer_cpu_offload) |
|
|
|
|
generate_kwargs = _set_default_generate_kwargs(strategy, generate_kwargs, actor) |
|
|
|
|
super().__init__(strategy, experience_maker, replay_buffer, experience_batch_size, max_epochs, tokenizer, |
|
|
|
|
sample_replay_buffer, dataloader_pin_memory, callbacks, **generate_kwargs) |
|
|
|
|
self.actor = actor |
|
|
|
@ -73,7 +74,6 @@ class PPOTrainer(Trainer):
|
|
|
|
|
|
|
|
|
|
self.actor_optim = actor_optim |
|
|
|
|
self.critic_optim = critic_optim |
|
|
|
|
self._set_default_generate_kwargs(generate_kwargs, actor) |
|
|
|
|
|
|
|
|
|
def training_step(self, experience: Experience) -> Dict[str, float]: |
|
|
|
|
self.actor.train() |
|
|
|
@ -102,11 +102,15 @@ class PPOTrainer(Trainer):
|
|
|
|
|
|
|
|
|
|
return {'actor_loss': actor_loss.item(), 'critic_loss': critic_loss.item()} |
|
|
|
|
|
|
|
|
|
def _set_default_generate_kwargs(self, generate_kwargs: dict, actor: Actor) -> None: |
|
|
|
|
origin_model = self.strategy._unwrap_actor(actor) |
|
|
|
|
# use huggingface models method directly |
|
|
|
|
if 'prepare_inputs_fn' not in generate_kwargs and hasattr(origin_model, 'prepare_inputs_for_generation'): |
|
|
|
|
generate_kwargs['prepare_inputs_fn'] = origin_model.prepare_inputs_for_generation |
|
|
|
|
|
|
|
|
|
if 'update_model_kwargs_fn' not in generate_kwargs: |
|
|
|
|
generate_kwargs['update_model_kwargs_fn'] = update_model_kwargs_fn |
|
|
|
|
def _set_default_generate_kwargs(strategy: Strategy, generate_kwargs: dict, actor: Actor) -> None: |
|
|
|
|
origin_model = strategy._unwrap_actor(actor) |
|
|
|
|
new_kwargs = {**generate_kwargs} |
|
|
|
|
# use huggingface models method directly |
|
|
|
|
if 'prepare_inputs_fn' not in generate_kwargs and hasattr(origin_model, 'prepare_inputs_for_generation'): |
|
|
|
|
new_kwargs['prepare_inputs_fn'] = origin_model.prepare_inputs_for_generation |
|
|
|
|
|
|
|
|
|
if 'update_model_kwargs_fn' not in generate_kwargs: |
|
|
|
|
new_kwargs['update_model_kwargs_fn'] = update_model_kwargs_fn |
|
|
|
|
|
|
|
|
|
return new_kwargs |
|
|
|
|