2023-06-29 02:48:09 +00:00
|
|
|
from typing import Dict, List
|
2023-03-28 12:25:36 +00:00
|
|
|
|
|
|
|
import torch.nn as nn
|
2023-08-02 02:17:36 +00:00
|
|
|
from coati.experience_buffer import NaiveExperienceBuffer
|
2023-03-28 12:25:36 +00:00
|
|
|
from coati.experience_maker import Experience, NaiveExperienceMaker
|
2023-06-13 05:31:56 +00:00
|
|
|
from coati.models.base import Actor, Critic, get_base_model
|
2023-04-26 10:11:49 +00:00
|
|
|
from coati.models.loss import GPTLMLoss, PolicyLoss, ValueLoss
|
2023-06-13 05:31:56 +00:00
|
|
|
from coati.models.utils import calc_action_log_probs
|
2023-04-18 08:44:03 +00:00
|
|
|
from torch import Tensor
|
2023-03-28 12:25:36 +00:00
|
|
|
from torch.optim import Optimizer
|
2023-06-29 02:48:09 +00:00
|
|
|
from torch.utils.data import DataLoader, DistributedSampler
|
2023-04-18 08:44:03 +00:00
|
|
|
from tqdm import tqdm
|
2023-03-28 12:25:36 +00:00
|
|
|
|
2023-04-26 10:11:49 +00:00
|
|
|
from colossalai.utils import get_current_device
|
|
|
|
|
2023-06-29 02:48:09 +00:00
|
|
|
from .base import OnPolicyTrainer
|
2023-03-28 12:25:36 +00:00
|
|
|
from .callbacks import Callback
|
2023-06-29 10:11:00 +00:00
|
|
|
from .strategies import GeminiStrategy, Strategy
|
2023-04-26 10:11:49 +00:00
|
|
|
from .utils import is_rank_0, to_device
|
2023-03-28 12:25:36 +00:00
|
|
|
|
|
|
|
|
2023-06-29 02:48:09 +00:00
|
|
|
def _set_default_generate_kwargs(strategy: Strategy, generate_kwargs: dict, actor: Actor) -> Dict:
|
|
|
|
unwrapper_model = strategy.unwrap_model(actor)
|
|
|
|
hf_model = get_base_model(unwrapper_model)
|
|
|
|
new_kwargs = {**generate_kwargs}
|
|
|
|
# use huggingface models method directly
|
|
|
|
if 'prepare_inputs_fn' not in generate_kwargs and hasattr(hf_model, 'prepare_inputs_for_generation'):
|
|
|
|
new_kwargs['prepare_inputs_fn'] = hf_model.prepare_inputs_for_generation
|
|
|
|
|
|
|
|
if 'update_model_kwargs_fn' not in generate_kwargs and hasattr(hf_model, '_update_model_kwargs_for_generation'):
|
|
|
|
new_kwargs['update_model_kwargs_fn'] = hf_model._update_model_kwargs_for_generation
|
|
|
|
|
|
|
|
return new_kwargs
|
|
|
|
|
|
|
|
|
|
|
|
class PPOTrainer(OnPolicyTrainer):
|
2023-03-28 12:25:36 +00:00
|
|
|
"""
|
|
|
|
Trainer for PPO algorithm.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
strategy (Strategy): the strategy to use for training
|
|
|
|
actor (Actor): the actor model in ppo algorithm
|
|
|
|
critic (Critic): the critic model in ppo algorithm
|
|
|
|
reward_model (nn.Module): the reward model in rlhf algorithm to make reward of sentences
|
2023-04-20 09:22:15 +00:00
|
|
|
initial_model (Actor): the initial model in rlhf algorithm to generate reference logics to limit the update of actor
|
2023-03-28 12:25:36 +00:00
|
|
|
actor_optim (Optimizer): the optimizer to use for actor model
|
|
|
|
critic_optim (Optimizer): the optimizer to use for critic model
|
|
|
|
kl_coef (float, defaults to 0.1): the coefficient of kl divergence loss
|
|
|
|
train_batch_size (int, defaults to 8): the batch size to use for training
|
2023-06-29 02:48:09 +00:00
|
|
|
buffer_limit (int, defaults to 0): the max_size limitation of buffer
|
|
|
|
buffer_cpu_offload (bool, defaults to True): whether to offload buffer to cpu
|
2023-03-28 12:25:36 +00:00
|
|
|
eps_clip (float, defaults to 0.2): the clip coefficient of policy loss
|
2023-04-11 01:54:59 +00:00
|
|
|
vf_coef (float, defaults to 1.0): the coefficient of value loss
|
2023-04-18 08:44:03 +00:00
|
|
|
ptx_coef (float, defaults to 0.9): the coefficient of ptx loss
|
2023-03-28 12:25:36 +00:00
|
|
|
value_clip (float, defaults to 0.4): the clip coefficient of value loss
|
2023-06-29 02:48:09 +00:00
|
|
|
sample_buffer (bool, defaults to False): whether to sample from buffer
|
2023-03-28 12:25:36 +00:00
|
|
|
dataloader_pin_memory (bool, defaults to True): whether to pin memory for data loader
|
2023-04-26 10:11:49 +00:00
|
|
|
offload_inference_models (bool, defaults to True): whether to offload inference models to cpu during training process
|
2023-03-28 12:25:36 +00:00
|
|
|
callbacks (List[Callback], defaults to []): the callbacks to call during training process
|
|
|
|
generate_kwargs (dict, optional): the kwargs to use while model generating
|
|
|
|
"""
|
|
|
|
|
|
|
|
def __init__(self,
|
|
|
|
strategy: Strategy,
|
|
|
|
actor: Actor,
|
|
|
|
critic: Critic,
|
|
|
|
reward_model: nn.Module,
|
|
|
|
initial_model: Actor,
|
|
|
|
actor_optim: Optimizer,
|
|
|
|
critic_optim: Optimizer,
|
|
|
|
kl_coef: float = 0.1,
|
|
|
|
ptx_coef: float = 0.9,
|
|
|
|
train_batch_size: int = 8,
|
|
|
|
buffer_limit: int = 0,
|
|
|
|
buffer_cpu_offload: bool = True,
|
|
|
|
eps_clip: float = 0.2,
|
2023-04-11 01:54:59 +00:00
|
|
|
vf_coef: float = 1.0,
|
2023-03-28 12:25:36 +00:00
|
|
|
value_clip: float = 0.4,
|
2023-06-29 02:48:09 +00:00
|
|
|
sample_buffer: bool = False,
|
2023-03-28 12:25:36 +00:00
|
|
|
dataloader_pin_memory: bool = True,
|
2023-04-26 10:11:49 +00:00
|
|
|
offload_inference_models: bool = True,
|
2023-03-28 12:25:36 +00:00
|
|
|
callbacks: List[Callback] = [],
|
2023-06-29 02:48:09 +00:00
|
|
|
**generate_kwargs
|
|
|
|
) -> None:
|
2023-06-29 10:11:00 +00:00
|
|
|
if isinstance(strategy, GeminiStrategy):
|
|
|
|
assert not offload_inference_models, \
|
2023-06-25 09:36:21 +00:00
|
|
|
"GeminiPlugin is not compatible with manual model.to('cpu')"
|
|
|
|
|
2023-08-02 02:17:36 +00:00
|
|
|
data_buffer = NaiveExperienceBuffer(train_batch_size, buffer_limit, buffer_cpu_offload)
|
2023-06-29 02:48:09 +00:00
|
|
|
super().__init__(
|
2023-08-02 02:17:36 +00:00
|
|
|
strategy, data_buffer,
|
2023-06-29 02:48:09 +00:00
|
|
|
sample_buffer, dataloader_pin_memory,
|
|
|
|
callbacks
|
|
|
|
)
|
2023-04-18 08:44:03 +00:00
|
|
|
|
2023-06-29 02:48:09 +00:00
|
|
|
self.generate_kwargs = _set_default_generate_kwargs(strategy, generate_kwargs, actor)
|
|
|
|
self.experience_maker = NaiveExperienceMaker(actor, critic, reward_model, initial_model, kl_coef)
|
2023-04-26 10:11:49 +00:00
|
|
|
self.offload_inference_models = offload_inference_models
|
2023-04-18 08:44:03 +00:00
|
|
|
|
2023-03-28 12:25:36 +00:00
|
|
|
self.actor = actor
|
|
|
|
self.critic = critic
|
|
|
|
|
|
|
|
self.actor_loss_fn = PolicyLoss(eps_clip)
|
|
|
|
self.critic_loss_fn = ValueLoss(value_clip)
|
2023-04-11 01:54:59 +00:00
|
|
|
self.vf_coef = vf_coef
|
2023-04-26 10:11:49 +00:00
|
|
|
self.ptx_loss_fn = GPTLMLoss()
|
2023-03-28 12:25:36 +00:00
|
|
|
self.ptx_coef = ptx_coef
|
|
|
|
self.actor_optim = actor_optim
|
|
|
|
self.critic_optim = critic_optim
|
|
|
|
|
2023-04-26 10:11:49 +00:00
|
|
|
self.device = get_current_device()
|
|
|
|
|
2023-06-29 02:48:09 +00:00
|
|
|
def _make_experience(self, collect_step: int) -> Experience:
|
|
|
|
prompts = self.prompt_dataloader.next()
|
|
|
|
if self.offload_inference_models:
|
|
|
|
# TODO(ver217): this may be controlled by strategy if they are prepared by strategy
|
|
|
|
self.experience_maker.initial_model.to(self.device)
|
|
|
|
self.experience_maker.reward_model.to(self.device)
|
|
|
|
if isinstance(prompts, Tensor):
|
|
|
|
return self.experience_maker.make_experience(prompts, **self.generate_kwargs)
|
|
|
|
elif isinstance(prompts, dict):
|
|
|
|
return self.experience_maker.make_experience(**prompts, **self.generate_kwargs)
|
2023-04-18 08:44:03 +00:00
|
|
|
else:
|
2023-06-29 02:48:09 +00:00
|
|
|
raise ValueError(f'Unsupported input type "{type(prompts)}"')
|
|
|
|
|
|
|
|
def _training_step(self, experience: Experience) -> Dict[str, float]:
|
2023-03-28 12:25:36 +00:00
|
|
|
self.actor.train()
|
|
|
|
self.critic.train()
|
|
|
|
# policy loss
|
|
|
|
num_actions = experience.action_mask.size(1)
|
2023-06-13 05:31:56 +00:00
|
|
|
actor_output = self.actor(experience.sequences, attention_mask=experience.attention_mask)
|
|
|
|
action_log_probs = calc_action_log_probs(actor_output, experience.sequences, num_actions)
|
2023-03-28 12:25:36 +00:00
|
|
|
actor_loss = self.actor_loss_fn(action_log_probs,
|
|
|
|
experience.action_log_probs,
|
|
|
|
experience.advantages,
|
|
|
|
action_mask=experience.action_mask)
|
|
|
|
|
|
|
|
# ptx loss
|
|
|
|
if self.ptx_coef != 0:
|
2023-06-29 02:48:09 +00:00
|
|
|
batch = self.pretrain_dataloader.next()
|
2023-04-26 10:11:49 +00:00
|
|
|
batch = to_device(batch, self.device)
|
2023-06-13 05:31:56 +00:00
|
|
|
ptx_log_probs = self.actor(batch['input_ids'],
|
|
|
|
attention_mask=batch['attention_mask'])['logits']
|
2023-04-26 10:11:49 +00:00
|
|
|
ptx_loss = self.ptx_loss_fn(ptx_log_probs, batch['labels'])
|
2023-03-28 12:25:36 +00:00
|
|
|
actor_loss = ptx_loss * self.ptx_coef + actor_loss * (1 - self.ptx_coef)
|
|
|
|
|
|
|
|
self.strategy.backward(actor_loss, self.actor, self.actor_optim)
|
|
|
|
self.strategy.optimizer_step(self.actor_optim)
|
|
|
|
self.actor_optim.zero_grad()
|
|
|
|
|
|
|
|
# value loss
|
|
|
|
values = self.critic(experience.sequences,
|
|
|
|
action_mask=experience.action_mask,
|
|
|
|
attention_mask=experience.attention_mask)
|
|
|
|
critic_loss = self.critic_loss_fn(values,
|
|
|
|
experience.values,
|
|
|
|
experience.reward,
|
|
|
|
action_mask=experience.action_mask)
|
2023-04-11 01:54:59 +00:00
|
|
|
critic_loss = critic_loss * self.vf_coef
|
2023-03-28 12:25:36 +00:00
|
|
|
self.strategy.backward(critic_loss, self.critic, self.critic_optim)
|
|
|
|
self.strategy.optimizer_step(self.critic_optim)
|
|
|
|
self.critic_optim.zero_grad()
|
|
|
|
|
|
|
|
return {'reward': experience.reward.mean().item()}
|
2023-04-26 08:32:40 +00:00
|
|
|
|
2023-06-29 02:48:09 +00:00
|
|
|
def _learn(self, update_step: int):
|
|
|
|
if self.offload_inference_models:
|
|
|
|
self.experience_maker.initial_model.to('cpu')
|
|
|
|
self.experience_maker.reward_model.to('cpu')
|
|
|
|
|
|
|
|
# buffer may be empty at first, we should rebuild at each training
|
|
|
|
if self.sample_buffer:
|
2023-08-02 02:17:36 +00:00
|
|
|
experience = self.data_buffer.sample()
|
2023-06-29 02:48:09 +00:00
|
|
|
self._on_learn_batch_start()
|
|
|
|
experience.to_device(self.device)
|
|
|
|
metrics = self._training_step(experience)
|
|
|
|
self._on_learn_batch_end(metrics, experience)
|
|
|
|
else:
|
|
|
|
if isinstance(self.dataloader.sampler, DistributedSampler):
|
|
|
|
self.dataloader.sampler.set_epoch(update_step)
|
|
|
|
pbar = tqdm(
|
|
|
|
self.dataloader,
|
|
|
|
desc=f'Train epoch [{update_step + 1}]',
|
|
|
|
disable=not is_rank_0()
|
|
|
|
)
|
|
|
|
for experience in pbar:
|
|
|
|
self._on_learn_batch_start()
|
|
|
|
experience.to_device(self.device)
|
|
|
|
metrics = self._training_step(experience)
|
|
|
|
self._on_learn_batch_end(metrics, experience)
|
|
|
|
pbar.set_postfix(metrics)
|