ColossalAI/applications/Chat/coati/trainer/ppo.py

203 lines
9.1 KiB
Python
Raw Normal View History

from typing import Dict, List, Optional
2023-03-28 12:25:36 +00:00
[chat] fix bugs and add unit tests (#4213) * style: rename replay buffer Experience replay is typically for off policy algorithms. Use this name in PPO maybe misleading. * fix: fix wrong zero2 default arg * test: update experience tests * style: rename zero_pad fn * fix: defer init in CycledDataLoader * test: add benchmark test * style: rename internal fn of generation * style: rename internal fn of lora * fix: remove unused loss fn * fix: remove unused utils fn * refactor: remove generate_with_actor fn * fix: fix type annotation * test: add models tests * fix: skip llama due to long execution time * style: modify dataset * style: apply formatter * perf: update reward dataset * fix: fix wrong IGNORE_INDEX in sft dataset * fix: remove DataCollatorForSupervisedDataset * test: add dataset tests * style: apply formatter * style: rename test_ci to test_train * feat: add llama in inference * test: add inference tests * test: change test scripts directory * fix: update ci * fix: fix typo * fix: skip llama due to oom * fix: fix file mod * style: apply formatter * refactor: remove duplicated llama_gptq * style: apply formatter * to: update rm test * feat: add tokenizer arg * feat: add download model script * test: update train tests * fix: modify gemini load and save pretrained * test: update checkpoint io test * to: modify nproc_per_node * fix: do not remove existing dir * fix: modify save path * test: add random choice * fix: fix sft path * fix: enlarge nproc_per_node to avoid oom * fix: add num_retry * fix: make lora config of rm and critic consistent * fix: add warning about lora weights * fix: skip some gpt2 tests * fix: remove grad ckpt in rm and critic due to errors * refactor: directly use Actor in train_sft * test: add more arguments * fix: disable grad ckpt when using lora * fix: fix save_pretrained and related tests * test: enable zero2 tests * revert: remove useless fn * style: polish code * test: modify test args
2023-08-02 02:17:36 +00:00
from coati.experience_buffer import NaiveExperienceBuffer
2023-03-28 12:25:36 +00:00
from coati.experience_maker import Experience, NaiveExperienceMaker
from coati.models.base import Actor, Critic, RewardModel, get_base_model
from coati.models.loss import GPTLMLoss, PolicyLoss, ValueLoss
from coati.models.utils import calc_action_log_probs
2023-03-28 12:25:36 +00:00
from torch.optim import Optimizer
from torch.utils.data import DataLoader, DistributedSampler
from tqdm import tqdm
from transformers import PreTrainedTokenizerBase
2023-03-28 12:25:36 +00:00
from colossalai.utils import get_current_device
from .base import OnPolicyTrainer
2023-03-28 12:25:36 +00:00
from .callbacks import Callback
from .strategies import GeminiStrategy, Strategy
from .utils import CycledDataLoader, is_rank_0, to_device
2023-03-28 12:25:36 +00:00
def _set_default_generate_kwargs(strategy: Strategy, generate_kwargs: dict, actor: Actor) -> Dict:
unwrapped_model = strategy.unwrap_model(actor)
hf_model = get_base_model(unwrapped_model)
new_kwargs = {**generate_kwargs}
# use huggingface models method directly
if "prepare_inputs_fn" not in generate_kwargs and hasattr(hf_model, "prepare_inputs_for_generation"):
new_kwargs["prepare_inputs_fn"] = hf_model.prepare_inputs_for_generation
if "update_model_kwargs_fn" not in generate_kwargs and hasattr(hf_model, "_update_model_kwargs_for_generation"):
new_kwargs["update_model_kwargs_fn"] = hf_model._update_model_kwargs_for_generation
return new_kwargs
class PPOTrainer(OnPolicyTrainer):
2023-03-28 12:25:36 +00:00
"""
Trainer for PPO algorithm.
Args:
strategy (Strategy): the strategy to use for training
actor (Actor): the actor model in ppo algorithm
critic (Critic): the critic model in ppo algorithm
reward_model (RewardModel): the reward model in rlhf algorithm to make reward of sentences
2023-04-20 09:22:15 +00:00
initial_model (Actor): the initial model in rlhf algorithm to generate reference logics to limit the update of actor
2023-03-28 12:25:36 +00:00
actor_optim (Optimizer): the optimizer to use for actor model
critic_optim (Optimizer): the optimizer to use for critic model
kl_coef (float, defaults to 0.1): the coefficient of kl divergence loss
train_batch_size (int, defaults to 8): the batch size to use for training
buffer_limit (int, defaults to 0): the max_size limitation of buffer
buffer_cpu_offload (bool, defaults to True): whether to offload buffer to cpu
2023-03-28 12:25:36 +00:00
eps_clip (float, defaults to 0.2): the clip coefficient of policy loss
vf_coef (float, defaults to 1.0): the coefficient of value loss
ptx_coef (float, defaults to 0.9): the coefficient of ptx loss
2023-03-28 12:25:36 +00:00
value_clip (float, defaults to 0.4): the clip coefficient of value loss
sample_buffer (bool, defaults to False): whether to sample from buffer
2023-03-28 12:25:36 +00:00
dataloader_pin_memory (bool, defaults to True): whether to pin memory for data loader
offload_inference_models (bool, defaults to True): whether to offload inference models to cpu during training process
2023-03-28 12:25:36 +00:00
callbacks (List[Callback], defaults to []): the callbacks to call during training process
generate_kwargs (dict, optional): the kwargs to use while model generating
"""
def __init__(
self,
strategy: Strategy,
actor: Actor,
critic: Critic,
reward_model: RewardModel,
initial_model: Actor,
actor_optim: Optimizer,
critic_optim: Optimizer,
tokenizer: PreTrainedTokenizerBase,
kl_coef: float = 0.1,
ptx_coef: float = 0.9,
train_batch_size: int = 8,
buffer_limit: int = 0,
buffer_cpu_offload: bool = True,
eps_clip: float = 0.2,
vf_coef: float = 1.0,
value_clip: float = 0.4,
sample_buffer: bool = False,
dataloader_pin_memory: bool = True,
offload_inference_models: bool = True,
callbacks: List[Callback] = [],
**generate_kwargs,
) -> None:
if isinstance(strategy, GeminiStrategy):
assert not offload_inference_models, "GeminiPlugin is not compatible with manual model.to('cpu')"
[chat] fix bugs and add unit tests (#4213) * style: rename replay buffer Experience replay is typically for off policy algorithms. Use this name in PPO maybe misleading. * fix: fix wrong zero2 default arg * test: update experience tests * style: rename zero_pad fn * fix: defer init in CycledDataLoader * test: add benchmark test * style: rename internal fn of generation * style: rename internal fn of lora * fix: remove unused loss fn * fix: remove unused utils fn * refactor: remove generate_with_actor fn * fix: fix type annotation * test: add models tests * fix: skip llama due to long execution time * style: modify dataset * style: apply formatter * perf: update reward dataset * fix: fix wrong IGNORE_INDEX in sft dataset * fix: remove DataCollatorForSupervisedDataset * test: add dataset tests * style: apply formatter * style: rename test_ci to test_train * feat: add llama in inference * test: add inference tests * test: change test scripts directory * fix: update ci * fix: fix typo * fix: skip llama due to oom * fix: fix file mod * style: apply formatter * refactor: remove duplicated llama_gptq * style: apply formatter * to: update rm test * feat: add tokenizer arg * feat: add download model script * test: update train tests * fix: modify gemini load and save pretrained * test: update checkpoint io test * to: modify nproc_per_node * fix: do not remove existing dir * fix: modify save path * test: add random choice * fix: fix sft path * fix: enlarge nproc_per_node to avoid oom * fix: add num_retry * fix: make lora config of rm and critic consistent * fix: add warning about lora weights * fix: skip some gpt2 tests * fix: remove grad ckpt in rm and critic due to errors * refactor: directly use Actor in train_sft * test: add more arguments * fix: disable grad ckpt when using lora * fix: fix save_pretrained and related tests * test: enable zero2 tests * revert: remove useless fn * style: polish code * test: modify test args
2023-08-02 02:17:36 +00:00
data_buffer = NaiveExperienceBuffer(train_batch_size, buffer_limit, buffer_cpu_offload)
super().__init__(strategy, data_buffer, sample_buffer, dataloader_pin_memory, callbacks)
self.generate_kwargs = _set_default_generate_kwargs(strategy, generate_kwargs, actor)
self.experience_maker = NaiveExperienceMaker(actor, critic, reward_model, initial_model, tokenizer, kl_coef)
2023-03-28 12:25:36 +00:00
self.actor = actor
self.critic = critic
self.tokenizer = tokenizer
2023-03-28 12:25:36 +00:00
self.actor_loss_fn = PolicyLoss(eps_clip)
self.critic_loss_fn = ValueLoss(value_clip)
self.vf_coef = vf_coef
self.ptx_loss_fn = GPTLMLoss()
2023-03-28 12:25:36 +00:00
self.ptx_coef = ptx_coef
self.actor_optim = actor_optim
self.critic_optim = critic_optim
self.offload_inference_models = offload_inference_models
self.device = get_current_device()
def _before_fit(
self,
prompt_dataloader: DataLoader,
pretrain_dataloader: DataLoader,
log_dir: Optional[str] = None,
use_wandb: bool = False,
):
"""
Args:
prompt_dataloader (DataLoader): the dataloader to use for prompt data
pretrain_dataloader (DataLoader): the dataloader to use for pretrain data
"""
self.prompt_dataloader = CycledDataLoader(prompt_dataloader)
self.pretrain_dataloader = CycledDataLoader(pretrain_dataloader)
self.writer = None
if use_wandb and is_rank_0():
assert log_dir is not None, "log_dir must be provided when use_wandb is True"
import wandb
wandb.init(project="Coati-ppo", sync_tensorboard=True)
if log_dir is not None and is_rank_0():
import os
import time
from torch.utils.tensorboard import SummaryWriter
log_dir = os.path.join(log_dir, "ppo")
log_dir = os.path.join(log_dir, time.strftime("%Y-%m-%d_%H:%M:%S", time.localtime()))
self.writer = SummaryWriter(log_dir=log_dir)
def _make_experience(self, collect_step: int) -> Experience:
prompts = self.prompt_dataloader.next()
if self.offload_inference_models:
# TODO(ver217): this may be controlled by strategy if they are prepared by strategy
self.experience_maker.initial_model.to(self.device)
self.experience_maker.reward_model.to(self.device)
assert isinstance(prompts, dict), f'Unsupported input type "{type(prompts)}"'
return self.experience_maker.make_experience(**prompts, **self.generate_kwargs)
def _training_step(self, experience: Experience):
2023-03-28 12:25:36 +00:00
self.actor.train()
self.critic.train()
# policy loss
num_actions = experience.action_log_probs.size(1)
actor_logits = self.actor(experience.sequences, experience.attention_mask)["logits"]
action_log_probs = calc_action_log_probs(actor_logits, experience.sequences, num_actions)
actor_loss = self.actor_loss_fn(
action_log_probs, experience.action_log_probs, experience.advantages, action_mask=experience.action_mask
)
actor_loss = (1 - self.ptx_coef) * actor_loss
self.strategy.backward(actor_loss, self.actor, self.actor_optim)
2023-03-28 12:25:36 +00:00
# ptx loss
if self.ptx_coef != 0:
batch = self.pretrain_dataloader.next()
batch = to_device(batch, self.device)
ptx_log_probs = self.actor(batch["input_ids"], batch["attention_mask"])["logits"]
ptx_loss = self.ptx_coef * self.ptx_loss_fn(ptx_log_probs, batch["labels"])
self.strategy.backward(ptx_loss, self.actor, self.actor_optim)
2023-03-28 12:25:36 +00:00
self.strategy.optimizer_step(self.actor_optim)
self.actor_optim.zero_grad()
# value loss
values = self.critic(experience.sequences, attention_mask=experience.attention_mask)
critic_loss = self.critic_loss_fn(values, experience.values, experience.reward)
critic_loss = critic_loss * self.vf_coef
2023-03-28 12:25:36 +00:00
self.strategy.backward(critic_loss, self.critic, self.critic_optim)
self.strategy.optimizer_step(self.critic_optim)
self.critic_optim.zero_grad()
def _learn(self, update_step: int):
if self.offload_inference_models:
self.experience_maker.initial_model.to("cpu")
self.experience_maker.reward_model.to("cpu")
# buffer may be empty at first, we should rebuild at each training
if self.sample_buffer:
[chat] fix bugs and add unit tests (#4213) * style: rename replay buffer Experience replay is typically for off policy algorithms. Use this name in PPO maybe misleading. * fix: fix wrong zero2 default arg * test: update experience tests * style: rename zero_pad fn * fix: defer init in CycledDataLoader * test: add benchmark test * style: rename internal fn of generation * style: rename internal fn of lora * fix: remove unused loss fn * fix: remove unused utils fn * refactor: remove generate_with_actor fn * fix: fix type annotation * test: add models tests * fix: skip llama due to long execution time * style: modify dataset * style: apply formatter * perf: update reward dataset * fix: fix wrong IGNORE_INDEX in sft dataset * fix: remove DataCollatorForSupervisedDataset * test: add dataset tests * style: apply formatter * style: rename test_ci to test_train * feat: add llama in inference * test: add inference tests * test: change test scripts directory * fix: update ci * fix: fix typo * fix: skip llama due to oom * fix: fix file mod * style: apply formatter * refactor: remove duplicated llama_gptq * style: apply formatter * to: update rm test * feat: add tokenizer arg * feat: add download model script * test: update train tests * fix: modify gemini load and save pretrained * test: update checkpoint io test * to: modify nproc_per_node * fix: do not remove existing dir * fix: modify save path * test: add random choice * fix: fix sft path * fix: enlarge nproc_per_node to avoid oom * fix: add num_retry * fix: make lora config of rm and critic consistent * fix: add warning about lora weights * fix: skip some gpt2 tests * fix: remove grad ckpt in rm and critic due to errors * refactor: directly use Actor in train_sft * test: add more arguments * fix: disable grad ckpt when using lora * fix: fix save_pretrained and related tests * test: enable zero2 tests * revert: remove useless fn * style: polish code * test: modify test args
2023-08-02 02:17:36 +00:00
experience = self.data_buffer.sample()
self._on_learn_batch_start()
experience.to_device(self.device)
self._training_step(experience)
self._on_learn_batch_end(experience)
else:
if isinstance(self.dataloader.sampler, DistributedSampler):
self.dataloader.sampler.set_epoch(update_step)
pbar = tqdm(self.dataloader, desc=f"Train epoch [{update_step + 1}]", disable=not is_rank_0())
for experience in pbar:
self._on_learn_batch_start()
experience.to_device(self.device)
self._training_step(experience)
self._on_learn_batch_end(experience)