diff --git a/applications/ColossalChat/README.md b/applications/ColossalChat/README.md index 81009da9d..8783ea61e 100755 --- a/applications/ColossalChat/README.md +++ b/applications/ColossalChat/README.md @@ -23,6 +23,8 @@ - [Open QA](#open-qa) - [Limitation for LLaMA-finetuned models](#limitation) - [Limitation of dataset](#limitation) +- [Alternative Option For RLHF: DPO](#alternative-option-for-rlhf-direct-preference-optimization) +- [Alternative Option For RLHF: SimPO](#alternative-option-for-rlhf-simple-preference-optimization) - [FAQ](#faq) - [How to save/load checkpoint](#faq) - [How to train with limited resources](#faq) @@ -262,12 +264,8 @@ experience buffer size = train_batch_size * accumulation_steps * num_tp_group ``` -## Alternative Option For RLHF: Direct Preference Optimization - -For those seeking an alternative to Reinforcement Learning from Human Feedback (RLHF), Direct Preference Optimization (DPO) presents a compelling option. DPO, as detailed in this [paper](https://arxiv.org/abs/2305.18290), DPO offers an low-cost way to perform RLHF and usually request less computation resources compares to PPO. - -## Alternative Option For RLHF: Simple Preference Optimization -Simple Preference Optimization (SimPO) from this [paper](https://arxiv.org/pdf/2405.14734) is similar to DPO but it abandons the use of the reference model, which makes the training more efficient. It also adds a reward shaping term called target reward margin to enhance training stability. It also use length normalization to better align with the inference process. +## Alternative Option For RLHF: Direct Preference Optimization (DPO) +For those seeking an alternative to Reinforcement Learning from Human Feedback (RLHF), Direct Preference Optimization (DPO) presents a compelling option. DPO, as detailed in this [paper](https://arxiv.org/abs/2305.18290), DPO offers an low-cost way to perform RLHF and usually request less computation resources compares to PPO. Read this [README](./examples/README.md) for more information. ### DPO Training Stage1 - Supervised Instructs Tuning @@ -280,6 +278,12 @@ For DPO training, you only need the preference dataset. Please follow the instru #### Step 2: Training You can run the [train_dpo.sh](./examples/training_scripts/train_dpo.sh) to start DPO training. More detais can be found in [example guideline](./examples/README.md). +## Alternative Option For RLHF: Simple Preference Optimization (SimPO) +Simple Preference Optimization (SimPO) from this [paper](https://arxiv.org/pdf/2405.14734) is similar to DPO but it abandons the use of the reference model, which makes the training more efficient. It also adds a reward shaping term called target reward margin to enhance training stability. It also use length normalization to better align with the inference process. Read this [README](./examples/README.md) for more information. + +## Alternative Option For RLHF: Odds Ratio Preference Optimization (ORPO) +Odds Ratio Preference Optimization (ORPO) from this [paper](https://arxiv.org/pdf/2403.07691) is a reference model free alignment method that use a mixture of SFT loss and a reinforcement leanring loss calculated based on odds-ratio-based implicit reward to makes the training more efficient and stable. Read this [README](./examples/README.md) for more information. + ### Inference Quantization and Serving - After Training We provide an online inference server and a benchmark. We aim to run inference on single GPU, so quantization is essential when using large models. diff --git a/applications/ColossalChat/coati/models/loss.py b/applications/ColossalChat/coati/models/loss.py index fd5c82efc..06c2d688b 100755 --- a/applications/ColossalChat/coati/models/loss.py +++ b/applications/ColossalChat/coati/models/loss.py @@ -179,3 +179,28 @@ class LogExpLoss(nn.Module): def forward(self, chosen_reward: torch.Tensor, reject_reward: torch.Tensor) -> torch.Tensor: loss = torch.log(1 + torch.exp(reject_reward - chosen_reward)).mean() return loss + + +class OddsRatioLoss(nn.Module): + """ + Odds Ratio Loss in ORPO + Details: https://arxiv.org/pdf/2403.07691 + """ + + def forward( + self, + chosen_logp: torch.Tensor, + reject_logp: torch.Tensor, + chosen_loss_mask: torch.Tensor, + reject_loss_mask: torch.Tensor, + ) -> torch.Tensor: + chosen_logp = chosen_logp.to(dtype=torch.float32) + reject_logp = reject_logp.to(dtype=torch.float32) + chosen_odds = chosen_logp - torch.log(-torch.exp(chosen_logp) + 1.0001) + chosen_odds_masked = torch.sum(chosen_odds * chosen_loss_mask.float()) / torch.sum(chosen_loss_mask) + reject_odds = reject_logp - torch.log(-torch.exp(reject_logp) + 1.0001) + reject_odds_masked = torch.sum(reject_odds * reject_loss_mask.float()) / torch.sum(reject_loss_mask) + # print("chosen_odds_masked", chosen_odds_masked[0], "reject_odds_masked", reject_odds_masked[0]) + log_odds_ratio = chosen_odds_masked - reject_odds_masked + ratio = torch.log(torch.nn.functional.sigmoid(log_odds_ratio)) + return ratio.to(dtype=torch.bfloat16), log_odds_ratio diff --git a/applications/ColossalChat/coati/trainer/__init__.py b/applications/ColossalChat/coati/trainer/__init__.py index 2eff8ca76..6ce159678 100755 --- a/applications/ColossalChat/coati/trainer/__init__.py +++ b/applications/ColossalChat/coati/trainer/__init__.py @@ -1,7 +1,8 @@ from .base import OLTrainer, SLTrainer from .dpo import DPOTrainer +from .orpo import ORPOTrainer from .ppo import PPOTrainer from .rm import RewardModelTrainer from .sft import SFTTrainer -__all__ = ["SLTrainer", "OLTrainer", "RewardModelTrainer", "SFTTrainer", "PPOTrainer", "DPOTrainer"] +__all__ = ["SLTrainer", "OLTrainer", "RewardModelTrainer", "SFTTrainer", "PPOTrainer", "DPOTrainer", "ORPOTrainer"] diff --git a/applications/ColossalChat/coati/trainer/dpo.py b/applications/ColossalChat/coati/trainer/dpo.py index 97552fa7a..c095cc35c 100755 --- a/applications/ColossalChat/coati/trainer/dpo.py +++ b/applications/ColossalChat/coati/trainer/dpo.py @@ -134,7 +134,6 @@ class DPOTrainer(SLTrainer): batch["reject_attention_mask"], batch["reject_loss_mask"], ) - reject_loss_mask[:, -1] = False batch_size = chosen_input_ids.size()[0] actor_all_logits = self.model( diff --git a/applications/ColossalChat/coati/trainer/orpo.py b/applications/ColossalChat/coati/trainer/orpo.py new file mode 100644 index 000000000..aa94e0acb --- /dev/null +++ b/applications/ColossalChat/coati/trainer/orpo.py @@ -0,0 +1,339 @@ +""" +Orpo trainer +""" + +from typing import Any, Optional + +import torch +from coati.models.loss import OddsRatioLoss +from coati.models.utils import calc_masked_log_probs +from coati.trainer.utils import all_reduce_mean +from coati.utils import AccumulativeMeanMeter, save_checkpoint +from torch.nn import CrossEntropyLoss +from torch.optim import Optimizer +from torch.optim.lr_scheduler import _LRScheduler +from torch.utils.data import DataLoader +from tqdm import trange +from transformers import PreTrainedTokenizerBase + +from colossalai.booster import Booster +from colossalai.cluster import DistCoordinator +from colossalai.utils import get_current_device + +from .base import SLTrainer +from .utils import is_rank_0, to_device + + +class ORPOTrainer(SLTrainer): + """ + Trainer for PPO algorithm. + + Args: + actor (Actor): the actor model in ppo algorithm + booster (Strategy): the strategy to use for training + actor_optim (Optimizer): the optimizer to use for actor model + actor_lr_scheduler (_LRScheduler): the lr scheduler to use for actor model + tokenizer (PreTrainedTokenizerBase): the tokenizer to use for encoding + max_epochs (int, defaults to 1): the max number of epochs to train + lam (float, defaults to 0.1): the lambda parameter in ORPO loss + accumulation_steps (int): the number of steps to accumulate gradients + start_epoch (int, defaults to 0): the start epoch, non-zero if resumed from a checkpoint + save_interval (int): the interval to save model checkpoints, default to 0, which means no checkpoint will be saved during trainning + save_dir (str): the directory to save checkpoints + coordinator (DistCoordinator): the coordinator to use for distributed logging + """ + + def __init__( + self, + actor: Any, + booster: Booster, + actor_optim: Optimizer, + actor_lr_scheduler: _LRScheduler, + tokenizer: PreTrainedTokenizerBase, + max_epochs: int = 1, + lam: float = 0.1, + accumulation_steps: int = 1, + start_epoch: int = 0, + save_interval: int = 0, + save_dir: str = None, + coordinator: DistCoordinator = None, + ) -> None: + super().__init__(booster, max_epochs=max_epochs, model=actor, optimizer=actor_optim, start_epoch=start_epoch) + self.actor_scheduler = actor_lr_scheduler + self.tokenizer = tokenizer + self.odds_ratio_loss_fn = OddsRatioLoss() + self.sft_loss_fn = CrossEntropyLoss() + self.save_interval = save_interval + self.coordinator = coordinator + self.save_dir = save_dir + self.num_train_step = 0 + self.lam = lam + self.accumulation_steps = accumulation_steps + self.device = get_current_device() + self.accumulative_meter = AccumulativeMeanMeter() + + def _before_fit( + self, + train_preference_dataloader: DataLoader = None, + eval_preference_dataloader: DataLoader = None, + log_dir: Optional[str] = None, + use_wandb: bool = False, + ): + """ + Args: + prompt_dataloader (DataLoader): the dataloader to use for prompt data + pretrain_dataloader (DataLoader): the dataloader to use for pretrain data + """ + self.train_dataloader = train_preference_dataloader + self.eval_dataloader = eval_preference_dataloader + self.writer = None + if use_wandb and is_rank_0(): + assert log_dir is not None, "log_dir must be provided when use_wandb is True" + import wandb + + self.wandb_run = wandb.init(project="Coati-orpo", sync_tensorboard=True) + if log_dir is not None and is_rank_0(): + import os + import time + + from torch.utils.tensorboard import SummaryWriter + + log_dir = os.path.join(log_dir, "orpo") + log_dir = os.path.join(log_dir, time.strftime("%Y-%m-%d_%H:%M:%S", time.localtime())) + self.writer = SummaryWriter(log_dir=log_dir) + + def _train(self, epoch: int): + """ + Args: + epoch int: the number of current epoch + """ + self.model.train() + self.accumulative_meter.reset() + step_bar = trange( + len(self.train_dataloader) // self.accumulation_steps, + desc=f"Epoch {epoch + 1}/{self.max_epochs}", + disable=not is_rank_0(), + ) + for i, batch in enumerate(self.train_dataloader): + batch = to_device(batch, self.device) + ( + chosen_input_ids, + chosen_attention_mask, + chosen_loss_mask, + reject_input_ids, + reject_attention_mask, + reject_loss_mask, + ) = ( + batch["chosen_input_ids"], + batch["chosen_attention_mask"], + batch["chosen_loss_mask"], + batch["reject_input_ids"], + batch["reject_attention_mask"], + batch["reject_loss_mask"], + ) + batch_size = chosen_input_ids.size()[0] + actor_out = self.model( + input_ids=torch.cat([chosen_input_ids, reject_input_ids]), + attention_mask=torch.cat([chosen_attention_mask, reject_attention_mask]), + ) + torch.autograd.set_detect_anomaly(True) + actor_all_logits = actor_out["logits"].to(torch.float32) + actor_chosen_logits = actor_all_logits[:batch_size] + actor_reject_logits = actor_all_logits[batch_size:] + logprob_actor_chosen = calc_masked_log_probs(actor_chosen_logits, chosen_input_ids, chosen_loss_mask[:, 1:]) + + logprob_actor_reject = calc_masked_log_probs(actor_reject_logits, reject_input_ids, reject_loss_mask[:, 1:]) + chosen_logits = actor_chosen_logits[:, :-1, :].contiguous().view(-1, actor_chosen_logits.size(-1)) + label_chosen = chosen_input_ids[:, 1:].contiguous() + label_chosen_masked = ( + label_chosen.masked_fill(chosen_loss_mask[:, 1:] == 0, -100).view(-1).contiguous().detach() + ) + # label_chosen[chosen_loss_mask[:, 1:] == 0] = -100 + chosen_nll = self.sft_loss_fn(chosen_logits, label_chosen_masked).to(dtype=torch.bfloat16) + odds_ratio_loss, log_odds_ratio = self.odds_ratio_loss_fn( + logprob_actor_chosen, logprob_actor_reject, chosen_loss_mask[:, 1:], reject_loss_mask[:, 1:] + ) + loss = chosen_nll - odds_ratio_loss * self.lam + step_bar.set_description(f"Epoch {epoch + 1}/{self.max_epochs} Loss: {loss.detach().cpu().item():.4f}") + + self.booster.backward(loss=loss, optimizer=self.optimizer) + if self.num_train_step % self.accumulation_steps == self.accumulation_steps - 1: + self.optimizer.step() + self.optimizer.zero_grad() + self.actor_scheduler.step() + + chosen_rewards = torch.sum(logprob_actor_chosen) / torch.sum(chosen_loss_mask[:, 1:]) + rejected_rewards = torch.sum(logprob_actor_reject) / torch.sum(reject_loss_mask[:, 1:]) + reward_accuracies = torch.sum((log_odds_ratio > 0).float()) / torch.sum(log_odds_ratio != 0) + + # sync + loss_mean = all_reduce_mean(tensor=loss) + chosen_rewards_mean = all_reduce_mean(tensor=chosen_rewards) + rejected_rewards_mean = all_reduce_mean(tensor=rejected_rewards) + reward_accuracies_mean = all_reduce_mean(tensor=reward_accuracies) + self.accumulative_meter.add("chosen_rewards", chosen_rewards_mean.to(torch.float16).mean().item()) + self.accumulative_meter.add("rejected_rewards", rejected_rewards_mean.to(torch.float16).mean().item()) + self.accumulative_meter.add("loss", loss_mean.to(torch.float16).item()) + self.accumulative_meter.add("log_odds_ratio", log_odds_ratio.to(torch.float16).mean().item()) + self.accumulative_meter.add("accuracy", reward_accuracies_mean.to(torch.float16).item()) + + if i % self.accumulation_steps == self.accumulation_steps - 1: + self.num_train_step += 1 + step_bar.update() + # logging + if self.writer and is_rank_0(): + self.writer.add_scalar("train/loss", self.accumulative_meter.get("loss"), self.num_train_step) + self.writer.add_scalar("train/lr", self.optimizer.param_groups[0]["lr"], self.num_train_step) + self.writer.add_scalar( + "train/chosen_rewards", self.accumulative_meter.get("chosen_rewards"), self.num_train_step + ) + self.writer.add_scalar( + "train/rejected_rewards", + self.accumulative_meter.get("rejected_rewards"), + self.num_train_step, + ) + self.writer.add_scalar( + "train/margin", + self.accumulative_meter.get("chosen_rewards") - self.accumulative_meter.get("rejected_rewards"), + self.num_train_step, + ) + self.writer.add_scalar( + "train/accuracy", + self.accumulative_meter.get("accuracy"), + self.num_train_step, + ) + self.writer.add_scalar( + "train/log_odds_ratio", + self.accumulative_meter.get("log_odds_ratio"), + self.num_train_step, + ) + self.accumulative_meter.reset() + + if (self.num_train_step + 1) % self.save_interval == 0: + # save checkpoint + self.coordinator.print_on_master("\nStart saving model checkpoint with running states") + save_checkpoint( + save_dir=self.save_dir, + booster=self.booster, + model=self.model, + optimizer=self.optimizer, + lr_scheduler=self.actor_scheduler, + epoch=epoch, + step=i + 1, + batch_size=batch_size, + coordinator=self.coordinator, + ) + self.coordinator.print_on_master( + f"Saved checkpoint at epoch {epoch} step {self.save_interval} at folder {self.save_dir}" + ) + + step_bar.close() + + def _eval(self, epoch: int): + """ + Args: + epoch int: the number of current epoch + """ + if self.eval_dataloader is None: + self.coordinator.print_on_master("No eval dataloader is provided, skip evaluation") + return + self.model.eval() + self.coordinator.print_on_master("\nStart evaluation...") + + step_bar = trange( + len(self.eval_dataloader), + desc=f"Epoch {epoch + 1}/{self.max_epochs}", + disable=not is_rank_0(), + ) + + self.accumulative_meter.reset() + + with torch.no_grad(): + for i, batch in enumerate(self.eval_dataloader): + batch = to_device(batch, self.device) + ( + chosen_input_ids, + chosen_attention_mask, + chosen_loss_mask, + reject_input_ids, + reject_attention_mask, + reject_loss_mask, + ) = ( + batch["chosen_input_ids"], + batch["chosen_attention_mask"], + batch["chosen_loss_mask"], + batch["reject_input_ids"], + batch["reject_attention_mask"], + batch["reject_loss_mask"], + ) + batch_size = chosen_input_ids.size()[0] + actor_out = self.model( + input_ids=torch.cat([chosen_input_ids, reject_input_ids]), + labels=torch.cat([chosen_input_ids, reject_input_ids]), + attention_mask=torch.cat([chosen_attention_mask, reject_attention_mask]), + ) + actor_all_logits = actor_out["logits"].to(torch.float32) + chosen_nll = torch.mean(actor_out["loss"][:batch_size]).to(dtype=torch.bfloat16) + actor_chosen_logits = actor_all_logits[:batch_size] + actor_reject_logits = actor_all_logits[batch_size:] + logprob_actor_chosen = calc_masked_log_probs( + actor_chosen_logits, chosen_input_ids, chosen_loss_mask[:, 1:] + ) + + logprob_actor_reject = calc_masked_log_probs( + actor_reject_logits, reject_input_ids, reject_loss_mask[:, 1:] + ) + + odds_ratio_loss, log_odds_ratio = self.odds_ratio_loss_fn(logprob_actor_chosen, logprob_actor_reject) + + loss = chosen_nll - odds_ratio_loss * self.lam + + chosen_rewards = torch.mean(logprob_actor_chosen).item() + rejected_rewards = torch.mean(logprob_actor_reject).item() + reward_accuracies = (log_odds_ratio > 0).float().mean().item() + + # sync + loss_mean = all_reduce_mean(tensor=loss) + chosen_rewards_mean = all_reduce_mean(tensor=chosen_rewards) + rejected_rewards_mean = all_reduce_mean(tensor=rejected_rewards) + reward_accuracies_mean = all_reduce_mean(tensor=reward_accuracies) + self.accumulative_meter.add("chosen_rewards", chosen_rewards_mean.to(torch.float16).mean().item()) + self.accumulative_meter.add("rejected_rewards", rejected_rewards_mean.to(torch.float16).mean().item()) + self.accumulative_meter.add("loss", loss_mean.to(torch.float16).item()) + self.accumulative_meter.add("log_odds_ratio", log_odds_ratio.to(torch.float16).mean().item()) + self.accumulative_meter.add("accuracy", reward_accuracies_mean.to(torch.float16).item()) + + # logging + if self.writer and is_rank_0(): + self.writer.add_scalar("eval/loss", self.accumulative_meter.get("loss"), self.num_train_step) + self.writer.add_scalar("train/lr", self.optimizer.param_groups[0]["lr"], self.num_train_step) + self.writer.add_scalar( + "train/chosen_rewards", self.accumulative_meter.get("chosen_rewards"), self.num_train_step + ) + self.writer.add_scalar( + "train/rejected_rewards", + self.accumulative_meter.get("rejected_rewards"), + self.num_train_step, + ) + self.writer.add_scalar( + "train/log", + self.accumulative_meter.get("chosen_rewards") - self.accumulative_meter.get("rejected_rewards"), + self.num_train_step, + ) + self.writer.add_scalar( + "train/accuracy", + self.accumulative_meter.get("accuracy"), + self.num_train_step, + ) + self.writer.add_scalar( + "train/log_odds_ratio", + self.accumulative_meter.get("log_odds_ratio"), + self.num_train_step, + ) + self.step_bar.update() + + msg = "Evaluation Result:\n" + for tag in ["loss", "chosen_rewards", "rejected_rewards", "log_odds_ratio", "accuracy"]: + msg = msg + f"{tag}: {self.accumulative_meter.get(tag)}\n" + self.coordinator.print_on_master(msg) + step_bar.close() diff --git a/applications/ColossalChat/examples/README.md b/applications/ColossalChat/examples/README.md index 1a7ddd5a0..8b1f0d2b0 100755 --- a/applications/ColossalChat/examples/README.md +++ b/applications/ColossalChat/examples/README.md @@ -735,13 +735,22 @@ You can run the [train_dpo.sh](./examples/training_scripts/train_dpo.sh) to star ### Alternative Option For RLHF: Simple Preference Optimization We support the method introduced in the paper [SimPO: Simple Preference Optimization -with a Reference-Free Reward](https://arxiv.org/pdf/2405.14734) (SimPO). Which is a reference model free aligment method that add length normalization and reward shaping to the DPO loss to enhance training stability and efficiency. As the method doesn't deviate too much from DPO, we add support for length normalization and SimPO reward shaping in our DPO implementation. Simply set the flag to disable the use of the reference model, set the reward target margin and enable length normalization in the DPO training script. +with a Reference-Free Reward](https://arxiv.org/pdf/2405.14734) (SimPO). Which is a reference model free aligment method that add length normalization and reward shaping to the DPO loss to enhance training stability and efficiency. As the method doesn't deviate too much from DPO, we add support for length normalization and SimPO reward shaping in our DPO implementation. To use SimPO in alignment, use the [train_dpo.sh](./examples/training_scripts/train_dpo.sh) script, set the `loss_type` to `simpo_loss`, you can also set the value for temperature (`beta`) and reward target margin (`gamma`) but it is optional. #### SimPO Result
+
+