ColossalAI/applications/ColossalChat/coati/trainer/orpo.py

315 lines
14 KiB
Python
Raw Normal View History

[FP8] rebase main (#5963) * add SimPO * fix dataloader * remove debug code * add orpo * fix style * fix colossalai, transformers version * fix colossalai, transformers version * fix colossalai, transformers version * fix torch colossalai version * update transformers version * [shardformer] DeepseekMoE support (#5871) * [Feature] deepseek moe expert parallel implement * [misc] fix typo, remove redundant file (#5867) * [misc] fix typo * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> * [Feature] deepseek support & unit test * [misc] remove debug code & useless print * [misc] fix typos (#5872) * [Feature] remove modeling file, use auto config. (#5884) * [misc] fix typos * [Feature] deepseek support via auto model, remove modeling file * [misc] delete useless file * [misc] fix typos * [Deepseek] remove redundant code (#5888) * [misc] fix typos * [Feature] deepseek support via auto model, remove modeling file * [misc] delete useless file * [misc] fix typos * [misc] remove redundant code * [Feature/deepseek] resolve comment. (#5889) * [misc] fix typos * [Feature] deepseek support via auto model, remove modeling file * [misc] delete useless file * [misc] fix typos * [misc] remove redundant code * [misc] mv module replacement into if branch * [misc] add some warning message and modify some code in unit test * [misc] fix typos --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> * [Hoxfix] Fix CUDA_DEVICE_MAX_CONNECTIONS for comm overlap Co-authored-by: Edenzzzz <wtan45@wisc.edu> * [Feat] Diffusion Model(PixArtAlpha/StableDiffusion3) Support (#5838) * Diffusion Model Inference support * Stable Diffusion 3 Support * pixartalpha support * [HotFix] CI,import,requirements-test for #5838 (#5892) * [Hot Fix] CI,import,requirements-test --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> * [Feature] Enable PP + SP for llama (#5868) * fix cross-PP-stage position id length diff bug * fix typo * fix typo * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * use a one cross entropy func for all shardformer models --------- Co-authored-by: Edenzzzz <wtan45@wisc.edu> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> * [ShardFormer] Add Ulysses Sequence Parallelism support for Command-R, Qwen2 and ChatGLM (#5897) * add benchmark for sft, dpo, simpo, orpo. Add benchmarking result. Support lora with gradient checkpoint * fix style * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix eval * hotfix citation * [zero] support all-gather overlap (#5898) * [zero] support all-gather overlap * [zero] add overlap all-gather flag * [misc] fix typo * [zero] update api * fix orpo cross entropy loss * [Auto Parallel]: Speed up intra-op plan generation by 44% (#5446) * Remove unnecessary calls to deepcopy * Build DimSpec's difference dict only once This change considerably speeds up construction speed of DimSpec objects. The difference_dict is the same for each DimSpec object, so a single copy of it is enough. * Fix documentation of DimSpec's difference method * [ShardFormer] fix qwen2 sp (#5903) * [compatibility] support torch 2.2 (#5875) * Support Pytorch 2.2.2 * keep build_on_pr file and update .compatibility * fix object_to_tensor usage when torch>=2.3.0 (#5820) * [misc] support torch2.3 (#5893) * [misc] support torch2.3 * [devops] update compatibility ci * [devops] update compatibility ci * [devops] add debug * [devops] add debug * [devops] add debug * [devops] add debug * [devops] remove debug * [devops] remove debug * [release] update version (#5912) * [plugin] support all-gather overlap for hybrid parallel (#5919) * [plugin] fixed all-gather overlap support for hybrid parallel * add kto * fix style, add kto data sample * [Examples] Add lazy init to OPT and GPT examples (#5924) Co-authored-by: Edenzzzz <wtan45@wisc.edu> * [ColossalChat] Hotfix for ColossalChat (#5910) * add ignore and tiny llama * fix path issue * run style * fix issue * update bash * add ignore and tiny llama * fix path issue * run style * fix issue * update bash * fix ddp issue * add Qwen 1.5 32B * refactor tokenization * [FIX BUG] UnboundLocalError: cannot access local variable 'default_conversation' where it is not associated with a value (#5931) * cannot access local variable 'default_conversation' where it is not associated with a value set default value for 'default_conversation' * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> * fix test data * refactor evaluation * remove real data path * remove real data path * Add n_fused as an input from native_module (#5894) * [FIX BUG] convert env param to int in (#5934) * [Hotfix] Fix ZeRO typo #5936 Co-authored-by: Edenzzzz <wtan45@wisc.edu> * [Feature] Add a switch to control whether the model checkpoint needs to be saved after each epoch ends (#5941) * Add a switch to control whether the model checkpoint needs to be saved after each epoch ends * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> * fix style * fix style * fix style * [shardformer] hotfix attn mask (#5945) * [shardformer] hotfix attn mask (#5947) * [Feat] Distrifusion Acceleration Support for Diffusion Inference (#5895) * Distrifusion Support source * comp comm overlap optimization * sd3 benchmark * pixart distrifusion bug fix * sd3 bug fix and benchmark * generation bug fix * naming fix * add docstring, fix counter and shape error * add reference * readme and requirement * [zero] hotfix update master params (#5951) * [release] update version (#5952) * [Chat] Fix lora (#5946) * fix merging * remove filepath * fix style * Update README.md (#5958) * [hotfix] Remove unused plan section (#5957) * remove readme * fix readme * update * [test] add mixtral for sequence classification * [test] add mixtral transformer test * [moe] fix plugin * [test] mixtra pp shard test * [chore] handle non member group * [zero] solve hang * [test] pass mixtral shardformer test * [moe] implement transit between non moe tp and ep * [zero] solve hang * [misc] solve booster hang by rename the variable * solve hang when parallel mode = pp + dp * [moe] implement submesh initialization * [moe] add mixtral dp grad scaling when not all experts are activated * [chore] manually revert unintended commit * [chore] trivial fix * [chore] arg pass & remove drop token * [test] add mixtral modelling test * [moe] implement tp * [moe] test deepseek * [moe] clean legacy code * [Feature] MoE Ulysses Support (#5918) * moe sp support * moe sp bug solve * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> * [chore] minor fix * [moe] init moe plugin comm setting with sp * moe sp + ep bug fix * [moe] finalize test (no pp) * [moe] full test for deepseek and mixtral (pp + sp to fix) * [chore] minor fix after rebase * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * [chore] solve moe ckpt test failure and some other arg pass failure * [moe] remove ops * [test] fix test: test_zero1_2 * [bug] fix: somehow logger hangs the program * [moe] deepseek moe sp support * [test] add check * [deepseek] replace attn (a workaround for bug in transformers) * [misc] skip redunant test * [misc] remove debug/print code * [moe] refactor mesh assignment * Revert "[moe] implement submesh initialization" This reverts commit 2f9bce6686d1415a83d5726dc5ff02222c742582. * [chore] change moe_pg_mesh to private * [misc] remove incompatible test config * [misc] fix ci failure: change default value to false in moe plugin * [misc] remove useless condition * [chore] docstring * [moe] remove force_overlap_comm flag and add warning instead * [doc] add MoeHybridParallelPlugin docstring * [moe] solve dp axis issue * [chore] remove redundant test case, print string & reduce test tokens * [feat] Dist Loader for Eval (#5950) * support auto distributed data loader * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * support auto distributed data loader * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix tp error * remove unused parameters * remove unused * update inference * update docs * update inference --------- Co-authored-by: Michelle <qianranma8@gmail.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> * [lora] lora support hybrid parallel plugin (#5956) * lora support hybrid plugin * fix * fix * fix * fix * fp8 operators for compressed communication cast_to_fp8, cast_from_fp8, all_reduce_fp8 * fix scaling algorithm in FP8 casting * support fp8 communication in pipeline parallelism * add fp8_communication flag in the script * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix typo * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * shardformer fp8 * fix rebase * remove all to all * fix shardformer fp8 communication training degradation * [fp8] support all-gather flat tensor (#5932) * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix * Update low_level_optim.py --------- Co-authored-by: YeAnbang <anbangy2@outlook.com> Co-authored-by: Haze188 <haze188@qq.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Edenzzzz <wenxuan.tan@wisc.edu> Co-authored-by: Edenzzzz <wtan45@wisc.edu> Co-authored-by: Runyu Lu <77330637+LRY89757@users.noreply.github.com> Co-authored-by: Guangyao Zhang <xjtu521@qq.com> Co-authored-by: YeAnbang <44796419+YeAnbang@users.noreply.github.com> Co-authored-by: Hongxin Liu <lhx0217@gmail.com> Co-authored-by: Stephan Kö <stephankoe@users.noreply.github.com> Co-authored-by: アマデウス <kurisusnowdeng@users.noreply.github.com> Co-authored-by: Tong Li <tong.li352711588@gmail.com> Co-authored-by: zhurunhua <1281592874@qq.com> Co-authored-by: Insu Jang <insujang@umich.edu> Co-authored-by: Gao, Ruiyuan <905370712@qq.com> Co-authored-by: hxwang <wang1570@e.ntu.edu.sg> Co-authored-by: Michelle <qianranma8@gmail.com> Co-authored-by: Wang Binluo <32676639+wangbluo@users.noreply.github.com> Co-authored-by: HangXu <hangxu0304@gmail.com>
2024-08-06 08:29:37 +00:00
"""
Orpo trainer
"""
import os
from typing import Any, Optional
import torch
from coati.models.loss import OddsRatioLoss
from coati.models.utils import calc_masked_log_probs
from coati.trainer.utils import all_reduce_mean
from coati.utils import AccumulativeMeanMeter, save_checkpoint
from torch.optim import Optimizer
from torch.optim.lr_scheduler import _LRScheduler
from torch.utils.data import DataLoader
from tqdm import trange
from transformers import PreTrainedTokenizerBase
from colossalai.booster import Booster
from colossalai.cluster import DistCoordinator
from colossalai.utils import get_current_device
from .base import SLTrainer
from .utils import is_rank_0, to_device
class ORPOTrainer(SLTrainer):
"""
Trainer for ORPO algorithm.
Args:
actor (Actor): the actor model in ppo algorithm
booster (Strategy): the strategy to use for training
actor_optim (Optimizer): the optimizer to use for actor model
actor_lr_scheduler (_LRScheduler): the lr scheduler to use for actor model
tokenizer (PreTrainedTokenizerBase): the tokenizer to use for encoding
max_epochs (int, defaults to 1): the max number of epochs to train
lam (float, defaults to 0.1): the lambda parameter in ORPO loss
accumulation_steps (int): the number of steps to accumulate gradients
start_epoch (int, defaults to 0): the start epoch, non-zero if resumed from a checkpoint
save_interval (int): the interval to save model checkpoints, default to 0, which means no checkpoint will be saved during trainning
save_dir (str): the directory to save checkpoints
coordinator (DistCoordinator): the coordinator to use for distributed logging
"""
def __init__(
self,
actor: Any,
booster: Booster,
actor_optim: Optimizer,
actor_lr_scheduler: _LRScheduler,
tokenizer: PreTrainedTokenizerBase,
max_epochs: int = 1,
lam: float = 0.1,
accumulation_steps: int = 1,
start_epoch: int = 0,
save_interval: int = 0,
save_dir: str = None,
coordinator: DistCoordinator = None,
) -> None:
super().__init__(booster, max_epochs=max_epochs, model=actor, optimizer=actor_optim, start_epoch=start_epoch)
self.actor_scheduler = actor_lr_scheduler
self.tokenizer = tokenizer
self.odds_ratio_loss_fn = OddsRatioLoss()
self.save_interval = save_interval
self.coordinator = coordinator
self.save_dir = save_dir
self.num_train_step = 0
self.lam = lam
self.accumulation_steps = accumulation_steps
self.device = get_current_device()
self.accumulative_meter = AccumulativeMeanMeter()
def _before_fit(
self,
train_preference_dataloader: DataLoader = None,
eval_preference_dataloader: DataLoader = None,
log_dir: Optional[str] = None,
use_wandb: bool = False,
):
"""
Args:
prompt_dataloader (DataLoader): the dataloader to use for prompt data
pretrain_dataloader (DataLoader): the dataloader to use for pretrain data
"""
self.train_dataloader = train_preference_dataloader
self.eval_dataloader = eval_preference_dataloader
self.writer = None
if use_wandb and is_rank_0():
assert log_dir is not None, "log_dir must be provided when use_wandb is True"
import wandb
self.wandb_run = wandb.init(project="Coati-orpo", sync_tensorboard=True)
if log_dir is not None and is_rank_0():
import os
import time
from torch.utils.tensorboard import SummaryWriter
log_dir = os.path.join(log_dir, "orpo")
log_dir = os.path.join(log_dir, time.strftime("%Y-%m-%d_%H:%M:%S", time.localtime()))
self.writer = SummaryWriter(log_dir=log_dir)
def _train(self, epoch: int):
"""
Args:
epoch int: the number of current epoch
"""
self.model.train()
self.accumulative_meter.reset()
step_bar = trange(
len(self.train_dataloader) // self.accumulation_steps,
desc=f"Epoch {epoch + 1}/{self.max_epochs}",
disable=not is_rank_0(),
)
for i, batch in enumerate(self.train_dataloader):
batch = to_device(batch, self.device)
(
chosen_input_ids,
chosen_attention_mask,
chosen_loss_mask,
reject_input_ids,
reject_attention_mask,
reject_loss_mask,
) = (
batch["chosen_input_ids"],
batch["chosen_attention_mask"],
batch["chosen_loss_mask"],
batch["reject_input_ids"],
batch["reject_attention_mask"],
batch["reject_loss_mask"],
)
batch_size = chosen_input_ids.size()[0]
actor_out = self.model(
input_ids=torch.cat([chosen_input_ids, reject_input_ids]),
attention_mask=torch.cat([chosen_attention_mask, reject_attention_mask]),
labels=torch.cat(
[chosen_input_ids, torch.ones_like(reject_input_ids, dtype=reject_input_ids.dtype) * -100]
),
)
torch.autograd.set_detect_anomaly(True)
actor_all_logits = actor_out["logits"].to(torch.float32)
actor_chosen_logits = actor_all_logits[:batch_size]
actor_reject_logits = actor_all_logits[batch_size:]
logprob_actor_chosen = calc_masked_log_probs(actor_chosen_logits, chosen_input_ids, chosen_loss_mask[:, 1:])
logprob_actor_reject = calc_masked_log_probs(actor_reject_logits, reject_input_ids, reject_loss_mask[:, 1:])
# label_chosen[chosen_loss_mask[:, 1:] == 0] = -100
chosen_nll = actor_out["loss"]
odds_ratio_loss, log_odds_ratio = self.odds_ratio_loss_fn(
logprob_actor_chosen, logprob_actor_reject, chosen_loss_mask[:, 1:], reject_loss_mask[:, 1:]
)
loss = chosen_nll - odds_ratio_loss * self.lam
step_bar.set_description(f"Epoch {epoch + 1}/{self.max_epochs} Loss: {loss.detach().cpu().item():.4f}")
self.booster.backward(loss=loss, optimizer=self.optimizer)
if self.num_train_step % self.accumulation_steps == self.accumulation_steps - 1:
self.optimizer.step()
self.optimizer.zero_grad()
self.actor_scheduler.step()
chosen_rewards = torch.sum(logprob_actor_chosen) / torch.sum(chosen_loss_mask[:, 1:])
rejected_rewards = torch.sum(logprob_actor_reject) / torch.sum(reject_loss_mask[:, 1:])
reward_accuracies = torch.sum((log_odds_ratio > 0).float()) / torch.sum(log_odds_ratio != 0)
# sync
loss_mean = all_reduce_mean(tensor=loss)
chosen_rewards_mean = all_reduce_mean(tensor=chosen_rewards)
rejected_rewards_mean = all_reduce_mean(tensor=rejected_rewards)
reward_accuracies_mean = all_reduce_mean(tensor=reward_accuracies)
self.accumulative_meter.add("chosen_rewards", chosen_rewards_mean.to(torch.float16).mean().item())
self.accumulative_meter.add("rejected_rewards", rejected_rewards_mean.to(torch.float16).mean().item())
self.accumulative_meter.add("loss", loss_mean.to(torch.float16).item())
self.accumulative_meter.add("log_odds_ratio", log_odds_ratio.to(torch.float16).mean().item())
self.accumulative_meter.add("accuracy", reward_accuracies_mean.to(torch.float16).item())
if i % self.accumulation_steps == self.accumulation_steps - 1:
self.num_train_step += 1
step_bar.update()
# logging
if self.writer and is_rank_0():
self.writer.add_scalar("train/loss", self.accumulative_meter.get("loss"), self.num_train_step)
self.writer.add_scalar("train/lr", self.optimizer.param_groups[0]["lr"], self.num_train_step)
self.writer.add_scalar(
"train/chosen_rewards", self.accumulative_meter.get("chosen_rewards"), self.num_train_step
)
self.writer.add_scalar(
"train/rejected_rewards",
self.accumulative_meter.get("rejected_rewards"),
self.num_train_step,
)
self.writer.add_scalar(
"train/margin",
self.accumulative_meter.get("chosen_rewards") - self.accumulative_meter.get("rejected_rewards"),
self.num_train_step,
)
self.writer.add_scalar(
"train/accuracy",
self.accumulative_meter.get("accuracy"),
self.num_train_step,
)
self.writer.add_scalar(
"train/log_odds_ratio",
self.accumulative_meter.get("log_odds_ratio"),
self.num_train_step,
)
self.accumulative_meter.reset()
if self.save_dir is not None and (self.num_train_step + 1) % self.save_interval == 0:
# save checkpoint
self.coordinator.print_on_master("\nStart saving model checkpoint with running states")
save_checkpoint(
save_dir=self.save_dir,
booster=self.booster,
model=self.model,
optimizer=self.optimizer,
lr_scheduler=self.actor_scheduler,
epoch=epoch,
step=i + 1,
batch_size=batch_size,
coordinator=self.coordinator,
)
self.coordinator.print_on_master(
f"Saved checkpoint at epoch {epoch} step {self.save_interval} at folder {self.save_dir}"
)
step_bar.close()
def _eval(self, epoch: int):
"""
Args:
epoch int: the number of current epoch
"""
if self.eval_dataloader is None:
self.coordinator.print_on_master("No eval dataloader is provided, skip evaluation")
return
self.model.eval()
self.coordinator.print_on_master("\nStart evaluation...")
step_bar = trange(
len(self.eval_dataloader),
desc=f"Epoch {epoch + 1}/{self.max_epochs}",
disable=not is_rank_0(),
)
self.accumulative_meter.reset()
with torch.no_grad():
for i, batch in enumerate(self.eval_dataloader):
batch = to_device(batch, self.device)
(
chosen_input_ids,
chosen_attention_mask,
chosen_loss_mask,
reject_input_ids,
reject_attention_mask,
reject_loss_mask,
) = (
batch["chosen_input_ids"],
batch["chosen_attention_mask"],
batch["chosen_loss_mask"],
batch["reject_input_ids"],
batch["reject_attention_mask"],
batch["reject_loss_mask"],
)
batch_size = chosen_input_ids.size()[0]
actor_out = self.model(
input_ids=torch.cat([chosen_input_ids, reject_input_ids]),
attention_mask=torch.cat([chosen_attention_mask, reject_attention_mask]),
labels=torch.cat(
[chosen_input_ids, torch.ones_like(reject_input_ids, dtype=reject_input_ids.dtype) * -100]
),
)
torch.autograd.set_detect_anomaly(True)
actor_all_logits = actor_out["logits"].to(torch.float32)
actor_chosen_logits = actor_all_logits[:batch_size]
actor_reject_logits = actor_all_logits[batch_size:]
logprob_actor_chosen = calc_masked_log_probs(
actor_chosen_logits, chosen_input_ids, chosen_loss_mask[:, 1:]
)
logprob_actor_reject = calc_masked_log_probs(
actor_reject_logits, reject_input_ids, reject_loss_mask[:, 1:]
)
chosen_nll = actor_out["loss"]
odds_ratio_loss, log_odds_ratio = self.odds_ratio_loss_fn(
logprob_actor_chosen, logprob_actor_reject, chosen_loss_mask[:, 1:], reject_loss_mask[:, 1:]
)
loss = chosen_nll - odds_ratio_loss * self.lam
step_bar.set_description(f"Epoch {epoch + 1}/{self.max_epochs} Loss: {loss.detach().cpu().item():.4f}")
chosen_rewards = torch.sum(logprob_actor_chosen) / torch.sum(chosen_loss_mask[:, 1:])
rejected_rewards = torch.sum(logprob_actor_reject) / torch.sum(reject_loss_mask[:, 1:])
reward_accuracies = torch.sum((log_odds_ratio > 0).float()) / torch.sum(log_odds_ratio != 0)
# sync
loss_mean = all_reduce_mean(tensor=loss)
chosen_rewards_mean = all_reduce_mean(tensor=chosen_rewards)
rejected_rewards_mean = all_reduce_mean(tensor=rejected_rewards)
reward_accuracies_mean = all_reduce_mean(tensor=reward_accuracies)
self.accumulative_meter.add("chosen_rewards", chosen_rewards_mean.to(torch.float16).mean().item())
self.accumulative_meter.add("rejected_rewards", rejected_rewards_mean.to(torch.float16).mean().item())
self.accumulative_meter.add("loss", loss_mean.to(torch.float16).item())
self.accumulative_meter.add("log_odds_ratio", log_odds_ratio.to(torch.float16).mean().item())
self.accumulative_meter.add("accuracy", reward_accuracies_mean.to(torch.float16).item())
msg = "Evaluation Result:\n"
for tag in ["loss", "chosen_rewards", "rejected_rewards", "log_odds_ratio", "accuracy"]:
msg = msg + f"{tag}: {self.accumulative_meter.get(tag)}\n"
self.coordinator.print_on_master(msg)
os.makedirs(self.save_dir, exist_ok=True)
with open(os.path.join(self.save_dir, f"eval_result_epoch{epoch}.txt"), "w") as f:
f.write(msg)
step_bar.close()