2023-02-14 14:17:25 +00:00
|
|
|
from typing import Any, Callable, Dict, List, Optional
|
|
|
|
|
|
|
|
import torch.nn as nn
|
|
|
|
from chatgpt.experience_maker import Experience, NaiveExperienceMaker
|
|
|
|
from chatgpt.nn import Actor, Critic, PolicyLoss, ValueLoss
|
2023-02-15 05:59:58 +00:00
|
|
|
from chatgpt.nn.generation_utils import update_model_kwargs_fn
|
2023-02-14 14:17:25 +00:00
|
|
|
from chatgpt.replay_buffer import NaiveReplayBuffer
|
|
|
|
from torch.optim import Optimizer
|
|
|
|
|
|
|
|
from .base import Trainer
|
|
|
|
from .callbacks import Callback
|
|
|
|
from .strategies import Strategy
|
|
|
|
|
|
|
|
|
|
|
|
class PPOTrainer(Trainer):
|
|
|
|
"""
|
|
|
|
Trainer for PPO algorithm.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
strategy (Strategy): the strategy to use for training
|
|
|
|
actor (Actor): the actor model in ppo algorithm
|
|
|
|
critic (Critic): the critic model in ppo algorithm
|
|
|
|
reward_model (nn.Module): the reward model in rlhf algorithm to make reward of sentences
|
|
|
|
initial_model (Actor): the initial model in rlhf algorithm to generate reference logits to limit the update of actor
|
|
|
|
actor_optim (Optimizer): the optimizer to use for actor model
|
|
|
|
critic_optim (Optimizer): the optimizer to use for critic model
|
|
|
|
kl_coef (float, defaults to 0.1): the coefficient of kl divergence loss
|
|
|
|
train_batch_size (int, defaults to 8): the batch size to use for training
|
|
|
|
buffer_limit (int, defaults to 0): the max_size limitaiton of replay buffer
|
|
|
|
buffer_cpu_offload (bool, defaults to True): whether to offload replay buffer to cpu
|
|
|
|
eps_clip (float, defaults to 0.2): the clip coefficient of policy loss
|
|
|
|
value_clip (float, defaults to 0.4): the clip coefficient of value loss
|
|
|
|
experience_batch_size (int, defaults to 8): the batch size to use for experience generation
|
|
|
|
max_epochs (int, defaults to 1): the number of epochs of training process
|
|
|
|
tokenier (Callable, optional): the tokenizer to use for tokenizing the input
|
|
|
|
sample_replay_buffer (bool, defaults to False): whether to sample from replay buffer
|
|
|
|
dataloader_pin_memory (bool, defaults to True): whether to pin memory for data loader
|
|
|
|
callbacks (List[Callback], defaults to []): the callbacks to call during training process
|
|
|
|
generate_kwargs (dict, optional): the kwargs to use while model generating
|
|
|
|
"""
|
|
|
|
|
|
|
|
def __init__(self,
|
|
|
|
strategy: Strategy,
|
|
|
|
actor: Actor,
|
|
|
|
critic: Critic,
|
|
|
|
reward_model: nn.Module,
|
|
|
|
initial_model: Actor,
|
|
|
|
actor_optim: Optimizer,
|
|
|
|
critic_optim: Optimizer,
|
|
|
|
kl_coef: float = 0.1,
|
|
|
|
train_batch_size: int = 8,
|
|
|
|
buffer_limit: int = 0,
|
|
|
|
buffer_cpu_offload: bool = True,
|
|
|
|
eps_clip: float = 0.2,
|
|
|
|
value_clip: float = 0.4,
|
|
|
|
experience_batch_size: int = 8,
|
|
|
|
max_epochs: int = 1,
|
|
|
|
tokenizer: Optional[Callable[[Any], dict]] = None,
|
|
|
|
sample_replay_buffer: bool = False,
|
|
|
|
dataloader_pin_memory: bool = True,
|
|
|
|
callbacks: List[Callback] = [],
|
|
|
|
**generate_kwargs) -> None:
|
|
|
|
experience_maker = NaiveExperienceMaker(actor, critic, reward_model, initial_model, kl_coef)
|
|
|
|
replay_buffer = NaiveReplayBuffer(train_batch_size, buffer_limit, buffer_cpu_offload)
|
|
|
|
super().__init__(strategy, experience_maker, replay_buffer, experience_batch_size, max_epochs, tokenizer,
|
|
|
|
sample_replay_buffer, dataloader_pin_memory, callbacks, **generate_kwargs)
|
|
|
|
self.actor = actor
|
|
|
|
self.critic = critic
|
|
|
|
|
|
|
|
self.actor_loss_fn = PolicyLoss(eps_clip)
|
|
|
|
self.critic_loss_fn = ValueLoss(value_clip)
|
|
|
|
|
2023-02-17 03:27:27 +00:00
|
|
|
self.actor_optim = actor_optim
|
|
|
|
self.critic_optim = critic_optim
|
|
|
|
self._set_default_generate_kwargs(generate_kwargs, actor)
|
2023-02-14 14:17:25 +00:00
|
|
|
|
|
|
|
def training_step(self, experience: Experience) -> Dict[str, float]:
|
|
|
|
self.actor.train()
|
|
|
|
self.critic.train()
|
|
|
|
|
|
|
|
num_actions = experience.action_mask.size(1)
|
|
|
|
action_log_probs = self.actor(experience.sequences, num_actions, attention_mask=experience.attention_mask)
|
|
|
|
actor_loss = self.actor_loss_fn(action_log_probs,
|
|
|
|
experience.action_log_probs,
|
|
|
|
experience.advantages,
|
|
|
|
action_mask=experience.action_mask)
|
|
|
|
self.strategy.backward(actor_loss, self.actor, self.actor_optim)
|
|
|
|
self.strategy.optimizer_step(self.actor_optim)
|
|
|
|
self.actor_optim.zero_grad()
|
|
|
|
|
|
|
|
values = self.critic(experience.sequences,
|
|
|
|
action_mask=experience.action_mask,
|
|
|
|
attention_mask=experience.attention_mask)
|
|
|
|
critic_loss = self.critic_loss_fn(values,
|
|
|
|
experience.values,
|
|
|
|
experience.reward,
|
|
|
|
action_mask=experience.action_mask)
|
|
|
|
self.strategy.backward(critic_loss, self.critic, self.critic_optim)
|
|
|
|
self.strategy.optimizer_step(self.critic_optim)
|
|
|
|
self.critic_optim.zero_grad()
|
|
|
|
|
|
|
|
return {'actor_loss': actor_loss.item(), 'critic_loss': critic_loss.item()}
|
2023-02-15 05:59:58 +00:00
|
|
|
|
|
|
|
def _set_default_generate_kwargs(self, generate_kwargs: dict, actor: Actor) -> None:
|
2023-02-17 03:27:27 +00:00
|
|
|
origin_model = self.strategy._unwrap_actor(actor)
|
2023-02-15 05:59:58 +00:00
|
|
|
# use huggingface models method directly
|
2023-02-17 03:27:27 +00:00
|
|
|
if 'prepare_inputs_fn' not in generate_kwargs and hasattr(origin_model, 'prepare_inputs_for_generation'):
|
|
|
|
generate_kwargs['prepare_inputs_fn'] = origin_model.prepare_inputs_for_generation
|
2023-02-15 05:59:58 +00:00
|
|
|
|
|
|
|
if 'update_model_kwargs_fn' not in generate_kwargs:
|
|
|
|
generate_kwargs['update_model_kwargs_fn'] = update_model_kwargs_fn
|