mirror of https://github.com/hpcaitech/ColossalAI
add orpo
parent
f3de5a025c
commit
c8d1b4a968
|
@ -23,6 +23,8 @@
|
|||
- [Open QA](#open-qa)
|
||||
- [Limitation for LLaMA-finetuned models](#limitation)
|
||||
- [Limitation of dataset](#limitation)
|
||||
- [Alternative Option For RLHF: DPO](#alternative-option-for-rlhf-direct-preference-optimization)
|
||||
- [Alternative Option For RLHF: SimPO](#alternative-option-for-rlhf-simple-preference-optimization)
|
||||
- [FAQ](#faq)
|
||||
- [How to save/load checkpoint](#faq)
|
||||
- [How to train with limited resources](#faq)
|
||||
|
@ -262,12 +264,8 @@ experience buffer size
|
|||
= train_batch_size * accumulation_steps * num_tp_group
|
||||
```
|
||||
|
||||
## Alternative Option For RLHF: Direct Preference Optimization
|
||||
|
||||
For those seeking an alternative to Reinforcement Learning from Human Feedback (RLHF), Direct Preference Optimization (DPO) presents a compelling option. DPO, as detailed in this [paper](https://arxiv.org/abs/2305.18290), DPO offers an low-cost way to perform RLHF and usually request less computation resources compares to PPO.
|
||||
|
||||
## Alternative Option For RLHF: Simple Preference Optimization
|
||||
Simple Preference Optimization (SimPO) from this [paper](https://arxiv.org/pdf/2405.14734) is similar to DPO but it abandons the use of the reference model, which makes the training more efficient. It also adds a reward shaping term called target reward margin to enhance training stability. It also use length normalization to better align with the inference process.
|
||||
## Alternative Option For RLHF: Direct Preference Optimization (DPO)
|
||||
For those seeking an alternative to Reinforcement Learning from Human Feedback (RLHF), Direct Preference Optimization (DPO) presents a compelling option. DPO, as detailed in this [paper](https://arxiv.org/abs/2305.18290), DPO offers an low-cost way to perform RLHF and usually request less computation resources compares to PPO. Read this [README](./examples/README.md) for more information.
|
||||
|
||||
### DPO Training Stage1 - Supervised Instructs Tuning
|
||||
|
||||
|
@ -280,6 +278,12 @@ For DPO training, you only need the preference dataset. Please follow the instru
|
|||
#### Step 2: Training
|
||||
You can run the [train_dpo.sh](./examples/training_scripts/train_dpo.sh) to start DPO training. More detais can be found in [example guideline](./examples/README.md).
|
||||
|
||||
## Alternative Option For RLHF: Simple Preference Optimization (SimPO)
|
||||
Simple Preference Optimization (SimPO) from this [paper](https://arxiv.org/pdf/2405.14734) is similar to DPO but it abandons the use of the reference model, which makes the training more efficient. It also adds a reward shaping term called target reward margin to enhance training stability. It also use length normalization to better align with the inference process. Read this [README](./examples/README.md) for more information.
|
||||
|
||||
## Alternative Option For RLHF: Odds Ratio Preference Optimization (ORPO)
|
||||
Odds Ratio Preference Optimization (ORPO) from this [paper](https://arxiv.org/pdf/2403.07691) is a reference model free alignment method that use a mixture of SFT loss and a reinforcement leanring loss calculated based on odds-ratio-based implicit reward to makes the training more efficient and stable. Read this [README](./examples/README.md) for more information.
|
||||
|
||||
### Inference Quantization and Serving - After Training
|
||||
|
||||
We provide an online inference server and a benchmark. We aim to run inference on single GPU, so quantization is essential when using large models.
|
||||
|
|
|
@ -179,3 +179,28 @@ class LogExpLoss(nn.Module):
|
|||
def forward(self, chosen_reward: torch.Tensor, reject_reward: torch.Tensor) -> torch.Tensor:
|
||||
loss = torch.log(1 + torch.exp(reject_reward - chosen_reward)).mean()
|
||||
return loss
|
||||
|
||||
|
||||
class OddsRatioLoss(nn.Module):
|
||||
"""
|
||||
Odds Ratio Loss in ORPO
|
||||
Details: https://arxiv.org/pdf/2403.07691
|
||||
"""
|
||||
|
||||
def forward(
|
||||
self,
|
||||
chosen_logp: torch.Tensor,
|
||||
reject_logp: torch.Tensor,
|
||||
chosen_loss_mask: torch.Tensor,
|
||||
reject_loss_mask: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
chosen_logp = chosen_logp.to(dtype=torch.float32)
|
||||
reject_logp = reject_logp.to(dtype=torch.float32)
|
||||
chosen_odds = chosen_logp - torch.log(-torch.exp(chosen_logp) + 1.0001)
|
||||
chosen_odds_masked = torch.sum(chosen_odds * chosen_loss_mask.float()) / torch.sum(chosen_loss_mask)
|
||||
reject_odds = reject_logp - torch.log(-torch.exp(reject_logp) + 1.0001)
|
||||
reject_odds_masked = torch.sum(reject_odds * reject_loss_mask.float()) / torch.sum(reject_loss_mask)
|
||||
# print("chosen_odds_masked", chosen_odds_masked[0], "reject_odds_masked", reject_odds_masked[0])
|
||||
log_odds_ratio = chosen_odds_masked - reject_odds_masked
|
||||
ratio = torch.log(torch.nn.functional.sigmoid(log_odds_ratio))
|
||||
return ratio.to(dtype=torch.bfloat16), log_odds_ratio
|
||||
|
|
|
@ -1,7 +1,8 @@
|
|||
from .base import OLTrainer, SLTrainer
|
||||
from .dpo import DPOTrainer
|
||||
from .orpo import ORPOTrainer
|
||||
from .ppo import PPOTrainer
|
||||
from .rm import RewardModelTrainer
|
||||
from .sft import SFTTrainer
|
||||
|
||||
__all__ = ["SLTrainer", "OLTrainer", "RewardModelTrainer", "SFTTrainer", "PPOTrainer", "DPOTrainer"]
|
||||
__all__ = ["SLTrainer", "OLTrainer", "RewardModelTrainer", "SFTTrainer", "PPOTrainer", "DPOTrainer", "ORPOTrainer"]
|
||||
|
|
|
@ -134,7 +134,6 @@ class DPOTrainer(SLTrainer):
|
|||
batch["reject_attention_mask"],
|
||||
batch["reject_loss_mask"],
|
||||
)
|
||||
reject_loss_mask[:, -1] = False
|
||||
batch_size = chosen_input_ids.size()[0]
|
||||
|
||||
actor_all_logits = self.model(
|
||||
|
|
|
@ -0,0 +1,339 @@
|
|||
"""
|
||||
Orpo trainer
|
||||
"""
|
||||
|
||||
from typing import Any, Optional
|
||||
|
||||
import torch
|
||||
from coati.models.loss import OddsRatioLoss
|
||||
from coati.models.utils import calc_masked_log_probs
|
||||
from coati.trainer.utils import all_reduce_mean
|
||||
from coati.utils import AccumulativeMeanMeter, save_checkpoint
|
||||
from torch.nn import CrossEntropyLoss
|
||||
from torch.optim import Optimizer
|
||||
from torch.optim.lr_scheduler import _LRScheduler
|
||||
from torch.utils.data import DataLoader
|
||||
from tqdm import trange
|
||||
from transformers import PreTrainedTokenizerBase
|
||||
|
||||
from colossalai.booster import Booster
|
||||
from colossalai.cluster import DistCoordinator
|
||||
from colossalai.utils import get_current_device
|
||||
|
||||
from .base import SLTrainer
|
||||
from .utils import is_rank_0, to_device
|
||||
|
||||
|
||||
class ORPOTrainer(SLTrainer):
|
||||
"""
|
||||
Trainer for PPO algorithm.
|
||||
|
||||
Args:
|
||||
actor (Actor): the actor model in ppo algorithm
|
||||
booster (Strategy): the strategy to use for training
|
||||
actor_optim (Optimizer): the optimizer to use for actor model
|
||||
actor_lr_scheduler (_LRScheduler): the lr scheduler to use for actor model
|
||||
tokenizer (PreTrainedTokenizerBase): the tokenizer to use for encoding
|
||||
max_epochs (int, defaults to 1): the max number of epochs to train
|
||||
lam (float, defaults to 0.1): the lambda parameter in ORPO loss
|
||||
accumulation_steps (int): the number of steps to accumulate gradients
|
||||
start_epoch (int, defaults to 0): the start epoch, non-zero if resumed from a checkpoint
|
||||
save_interval (int): the interval to save model checkpoints, default to 0, which means no checkpoint will be saved during trainning
|
||||
save_dir (str): the directory to save checkpoints
|
||||
coordinator (DistCoordinator): the coordinator to use for distributed logging
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
actor: Any,
|
||||
booster: Booster,
|
||||
actor_optim: Optimizer,
|
||||
actor_lr_scheduler: _LRScheduler,
|
||||
tokenizer: PreTrainedTokenizerBase,
|
||||
max_epochs: int = 1,
|
||||
lam: float = 0.1,
|
||||
accumulation_steps: int = 1,
|
||||
start_epoch: int = 0,
|
||||
save_interval: int = 0,
|
||||
save_dir: str = None,
|
||||
coordinator: DistCoordinator = None,
|
||||
) -> None:
|
||||
super().__init__(booster, max_epochs=max_epochs, model=actor, optimizer=actor_optim, start_epoch=start_epoch)
|
||||
self.actor_scheduler = actor_lr_scheduler
|
||||
self.tokenizer = tokenizer
|
||||
self.odds_ratio_loss_fn = OddsRatioLoss()
|
||||
self.sft_loss_fn = CrossEntropyLoss()
|
||||
self.save_interval = save_interval
|
||||
self.coordinator = coordinator
|
||||
self.save_dir = save_dir
|
||||
self.num_train_step = 0
|
||||
self.lam = lam
|
||||
self.accumulation_steps = accumulation_steps
|
||||
self.device = get_current_device()
|
||||
self.accumulative_meter = AccumulativeMeanMeter()
|
||||
|
||||
def _before_fit(
|
||||
self,
|
||||
train_preference_dataloader: DataLoader = None,
|
||||
eval_preference_dataloader: DataLoader = None,
|
||||
log_dir: Optional[str] = None,
|
||||
use_wandb: bool = False,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
prompt_dataloader (DataLoader): the dataloader to use for prompt data
|
||||
pretrain_dataloader (DataLoader): the dataloader to use for pretrain data
|
||||
"""
|
||||
self.train_dataloader = train_preference_dataloader
|
||||
self.eval_dataloader = eval_preference_dataloader
|
||||
self.writer = None
|
||||
if use_wandb and is_rank_0():
|
||||
assert log_dir is not None, "log_dir must be provided when use_wandb is True"
|
||||
import wandb
|
||||
|
||||
self.wandb_run = wandb.init(project="Coati-orpo", sync_tensorboard=True)
|
||||
if log_dir is not None and is_rank_0():
|
||||
import os
|
||||
import time
|
||||
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
|
||||
log_dir = os.path.join(log_dir, "orpo")
|
||||
log_dir = os.path.join(log_dir, time.strftime("%Y-%m-%d_%H:%M:%S", time.localtime()))
|
||||
self.writer = SummaryWriter(log_dir=log_dir)
|
||||
|
||||
def _train(self, epoch: int):
|
||||
"""
|
||||
Args:
|
||||
epoch int: the number of current epoch
|
||||
"""
|
||||
self.model.train()
|
||||
self.accumulative_meter.reset()
|
||||
step_bar = trange(
|
||||
len(self.train_dataloader) // self.accumulation_steps,
|
||||
desc=f"Epoch {epoch + 1}/{self.max_epochs}",
|
||||
disable=not is_rank_0(),
|
||||
)
|
||||
for i, batch in enumerate(self.train_dataloader):
|
||||
batch = to_device(batch, self.device)
|
||||
(
|
||||
chosen_input_ids,
|
||||
chosen_attention_mask,
|
||||
chosen_loss_mask,
|
||||
reject_input_ids,
|
||||
reject_attention_mask,
|
||||
reject_loss_mask,
|
||||
) = (
|
||||
batch["chosen_input_ids"],
|
||||
batch["chosen_attention_mask"],
|
||||
batch["chosen_loss_mask"],
|
||||
batch["reject_input_ids"],
|
||||
batch["reject_attention_mask"],
|
||||
batch["reject_loss_mask"],
|
||||
)
|
||||
batch_size = chosen_input_ids.size()[0]
|
||||
actor_out = self.model(
|
||||
input_ids=torch.cat([chosen_input_ids, reject_input_ids]),
|
||||
attention_mask=torch.cat([chosen_attention_mask, reject_attention_mask]),
|
||||
)
|
||||
torch.autograd.set_detect_anomaly(True)
|
||||
actor_all_logits = actor_out["logits"].to(torch.float32)
|
||||
actor_chosen_logits = actor_all_logits[:batch_size]
|
||||
actor_reject_logits = actor_all_logits[batch_size:]
|
||||
logprob_actor_chosen = calc_masked_log_probs(actor_chosen_logits, chosen_input_ids, chosen_loss_mask[:, 1:])
|
||||
|
||||
logprob_actor_reject = calc_masked_log_probs(actor_reject_logits, reject_input_ids, reject_loss_mask[:, 1:])
|
||||
chosen_logits = actor_chosen_logits[:, :-1, :].contiguous().view(-1, actor_chosen_logits.size(-1))
|
||||
label_chosen = chosen_input_ids[:, 1:].contiguous()
|
||||
label_chosen_masked = (
|
||||
label_chosen.masked_fill(chosen_loss_mask[:, 1:] == 0, -100).view(-1).contiguous().detach()
|
||||
)
|
||||
# label_chosen[chosen_loss_mask[:, 1:] == 0] = -100
|
||||
chosen_nll = self.sft_loss_fn(chosen_logits, label_chosen_masked).to(dtype=torch.bfloat16)
|
||||
odds_ratio_loss, log_odds_ratio = self.odds_ratio_loss_fn(
|
||||
logprob_actor_chosen, logprob_actor_reject, chosen_loss_mask[:, 1:], reject_loss_mask[:, 1:]
|
||||
)
|
||||
loss = chosen_nll - odds_ratio_loss * self.lam
|
||||
step_bar.set_description(f"Epoch {epoch + 1}/{self.max_epochs} Loss: {loss.detach().cpu().item():.4f}")
|
||||
|
||||
self.booster.backward(loss=loss, optimizer=self.optimizer)
|
||||
if self.num_train_step % self.accumulation_steps == self.accumulation_steps - 1:
|
||||
self.optimizer.step()
|
||||
self.optimizer.zero_grad()
|
||||
self.actor_scheduler.step()
|
||||
|
||||
chosen_rewards = torch.sum(logprob_actor_chosen) / torch.sum(chosen_loss_mask[:, 1:])
|
||||
rejected_rewards = torch.sum(logprob_actor_reject) / torch.sum(reject_loss_mask[:, 1:])
|
||||
reward_accuracies = torch.sum((log_odds_ratio > 0).float()) / torch.sum(log_odds_ratio != 0)
|
||||
|
||||
# sync
|
||||
loss_mean = all_reduce_mean(tensor=loss)
|
||||
chosen_rewards_mean = all_reduce_mean(tensor=chosen_rewards)
|
||||
rejected_rewards_mean = all_reduce_mean(tensor=rejected_rewards)
|
||||
reward_accuracies_mean = all_reduce_mean(tensor=reward_accuracies)
|
||||
self.accumulative_meter.add("chosen_rewards", chosen_rewards_mean.to(torch.float16).mean().item())
|
||||
self.accumulative_meter.add("rejected_rewards", rejected_rewards_mean.to(torch.float16).mean().item())
|
||||
self.accumulative_meter.add("loss", loss_mean.to(torch.float16).item())
|
||||
self.accumulative_meter.add("log_odds_ratio", log_odds_ratio.to(torch.float16).mean().item())
|
||||
self.accumulative_meter.add("accuracy", reward_accuracies_mean.to(torch.float16).item())
|
||||
|
||||
if i % self.accumulation_steps == self.accumulation_steps - 1:
|
||||
self.num_train_step += 1
|
||||
step_bar.update()
|
||||
# logging
|
||||
if self.writer and is_rank_0():
|
||||
self.writer.add_scalar("train/loss", self.accumulative_meter.get("loss"), self.num_train_step)
|
||||
self.writer.add_scalar("train/lr", self.optimizer.param_groups[0]["lr"], self.num_train_step)
|
||||
self.writer.add_scalar(
|
||||
"train/chosen_rewards", self.accumulative_meter.get("chosen_rewards"), self.num_train_step
|
||||
)
|
||||
self.writer.add_scalar(
|
||||
"train/rejected_rewards",
|
||||
self.accumulative_meter.get("rejected_rewards"),
|
||||
self.num_train_step,
|
||||
)
|
||||
self.writer.add_scalar(
|
||||
"train/margin",
|
||||
self.accumulative_meter.get("chosen_rewards") - self.accumulative_meter.get("rejected_rewards"),
|
||||
self.num_train_step,
|
||||
)
|
||||
self.writer.add_scalar(
|
||||
"train/accuracy",
|
||||
self.accumulative_meter.get("accuracy"),
|
||||
self.num_train_step,
|
||||
)
|
||||
self.writer.add_scalar(
|
||||
"train/log_odds_ratio",
|
||||
self.accumulative_meter.get("log_odds_ratio"),
|
||||
self.num_train_step,
|
||||
)
|
||||
self.accumulative_meter.reset()
|
||||
|
||||
if (self.num_train_step + 1) % self.save_interval == 0:
|
||||
# save checkpoint
|
||||
self.coordinator.print_on_master("\nStart saving model checkpoint with running states")
|
||||
save_checkpoint(
|
||||
save_dir=self.save_dir,
|
||||
booster=self.booster,
|
||||
model=self.model,
|
||||
optimizer=self.optimizer,
|
||||
lr_scheduler=self.actor_scheduler,
|
||||
epoch=epoch,
|
||||
step=i + 1,
|
||||
batch_size=batch_size,
|
||||
coordinator=self.coordinator,
|
||||
)
|
||||
self.coordinator.print_on_master(
|
||||
f"Saved checkpoint at epoch {epoch} step {self.save_interval} at folder {self.save_dir}"
|
||||
)
|
||||
|
||||
step_bar.close()
|
||||
|
||||
def _eval(self, epoch: int):
|
||||
"""
|
||||
Args:
|
||||
epoch int: the number of current epoch
|
||||
"""
|
||||
if self.eval_dataloader is None:
|
||||
self.coordinator.print_on_master("No eval dataloader is provided, skip evaluation")
|
||||
return
|
||||
self.model.eval()
|
||||
self.coordinator.print_on_master("\nStart evaluation...")
|
||||
|
||||
step_bar = trange(
|
||||
len(self.eval_dataloader),
|
||||
desc=f"Epoch {epoch + 1}/{self.max_epochs}",
|
||||
disable=not is_rank_0(),
|
||||
)
|
||||
|
||||
self.accumulative_meter.reset()
|
||||
|
||||
with torch.no_grad():
|
||||
for i, batch in enumerate(self.eval_dataloader):
|
||||
batch = to_device(batch, self.device)
|
||||
(
|
||||
chosen_input_ids,
|
||||
chosen_attention_mask,
|
||||
chosen_loss_mask,
|
||||
reject_input_ids,
|
||||
reject_attention_mask,
|
||||
reject_loss_mask,
|
||||
) = (
|
||||
batch["chosen_input_ids"],
|
||||
batch["chosen_attention_mask"],
|
||||
batch["chosen_loss_mask"],
|
||||
batch["reject_input_ids"],
|
||||
batch["reject_attention_mask"],
|
||||
batch["reject_loss_mask"],
|
||||
)
|
||||
batch_size = chosen_input_ids.size()[0]
|
||||
actor_out = self.model(
|
||||
input_ids=torch.cat([chosen_input_ids, reject_input_ids]),
|
||||
labels=torch.cat([chosen_input_ids, reject_input_ids]),
|
||||
attention_mask=torch.cat([chosen_attention_mask, reject_attention_mask]),
|
||||
)
|
||||
actor_all_logits = actor_out["logits"].to(torch.float32)
|
||||
chosen_nll = torch.mean(actor_out["loss"][:batch_size]).to(dtype=torch.bfloat16)
|
||||
actor_chosen_logits = actor_all_logits[:batch_size]
|
||||
actor_reject_logits = actor_all_logits[batch_size:]
|
||||
logprob_actor_chosen = calc_masked_log_probs(
|
||||
actor_chosen_logits, chosen_input_ids, chosen_loss_mask[:, 1:]
|
||||
)
|
||||
|
||||
logprob_actor_reject = calc_masked_log_probs(
|
||||
actor_reject_logits, reject_input_ids, reject_loss_mask[:, 1:]
|
||||
)
|
||||
|
||||
odds_ratio_loss, log_odds_ratio = self.odds_ratio_loss_fn(logprob_actor_chosen, logprob_actor_reject)
|
||||
|
||||
loss = chosen_nll - odds_ratio_loss * self.lam
|
||||
|
||||
chosen_rewards = torch.mean(logprob_actor_chosen).item()
|
||||
rejected_rewards = torch.mean(logprob_actor_reject).item()
|
||||
reward_accuracies = (log_odds_ratio > 0).float().mean().item()
|
||||
|
||||
# sync
|
||||
loss_mean = all_reduce_mean(tensor=loss)
|
||||
chosen_rewards_mean = all_reduce_mean(tensor=chosen_rewards)
|
||||
rejected_rewards_mean = all_reduce_mean(tensor=rejected_rewards)
|
||||
reward_accuracies_mean = all_reduce_mean(tensor=reward_accuracies)
|
||||
self.accumulative_meter.add("chosen_rewards", chosen_rewards_mean.to(torch.float16).mean().item())
|
||||
self.accumulative_meter.add("rejected_rewards", rejected_rewards_mean.to(torch.float16).mean().item())
|
||||
self.accumulative_meter.add("loss", loss_mean.to(torch.float16).item())
|
||||
self.accumulative_meter.add("log_odds_ratio", log_odds_ratio.to(torch.float16).mean().item())
|
||||
self.accumulative_meter.add("accuracy", reward_accuracies_mean.to(torch.float16).item())
|
||||
|
||||
# logging
|
||||
if self.writer and is_rank_0():
|
||||
self.writer.add_scalar("eval/loss", self.accumulative_meter.get("loss"), self.num_train_step)
|
||||
self.writer.add_scalar("train/lr", self.optimizer.param_groups[0]["lr"], self.num_train_step)
|
||||
self.writer.add_scalar(
|
||||
"train/chosen_rewards", self.accumulative_meter.get("chosen_rewards"), self.num_train_step
|
||||
)
|
||||
self.writer.add_scalar(
|
||||
"train/rejected_rewards",
|
||||
self.accumulative_meter.get("rejected_rewards"),
|
||||
self.num_train_step,
|
||||
)
|
||||
self.writer.add_scalar(
|
||||
"train/log",
|
||||
self.accumulative_meter.get("chosen_rewards") - self.accumulative_meter.get("rejected_rewards"),
|
||||
self.num_train_step,
|
||||
)
|
||||
self.writer.add_scalar(
|
||||
"train/accuracy",
|
||||
self.accumulative_meter.get("accuracy"),
|
||||
self.num_train_step,
|
||||
)
|
||||
self.writer.add_scalar(
|
||||
"train/log_odds_ratio",
|
||||
self.accumulative_meter.get("log_odds_ratio"),
|
||||
self.num_train_step,
|
||||
)
|
||||
self.step_bar.update()
|
||||
|
||||
msg = "Evaluation Result:\n"
|
||||
for tag in ["loss", "chosen_rewards", "rejected_rewards", "log_odds_ratio", "accuracy"]:
|
||||
msg = msg + f"{tag}: {self.accumulative_meter.get(tag)}\n"
|
||||
self.coordinator.print_on_master(msg)
|
||||
step_bar.close()
|
|
@ -735,13 +735,22 @@ You can run the [train_dpo.sh](./examples/training_scripts/train_dpo.sh) to star
|
|||
### Alternative Option For RLHF: Simple Preference Optimization
|
||||
|
||||
We support the method introduced in the paper [SimPO: Simple Preference Optimization
|
||||
with a Reference-Free Reward](https://arxiv.org/pdf/2405.14734) (SimPO). Which is a reference model free aligment method that add length normalization and reward shaping to the DPO loss to enhance training stability and efficiency. As the method doesn't deviate too much from DPO, we add support for length normalization and SimPO reward shaping in our DPO implementation. Simply set the flag to disable the use of the reference model, set the reward target margin and enable length normalization in the DPO training script.
|
||||
with a Reference-Free Reward](https://arxiv.org/pdf/2405.14734) (SimPO). Which is a reference model free aligment method that add length normalization and reward shaping to the DPO loss to enhance training stability and efficiency. As the method doesn't deviate too much from DPO, we add support for length normalization and SimPO reward shaping in our DPO implementation. To use SimPO in alignment, use the [train_dpo.sh](./examples/training_scripts/train_dpo.sh) script, set the `loss_type` to `simpo_loss`, you can also set the value for temperature (`beta`) and reward target margin (`gamma`) but it is optional.
|
||||
|
||||
#### SimPO Result
|
||||
<p align="center">
|
||||
<img width="1000" alt="image" src="https://raw.githubusercontent.com/hpcaitech/public_assets/main/applications/chat/SimPO_margin.png">
|
||||
</p>
|
||||
|
||||
|
||||
### Alternative Option For RLHF: Odds Ratio Preference Optimization
|
||||
We support the method introduced in the paper [ORPO: Monolithic Preference Optimization without Reference Model](https://arxiv.org/abs/2403.07691) (ORPO). Which is a reference model free aligment method that mixes the SFT loss with a reinforcement learning loss that uses odds ratio as the implicit reward to enhance training stability and efficiency. Simply set the flag to disable the use of the reference model, set the reward target margin and enable length normalization in the DPO training script. To use ORPO in alignment, use the [train_orpo.sh](./examples/training_scripts/train_orpo.sh) script, You can set the value for `lambda` (which determine how strongly the reinforcement learning loss affect the training) but it is optional.
|
||||
|
||||
#### ORPO Result
|
||||
<p align="center">
|
||||
<img width="1000" alt="image" src="https://raw.githubusercontent.com/hpcaitech/public_assets/main/applications/chat/ORPO_margin.png">
|
||||
</p>
|
||||
|
||||
## Hardware Requirements
|
||||
For PPO, we suggest using Tensor Parallelism. The following table shows the VRAM consumption of training a 7B model on a dummy dataset with 2048 sequence length and 512 layout length with different tp_size (equal to the number of GPUs). In this experiment, we use an H800 GPU with 80GB VRAM.
|
||||
| PPO | tp=8 | tp=4 |
|
||||
|
|
|
@ -5,7 +5,7 @@ rm -rf $SAVE_DIR/jsonl
|
|||
rm -rf $SAVE_DIR/arrow
|
||||
|
||||
python prepare_dataset.py --type sft \
|
||||
--data_input_dirs /PATH/TO/PREFERENCE/DATASET \
|
||||
--data_input_dirs /PATH/TO/SFT/DATASET \
|
||||
--conversation_template_config /PATH/TO/CHAT/TEMPLATE/CONFIG.json \
|
||||
--tokenizer_dir "" \
|
||||
--data_cache_dir $SAVE_DIR/cache \
|
||||
|
|
|
@ -299,6 +299,7 @@ if __name__ == "__main__":
|
|||
parser.add_argument("--tp", type=int, default=1)
|
||||
parser.add_argument("--pp", type=int, default=1)
|
||||
parser.add_argument("--sp", type=int, default=1)
|
||||
parser.add_argument("--loss_type", type=str, default="dpo_loss", help="do_loss or simpo_loss")
|
||||
parser.add_argument("--beta", type=float, default=0.1, help="beta in DPO loss")
|
||||
parser.add_argument("--gamma", type=float, default=0.0, help="gamma in SimPO loss")
|
||||
parser.add_argument("--length_normalization", default=False, action="store_true")
|
||||
|
@ -341,6 +342,12 @@ if __name__ == "__main__":
|
|||
parser.add_argument("--grad_checkpoint", default=False, action="store_true")
|
||||
parser.add_argument("--use_flash_attn", default=False, action="store_true")
|
||||
args = parser.parse_args()
|
||||
|
||||
# fool proof hyperparameter setup
|
||||
if args.loss_type == "simpo_loss":
|
||||
args.length_normalization = True
|
||||
args.gamma = args.gamma if args.gamma > 0 else 1.4
|
||||
|
||||
os.makedirs(os.path.dirname(args.config_file), exist_ok=True)
|
||||
with open(args.config_file, "w") as f:
|
||||
json.dump(args.__dict__, f, indent=4)
|
||||
|
|
|
@ -14,7 +14,6 @@ set_n_least_used_CUDA_VISIBLE_DEVICES() {
|
|||
echo "CUDA_VISIBLE_DEVICES=$CUDA_VISIBLE_DEVICES"
|
||||
}
|
||||
set_n_least_used_CUDA_VISIBLE_DEVICES 4
|
||||
# export CUDA_VISIBLE_DEVICES=6
|
||||
|
||||
PROJECT_NAME="dpo"
|
||||
PARENT_SAVE_DIR="" # Path to a folder to save checkpoints
|
||||
|
|
|
@ -0,0 +1,326 @@
|
|||
import argparse
|
||||
import json
|
||||
import os
|
||||
import resource
|
||||
from contextlib import nullcontext
|
||||
|
||||
import torch
|
||||
from coati.dataset import DataCollatorForPreferenceDataset, StatefulDistributedSampler, load_tokenized_dataset
|
||||
from coati.models import convert_to_lora_module, disable_dropout
|
||||
from coati.trainer import ORPOTrainer
|
||||
from coati.utils import load_checkpoint
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
|
||||
import colossalai
|
||||
from colossalai.booster import Booster
|
||||
from colossalai.booster.plugin import GeminiPlugin, HybridParallelPlugin, LowLevelZeroPlugin
|
||||
from colossalai.cluster import DistCoordinator
|
||||
from colossalai.logging import get_dist_logger
|
||||
from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR
|
||||
from colossalai.nn.optimizer import HybridAdam
|
||||
|
||||
logger = get_dist_logger()
|
||||
|
||||
|
||||
def train(args):
|
||||
# check lora compatibility
|
||||
if "gemini" in args.plugin and args.lora_rank > 0:
|
||||
raise ValueError("LoRA is not supported in GeminiPlugin. Please use other plugin")
|
||||
if args.plugin == "gemini_auto" and args.accumulation_steps > 1:
|
||||
raise ValueError("Gradient accumulation is not supported in GeminiPlugin. Please use other plugin")
|
||||
|
||||
# ==============================
|
||||
# Initialize Distributed Training
|
||||
# ==============================
|
||||
colossalai.launch_from_torch()
|
||||
coordinator = DistCoordinator()
|
||||
|
||||
# ==============================
|
||||
# Initialize Booster
|
||||
# ==============================
|
||||
if args.plugin == "ddp":
|
||||
"""
|
||||
Default torch ddp plugin without any acceleration, for
|
||||
debugging purpose acceleration, for debugging purpose
|
||||
"""
|
||||
plugin = TorchDDPPlugin(find_unused_parameters=True)
|
||||
elif args.plugin == "gemini":
|
||||
plugin = GeminiPlugin(
|
||||
precision=args.mixed_precision,
|
||||
placement_policy="static",
|
||||
initial_scale=2**16,
|
||||
max_norm=args.grad_clip,
|
||||
enable_gradient_accumulation=True,
|
||||
enable_flash_attention=args.use_flash_attn,
|
||||
)
|
||||
elif args.plugin == "gemini_auto":
|
||||
plugin = GeminiPlugin(
|
||||
precision=args.mixed_precision,
|
||||
placement_policy="auto",
|
||||
initial_scale=2**16,
|
||||
max_norm=args.grad_clip,
|
||||
enable_flash_attention=args.use_flash_attn,
|
||||
)
|
||||
elif args.plugin == "zero2":
|
||||
plugin = LowLevelZeroPlugin(
|
||||
stage=2,
|
||||
precision=args.mixed_precision,
|
||||
initial_scale=2**16,
|
||||
max_norm=args.grad_clip,
|
||||
)
|
||||
elif args.plugin == "zero2_cpu":
|
||||
plugin = LowLevelZeroPlugin(
|
||||
stage=2,
|
||||
precision=args.mixed_precision,
|
||||
initial_scale=2**16,
|
||||
cpu_offload=True,
|
||||
max_norm=args.grad_clip,
|
||||
)
|
||||
elif args.plugin == "3d":
|
||||
plugin = HybridParallelPlugin(
|
||||
tp_size=args.tp,
|
||||
pp_size=args.pp,
|
||||
sp_size=args.sp,
|
||||
sequence_parallelism_mode=args.sp_mode,
|
||||
zero_stage=args.zero_stage,
|
||||
enable_flash_attention=args.use_flash_attn,
|
||||
enable_sequence_parallelism=args.enable_sequence_parallelism,
|
||||
cpu_offload=True if args.zero_stage >= 1 and args.zero_cpu_offload else False,
|
||||
parallel_output=False,
|
||||
max_norm=args.grad_clip,
|
||||
precision=args.mixed_precision,
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unknown plugin {args.plugin}")
|
||||
|
||||
booster = Booster(plugin=plugin)
|
||||
|
||||
# ======================================================
|
||||
# Initialize Model, Objective, Optimizer and LR Scheduler
|
||||
# ======================================================
|
||||
# Temp Fix: Disable lazy init due to version conflict
|
||||
# init_ctx = (
|
||||
# LazyInitContext(default_device=get_current_device()) if isinstance(plugin, (GeminiPlugin,)) else nullcontext()
|
||||
# )
|
||||
|
||||
init_ctx = nullcontext()
|
||||
with init_ctx:
|
||||
if args.use_flash_attn:
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
args.pretrain,
|
||||
torch_dtype=torch.bfloat16 if args.mixed_precision == "bf16" else torch.float16,
|
||||
use_flash_attention_2=True,
|
||||
)
|
||||
coordinator.print_on_master(msg="Flash-attention enabled successfully")
|
||||
else:
|
||||
model = AutoModelForCausalLM.from_pretrained(args.pretrain)
|
||||
disable_dropout(model)
|
||||
if args.lora_rank > 0:
|
||||
model = convert_to_lora_module(model, args.lora_rank, lora_train_bias=args.lora_train_bias)
|
||||
|
||||
if args.grad_checkpoint and args.lora_rank == 0:
|
||||
model.gradient_checkpointing_enable()
|
||||
coordinator.print_on_master(msg="Gradient checkpointing enabled successfully")
|
||||
elif args.lora_rank > 0:
|
||||
coordinator.print_on_master(msg="Gradient checkpointing will be disabled when LoRA is enabled")
|
||||
|
||||
# configure tokenizer
|
||||
tokenizer_dir = args.tokenizer_dir if args.tokenizer_dir is not None else args.pretrain
|
||||
tokenizer = AutoTokenizer.from_pretrained(tokenizer_dir, use_fast=False, trust_remote_code=True)
|
||||
if hasattr(tokenizer, "pad_token") and hasattr(tokenizer, "eos_token") and tokenizer.eos_token is not None:
|
||||
try:
|
||||
# Some tokenizers doesn't allow to set pad_token mannually e.g., Qwen
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
except AttributeError as e:
|
||||
logger.warning(f"Unable to set pad token to eos token, {str(e)}")
|
||||
if not hasattr(tokenizer, "pad_token") or tokenizer.pad_token is None:
|
||||
logger.warning(
|
||||
"The tokenizer does not have a pad token which is required. May lead to unintended behavior in training, Please consider manually set them."
|
||||
)
|
||||
|
||||
tokenizer.add_bos_token = False
|
||||
tokenizer.add_eos_token = False
|
||||
|
||||
# configure optimizer
|
||||
optim = HybridAdam(
|
||||
model_params=model.parameters(),
|
||||
lr=args.lr,
|
||||
betas=(0.9, 0.95),
|
||||
weight_decay=args.weight_decay,
|
||||
adamw_mode=True,
|
||||
)
|
||||
|
||||
# configure dataset
|
||||
coordinator.print_on_master(f"Load dataset: {args.dataset}")
|
||||
mode_map = {"train": "train", "valid": "validation", "test": "test"}
|
||||
train_dataset = load_tokenized_dataset(dataset_paths=args.dataset, mode="train", mode_map=mode_map)
|
||||
data_collator = DataCollatorForPreferenceDataset(tokenizer=tokenizer, max_length=args.max_length)
|
||||
|
||||
train_dataloader = plugin.prepare_dataloader(
|
||||
dataset=train_dataset,
|
||||
batch_size=args.batch_size,
|
||||
shuffle=True,
|
||||
drop_last=True,
|
||||
collate_fn=data_collator,
|
||||
distributed_sampler_cls=StatefulDistributedSampler,
|
||||
)
|
||||
|
||||
num_update_steps_per_epoch = len(train_dataloader) // args.accumulation_steps
|
||||
if args.warmup_steps is None:
|
||||
args.warmup_steps = int(args.max_epochs * 0.025 * (len(train_dataloader) // args.accumulation_steps))
|
||||
coordinator.print_on_master(f"Warmup steps is set to {args.warmup_steps}")
|
||||
|
||||
lr_scheduler = CosineAnnealingWarmupLR(
|
||||
optimizer=optim,
|
||||
total_steps=args.max_epochs * num_update_steps_per_epoch,
|
||||
warmup_steps=args.warmup_steps,
|
||||
eta_min=0.1 * args.lr,
|
||||
)
|
||||
|
||||
default_dtype = torch.float16 if args.mixed_precision == "fp16" else torch.bfloat16
|
||||
torch.set_default_dtype(default_dtype)
|
||||
model, optim, _, train_dataloader, lr_scheduler = booster.boost(
|
||||
model=model,
|
||||
optimizer=optim,
|
||||
lr_scheduler=lr_scheduler,
|
||||
dataloader=train_dataloader,
|
||||
)
|
||||
torch.set_default_dtype(torch.float)
|
||||
|
||||
coordinator.print_on_master(f"Booster init max CUDA memory: {torch.cuda.max_memory_allocated() / 1024 ** 2:.2f} MB")
|
||||
coordinator.print_on_master(
|
||||
f"Booster init max CPU memory: {resource.getrusage(resource.RUSAGE_SELF).ru_maxrss / 1024:.2f} MB"
|
||||
)
|
||||
|
||||
start_epoch = 0
|
||||
sampler_start_idx = 0
|
||||
start_step = 0
|
||||
if args.checkpoint_path is not None:
|
||||
if "modeling" in args.checkpoint_path:
|
||||
coordinator.print_on_master(f"Continued pretrain from checkpoint {args.checkpoint_path}")
|
||||
booster.load_model(model, args.checkpoint_path)
|
||||
else:
|
||||
coordinator.print_on_master(f"Load model checkpoint from {args.checkpoint_path}")
|
||||
start_epoch, start_step, sampler_start_idx = load_checkpoint(
|
||||
load_dir=args.checkpoint_path,
|
||||
booster=booster,
|
||||
model=model,
|
||||
optimizer=optim,
|
||||
lr_scheduler=lr_scheduler,
|
||||
)
|
||||
assert isinstance(train_dataloader.sampler, StatefulDistributedSampler)
|
||||
train_dataloader.sampler.set_start_index(start_index=sampler_start_idx)
|
||||
|
||||
coordinator.print_on_master(
|
||||
f"Loaded checkpoint {args.checkpoint_path} at epoch {start_epoch} step {start_step}"
|
||||
)
|
||||
coordinator.print_on_master(f"Loaded sample at index {sampler_start_idx}")
|
||||
|
||||
coordinator.print_on_master(
|
||||
f"Checkpoint loaded max CUDA memory: {torch.cuda.max_memory_allocated() / 1024 ** 2:.2f} MB"
|
||||
)
|
||||
coordinator.print_on_master(
|
||||
f"Checkpoint loaded CUDA memory: {torch.cuda.memory_allocated() / 1024 ** 2:.2f} MB"
|
||||
)
|
||||
coordinator.print_on_master(
|
||||
f"Checkpoint loaded max CPU memory: {resource.getrusage(resource.RUSAGE_SELF).ru_maxrss / 1024:.2f} MB"
|
||||
)
|
||||
|
||||
trainer = ORPOTrainer(
|
||||
actor=model,
|
||||
booster=booster,
|
||||
actor_optim=optim,
|
||||
actor_lr_scheduler=lr_scheduler,
|
||||
tokenizer=tokenizer,
|
||||
max_epochs=args.max_epochs,
|
||||
accumulation_steps=args.accumulation_steps,
|
||||
start_epoch=start_epoch,
|
||||
save_interval=args.save_interval,
|
||||
save_dir=args.save_dir,
|
||||
coordinator=coordinator,
|
||||
lam=args.lam,
|
||||
)
|
||||
|
||||
trainer.fit(
|
||||
train_preference_dataloader=train_dataloader,
|
||||
eval_preference_dataloader=None,
|
||||
log_dir=args.log_dir,
|
||||
use_wandb=args.use_wandb,
|
||||
)
|
||||
|
||||
if args.lora_rank > 0 and args.merge_lora_weights:
|
||||
from coati.models.lora import LORA_MANAGER
|
||||
|
||||
# NOTE: set model to eval to merge LoRA weights
|
||||
LORA_MANAGER.merge_weights = True
|
||||
model.eval()
|
||||
# save model checkpoint after fitting on only rank0
|
||||
coordinator.print_on_master("Start saving final model checkpoint")
|
||||
booster.save_model(model, os.path.join(args.save_dir, "modeling"), shard=True)
|
||||
coordinator.print_on_master(f"Saved final model checkpoint at epoch {args.max_epochs} at folder {args.save_dir}")
|
||||
|
||||
coordinator.print_on_master(f"Max CUDA memory usage: {torch.cuda.max_memory_allocated()/1024**2:.2f} MB")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# ==============================
|
||||
# Parse Arguments
|
||||
# ==============================
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--plugin",
|
||||
type=str,
|
||||
default="gemini",
|
||||
choices=["gemini", "gemini_auto", "zero2", "zero2_cpu", "3d"],
|
||||
help="Choose which plugin to use",
|
||||
)
|
||||
parser.add_argument("--grad_clip", type=float, default=1.0, help="Gradient clipping value")
|
||||
parser.add_argument("--weight_decay", type=float, default=0.1, help="Weight decay")
|
||||
parser.add_argument("--warmup_steps", type=int, default=None, help="Warmup steps")
|
||||
parser.add_argument("--tp", type=int, default=1)
|
||||
parser.add_argument("--pp", type=int, default=1)
|
||||
parser.add_argument("--sp", type=int, default=1)
|
||||
parser.add_argument("--lam", type=float, default=0.1, help="lambda in ORPO loss")
|
||||
parser.add_argument("--enable_sequence_parallelism", default=False, action="store_true")
|
||||
parser.add_argument("--zero_stage", type=int, default=0, help="Zero stage", choices=[0, 1, 2])
|
||||
parser.add_argument("--zero_cpu_offload", default=False, action="store_true")
|
||||
parser.add_argument("--sp_mode", type=str, default="split_gather", choices=["split_gather", "ring", "all_to_all"])
|
||||
parser.add_argument("--pretrain", type=str, default=None)
|
||||
parser.add_argument("--model_type", type=str, default=None)
|
||||
parser.add_argument("--tokenizer_dir", type=str, default=None)
|
||||
parser.add_argument("--dataset", nargs="+", default=[])
|
||||
parser.add_argument(
|
||||
"--checkpoint_path", type=str, default=None, help="Checkpoint path if need to resume training form a checkpoint"
|
||||
)
|
||||
parser.add_argument("--config_file", type=str, default="config_file", help="Config file")
|
||||
parser.add_argument("--save_dir", type=str, default="output")
|
||||
parser.add_argument("--max_length", type=int, default=2048, help="Model max length")
|
||||
parser.add_argument("--max_epochs", type=int, default=3)
|
||||
parser.add_argument("--batch_size", type=int, default=4)
|
||||
parser.add_argument(
|
||||
"--disable_reference_model",
|
||||
action="store_true",
|
||||
default=False,
|
||||
help="Disable the reference model (enabled by default)",
|
||||
)
|
||||
parser.add_argument("--mixed_precision", type=str, default="fp16", choices=["fp16", "bf16"], help="Mixed precision")
|
||||
parser.add_argument("--lora_rank", type=int, default=0, help="low-rank adaptation matrices rank")
|
||||
parser.add_argument(
|
||||
"--lora_train_bias",
|
||||
type=str,
|
||||
default="none",
|
||||
help="'none' means it doesn't train biases. 'all' means it trains all biases. 'lora_only' means it only trains biases of LoRA layers",
|
||||
)
|
||||
parser.add_argument("--save_interval", type=int, default=1000, help="number of step between two checkpoints")
|
||||
parser.add_argument("--merge_lora_weights", type=bool, default=True)
|
||||
parser.add_argument("--lr", type=float, default=5e-6)
|
||||
parser.add_argument("--accumulation_steps", type=int, default=8)
|
||||
parser.add_argument("--log_dir", default="logs", type=str)
|
||||
parser.add_argument("--use_wandb", default=False, action="store_true")
|
||||
parser.add_argument("--grad_checkpoint", default=False, action="store_true")
|
||||
parser.add_argument("--use_flash_attn", default=False, action="store_true")
|
||||
args = parser.parse_args()
|
||||
os.makedirs(os.path.dirname(args.config_file), exist_ok=True)
|
||||
with open(args.config_file, "w") as f:
|
||||
json.dump(args.__dict__, f, indent=4)
|
||||
train(args)
|
|
@ -0,0 +1,63 @@
|
|||
#!/bin/bash
|
||||
set_n_least_used_CUDA_VISIBLE_DEVICES() {
|
||||
local n=${1:-"9999"}
|
||||
echo "GPU Memory Usage:"
|
||||
local FIRST_N_GPU_IDS=$(nvidia-smi --query-gpu=memory.used --format=csv |
|
||||
tail -n +2 |
|
||||
nl -v 0 |
|
||||
tee /dev/tty |
|
||||
sort -g -k 2 |
|
||||
awk '{print $1}' |
|
||||
head -n $n)
|
||||
export CUDA_VISIBLE_DEVICES=$(echo $FIRST_N_GPU_IDS | sed 's/ /,/g')
|
||||
echo "Now CUDA_VISIBLE_DEVICES is set to:"
|
||||
echo "CUDA_VISIBLE_DEVICES=$CUDA_VISIBLE_DEVICES"
|
||||
}
|
||||
set_n_least_used_CUDA_VISIBLE_DEVICES 8
|
||||
|
||||
PROJECT_NAME="dpo"
|
||||
PARENT_SAVE_DIR="" # Path to a folder to save checkpoints
|
||||
PARENT_TENSORBOARD_DIR="" # Path to a folder to save logs
|
||||
PARENT_CONFIG_FILE="" # Path to a folder to save training config logs
|
||||
PRETRAINED_MODEL_PATH="" # huggingface or local model path
|
||||
PRETRAINED_TOKENIZER_PATH="" # huggingface or local tokenizer path
|
||||
|
||||
declare -a dataset=(
|
||||
/Your/Preference/Data/arrow/part-00000
|
||||
/Your/Preference/Data/arrow/part-00001
|
||||
/Your/Preference/Data/arrow/part-00002
|
||||
/Your/Preference/Data/arrow/part-00003
|
||||
/Your/Preference/Data/arrow/part-00004
|
||||
/Your/Preference/Data/arrow/part-00005
|
||||
/Your/Preference/Data/arrow/part-00006
|
||||
/Your/Preference/Data/arrow/part-00007
|
||||
/Your/Preference/Data/arrow/part-00008
|
||||
/Your/Preference/Data/arrow/part-00009
|
||||
)
|
||||
|
||||
TIMESTAMP=$(date +%Y-%m-%d-%H-%M-%S)
|
||||
FULL_PROJECT_NAME="${PROJECT_NAME}-${TIMESTAMP}"
|
||||
SAVE_DIR="${PARENT_SAVE_DIR}${FULL_PROJECT_NAME}"
|
||||
CONFIG_FILE="${PARENT_CONFIG_FILE}-${FULL_PROJECT_NAME}.json"
|
||||
|
||||
colossalai run --nproc_per_node 8 --hostfile hostfile --master_port 31313 train_orpo.py \
|
||||
--pretrain $PRETRAINED_MODEL_PATH \
|
||||
--checkpoint_path $PRETRAINED_MODEL_PATH \
|
||||
--tokenizer_dir $PRETRAINED_TOKENIZER_PATH \
|
||||
--dataset ${dataset[@]} \
|
||||
--plugin "zero2" \
|
||||
--save_interval 1000 \
|
||||
--save_dir $SAVE_DIR \
|
||||
--config_file $CONFIG_FILE \
|
||||
--max_epochs 3 \
|
||||
--accumulation_steps 1 \
|
||||
--batch_size 16 \
|
||||
--lr 8e-6 \
|
||||
--lam 0.5 \
|
||||
--mixed_precision "bf16" \
|
||||
--grad_clip 1.0 \
|
||||
--max_length 1024 \
|
||||
--weight_decay 0.01 \
|
||||
--warmup_steps 60 \
|
||||
--grad_checkpoint \
|
||||
--use_wandb
|
|
@ -13,8 +13,6 @@ set_n_least_used_CUDA_VISIBLE_DEVICES() {
|
|||
echo "CUDA_VISIBLE_DEVICES=$CUDA_VISIBLE_DEVICES"
|
||||
}
|
||||
|
||||
|
||||
# export CUDA_VISIBLE_DEVICES=4,5,6
|
||||
set_n_least_used_CUDA_VISIBLE_DEVICES 4
|
||||
PROJECT_NAME="sft"
|
||||
PARENT_SAVE_DIR="" # Path to a folder to save checkpoints
|
||||
|
|
Loading…
Reference in New Issue