mirror of https://github.com/hpcaitech/ColossalAI
commit
dd9e1cdafe
|
@ -23,6 +23,8 @@
|
|||
- [Open QA](#open-qa)
|
||||
- [Limitation for LLaMA-finetuned models](#limitation)
|
||||
- [Limitation of dataset](#limitation)
|
||||
- [Alternative Option For RLHF: DPO](#alternative-option-for-rlhf-direct-preference-optimization)
|
||||
- [Alternative Option For RLHF: SimPO](#alternative-option-for-rlhf-simple-preference-optimization)
|
||||
- [FAQ](#faq)
|
||||
- [How to save/load checkpoint](#faq)
|
||||
- [How to train with limited resources](#faq)
|
||||
|
@ -262,9 +264,8 @@ experience buffer size
|
|||
= train_batch_size * accumulation_steps * num_tp_group
|
||||
```
|
||||
|
||||
## Alternative Option For RLHF: Direct Preference Optimization
|
||||
|
||||
For those seeking an alternative to Reinforcement Learning from Human Feedback (RLHF), Direct Preference Optimization (DPO) presents a compelling option. DPO, as detailed in the paper (available at [https://arxiv.org/abs/2305.18290](https://arxiv.org/abs/2305.18290)), DPO offers an low-cost way to perform RLHF and usually request less computation resources compares to PPO.
|
||||
## Alternative Option For RLHF: Direct Preference Optimization (DPO)
|
||||
For those seeking an alternative to Reinforcement Learning from Human Feedback (RLHF), Direct Preference Optimization (DPO) presents a compelling option. DPO, as detailed in this [paper](https://arxiv.org/abs/2305.18290), DPO offers an low-cost way to perform RLHF and usually request less computation resources compares to PPO. Read this [README](./examples/README.md) for more information.
|
||||
|
||||
### DPO Training Stage1 - Supervised Instructs Tuning
|
||||
|
||||
|
@ -277,6 +278,12 @@ For DPO training, you only need the preference dataset. Please follow the instru
|
|||
#### Step 2: Training
|
||||
You can run the [train_dpo.sh](./examples/training_scripts/train_dpo.sh) to start DPO training. More detais can be found in [example guideline](./examples/README.md).
|
||||
|
||||
## Alternative Option For RLHF: Simple Preference Optimization (SimPO)
|
||||
Simple Preference Optimization (SimPO) from this [paper](https://arxiv.org/pdf/2405.14734) is similar to DPO but it abandons the use of the reference model, which makes the training more efficient. It also adds a reward shaping term called target reward margin to enhance training stability. It also use length normalization to better align with the inference process. Read this [README](./examples/README.md) for more information.
|
||||
|
||||
## Alternative Option For RLHF: Odds Ratio Preference Optimization (ORPO)
|
||||
Odds Ratio Preference Optimization (ORPO) from this [paper](https://arxiv.org/pdf/2403.07691) is a reference model free alignment method that use a mixture of SFT loss and a reinforcement leanring loss calculated based on odds-ratio-based implicit reward to makes the training more efficient and stable. Read this [README](./examples/README.md) for more information.
|
||||
|
||||
### Inference Quantization and Serving - After Training
|
||||
|
||||
We provide an online inference server and a benchmark. We aim to run inference on single GPU, so quantization is essential when using large models.
|
||||
|
@ -522,7 +529,7 @@ Coati is developed by ColossalAI Team:
|
|||
- [Fazzie](https://fazzie-key.cool/about/index.html) Contributing to the algorithm and development for SFT.
|
||||
- [ofey404](https://github.com/ofey404) Contributing to both front-end and back-end development.
|
||||
- [Wenhao Chen](https://github.com/CWHer) Contributing to subsequent code enhancements and performance improvements.
|
||||
- [Anbang Ye](https://github.com/YeAnbang) Contributing to the refactored version with updated acceleration framework, LoRA, DPO and PPO.
|
||||
- [Anbang Ye](https://github.com/YeAnbang) Contributing to the refactored PPO version with updated acceleration framework. Add support for DPO, SimPO.
|
||||
|
||||
The PhD student from [(HPC-AI) Lab](https://ai.comp.nus.edu.sg/) also contributed a lot to this project.
|
||||
- [Zangwei Zheng](https://github.com/zhengzangw)
|
||||
|
|
|
@ -0,0 +1,340 @@
|
|||
import argparse
|
||||
import json
|
||||
import os
|
||||
import resource
|
||||
from contextlib import nullcontext
|
||||
|
||||
import torch
|
||||
from coati.dataset import DataCollatorForPreferenceDataset, StatefulDistributedSampler
|
||||
from coati.models import convert_to_lora_module, disable_dropout
|
||||
from coati.trainer import DPOTrainer
|
||||
from coati.utils import load_checkpoint
|
||||
from dummy_dataset import DummyLLMDataset
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
|
||||
import colossalai
|
||||
from colossalai.booster import Booster
|
||||
from colossalai.booster.plugin import GeminiPlugin, HybridParallelPlugin, LowLevelZeroPlugin
|
||||
from colossalai.cluster import DistCoordinator
|
||||
from colossalai.logging import get_dist_logger
|
||||
from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR
|
||||
from colossalai.nn.optimizer import HybridAdam
|
||||
|
||||
logger = get_dist_logger()
|
||||
|
||||
|
||||
def train(args):
|
||||
# check lora compatibility
|
||||
if "gemini" in args.plugin and args.lora_rank > 0:
|
||||
raise ValueError("LoRA is not supported in GeminiPlugin. Please use other plugin")
|
||||
if args.plugin == "gemini_auto" and args.accumulation_steps > 1:
|
||||
raise ValueError("Gradient accumulation is not supported in GeminiPlugin. Please use other plugin")
|
||||
|
||||
# ==============================
|
||||
# Initialize Distributed Training
|
||||
# ==============================
|
||||
colossalai.launch_from_torch()
|
||||
coordinator = DistCoordinator()
|
||||
|
||||
# ==============================
|
||||
# Initialize Booster
|
||||
# ==============================
|
||||
if args.plugin == "ddp":
|
||||
"""
|
||||
Default torch ddp plugin without any acceleration, for
|
||||
debugging purpose acceleration, for debugging purpose
|
||||
"""
|
||||
plugin = TorchDDPPlugin(find_unused_parameters=True)
|
||||
elif args.plugin == "gemini":
|
||||
plugin = GeminiPlugin(
|
||||
precision=args.mixed_precision,
|
||||
placement_policy="static",
|
||||
initial_scale=2**16,
|
||||
max_norm=args.grad_clip,
|
||||
enable_gradient_accumulation=True,
|
||||
enable_flash_attention=args.use_flash_attn,
|
||||
)
|
||||
elif args.plugin == "gemini_auto":
|
||||
plugin = GeminiPlugin(
|
||||
precision=args.mixed_precision,
|
||||
placement_policy="auto",
|
||||
initial_scale=2**16,
|
||||
max_norm=args.grad_clip,
|
||||
enable_flash_attention=args.use_flash_attn,
|
||||
)
|
||||
elif args.plugin == "zero2":
|
||||
plugin = LowLevelZeroPlugin(
|
||||
stage=2,
|
||||
precision=args.mixed_precision,
|
||||
initial_scale=2**16,
|
||||
max_norm=args.grad_clip,
|
||||
)
|
||||
elif args.plugin == "zero2_cpu":
|
||||
plugin = LowLevelZeroPlugin(
|
||||
stage=2,
|
||||
precision=args.mixed_precision,
|
||||
initial_scale=2**16,
|
||||
cpu_offload=True,
|
||||
max_norm=args.grad_clip,
|
||||
)
|
||||
elif args.plugin == "3d":
|
||||
plugin = HybridParallelPlugin(
|
||||
tp_size=args.tp,
|
||||
pp_size=args.pp,
|
||||
sp_size=args.sp,
|
||||
sequence_parallelism_mode=args.sp_mode,
|
||||
zero_stage=args.zero_stage,
|
||||
enable_flash_attention=args.use_flash_attn,
|
||||
enable_sequence_parallelism=args.enable_sequence_parallelism,
|
||||
cpu_offload=True if args.zero_stage >= 1 and args.zero_cpu_offload else False,
|
||||
parallel_output=False,
|
||||
max_norm=args.grad_clip,
|
||||
precision=args.mixed_precision,
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unknown plugin {args.plugin}")
|
||||
|
||||
booster = Booster(plugin=plugin)
|
||||
ref_booster = Booster(plugin=plugin)
|
||||
|
||||
# ======================================================
|
||||
# Initialize Model, Objective, Optimizer and LR Scheduler
|
||||
# ======================================================
|
||||
# Temp Fix: Disable lazy init due to version conflict
|
||||
# init_ctx = (
|
||||
# LazyInitContext(default_device=get_current_device()) if isinstance(plugin, (GeminiPlugin,)) else nullcontext()
|
||||
# )
|
||||
|
||||
init_ctx = nullcontext()
|
||||
with init_ctx:
|
||||
if args.use_flash_attn:
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
args.pretrain,
|
||||
torch_dtype=torch.bfloat16 if args.mixed_precision == "bf16" else torch.float16,
|
||||
use_flash_attention_2=True,
|
||||
)
|
||||
coordinator.print_on_master(msg="Flash-attention enabled successfully")
|
||||
else:
|
||||
model = AutoModelForCausalLM.from_pretrained(args.pretrain)
|
||||
disable_dropout(model)
|
||||
if not args.disable_reference_model:
|
||||
if args.use_flash_attn:
|
||||
ref_model = AutoModelForCausalLM.from_pretrained(
|
||||
args.pretrain,
|
||||
torch_dtype=torch.bfloat16 if args.mixed_precision == "bf16" else torch.float16,
|
||||
use_flash_attention_2=True,
|
||||
)
|
||||
else:
|
||||
ref_model = AutoModelForCausalLM.from_pretrained(args.pretrain)
|
||||
disable_dropout(ref_model)
|
||||
else:
|
||||
ref_model = None
|
||||
if args.lora_rank > 0:
|
||||
model = convert_to_lora_module(model, args.lora_rank, lora_train_bias=args.lora_train_bias)
|
||||
|
||||
if args.grad_checkpoint:
|
||||
# Note, for some models, lora may not be compatible with gradient checkpointing
|
||||
model.gradient_checkpointing_enable()
|
||||
coordinator.print_on_master(msg="Gradient checkpointing enabled successfully")
|
||||
|
||||
# configure tokenizer
|
||||
tokenizer_dir = args.tokenizer_dir if args.tokenizer_dir is not None else args.pretrain
|
||||
tokenizer = AutoTokenizer.from_pretrained(tokenizer_dir, use_fast=False, trust_remote_code=True)
|
||||
if hasattr(tokenizer, "pad_token") and hasattr(tokenizer, "eos_token") and tokenizer.eos_token is not None:
|
||||
try:
|
||||
# Some tokenizers doesn't allow to set pad_token mannually e.g., Qwen
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
except AttributeError as e:
|
||||
logger.warning(f"Unable to set pad token to eos token, {str(e)}")
|
||||
if not hasattr(tokenizer, "pad_token") or tokenizer.pad_token is None:
|
||||
logger.warning(
|
||||
"The tokenizer does not have a pad token which is required. May lead to unintended behavior in training, Please consider manually set them."
|
||||
)
|
||||
|
||||
tokenizer.add_bos_token = False
|
||||
tokenizer.add_eos_token = False
|
||||
|
||||
# configure optimizer
|
||||
optim = HybridAdam(
|
||||
model_params=model.parameters(),
|
||||
lr=args.lr,
|
||||
betas=(0.9, 0.95),
|
||||
weight_decay=args.weight_decay,
|
||||
adamw_mode=True,
|
||||
)
|
||||
|
||||
# configure dataset
|
||||
mode_map = {"train": "train", "valid": "validation", "test": "test"}
|
||||
train_dataset = DummyLLMDataset(
|
||||
["chosen_input_ids", "chosen_loss_mask", "rejected_input_ids", "rejected_loss_mask"],
|
||||
args.max_length,
|
||||
args.dataset_size,
|
||||
)
|
||||
data_collator = DataCollatorForPreferenceDataset(tokenizer=tokenizer, max_length=args.max_length)
|
||||
|
||||
train_dataloader = plugin.prepare_dataloader(
|
||||
dataset=train_dataset,
|
||||
batch_size=args.batch_size,
|
||||
shuffle=True,
|
||||
drop_last=True,
|
||||
collate_fn=data_collator,
|
||||
distributed_sampler_cls=StatefulDistributedSampler,
|
||||
)
|
||||
|
||||
num_update_steps_per_epoch = len(train_dataloader) // args.accumulation_steps
|
||||
if args.warmup_steps is None:
|
||||
args.warmup_steps = int(args.max_epochs * 0.025 * (len(train_dataloader) // args.accumulation_steps))
|
||||
coordinator.print_on_master(f"Warmup steps is set to {args.warmup_steps}")
|
||||
|
||||
lr_scheduler = CosineAnnealingWarmupLR(
|
||||
optimizer=optim,
|
||||
total_steps=args.max_epochs * num_update_steps_per_epoch,
|
||||
warmup_steps=args.warmup_steps,
|
||||
eta_min=0.1 * args.lr,
|
||||
)
|
||||
|
||||
default_dtype = torch.float16 if args.mixed_precision == "fp16" else torch.bfloat16
|
||||
torch.set_default_dtype(default_dtype)
|
||||
model, optim, _, train_dataloader, lr_scheduler = booster.boost(
|
||||
model=model,
|
||||
optimizer=optim,
|
||||
lr_scheduler=lr_scheduler,
|
||||
dataloader=train_dataloader,
|
||||
)
|
||||
if ref_model is not None:
|
||||
ref_model, _, _, _, _ = ref_booster.boost(model=ref_model, dataloader=train_dataloader)
|
||||
torch.set_default_dtype(torch.float)
|
||||
|
||||
coordinator.print_on_master(f"Booster init max CUDA memory: {torch.cuda.max_memory_allocated() / 1024 ** 2:.2f} MB")
|
||||
coordinator.print_on_master(
|
||||
f"Booster init max CPU memory: {resource.getrusage(resource.RUSAGE_SELF).ru_maxrss / 1024:.2f} MB"
|
||||
)
|
||||
|
||||
start_epoch = 0
|
||||
sampler_start_idx = 0
|
||||
start_step = 0
|
||||
if args.checkpoint_path is not None:
|
||||
if "modeling" in args.checkpoint_path:
|
||||
coordinator.print_on_master(f"Continued pretrain from checkpoint {args.checkpoint_path}")
|
||||
booster.load_model(model, args.checkpoint_path)
|
||||
else:
|
||||
coordinator.print_on_master(f"Load model checkpoint from {args.checkpoint_path}")
|
||||
start_epoch, start_step, sampler_start_idx = load_checkpoint(
|
||||
load_dir=args.checkpoint_path,
|
||||
booster=booster,
|
||||
model=model,
|
||||
optimizer=optim,
|
||||
lr_scheduler=lr_scheduler,
|
||||
)
|
||||
assert isinstance(train_dataloader.sampler, StatefulDistributedSampler)
|
||||
train_dataloader.sampler.set_start_index(start_index=sampler_start_idx)
|
||||
|
||||
coordinator.print_on_master(
|
||||
f"Loaded checkpoint {args.checkpoint_path} at epoch {start_epoch} step {start_step}"
|
||||
)
|
||||
coordinator.print_on_master(f"Loaded sample at index {sampler_start_idx}")
|
||||
|
||||
coordinator.print_on_master(
|
||||
f"Checkpoint loaded max CUDA memory: {torch.cuda.max_memory_allocated() / 1024 ** 2:.2f} MB"
|
||||
)
|
||||
coordinator.print_on_master(
|
||||
f"Checkpoint loaded CUDA memory: {torch.cuda.memory_allocated() / 1024 ** 2:.2f} MB"
|
||||
)
|
||||
coordinator.print_on_master(
|
||||
f"Checkpoint loaded max CPU memory: {resource.getrusage(resource.RUSAGE_SELF).ru_maxrss / 1024:.2f} MB"
|
||||
)
|
||||
|
||||
trainer = DPOTrainer(
|
||||
actor=model,
|
||||
ref_model=ref_model,
|
||||
booster=booster,
|
||||
actor_optim=optim,
|
||||
actor_lr_scheduler=lr_scheduler,
|
||||
tokenizer=tokenizer,
|
||||
max_epochs=args.max_epochs,
|
||||
accumulation_steps=args.accumulation_steps,
|
||||
start_epoch=start_epoch,
|
||||
save_interval=None,
|
||||
save_dir=None,
|
||||
coordinator=coordinator,
|
||||
beta=args.beta,
|
||||
gamma=args.gamma,
|
||||
length_normalization=args.length_normalization,
|
||||
)
|
||||
|
||||
trainer.fit(
|
||||
train_preference_dataloader=train_dataloader,
|
||||
eval_preference_dataloader=None,
|
||||
log_dir=None,
|
||||
use_wandb=False,
|
||||
)
|
||||
coordinator.print_on_master(f"Max CUDA memory usage: {torch.cuda.max_memory_allocated()/1024**2:.2f} MB")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# ==============================
|
||||
# Parse Arguments
|
||||
# ==============================
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--plugin",
|
||||
type=str,
|
||||
default="gemini",
|
||||
choices=["gemini", "gemini_auto", "zero2", "zero2_cpu", "3d"],
|
||||
help="Choose which plugin to use",
|
||||
)
|
||||
parser.add_argument("--grad_clip", type=float, default=1.0, help="Gradient clipping value")
|
||||
parser.add_argument("--weight_decay", type=float, default=0.1, help="Weight decay")
|
||||
parser.add_argument("--warmup_steps", type=int, default=None, help="Warmup steps")
|
||||
parser.add_argument("--tp", type=int, default=1)
|
||||
parser.add_argument("--pp", type=int, default=1)
|
||||
parser.add_argument("--sp", type=int, default=1)
|
||||
parser.add_argument("--loss_type", type=str, default="dpo_loss", help="dpo_loss or simpo_loss")
|
||||
parser.add_argument("--beta", type=float, default=0.1, help="beta in DPO loss")
|
||||
parser.add_argument("--gamma", type=float, default=0.0, help="gamma in SimPO loss")
|
||||
parser.add_argument("--length_normalization", default=False, action="store_true")
|
||||
parser.add_argument("--enable_sequence_parallelism", default=False, action="store_true")
|
||||
parser.add_argument("--zero_stage", type=int, default=0, help="Zero stage", choices=[0, 1, 2])
|
||||
parser.add_argument("--zero_cpu_offload", default=False, action="store_true")
|
||||
parser.add_argument("--sp_mode", type=str, default="split_gather", choices=["split_gather", "ring", "all_to_all"])
|
||||
parser.add_argument("--pretrain", type=str, default=None)
|
||||
parser.add_argument("--model_type", type=str, default=None)
|
||||
parser.add_argument("--tokenizer_dir", type=str, default=None)
|
||||
parser.add_argument(
|
||||
"--checkpoint_path", type=str, default=None, help="Checkpoint path if need to resume training form a checkpoint"
|
||||
)
|
||||
parser.add_argument("--config_file", type=str, default="config_file", help="Config file")
|
||||
parser.add_argument("--max_length", type=int, default=2048, help="Model max length")
|
||||
parser.add_argument("--max_epochs", type=int, default=3)
|
||||
parser.add_argument("--batch_size", type=int, default=4)
|
||||
parser.add_argument("--dataset_size", type=int, default=500)
|
||||
parser.add_argument(
|
||||
"--disable_reference_model",
|
||||
action="store_true",
|
||||
default=False,
|
||||
help="Disable the reference model (enabled by default)",
|
||||
)
|
||||
parser.add_argument("--mixed_precision", type=str, default="fp16", choices=["fp16", "bf16"], help="Mixed precision")
|
||||
parser.add_argument("--lora_rank", type=int, default=0, help="low-rank adaptation matrices rank")
|
||||
parser.add_argument(
|
||||
"--lora_train_bias",
|
||||
type=str,
|
||||
default="none",
|
||||
help="'none' means it doesn't train biases. 'all' means it trains all biases. 'lora_only' means it only trains biases of LoRA layers",
|
||||
)
|
||||
parser.add_argument("--merge_lora_weights", type=bool, default=True)
|
||||
parser.add_argument("--lr", type=float, default=5e-6)
|
||||
parser.add_argument("--accumulation_steps", type=int, default=8)
|
||||
parser.add_argument("--grad_checkpoint", default=False, action="store_true")
|
||||
parser.add_argument("--use_flash_attn", default=False, action="store_true")
|
||||
args = parser.parse_args()
|
||||
|
||||
# fool proof hyperparameter setup
|
||||
if args.loss_type == "simpo_loss":
|
||||
args.length_normalization = True
|
||||
args.gamma = args.gamma if args.gamma > 0 else 1.4
|
||||
|
||||
os.makedirs(os.path.dirname(args.config_file), exist_ok=True)
|
||||
with open(args.config_file, "w") as f:
|
||||
json.dump(args.__dict__, f, indent=4)
|
||||
train(args)
|
|
@ -0,0 +1,48 @@
|
|||
#!/bin/bash
|
||||
set_n_least_used_CUDA_VISIBLE_DEVICES() {
|
||||
local n=${1:-"9999"}
|
||||
echo "GPU Memory Usage:"
|
||||
local FIRST_N_GPU_IDS=$(nvidia-smi --query-gpu=memory.used --format=csv |
|
||||
tail -n +2 |
|
||||
nl -v 0 |
|
||||
tee /dev/tty |
|
||||
sort -g -k 2 |
|
||||
awk '{print $1}' |
|
||||
head -n $n)
|
||||
export CUDA_VISIBLE_DEVICES=$(echo $FIRST_N_GPU_IDS | sed 's/ /,/g')
|
||||
echo "Now CUDA_VISIBLE_DEVICES is set to:"
|
||||
echo "CUDA_VISIBLE_DEVICES=$CUDA_VISIBLE_DEVICES"
|
||||
}
|
||||
set_n_least_used_CUDA_VISIBLE_DEVICES 4
|
||||
|
||||
PROJECT_NAME="dpo"
|
||||
PARENT_CONFIG_FILE="./benchmark_config" # Path to a folder to save training config logs
|
||||
PRETRAINED_MODEL_PATH="" # huggingface or local model path
|
||||
PRETRAINED_TOKENIZER_PATH="" # huggingface or local tokenizer path
|
||||
|
||||
TIMESTAMP=$(date +%Y-%m-%d-%H-%M-%S)
|
||||
FULL_PROJECT_NAME="${PROJECT_NAME}-${TIMESTAMP}"
|
||||
SAVE_DIR="${PARENT_SAVE_DIR}${FULL_PROJECT_NAME}"
|
||||
CONFIG_FILE="${PARENT_CONFIG_FILE}-${FULL_PROJECT_NAME}.json"
|
||||
|
||||
colossalai run --nproc_per_node 4 --master_port 31313 benchmark_dpo.py \
|
||||
--pretrain $PRETRAINED_MODEL_PATH \
|
||||
--tokenizer_dir $PRETRAINED_TOKENIZER_PATH \
|
||||
--config_file $CONFIG_FILE \
|
||||
--plugin "zero2_cpu" \
|
||||
--max_epochs 1 \
|
||||
--accumulation_steps 1 \
|
||||
--batch_size 8 \
|
||||
--lr 1e-6 \
|
||||
--beta 0.1 \
|
||||
--gamma 0.6 \
|
||||
--mixed_precision "bf16" \
|
||||
--grad_clip 1.0 \
|
||||
--max_length 2048 \
|
||||
--dataset_size 640 \
|
||||
--weight_decay 0.01 \
|
||||
--warmup_steps 60 \
|
||||
--disable_reference_model \
|
||||
--length_normalization \
|
||||
--grad_checkpoint \
|
||||
--use_flash_attn
|
|
@ -0,0 +1,315 @@
|
|||
import argparse
|
||||
import json
|
||||
import os
|
||||
import resource
|
||||
from contextlib import nullcontext
|
||||
|
||||
import torch
|
||||
from coati.dataset import DataCollatorForPreferenceDataset, StatefulDistributedSampler
|
||||
from coati.models import convert_to_lora_module, disable_dropout
|
||||
from coati.trainer import ORPOTrainer
|
||||
from coati.utils import load_checkpoint
|
||||
from dummy_dataset import DummyLLMDataset
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
|
||||
import colossalai
|
||||
from colossalai.booster import Booster
|
||||
from colossalai.booster.plugin import GeminiPlugin, HybridParallelPlugin, LowLevelZeroPlugin
|
||||
from colossalai.cluster import DistCoordinator
|
||||
from colossalai.logging import get_dist_logger
|
||||
from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR
|
||||
from colossalai.nn.optimizer import HybridAdam
|
||||
|
||||
logger = get_dist_logger()
|
||||
|
||||
|
||||
def train(args):
|
||||
# check lora compatibility
|
||||
if "gemini" in args.plugin and args.lora_rank > 0:
|
||||
raise ValueError("LoRA is not supported in GeminiPlugin. Please use other plugin")
|
||||
if args.plugin == "gemini_auto" and args.accumulation_steps > 1:
|
||||
raise ValueError("Gradient accumulation is not supported in GeminiPlugin. Please use other plugin")
|
||||
|
||||
# ==============================
|
||||
# Initialize Distributed Training
|
||||
# ==============================
|
||||
colossalai.launch_from_torch()
|
||||
coordinator = DistCoordinator()
|
||||
|
||||
# ==============================
|
||||
# Initialize Booster
|
||||
# ==============================
|
||||
if args.plugin == "ddp":
|
||||
"""
|
||||
Default torch ddp plugin without any acceleration, for
|
||||
debugging purpose acceleration, for debugging purpose
|
||||
"""
|
||||
plugin = TorchDDPPlugin(find_unused_parameters=True)
|
||||
elif args.plugin == "gemini":
|
||||
plugin = GeminiPlugin(
|
||||
precision=args.mixed_precision,
|
||||
placement_policy="static",
|
||||
initial_scale=2**16,
|
||||
max_norm=args.grad_clip,
|
||||
enable_gradient_accumulation=True,
|
||||
enable_flash_attention=args.use_flash_attn,
|
||||
)
|
||||
elif args.plugin == "gemini_auto":
|
||||
plugin = GeminiPlugin(
|
||||
precision=args.mixed_precision,
|
||||
placement_policy="auto",
|
||||
initial_scale=2**16,
|
||||
max_norm=args.grad_clip,
|
||||
enable_flash_attention=args.use_flash_attn,
|
||||
)
|
||||
elif args.plugin == "zero2":
|
||||
plugin = LowLevelZeroPlugin(
|
||||
stage=2,
|
||||
precision=args.mixed_precision,
|
||||
initial_scale=2**16,
|
||||
max_norm=args.grad_clip,
|
||||
)
|
||||
elif args.plugin == "zero2_cpu":
|
||||
plugin = LowLevelZeroPlugin(
|
||||
stage=2,
|
||||
precision=args.mixed_precision,
|
||||
initial_scale=2**16,
|
||||
cpu_offload=True,
|
||||
max_norm=args.grad_clip,
|
||||
)
|
||||
elif args.plugin == "3d":
|
||||
plugin = HybridParallelPlugin(
|
||||
tp_size=args.tp,
|
||||
pp_size=args.pp,
|
||||
sp_size=args.sp,
|
||||
sequence_parallelism_mode=args.sp_mode,
|
||||
zero_stage=args.zero_stage,
|
||||
enable_flash_attention=args.use_flash_attn,
|
||||
enable_sequence_parallelism=args.enable_sequence_parallelism,
|
||||
cpu_offload=True if args.zero_stage >= 1 and args.zero_cpu_offload else False,
|
||||
parallel_output=False,
|
||||
max_norm=args.grad_clip,
|
||||
precision=args.mixed_precision,
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unknown plugin {args.plugin}")
|
||||
|
||||
booster = Booster(plugin=plugin)
|
||||
|
||||
# ======================================================
|
||||
# Initialize Model, Objective, Optimizer and LR Scheduler
|
||||
# ======================================================
|
||||
# Temp Fix: Disable lazy init due to version conflict
|
||||
# init_ctx = (
|
||||
# LazyInitContext(default_device=get_current_device()) if isinstance(plugin, (GeminiPlugin,)) else nullcontext()
|
||||
# )
|
||||
|
||||
init_ctx = nullcontext()
|
||||
with init_ctx:
|
||||
if args.use_flash_attn:
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
args.pretrain,
|
||||
torch_dtype=torch.bfloat16 if args.mixed_precision == "bf16" else torch.float16,
|
||||
use_flash_attention_2=True,
|
||||
)
|
||||
coordinator.print_on_master(msg="Flash-attention enabled successfully")
|
||||
else:
|
||||
model = AutoModelForCausalLM.from_pretrained(args.pretrain)
|
||||
disable_dropout(model)
|
||||
if args.lora_rank > 0:
|
||||
model = convert_to_lora_module(model, args.lora_rank, lora_train_bias=args.lora_train_bias)
|
||||
|
||||
if args.grad_checkpoint:
|
||||
# Note, for some models, lora may not be compatible with gradient checkpointing
|
||||
model.gradient_checkpointing_enable()
|
||||
coordinator.print_on_master(msg="Gradient checkpointing enabled successfully")
|
||||
|
||||
# configure tokenizer
|
||||
tokenizer_dir = args.tokenizer_dir if args.tokenizer_dir is not None else args.pretrain
|
||||
tokenizer = AutoTokenizer.from_pretrained(tokenizer_dir, use_fast=False, trust_remote_code=True)
|
||||
if hasattr(tokenizer, "pad_token") and hasattr(tokenizer, "eos_token") and tokenizer.eos_token is not None:
|
||||
try:
|
||||
# Some tokenizers doesn't allow to set pad_token mannually e.g., Qwen
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
except AttributeError as e:
|
||||
logger.warning(f"Unable to set pad token to eos token, {str(e)}")
|
||||
if not hasattr(tokenizer, "pad_token") or tokenizer.pad_token is None:
|
||||
logger.warning(
|
||||
"The tokenizer does not have a pad token which is required. May lead to unintended behavior in training, Please consider manually set them."
|
||||
)
|
||||
|
||||
tokenizer.add_bos_token = False
|
||||
tokenizer.add_eos_token = False
|
||||
|
||||
# configure optimizer
|
||||
optim = HybridAdam(
|
||||
model_params=model.parameters(),
|
||||
lr=args.lr,
|
||||
betas=(0.9, 0.95),
|
||||
weight_decay=args.weight_decay,
|
||||
adamw_mode=True,
|
||||
)
|
||||
|
||||
# configure dataset
|
||||
coordinator.print_on_master(f"Load dataset: {args.dataset}")
|
||||
mode_map = {"train": "train", "valid": "validation", "test": "test"}
|
||||
train_dataset = DummyLLMDataset(
|
||||
["chosen_input_ids", "chosen_loss_mask", "rejected_input_ids", "rejected_loss_mask"],
|
||||
args.max_length,
|
||||
args.dataset_size,
|
||||
)
|
||||
data_collator = DataCollatorForPreferenceDataset(tokenizer=tokenizer, max_length=args.max_length)
|
||||
|
||||
train_dataloader = plugin.prepare_dataloader(
|
||||
dataset=train_dataset,
|
||||
batch_size=args.batch_size,
|
||||
shuffle=True,
|
||||
drop_last=True,
|
||||
collate_fn=data_collator,
|
||||
distributed_sampler_cls=StatefulDistributedSampler,
|
||||
)
|
||||
|
||||
num_update_steps_per_epoch = len(train_dataloader) // args.accumulation_steps
|
||||
if args.warmup_steps is None:
|
||||
args.warmup_steps = int(args.max_epochs * 0.025 * (len(train_dataloader) // args.accumulation_steps))
|
||||
coordinator.print_on_master(f"Warmup steps is set to {args.warmup_steps}")
|
||||
|
||||
lr_scheduler = CosineAnnealingWarmupLR(
|
||||
optimizer=optim,
|
||||
total_steps=args.max_epochs * num_update_steps_per_epoch,
|
||||
warmup_steps=args.warmup_steps,
|
||||
eta_min=0.1 * args.lr,
|
||||
)
|
||||
|
||||
default_dtype = torch.float16 if args.mixed_precision == "fp16" else torch.bfloat16
|
||||
torch.set_default_dtype(default_dtype)
|
||||
model, optim, _, train_dataloader, lr_scheduler = booster.boost(
|
||||
model=model,
|
||||
optimizer=optim,
|
||||
lr_scheduler=lr_scheduler,
|
||||
dataloader=train_dataloader,
|
||||
)
|
||||
torch.set_default_dtype(torch.float)
|
||||
|
||||
coordinator.print_on_master(f"Booster init max CUDA memory: {torch.cuda.max_memory_allocated() / 1024 ** 2:.2f} MB")
|
||||
coordinator.print_on_master(
|
||||
f"Booster init max CPU memory: {resource.getrusage(resource.RUSAGE_SELF).ru_maxrss / 1024:.2f} MB"
|
||||
)
|
||||
|
||||
start_epoch = 0
|
||||
sampler_start_idx = 0
|
||||
start_step = 0
|
||||
if args.checkpoint_path is not None:
|
||||
if "modeling" in args.checkpoint_path:
|
||||
coordinator.print_on_master(f"Continued pretrain from checkpoint {args.checkpoint_path}")
|
||||
booster.load_model(model, args.checkpoint_path)
|
||||
else:
|
||||
coordinator.print_on_master(f"Load model checkpoint from {args.checkpoint_path}")
|
||||
start_epoch, start_step, sampler_start_idx = load_checkpoint(
|
||||
load_dir=args.checkpoint_path,
|
||||
booster=booster,
|
||||
model=model,
|
||||
optimizer=optim,
|
||||
lr_scheduler=lr_scheduler,
|
||||
)
|
||||
assert isinstance(train_dataloader.sampler, StatefulDistributedSampler)
|
||||
train_dataloader.sampler.set_start_index(start_index=sampler_start_idx)
|
||||
|
||||
coordinator.print_on_master(
|
||||
f"Loaded checkpoint {args.checkpoint_path} at epoch {start_epoch} step {start_step}"
|
||||
)
|
||||
coordinator.print_on_master(f"Loaded sample at index {sampler_start_idx}")
|
||||
|
||||
coordinator.print_on_master(
|
||||
f"Checkpoint loaded max CUDA memory: {torch.cuda.max_memory_allocated() / 1024 ** 2:.2f} MB"
|
||||
)
|
||||
coordinator.print_on_master(
|
||||
f"Checkpoint loaded CUDA memory: {torch.cuda.memory_allocated() / 1024 ** 2:.2f} MB"
|
||||
)
|
||||
coordinator.print_on_master(
|
||||
f"Checkpoint loaded max CPU memory: {resource.getrusage(resource.RUSAGE_SELF).ru_maxrss / 1024:.2f} MB"
|
||||
)
|
||||
|
||||
trainer = ORPOTrainer(
|
||||
actor=model,
|
||||
booster=booster,
|
||||
actor_optim=optim,
|
||||
actor_lr_scheduler=lr_scheduler,
|
||||
tokenizer=tokenizer,
|
||||
max_epochs=args.max_epochs,
|
||||
accumulation_steps=args.accumulation_steps,
|
||||
start_epoch=start_epoch,
|
||||
save_interval=None,
|
||||
save_dir=None,
|
||||
coordinator=coordinator,
|
||||
lam=args.lam,
|
||||
)
|
||||
|
||||
trainer.fit(
|
||||
train_preference_dataloader=train_dataloader,
|
||||
eval_preference_dataloader=None,
|
||||
log_dir=None,
|
||||
use_wandb=False,
|
||||
)
|
||||
coordinator.print_on_master(f"Max CUDA memory usage: {torch.cuda.max_memory_allocated()/1024**2:.2f} MB")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# ==============================
|
||||
# Parse Arguments
|
||||
# ==============================
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--plugin",
|
||||
type=str,
|
||||
default="gemini",
|
||||
choices=["gemini", "gemini_auto", "zero2", "zero2_cpu", "3d"],
|
||||
help="Choose which plugin to use",
|
||||
)
|
||||
parser.add_argument("--grad_clip", type=float, default=1.0, help="Gradient clipping value")
|
||||
parser.add_argument("--weight_decay", type=float, default=0.1, help="Weight decay")
|
||||
parser.add_argument("--warmup_steps", type=int, default=None, help="Warmup steps")
|
||||
parser.add_argument("--tp", type=int, default=1)
|
||||
parser.add_argument("--pp", type=int, default=1)
|
||||
parser.add_argument("--sp", type=int, default=1)
|
||||
parser.add_argument("--lam", type=float, default=0.1, help="lambda in ORPO loss")
|
||||
parser.add_argument("--enable_sequence_parallelism", default=False, action="store_true")
|
||||
parser.add_argument("--zero_stage", type=int, default=0, help="Zero stage", choices=[0, 1, 2])
|
||||
parser.add_argument("--zero_cpu_offload", default=False, action="store_true")
|
||||
parser.add_argument("--sp_mode", type=str, default="split_gather", choices=["split_gather", "ring", "all_to_all"])
|
||||
parser.add_argument("--pretrain", type=str, default=None)
|
||||
parser.add_argument("--model_type", type=str, default=None)
|
||||
parser.add_argument("--tokenizer_dir", type=str, default=None)
|
||||
parser.add_argument("--dataset", nargs="+", default=[])
|
||||
parser.add_argument(
|
||||
"--checkpoint_path", type=str, default=None, help="Checkpoint path if need to resume training form a checkpoint"
|
||||
)
|
||||
parser.add_argument("--config_file", type=str, default="config_file", help="Config file")
|
||||
parser.add_argument("--max_length", type=int, default=2048, help="Model max length")
|
||||
parser.add_argument("--max_epochs", type=int, default=3)
|
||||
parser.add_argument("--batch_size", type=int, default=4)
|
||||
parser.add_argument(
|
||||
"--disable_reference_model",
|
||||
action="store_true",
|
||||
default=False,
|
||||
help="Disable the reference model (enabled by default)",
|
||||
)
|
||||
parser.add_argument("--dataset_size", type=int, default=500)
|
||||
parser.add_argument("--mixed_precision", type=str, default="fp16", choices=["fp16", "bf16"], help="Mixed precision")
|
||||
parser.add_argument("--lora_rank", type=int, default=0, help="low-rank adaptation matrices rank")
|
||||
parser.add_argument(
|
||||
"--lora_train_bias",
|
||||
type=str,
|
||||
default="none",
|
||||
help="'none' means it doesn't train biases. 'all' means it trains all biases. 'lora_only' means it only trains biases of LoRA layers",
|
||||
)
|
||||
parser.add_argument("--merge_lora_weights", type=bool, default=True)
|
||||
parser.add_argument("--lr", type=float, default=5e-6)
|
||||
parser.add_argument("--accumulation_steps", type=int, default=8)
|
||||
parser.add_argument("--grad_checkpoint", default=False, action="store_true")
|
||||
parser.add_argument("--use_flash_attn", default=False, action="store_true")
|
||||
args = parser.parse_args()
|
||||
os.makedirs(os.path.dirname(args.config_file), exist_ok=True)
|
||||
with open(args.config_file, "w") as f:
|
||||
json.dump(args.__dict__, f, indent=4)
|
||||
train(args)
|
|
@ -0,0 +1,44 @@
|
|||
#!/bin/bash
|
||||
set_n_least_used_CUDA_VISIBLE_DEVICES() {
|
||||
local n=${1:-"9999"}
|
||||
echo "GPU Memory Usage:"
|
||||
local FIRST_N_GPU_IDS=$(nvidia-smi --query-gpu=memory.used --format=csv |
|
||||
tail -n +2 |
|
||||
nl -v 0 |
|
||||
tee /dev/tty |
|
||||
sort -g -k 2 |
|
||||
awk '{print $1}' |
|
||||
head -n $n)
|
||||
export CUDA_VISIBLE_DEVICES=$(echo $FIRST_N_GPU_IDS | sed 's/ /,/g')
|
||||
echo "Now CUDA_VISIBLE_DEVICES is set to:"
|
||||
echo "CUDA_VISIBLE_DEVICES=$CUDA_VISIBLE_DEVICES"
|
||||
}
|
||||
set_n_least_used_CUDA_VISIBLE_DEVICES 2
|
||||
|
||||
PROJECT_NAME="dpo"
|
||||
PARENT_CONFIG_FILE="./benchmark_config" # Path to a folder to save training config logs
|
||||
PRETRAINED_MODEL_PATH="" # huggingface or local model path
|
||||
PRETRAINED_TOKENIZER_PATH="" # huggingface or local tokenizer path
|
||||
|
||||
TIMESTAMP=$(date +%Y-%m-%d-%H-%M-%S)
|
||||
FULL_PROJECT_NAME="${PROJECT_NAME}-${TIMESTAMP}"
|
||||
CONFIG_FILE="${PARENT_CONFIG_FILE}-${FULL_PROJECT_NAME}.json"
|
||||
|
||||
colossalai run --nproc_per_node 2 --master_port 31313 benchmark_orpo.py \
|
||||
--pretrain $PRETRAINED_MODEL_PATH \
|
||||
--tokenizer_dir $PRETRAINED_TOKENIZER_PATH \
|
||||
--plugin "zero2" \
|
||||
--config_file $CONFIG_FILE \
|
||||
--max_epochs 1 \
|
||||
--accumulation_steps 1 \
|
||||
--batch_size 4 \
|
||||
--lr 8e-6 \
|
||||
--lam 0.5 \
|
||||
--mixed_precision "bf16" \
|
||||
--grad_clip 1.0 \
|
||||
--max_length 2048 \
|
||||
--weight_decay 0.01 \
|
||||
--warmup_steps 60 \
|
||||
--dataset_size 160 \
|
||||
--grad_checkpoint \
|
||||
--use_flash_attn
|
|
@ -0,0 +1,315 @@
|
|||
import argparse
|
||||
import json
|
||||
import math
|
||||
import os
|
||||
import resource
|
||||
from contextlib import nullcontext
|
||||
|
||||
import torch
|
||||
from coati.dataset import DataCollatorForSupervisedDataset, StatefulDistributedSampler
|
||||
from coati.models import convert_to_lora_module
|
||||
from coati.trainer import SFTTrainer
|
||||
from coati.utils import load_checkpoint
|
||||
from dummy_dataset import DummyLLMDataset
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
|
||||
import colossalai
|
||||
from colossalai.booster import Booster
|
||||
from colossalai.booster.plugin import GeminiPlugin, HybridParallelPlugin, LowLevelZeroPlugin, TorchDDPPlugin
|
||||
from colossalai.cluster import DistCoordinator
|
||||
from colossalai.logging import get_dist_logger
|
||||
from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR
|
||||
from colossalai.nn.optimizer import HybridAdam
|
||||
|
||||
logger = get_dist_logger()
|
||||
|
||||
|
||||
def train(args):
|
||||
# check lora compatibility
|
||||
if "gemini" in args.plugin and args.lora_rank > 0:
|
||||
raise ValueError("LoRA is not supported in GeminiPlugin. Please use other plugin")
|
||||
if args.plugin == "gemini_auto" and args.accumulation_steps > 1:
|
||||
raise ValueError("Gradient accumulation is not supported in GeminiPlugin. Please use other plugin")
|
||||
# ==============================
|
||||
# Initialize Distributed Training
|
||||
# ==============================
|
||||
colossalai.launch_from_torch()
|
||||
coordinator = DistCoordinator()
|
||||
|
||||
# ==============================
|
||||
# Initialize Booster
|
||||
# ==============================
|
||||
init_ctx = nullcontext()
|
||||
with init_ctx:
|
||||
if args.use_flash_attn:
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
args.pretrain,
|
||||
torch_dtype=torch.bfloat16 if args.mixed_precision == "bf16" else torch.float16,
|
||||
attn_implementation="flash_attention_2",
|
||||
trust_remote_code=True,
|
||||
)
|
||||
else:
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
args.pretrain,
|
||||
torch_dtype=torch.bfloat16 if args.mixed_precision == "bf16" else torch.float16,
|
||||
trust_remote_code=True,
|
||||
)
|
||||
if args.lora_rank > 0:
|
||||
model = convert_to_lora_module(model, args.lora_rank, lora_train_bias=args.lora_train_bias)
|
||||
|
||||
if args.plugin == "ddp":
|
||||
"""
|
||||
Default torch ddp plugin without any acceleration, for
|
||||
debugging purpose acceleration, for debugging purpose
|
||||
"""
|
||||
plugin = TorchDDPPlugin(find_unused_parameters=True)
|
||||
elif args.plugin == "gemini":
|
||||
plugin = GeminiPlugin(
|
||||
precision=args.mixed_precision,
|
||||
placement_policy="static",
|
||||
initial_scale=2**16,
|
||||
max_norm=args.grad_clip,
|
||||
enable_gradient_accumulation=True if args.accumulation_steps > 1 else False,
|
||||
enable_flash_attention=args.use_flash_attn,
|
||||
)
|
||||
elif args.plugin == "gemini_auto":
|
||||
plugin = GeminiPlugin(
|
||||
precision=args.mixed_precision,
|
||||
placement_policy="auto",
|
||||
initial_scale=2**16,
|
||||
max_norm=args.grad_clip,
|
||||
enable_flash_attention=args.use_flash_attn,
|
||||
)
|
||||
elif args.plugin == "zero2":
|
||||
plugin = LowLevelZeroPlugin(
|
||||
stage=2,
|
||||
precision=args.mixed_precision,
|
||||
initial_scale=2**16,
|
||||
max_norm=args.grad_clip,
|
||||
)
|
||||
elif args.plugin == "zero2_cpu":
|
||||
plugin = LowLevelZeroPlugin(
|
||||
stage=2,
|
||||
precision=args.mixed_precision,
|
||||
initial_scale=2**16,
|
||||
cpu_offload=True,
|
||||
max_norm=args.grad_clip,
|
||||
)
|
||||
elif args.plugin == "3d":
|
||||
plugin = HybridParallelPlugin(
|
||||
tp_size=args.tp,
|
||||
pp_size=args.pp,
|
||||
sp_size=args.sp,
|
||||
sequence_parallelism_mode=args.sp_mode,
|
||||
zero_stage=args.zero_stage,
|
||||
enable_flash_attention=args.use_flash_attn,
|
||||
enable_sequence_parallelism=args.enable_sequence_parallelism,
|
||||
cpu_offload=True if args.zero_stage >= 1 and args.zero_cpu_offload else False,
|
||||
parallel_output=False,
|
||||
max_norm=args.grad_clip,
|
||||
precision=args.mixed_precision,
|
||||
microbatch_size=args.batch_size,
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unknown plugin {args.plugin}")
|
||||
|
||||
booster = Booster(plugin=plugin)
|
||||
|
||||
# ======================================================
|
||||
# Initialize Model, Objective, Optimizer and LR Scheduler
|
||||
# ======================================================
|
||||
# Temp Fix: Disable lazy init due to version conflict
|
||||
# init_ctx = (
|
||||
# LazyInitContext(default_device=get_current_device()) if isinstance(plugin, (GeminiPlugin,)) else nullcontext()
|
||||
# )
|
||||
|
||||
if args.grad_checkpoint:
|
||||
# Note, for some models, lora may not be compatible with gradient checkpointing
|
||||
model.gradient_checkpointing_enable()
|
||||
coordinator.print_on_master(msg="Gradient checkpointing enabled successfully")
|
||||
|
||||
# configure tokenizer
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
args.tokenizer_dir or args.pretrain, use_fast=False, trust_remote_code=True
|
||||
)
|
||||
if hasattr(tokenizer, "pad_token") and hasattr(tokenizer, "eos_token") and tokenizer.eos_token is not None:
|
||||
try:
|
||||
# Some tokenizers doesn't allow to set pad_token mannually e.g., Qwen
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
except AttributeError as e:
|
||||
logger.warning(f"Unable to set pad token to eos token, {str(e)}")
|
||||
if not hasattr(tokenizer, "pad_token") or tokenizer.pad_token is None:
|
||||
logger.warning(
|
||||
"The tokenizer does not have a pad token which is required. May lead to unintended behavior in training, Please consider manually set them."
|
||||
)
|
||||
|
||||
tokenizer.add_bos_token = False
|
||||
tokenizer.add_eos_token = False
|
||||
tokenizer.padding_side = "right"
|
||||
|
||||
coordinator.print_on_master(f"Configuration file will be saved at: {args.config_file}")
|
||||
|
||||
# configure optimizer
|
||||
optim = HybridAdam(
|
||||
model_params=model.parameters(),
|
||||
lr=args.lr,
|
||||
betas=(0.9, 0.95),
|
||||
weight_decay=args.weight_decay,
|
||||
adamw_mode=True,
|
||||
)
|
||||
|
||||
# configure dataset
|
||||
coordinator.print_on_master(
|
||||
f"Max CUDA memory before data loader: {torch.cuda.max_memory_allocated() / 1024 ** 2:.2f} MB"
|
||||
)
|
||||
dataset = DummyLLMDataset(["input_ids", "attention_mask", "labels"], args.max_len, args.dataset_size)
|
||||
data_collator = DataCollatorForSupervisedDataset(tokenizer=tokenizer, max_length=args.max_len)
|
||||
|
||||
train_dataloader = plugin.prepare_dataloader(
|
||||
dataset=dataset,
|
||||
batch_size=args.batch_size,
|
||||
shuffle=True,
|
||||
drop_last=True,
|
||||
collate_fn=data_collator,
|
||||
distributed_sampler_cls=StatefulDistributedSampler,
|
||||
)
|
||||
coordinator.print_on_master(
|
||||
f"Max CUDA memory after data loader: {torch.cuda.max_memory_allocated() / 1024 ** 2:.2f} MB"
|
||||
)
|
||||
|
||||
num_update_steps_per_epoch = len(train_dataloader) // args.accumulation_steps
|
||||
math.ceil(args.max_epochs * num_update_steps_per_epoch)
|
||||
|
||||
if args.warmup_steps is None:
|
||||
args.warmup_steps = int(args.max_epochs * 0.025 * (len(train_dataloader) // args.accumulation_steps))
|
||||
coordinator.print_on_master(f"Warmup steps is set to {args.warmup_steps}")
|
||||
|
||||
lr_scheduler = CosineAnnealingWarmupLR(
|
||||
optimizer=optim,
|
||||
total_steps=args.max_epochs * num_update_steps_per_epoch,
|
||||
warmup_steps=args.warmup_steps,
|
||||
eta_min=0.1 * args.lr,
|
||||
)
|
||||
|
||||
# Flash attention will be disabled because it does NOT support fp32.
|
||||
default_dtype = torch.float16 if args.mixed_precision == "fp16" else torch.bfloat16
|
||||
torch.set_default_dtype(default_dtype)
|
||||
model, optim, _, train_dataloader, lr_scheduler = booster.boost(
|
||||
model=model,
|
||||
optimizer=optim,
|
||||
lr_scheduler=lr_scheduler,
|
||||
dataloader=train_dataloader,
|
||||
)
|
||||
torch.set_default_dtype(torch.float)
|
||||
|
||||
coordinator.print_on_master(f"Booster init max CUDA memory: {torch.cuda.max_memory_allocated() / 1024 ** 2:.2f} MB")
|
||||
coordinator.print_on_master(
|
||||
f"Booster init max CPU memory: {resource.getrusage(resource.RUSAGE_SELF).ru_maxrss / 1024:.2f} MB"
|
||||
)
|
||||
|
||||
start_epoch = 0
|
||||
sampler_start_idx = 0
|
||||
start_step = 0
|
||||
if args.checkpoint_path is not None:
|
||||
if "modeling" in args.checkpoint_path:
|
||||
coordinator.print_on_master(f"Continued pretrain from checkpoint {args.checkpoint_path}")
|
||||
booster.load_model(model, args.checkpoint_path)
|
||||
else:
|
||||
coordinator.print_on_master(f"Load model checkpoint from {args.checkpoint_path}")
|
||||
start_epoch, start_step, sampler_start_idx = load_checkpoint(
|
||||
load_dir=args.checkpoint_path,
|
||||
booster=booster,
|
||||
model=model,
|
||||
optimizer=optim,
|
||||
lr_scheduler=lr_scheduler,
|
||||
)
|
||||
train_dataloader.sampler.set_start_index(start_index=sampler_start_idx)
|
||||
|
||||
coordinator.print_on_master(
|
||||
f"Loaded checkpoint {args.checkpoint_path} at epoch {start_epoch} step {start_step}"
|
||||
)
|
||||
coordinator.print_on_master(f"Loaded sample at index {sampler_start_idx}")
|
||||
|
||||
coordinator.print_on_master(
|
||||
f"Checkpoint loaded max CUDA memory: {torch.cuda.max_memory_allocated() / 1024 ** 2:.2f} MB"
|
||||
)
|
||||
coordinator.print_on_master(
|
||||
f"Checkpoint loaded CUDA memory: {torch.cuda.memory_allocated() / 1024 ** 2:.2f} MB"
|
||||
)
|
||||
coordinator.print_on_master(
|
||||
f"Checkpoint loaded max CPU memory: {resource.getrusage(resource.RUSAGE_SELF).ru_maxrss / 1024:.2f} MB"
|
||||
)
|
||||
|
||||
trainer = SFTTrainer(
|
||||
model=model,
|
||||
booster=booster,
|
||||
optim=optim,
|
||||
lr_scheduler=lr_scheduler,
|
||||
max_epochs=args.max_epochs,
|
||||
accumulation_steps=args.accumulation_steps,
|
||||
start_epoch=start_epoch,
|
||||
save_interval=None,
|
||||
save_dir=None,
|
||||
coordinator=coordinator,
|
||||
)
|
||||
|
||||
trainer.fit(
|
||||
train_dataloader=train_dataloader,
|
||||
eval_dataloader=None,
|
||||
log_dir=None,
|
||||
use_wandb=False,
|
||||
)
|
||||
|
||||
coordinator.print_on_master(f"Max CUDA memory usage: {torch.cuda.max_memory_allocated()/1024**2:.2f} MB")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# ==============================
|
||||
# Parse Arguments
|
||||
# ==============================
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--plugin",
|
||||
type=str,
|
||||
default="gemini",
|
||||
choices=["gemini", "gemini_auto", "3d", "ddp", "zero2_cpu", "zero2"],
|
||||
help="Choose which plugin to use",
|
||||
)
|
||||
parser.add_argument("--grad_clip", type=float, default=1.0, help="Gradient clipping value")
|
||||
parser.add_argument("--weight_decay", type=float, default=0.1, help="Weight decay")
|
||||
parser.add_argument("--warmup_steps", type=int, default=None, help="Warmup steps")
|
||||
parser.add_argument("--tp", type=int, default=1)
|
||||
parser.add_argument("--pp", type=int, default=1)
|
||||
parser.add_argument("--sp", type=int, default=1)
|
||||
parser.add_argument("--enable_sequence_parallelism", default=False, action="store_true")
|
||||
parser.add_argument("--zero_stage", type=int, default=0, help="Zero stage", choices=[0, 1, 2])
|
||||
parser.add_argument("--zero_cpu_offload", default=False, action="store_true")
|
||||
parser.add_argument("--sp_mode", type=str, default="split_gather", choices=["split_gather", "ring", "all_to_all"])
|
||||
parser.add_argument("--pretrain", type=str, default=None)
|
||||
parser.add_argument("--tokenizer_dir", type=str, default=None)
|
||||
parser.add_argument(
|
||||
"--checkpoint_path", type=str, default=None, help="Checkpoint path if need to resume training form a checkpoint"
|
||||
)
|
||||
parser.add_argument("--max_epochs", type=int, default=3)
|
||||
parser.add_argument("--batch_size", type=int, default=4)
|
||||
parser.add_argument("--max_len", type=int, default=512)
|
||||
parser.add_argument("--mixed_precision", type=str, default="bf16", choices=["fp16", "bf16"], help="Mixed precision")
|
||||
parser.add_argument("--lora_rank", type=int, default=0, help="low-rank adaptation matrices rank")
|
||||
parser.add_argument(
|
||||
"--lora_train_bias",
|
||||
type=str,
|
||||
default="none",
|
||||
help="'none' means it doesn't train biases. 'all' means it trains all biases. 'lora_only' means it only trains biases of LoRA layers",
|
||||
)
|
||||
parser.add_argument("--merge_lora_weights", type=bool, default=True)
|
||||
parser.add_argument("--lr", type=float, default=5e-6)
|
||||
parser.add_argument("--config_file", type=str, default="config_file", help="Config file")
|
||||
parser.add_argument("--accumulation_steps", type=int, default=8)
|
||||
parser.add_argument("--grad_checkpoint", default=False, action="store_true")
|
||||
parser.add_argument("--use_flash_attn", default=False, action="store_true")
|
||||
parser.add_argument("--dataset_size", type=int, default=500)
|
||||
args = parser.parse_args()
|
||||
os.makedirs(os.path.dirname(args.config_file), exist_ok=True)
|
||||
with open(args.config_file, "w") as f:
|
||||
json.dump(args.__dict__, f, indent=4)
|
||||
train(args)
|
|
@ -0,0 +1,41 @@
|
|||
set_n_least_used_CUDA_VISIBLE_DEVICES() {
|
||||
local n=${1:-"9999"}
|
||||
echo "GPU Memory Usage:"
|
||||
local FIRST_N_GPU_IDS=$(nvidia-smi --query-gpu=memory.used --format=csv |
|
||||
tail -n +2 |
|
||||
nl -v 0 |
|
||||
tee /dev/tty |
|
||||
sort -g -k 2 |
|
||||
awk '{print $1}' |
|
||||
head -n $n)
|
||||
export CUDA_VISIBLE_DEVICES=$(echo $FIRST_N_GPU_IDS | sed 's/ /,/g')
|
||||
echo "Now CUDA_VISIBLE_DEVICES is set to:"
|
||||
echo "CUDA_VISIBLE_DEVICES=$CUDA_VISIBLE_DEVICES"
|
||||
}
|
||||
|
||||
set_n_least_used_CUDA_VISIBLE_DEVICES 4
|
||||
# export CUDA_VISIBLE_DEVICES=3,4
|
||||
PROJECT_NAME="sft"
|
||||
PARENT_CONFIG_FILE="./benchmark_config" # Path to a folder to save training config logs
|
||||
PRETRAINED_MODEL_PATH="" # huggingface or local model path
|
||||
PRETRAINED_TOKENIZER_PATH="" # huggingface or local tokenizer path
|
||||
|
||||
TIMESTAMP=$(date +%Y-%m-%d-%H-%M-%S)
|
||||
FULL_PROJECT_NAME="${PROJECT_NAME}-${TIMESTAMP}"
|
||||
CONFIG_FILE="${PARENT_CONFIG_FILE}-${FULL_PROJECT_NAME}.json"
|
||||
|
||||
# the real batch size for gradient descent is number_of_node_in_hostfile * nproc_per_node * train_batch_size
|
||||
colossalai run --nproc_per_node 4 --master_port 31312 benchmark_sft.py \
|
||||
--pretrain $PRETRAINED_MODEL_PATH \
|
||||
--tokenizer_dir $PRETRAINED_TOKENIZER_PATH \
|
||||
--config_file $CONFIG_FILE \
|
||||
--plugin zero2 \
|
||||
--batch_size 8 \
|
||||
--max_epochs 1 \
|
||||
--accumulation_steps 1 \
|
||||
--lr 5e-5 \
|
||||
--lora_rank 32 \
|
||||
--max_len 2048 \
|
||||
--dataset_size 640 \
|
||||
--grad_checkpoint \
|
||||
--use_flash_attn
|
|
@ -0,0 +1,22 @@
|
|||
import torch
|
||||
from torch.utils.data import Dataset
|
||||
|
||||
|
||||
class DummyLLMDataset(Dataset):
|
||||
def __init__(self, keys, seq_len, size=500):
|
||||
self.keys = keys
|
||||
self.seq_len = seq_len
|
||||
self.data = self._generate_data()
|
||||
self.size = size
|
||||
|
||||
def _generate_data(self):
|
||||
data = {}
|
||||
for key in self.keys:
|
||||
data[key] = torch.ones(self.seq_len, dtype=torch.long)
|
||||
return data
|
||||
|
||||
def __len__(self):
|
||||
return self.size
|
||||
|
||||
def __getitem__(self, idx):
|
||||
return {key: self.data[key] for key in self.keys}
|
|
@ -73,9 +73,12 @@ def supervised_tokenize_sft(
|
|||
lo, hi = 0, len(turns)
|
||||
while lo < hi:
|
||||
mid = (lo + hi) // 2
|
||||
if max_length - 1 < len(
|
||||
tokenizer([template.get_prompt(2 * turns[mid] - 1)], add_special_tokens=False)["input_ids"][0]
|
||||
):
|
||||
prompt = template.get_prompt(2 * turns[mid] - 1)
|
||||
chunks, require_loss = split_templated_prompt_into_chunks(
|
||||
template.messages[: 2 * turns[mid] - 1], prompt, conversation_template.end_of_assistant
|
||||
)
|
||||
tokenized, starts, ends = tokenize_and_concatenate(tokenizer, chunks, require_loss)
|
||||
if max_length - 1 < len(tokenized):
|
||||
hi = mid
|
||||
else:
|
||||
lo = mid + 1
|
||||
|
@ -114,6 +117,7 @@ def supervised_tokenize_sft(
|
|||
to_truncate_len += 1
|
||||
else:
|
||||
break
|
||||
to_truncate_len = max(len(tokenized) - max_length, to_truncate_len)
|
||||
tokenized = tokenized[: len(tokenized) - to_truncate_len]
|
||||
labels = labels[: len(labels) - to_truncate_len]
|
||||
|
||||
|
@ -356,48 +360,24 @@ def tokenize_rlhf(
|
|||
rejected_loss_mask,
|
||||
rejected_label_decode,
|
||||
) = (None, None, None, None, None, None)
|
||||
if (
|
||||
len(tokenizer([chosen.get_prompt(len(chosen.messages))], add_special_tokens=False)["input_ids"][0])
|
||||
<= max_length - 1
|
||||
and len(tokenizer([rejected.get_prompt(len(rejected.messages))], add_special_tokens=False)["input_ids"][0])
|
||||
<= max_length - 1
|
||||
):
|
||||
chosen_data_packed = apply_rlhf_data_format(chosen, tokenizer, round_of_context)
|
||||
(chosen_input_ids, chosen_loss_mask, chosen_label_decode) = (
|
||||
chosen_data_packed["input_ids"],
|
||||
chosen_data_packed["loss_mask"],
|
||||
chosen_data_packed["label_decode"],
|
||||
)
|
||||
|
||||
rejected_data_packed = apply_rlhf_data_format(
|
||||
rejected, tokenizer, round_of_context, mask_out_target_assistant_line_end=True
|
||||
)
|
||||
(rejected_input_ids, rejected_loss_mask, rejected_label_decode) = (
|
||||
rejected_data_packed["input_ids"],
|
||||
rejected_data_packed["loss_mask"],
|
||||
rejected_data_packed["label_decode"],
|
||||
)
|
||||
chosen_data_packed = apply_rlhf_data_format(chosen, tokenizer, round_of_context)
|
||||
(chosen_input_ids, chosen_loss_mask, chosen_label_decode) = (
|
||||
chosen_data_packed["input_ids"],
|
||||
chosen_data_packed["loss_mask"],
|
||||
chosen_data_packed["label_decode"],
|
||||
)
|
||||
|
||||
# Check if loss mask is all 0s (no loss), this may happen when the tokenized length is too long
|
||||
if chosen_loss_mask.count(0) == len(chosen_loss_mask) or rejected_loss_mask.count(0) == len(rejected_loss_mask):
|
||||
return dict(
|
||||
chosen_input_ids=None,
|
||||
chosen_loss_mask=None,
|
||||
chosen_label_decode=None,
|
||||
rejected_input_ids=None,
|
||||
rejected_loss_mask=None,
|
||||
rejected_label_decode=None,
|
||||
)
|
||||
rejected_data_packed = apply_rlhf_data_format(
|
||||
rejected, tokenizer, round_of_context, mask_out_target_assistant_line_end=True
|
||||
)
|
||||
(rejected_input_ids, rejected_loss_mask, rejected_label_decode) = (
|
||||
rejected_data_packed["input_ids"],
|
||||
rejected_data_packed["loss_mask"],
|
||||
rejected_data_packed["label_decode"],
|
||||
)
|
||||
|
||||
return {
|
||||
"chosen_input_ids": chosen_input_ids,
|
||||
"chosen_loss_mask": chosen_loss_mask,
|
||||
"chosen_label_decode": chosen_label_decode,
|
||||
"rejected_input_ids": rejected_input_ids,
|
||||
"rejected_loss_mask": rejected_loss_mask,
|
||||
"rejected_label_decode": rejected_label_decode,
|
||||
}
|
||||
else:
|
||||
if len(chosen_input_ids) > max_length or len(rejected_input_ids) > max_length:
|
||||
return dict(
|
||||
chosen_input_ids=None,
|
||||
chosen_loss_mask=None,
|
||||
|
@ -406,3 +386,22 @@ def tokenize_rlhf(
|
|||
rejected_loss_mask=None,
|
||||
rejected_label_decode=None,
|
||||
)
|
||||
# Check if loss mask is all 0s (no loss), this may happen when the tokenized length is too long
|
||||
if chosen_loss_mask[1:].count(1) == 0 or rejected_loss_mask[1:].count(1) == 0:
|
||||
return dict(
|
||||
chosen_input_ids=None,
|
||||
chosen_loss_mask=None,
|
||||
chosen_label_decode=None,
|
||||
rejected_input_ids=None,
|
||||
rejected_loss_mask=None,
|
||||
rejected_label_decode=None,
|
||||
)
|
||||
|
||||
return {
|
||||
"chosen_input_ids": chosen_input_ids,
|
||||
"chosen_loss_mask": chosen_loss_mask,
|
||||
"chosen_label_decode": chosen_label_decode,
|
||||
"rejected_input_ids": rejected_input_ids,
|
||||
"rejected_loss_mask": rejected_loss_mask,
|
||||
"rejected_label_decode": rejected_label_decode,
|
||||
}
|
||||
|
|
|
@ -89,11 +89,22 @@ class DpoLoss(nn.Module):
|
|||
"""
|
||||
Dpo loss
|
||||
Details: https://arxiv.org/pdf/2305.18290.pdf
|
||||
|
||||
SimPO loss:
|
||||
Details: https://arxiv.org/pdf/2405.14734.pdf
|
||||
"""
|
||||
|
||||
def __init__(self, beta: float = 0.1):
|
||||
def __init__(self, beta: float = 0.1, gamma: float = 0.0):
|
||||
"""
|
||||
Args:
|
||||
beta: The temperature parameter in the DPO paper.
|
||||
gamma: The margin parameter in the SimPO paper.
|
||||
length_normalization: Whether to normalize the loss by the length of chosen and rejected responses.
|
||||
Refer to the length normalization in the SimPO paper
|
||||
"""
|
||||
super().__init__()
|
||||
self.beta = beta
|
||||
self.gamma = gamma
|
||||
|
||||
def forward(
|
||||
self,
|
||||
|
@ -104,7 +115,7 @@ class DpoLoss(nn.Module):
|
|||
chosen_mask: torch.Tensor,
|
||||
reject_mask: torch.Tensor,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
"""Compute the DPO loss for a batch of policy and reference model log probabilities.
|
||||
"""Compute the DPO/SimPO loss for a batch of policy and reference model log probabilities.
|
||||
|
||||
# adapted from https://github.com/huggingface/trl/blob/main/trl/trainer/dpo_trainer.py#L328
|
||||
|
||||
|
@ -113,6 +124,8 @@ class DpoLoss(nn.Module):
|
|||
logprob_actor_reject: Log probabilities of the policy model for the rejected responses. Shape: (batch_size,)
|
||||
logprob_ref_chosen: Log probabilities of the reference model for the chosen responses. Shape: (batch_size,)
|
||||
logprob_ref_reject: Log probabilities of the reference model for the rejected responses. Shape: (batch_size,)
|
||||
chosen_mask: Mask tensor indicating which responses were chosen. Shape: (batch_size,)
|
||||
reject_mask: Mask tensor indicating which responses were rejected. Shape: (batch_size,)
|
||||
|
||||
Returns:
|
||||
A tuple of three tensors: (losses, chosen_rewards, rejected_rewards).
|
||||
|
@ -127,13 +140,12 @@ class DpoLoss(nn.Module):
|
|||
if len(logprob_ref_chosen.shape) == 2:
|
||||
ref_logratios = logprob_ref_chosen.sum(-1) - logprob_ref_reject.sum(-1)
|
||||
else:
|
||||
ref_logratios = logprob_ref_chosen.squeeze() - logprob_ref_reject.squeeze()
|
||||
ref_logratios = logprob_ref_chosen - logprob_ref_reject
|
||||
else:
|
||||
# If no reference model is provided
|
||||
ref_logratios = 0.0
|
||||
|
||||
pi_logratios = logprob_actor_chosen.sum(-1) - logprob_actor_reject.sum(-1)
|
||||
logits = pi_logratios - ref_logratios
|
||||
logits = pi_logratios - ref_logratios - self.gamma / self.beta
|
||||
losses = -torch.nn.functional.logsigmoid(self.beta * logits)
|
||||
|
||||
# Calculate rewards for logging
|
||||
|
@ -168,3 +180,28 @@ class LogExpLoss(nn.Module):
|
|||
def forward(self, chosen_reward: torch.Tensor, reject_reward: torch.Tensor) -> torch.Tensor:
|
||||
loss = torch.log(1 + torch.exp(reject_reward - chosen_reward)).mean()
|
||||
return loss
|
||||
|
||||
|
||||
class OddsRatioLoss(nn.Module):
|
||||
"""
|
||||
Odds Ratio Loss in ORPO
|
||||
Details: https://arxiv.org/pdf/2403.07691
|
||||
"""
|
||||
|
||||
def forward(
|
||||
self,
|
||||
chosen_logp: torch.Tensor,
|
||||
reject_logp: torch.Tensor,
|
||||
chosen_loss_mask: torch.Tensor,
|
||||
reject_loss_mask: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
chosen_logp = chosen_logp.to(dtype=torch.float32)
|
||||
reject_logp = reject_logp.to(dtype=torch.float32)
|
||||
chosen_odds = chosen_logp - torch.log(-torch.exp(chosen_logp) + 1.0001)
|
||||
chosen_odds_masked = torch.sum(chosen_odds * chosen_loss_mask.float()) / torch.sum(chosen_loss_mask)
|
||||
reject_odds = reject_logp - torch.log(-torch.exp(reject_logp) + 1.0001)
|
||||
reject_odds_masked = torch.sum(reject_odds * reject_loss_mask.float()) / torch.sum(reject_loss_mask)
|
||||
# print("chosen_odds_masked", chosen_odds_masked[0], "reject_odds_masked", reject_odds_masked[0])
|
||||
log_odds_ratio = chosen_odds_masked - reject_odds_masked
|
||||
ratio = torch.log(torch.nn.functional.sigmoid(log_odds_ratio))
|
||||
return ratio.to(dtype=torch.bfloat16), log_odds_ratio
|
||||
|
|
|
@ -89,7 +89,9 @@ def masked_mean(tensor: torch.Tensor, mask: torch.Tensor, dim: int = 1) -> torch
|
|||
return mean
|
||||
|
||||
|
||||
def calc_masked_log_probs(logits: torch.Tensor, sequences: torch.LongTensor, mask: torch.Tensor) -> torch.Tensor:
|
||||
def calc_masked_log_probs(
|
||||
logits: torch.Tensor, sequences: torch.LongTensor, mask: torch.Tensor, length_normalization: bool = False
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Calculate the masked log probabilities for a given sequence of logits.
|
||||
|
||||
|
@ -103,7 +105,11 @@ def calc_masked_log_probs(logits: torch.Tensor, sequences: torch.LongTensor, mas
|
|||
"""
|
||||
# logits are probabilities of the next token, so we shift them to the left by one
|
||||
log_probs = _log_probs_from_logits(logits[:, :-1, :], sequences[:, 1:])
|
||||
return log_probs * mask
|
||||
|
||||
if not length_normalization:
|
||||
return log_probs * mask
|
||||
else:
|
||||
return log_probs * mask / (mask.sum(dim=-1, keepdim=True) + 0.01)
|
||||
|
||||
|
||||
def load_json(file_path: Union[str, os.PathLike]) -> Dict[str, Any]:
|
||||
|
|
|
@ -1,7 +1,8 @@
|
|||
from .base import OLTrainer, SLTrainer
|
||||
from .dpo import DPOTrainer
|
||||
from .orpo import ORPOTrainer
|
||||
from .ppo import PPOTrainer
|
||||
from .rm import RewardModelTrainer
|
||||
from .sft import SFTTrainer
|
||||
|
||||
__all__ = ["SLTrainer", "OLTrainer", "RewardModelTrainer", "SFTTrainer", "PPOTrainer", "DPOTrainer"]
|
||||
__all__ = ["SLTrainer", "OLTrainer", "RewardModelTrainer", "SFTTrainer", "PPOTrainer", "DPOTrainer", "ORPOTrainer"]
|
||||
|
|
|
@ -53,6 +53,8 @@ class DPOTrainer(SLTrainer):
|
|||
tokenizer: PreTrainedTokenizerBase,
|
||||
max_epochs: int = 1,
|
||||
beta: float = 0.1,
|
||||
gamma: float = 0.0,
|
||||
length_normalization: bool = False,
|
||||
accumulation_steps: int = 1,
|
||||
start_epoch: int = 0,
|
||||
save_interval: int = 0,
|
||||
|
@ -63,7 +65,7 @@ class DPOTrainer(SLTrainer):
|
|||
self.ref_model = ref_model
|
||||
self.actor_scheduler = actor_lr_scheduler
|
||||
self.tokenizer = tokenizer
|
||||
self.actor_loss_fn = DpoLoss(beta)
|
||||
self.actor_loss_fn = DpoLoss(beta, gamma)
|
||||
self.save_interval = save_interval
|
||||
self.coordinator = coordinator
|
||||
self.save_dir = save_dir
|
||||
|
@ -71,6 +73,7 @@ class DPOTrainer(SLTrainer):
|
|||
self.accumulation_steps = accumulation_steps
|
||||
self.device = get_current_device()
|
||||
self.accumulative_meter = AccumulativeMeanMeter()
|
||||
self.length_normalization = length_normalization
|
||||
|
||||
def _before_fit(
|
||||
self,
|
||||
|
@ -131,18 +134,21 @@ class DPOTrainer(SLTrainer):
|
|||
batch["reject_attention_mask"],
|
||||
batch["reject_loss_mask"],
|
||||
)
|
||||
reject_loss_mask[:, -1] = False
|
||||
batch_size = chosen_input_ids.size()[0]
|
||||
|
||||
actor_all_logits = self.model(
|
||||
input_ids=torch.cat([chosen_input_ids, reject_input_ids]),
|
||||
attention_mask=torch.cat([chosen_attention_mask, reject_attention_mask]),
|
||||
)["logits"].to(torch.float32)
|
||||
)["logits"]
|
||||
actor_chosen_logits = actor_all_logits[:batch_size]
|
||||
actor_reject_logits = actor_all_logits[batch_size:]
|
||||
logprob_actor_chosen = calc_masked_log_probs(actor_chosen_logits, chosen_input_ids, chosen_loss_mask[:, 1:])
|
||||
logprob_actor_chosen = calc_masked_log_probs(
|
||||
actor_chosen_logits, chosen_input_ids, chosen_loss_mask[:, 1:], self.length_normalization
|
||||
)
|
||||
|
||||
logprob_actor_reject = calc_masked_log_probs(actor_reject_logits, reject_input_ids, reject_loss_mask[:, 1:])
|
||||
logprob_actor_reject = calc_masked_log_probs(
|
||||
actor_reject_logits, reject_input_ids, reject_loss_mask[:, 1:], self.length_normalization
|
||||
)
|
||||
|
||||
if self.ref_model is not None:
|
||||
self.ref_model.eval()
|
||||
|
@ -150,14 +156,14 @@ class DPOTrainer(SLTrainer):
|
|||
ref_all_logits = self.ref_model(
|
||||
input_ids=torch.cat([chosen_input_ids, reject_input_ids]),
|
||||
attention_mask=torch.cat([chosen_attention_mask, reject_attention_mask]),
|
||||
)["logits"].to(torch.float32)
|
||||
)["logits"]
|
||||
ref_chosen_logits = ref_all_logits[:batch_size]
|
||||
ref_reject_logits = ref_all_logits[batch_size:]
|
||||
logprob_ref_chosen = calc_masked_log_probs(
|
||||
ref_chosen_logits, chosen_input_ids, chosen_loss_mask[:, 1:]
|
||||
ref_chosen_logits, chosen_input_ids, chosen_loss_mask[:, 1:], self.length_normalization
|
||||
)
|
||||
logprob_ref_reject = calc_masked_log_probs(
|
||||
ref_reject_logits, reject_input_ids, reject_loss_mask[:, 1:]
|
||||
ref_reject_logits, reject_input_ids, reject_loss_mask[:, 1:], self.length_normalization
|
||||
)
|
||||
else:
|
||||
logprob_ref_chosen = None
|
||||
|
@ -219,7 +225,7 @@ class DPOTrainer(SLTrainer):
|
|||
)
|
||||
self.accumulative_meter.reset()
|
||||
|
||||
if (self.num_train_step + 1) % self.save_interval == 0:
|
||||
if self.save_dir is not None and (self.num_train_step + 1) % self.save_interval == 0:
|
||||
# save checkpoint
|
||||
self.coordinator.print_on_master("\nStart saving model checkpoint with running states")
|
||||
save_checkpoint(
|
||||
|
@ -283,16 +289,16 @@ class DPOTrainer(SLTrainer):
|
|||
actor_all_logits = self.model(
|
||||
torch.cat([chosen_input_ids, reject_input_ids]),
|
||||
torch.cat([chosen_attention_mask, reject_attention_mask]),
|
||||
)["logits"].to(torch.float32)
|
||||
)["logits"]
|
||||
actor_chosen_logits = actor_all_logits[:batch_size]
|
||||
actor_reject_logits = actor_all_logits[batch_size:]
|
||||
|
||||
logprob_actor_chosen = calc_masked_log_probs(
|
||||
actor_chosen_logits, chosen_input_ids, chosen_loss_mask[:, 1:]
|
||||
actor_chosen_logits, chosen_input_ids, chosen_loss_mask[:, 1:], self.length_normalization
|
||||
)
|
||||
|
||||
logprob_actor_reject = calc_masked_log_probs(
|
||||
actor_reject_logits, reject_input_ids, reject_loss_mask[:, 1:]
|
||||
actor_reject_logits, reject_input_ids, reject_loss_mask[:, 1:], self.length_normalization
|
||||
)
|
||||
|
||||
self.ref_model.eval()
|
||||
|
@ -300,11 +306,15 @@ class DPOTrainer(SLTrainer):
|
|||
ref_all_logits = self.ref_model(
|
||||
torch.cat([chosen_input_ids, reject_input_ids]),
|
||||
torch.cat([chosen_attention_mask, reject_attention_mask]),
|
||||
)["logits"].to(torch.float32)
|
||||
)["logits"]
|
||||
ref_chosen_logits = ref_all_logits[:batch_size]
|
||||
ref_reject_logits = ref_all_logits[batch_size:]
|
||||
logprob_ref_chosen = calc_masked_log_probs(ref_chosen_logits, chosen_input_ids, chosen_loss_mask[:, 1:])
|
||||
logprob_ref_reject = calc_masked_log_probs(ref_reject_logits, reject_input_ids, reject_loss_mask[:, 1:])
|
||||
logprob_ref_chosen = calc_masked_log_probs(
|
||||
ref_chosen_logits, chosen_input_ids, chosen_loss_mask[:, 1:], self.length_normalization
|
||||
)
|
||||
logprob_ref_reject = calc_masked_log_probs(
|
||||
ref_reject_logits, reject_input_ids, reject_loss_mask[:, 1:], self.length_normalization
|
||||
)
|
||||
|
||||
losses, chosen_rewards, rejected_rewards = self.actor_loss_fn(
|
||||
logprob_actor_chosen,
|
||||
|
|
|
@ -0,0 +1,339 @@
|
|||
"""
|
||||
Orpo trainer
|
||||
"""
|
||||
|
||||
from typing import Any, Optional
|
||||
|
||||
import torch
|
||||
from coati.models.loss import OddsRatioLoss
|
||||
from coati.models.utils import calc_masked_log_probs
|
||||
from coati.trainer.utils import all_reduce_mean
|
||||
from coati.utils import AccumulativeMeanMeter, save_checkpoint
|
||||
from torch.nn import CrossEntropyLoss
|
||||
from torch.optim import Optimizer
|
||||
from torch.optim.lr_scheduler import _LRScheduler
|
||||
from torch.utils.data import DataLoader
|
||||
from tqdm import trange
|
||||
from transformers import PreTrainedTokenizerBase
|
||||
|
||||
from colossalai.booster import Booster
|
||||
from colossalai.cluster import DistCoordinator
|
||||
from colossalai.utils import get_current_device
|
||||
|
||||
from .base import SLTrainer
|
||||
from .utils import is_rank_0, to_device
|
||||
|
||||
|
||||
class ORPOTrainer(SLTrainer):
|
||||
"""
|
||||
Trainer for PPO algorithm.
|
||||
|
||||
Args:
|
||||
actor (Actor): the actor model in ppo algorithm
|
||||
booster (Strategy): the strategy to use for training
|
||||
actor_optim (Optimizer): the optimizer to use for actor model
|
||||
actor_lr_scheduler (_LRScheduler): the lr scheduler to use for actor model
|
||||
tokenizer (PreTrainedTokenizerBase): the tokenizer to use for encoding
|
||||
max_epochs (int, defaults to 1): the max number of epochs to train
|
||||
lam (float, defaults to 0.1): the lambda parameter in ORPO loss
|
||||
accumulation_steps (int): the number of steps to accumulate gradients
|
||||
start_epoch (int, defaults to 0): the start epoch, non-zero if resumed from a checkpoint
|
||||
save_interval (int): the interval to save model checkpoints, default to 0, which means no checkpoint will be saved during trainning
|
||||
save_dir (str): the directory to save checkpoints
|
||||
coordinator (DistCoordinator): the coordinator to use for distributed logging
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
actor: Any,
|
||||
booster: Booster,
|
||||
actor_optim: Optimizer,
|
||||
actor_lr_scheduler: _LRScheduler,
|
||||
tokenizer: PreTrainedTokenizerBase,
|
||||
max_epochs: int = 1,
|
||||
lam: float = 0.1,
|
||||
accumulation_steps: int = 1,
|
||||
start_epoch: int = 0,
|
||||
save_interval: int = 0,
|
||||
save_dir: str = None,
|
||||
coordinator: DistCoordinator = None,
|
||||
) -> None:
|
||||
super().__init__(booster, max_epochs=max_epochs, model=actor, optimizer=actor_optim, start_epoch=start_epoch)
|
||||
self.actor_scheduler = actor_lr_scheduler
|
||||
self.tokenizer = tokenizer
|
||||
self.odds_ratio_loss_fn = OddsRatioLoss()
|
||||
self.sft_loss_fn = CrossEntropyLoss()
|
||||
self.save_interval = save_interval
|
||||
self.coordinator = coordinator
|
||||
self.save_dir = save_dir
|
||||
self.num_train_step = 0
|
||||
self.lam = lam
|
||||
self.accumulation_steps = accumulation_steps
|
||||
self.device = get_current_device()
|
||||
self.accumulative_meter = AccumulativeMeanMeter()
|
||||
|
||||
def _before_fit(
|
||||
self,
|
||||
train_preference_dataloader: DataLoader = None,
|
||||
eval_preference_dataloader: DataLoader = None,
|
||||
log_dir: Optional[str] = None,
|
||||
use_wandb: bool = False,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
prompt_dataloader (DataLoader): the dataloader to use for prompt data
|
||||
pretrain_dataloader (DataLoader): the dataloader to use for pretrain data
|
||||
"""
|
||||
self.train_dataloader = train_preference_dataloader
|
||||
self.eval_dataloader = eval_preference_dataloader
|
||||
self.writer = None
|
||||
if use_wandb and is_rank_0():
|
||||
assert log_dir is not None, "log_dir must be provided when use_wandb is True"
|
||||
import wandb
|
||||
|
||||
self.wandb_run = wandb.init(project="Coati-orpo", sync_tensorboard=True)
|
||||
if log_dir is not None and is_rank_0():
|
||||
import os
|
||||
import time
|
||||
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
|
||||
log_dir = os.path.join(log_dir, "orpo")
|
||||
log_dir = os.path.join(log_dir, time.strftime("%Y-%m-%d_%H:%M:%S", time.localtime()))
|
||||
self.writer = SummaryWriter(log_dir=log_dir)
|
||||
|
||||
def _train(self, epoch: int):
|
||||
"""
|
||||
Args:
|
||||
epoch int: the number of current epoch
|
||||
"""
|
||||
self.model.train()
|
||||
self.accumulative_meter.reset()
|
||||
step_bar = trange(
|
||||
len(self.train_dataloader) // self.accumulation_steps,
|
||||
desc=f"Epoch {epoch + 1}/{self.max_epochs}",
|
||||
disable=not is_rank_0(),
|
||||
)
|
||||
for i, batch in enumerate(self.train_dataloader):
|
||||
batch = to_device(batch, self.device)
|
||||
(
|
||||
chosen_input_ids,
|
||||
chosen_attention_mask,
|
||||
chosen_loss_mask,
|
||||
reject_input_ids,
|
||||
reject_attention_mask,
|
||||
reject_loss_mask,
|
||||
) = (
|
||||
batch["chosen_input_ids"],
|
||||
batch["chosen_attention_mask"],
|
||||
batch["chosen_loss_mask"],
|
||||
batch["reject_input_ids"],
|
||||
batch["reject_attention_mask"],
|
||||
batch["reject_loss_mask"],
|
||||
)
|
||||
batch_size = chosen_input_ids.size()[0]
|
||||
actor_out = self.model(
|
||||
input_ids=torch.cat([chosen_input_ids, reject_input_ids]),
|
||||
attention_mask=torch.cat([chosen_attention_mask, reject_attention_mask]),
|
||||
)
|
||||
torch.autograd.set_detect_anomaly(True)
|
||||
actor_all_logits = actor_out["logits"].to(torch.float32)
|
||||
actor_chosen_logits = actor_all_logits[:batch_size]
|
||||
actor_reject_logits = actor_all_logits[batch_size:]
|
||||
logprob_actor_chosen = calc_masked_log_probs(actor_chosen_logits, chosen_input_ids, chosen_loss_mask[:, 1:])
|
||||
|
||||
logprob_actor_reject = calc_masked_log_probs(actor_reject_logits, reject_input_ids, reject_loss_mask[:, 1:])
|
||||
chosen_logits = actor_chosen_logits[:, :-1, :].contiguous().view(-1, actor_chosen_logits.size(-1))
|
||||
label_chosen = chosen_input_ids[:, 1:].contiguous()
|
||||
label_chosen_masked = (
|
||||
label_chosen.masked_fill(chosen_loss_mask[:, 1:] == 0, -100).view(-1).contiguous().detach()
|
||||
)
|
||||
# label_chosen[chosen_loss_mask[:, 1:] == 0] = -100
|
||||
chosen_nll = self.sft_loss_fn(chosen_logits, label_chosen_masked).to(dtype=torch.bfloat16)
|
||||
odds_ratio_loss, log_odds_ratio = self.odds_ratio_loss_fn(
|
||||
logprob_actor_chosen, logprob_actor_reject, chosen_loss_mask[:, 1:], reject_loss_mask[:, 1:]
|
||||
)
|
||||
loss = chosen_nll - odds_ratio_loss * self.lam
|
||||
step_bar.set_description(f"Epoch {epoch + 1}/{self.max_epochs} Loss: {loss.detach().cpu().item():.4f}")
|
||||
|
||||
self.booster.backward(loss=loss, optimizer=self.optimizer)
|
||||
if self.num_train_step % self.accumulation_steps == self.accumulation_steps - 1:
|
||||
self.optimizer.step()
|
||||
self.optimizer.zero_grad()
|
||||
self.actor_scheduler.step()
|
||||
|
||||
chosen_rewards = torch.sum(logprob_actor_chosen) / torch.sum(chosen_loss_mask[:, 1:])
|
||||
rejected_rewards = torch.sum(logprob_actor_reject) / torch.sum(reject_loss_mask[:, 1:])
|
||||
reward_accuracies = torch.sum((log_odds_ratio > 0).float()) / torch.sum(log_odds_ratio != 0)
|
||||
|
||||
# sync
|
||||
loss_mean = all_reduce_mean(tensor=loss)
|
||||
chosen_rewards_mean = all_reduce_mean(tensor=chosen_rewards)
|
||||
rejected_rewards_mean = all_reduce_mean(tensor=rejected_rewards)
|
||||
reward_accuracies_mean = all_reduce_mean(tensor=reward_accuracies)
|
||||
self.accumulative_meter.add("chosen_rewards", chosen_rewards_mean.to(torch.float16).mean().item())
|
||||
self.accumulative_meter.add("rejected_rewards", rejected_rewards_mean.to(torch.float16).mean().item())
|
||||
self.accumulative_meter.add("loss", loss_mean.to(torch.float16).item())
|
||||
self.accumulative_meter.add("log_odds_ratio", log_odds_ratio.to(torch.float16).mean().item())
|
||||
self.accumulative_meter.add("accuracy", reward_accuracies_mean.to(torch.float16).item())
|
||||
|
||||
if i % self.accumulation_steps == self.accumulation_steps - 1:
|
||||
self.num_train_step += 1
|
||||
step_bar.update()
|
||||
# logging
|
||||
if self.writer and is_rank_0():
|
||||
self.writer.add_scalar("train/loss", self.accumulative_meter.get("loss"), self.num_train_step)
|
||||
self.writer.add_scalar("train/lr", self.optimizer.param_groups[0]["lr"], self.num_train_step)
|
||||
self.writer.add_scalar(
|
||||
"train/chosen_rewards", self.accumulative_meter.get("chosen_rewards"), self.num_train_step
|
||||
)
|
||||
self.writer.add_scalar(
|
||||
"train/rejected_rewards",
|
||||
self.accumulative_meter.get("rejected_rewards"),
|
||||
self.num_train_step,
|
||||
)
|
||||
self.writer.add_scalar(
|
||||
"train/margin",
|
||||
self.accumulative_meter.get("chosen_rewards") - self.accumulative_meter.get("rejected_rewards"),
|
||||
self.num_train_step,
|
||||
)
|
||||
self.writer.add_scalar(
|
||||
"train/accuracy",
|
||||
self.accumulative_meter.get("accuracy"),
|
||||
self.num_train_step,
|
||||
)
|
||||
self.writer.add_scalar(
|
||||
"train/log_odds_ratio",
|
||||
self.accumulative_meter.get("log_odds_ratio"),
|
||||
self.num_train_step,
|
||||
)
|
||||
self.accumulative_meter.reset()
|
||||
|
||||
if self.save_dir is not None and (self.num_train_step + 1) % self.save_interval == 0:
|
||||
# save checkpoint
|
||||
self.coordinator.print_on_master("\nStart saving model checkpoint with running states")
|
||||
save_checkpoint(
|
||||
save_dir=self.save_dir,
|
||||
booster=self.booster,
|
||||
model=self.model,
|
||||
optimizer=self.optimizer,
|
||||
lr_scheduler=self.actor_scheduler,
|
||||
epoch=epoch,
|
||||
step=i + 1,
|
||||
batch_size=batch_size,
|
||||
coordinator=self.coordinator,
|
||||
)
|
||||
self.coordinator.print_on_master(
|
||||
f"Saved checkpoint at epoch {epoch} step {self.save_interval} at folder {self.save_dir}"
|
||||
)
|
||||
|
||||
step_bar.close()
|
||||
|
||||
def _eval(self, epoch: int):
|
||||
"""
|
||||
Args:
|
||||
epoch int: the number of current epoch
|
||||
"""
|
||||
if self.eval_dataloader is None:
|
||||
self.coordinator.print_on_master("No eval dataloader is provided, skip evaluation")
|
||||
return
|
||||
self.model.eval()
|
||||
self.coordinator.print_on_master("\nStart evaluation...")
|
||||
|
||||
step_bar = trange(
|
||||
len(self.eval_dataloader),
|
||||
desc=f"Epoch {epoch + 1}/{self.max_epochs}",
|
||||
disable=not is_rank_0(),
|
||||
)
|
||||
|
||||
self.accumulative_meter.reset()
|
||||
|
||||
with torch.no_grad():
|
||||
for i, batch in enumerate(self.eval_dataloader):
|
||||
batch = to_device(batch, self.device)
|
||||
(
|
||||
chosen_input_ids,
|
||||
chosen_attention_mask,
|
||||
chosen_loss_mask,
|
||||
reject_input_ids,
|
||||
reject_attention_mask,
|
||||
reject_loss_mask,
|
||||
) = (
|
||||
batch["chosen_input_ids"],
|
||||
batch["chosen_attention_mask"],
|
||||
batch["chosen_loss_mask"],
|
||||
batch["reject_input_ids"],
|
||||
batch["reject_attention_mask"],
|
||||
batch["reject_loss_mask"],
|
||||
)
|
||||
batch_size = chosen_input_ids.size()[0]
|
||||
actor_out = self.model(
|
||||
input_ids=torch.cat([chosen_input_ids, reject_input_ids]),
|
||||
labels=torch.cat([chosen_input_ids, reject_input_ids]),
|
||||
attention_mask=torch.cat([chosen_attention_mask, reject_attention_mask]),
|
||||
)
|
||||
actor_all_logits = actor_out["logits"].to(torch.float32)
|
||||
chosen_nll = torch.mean(actor_out["loss"][:batch_size]).to(dtype=torch.bfloat16)
|
||||
actor_chosen_logits = actor_all_logits[:batch_size]
|
||||
actor_reject_logits = actor_all_logits[batch_size:]
|
||||
logprob_actor_chosen = calc_masked_log_probs(
|
||||
actor_chosen_logits, chosen_input_ids, chosen_loss_mask[:, 1:]
|
||||
)
|
||||
|
||||
logprob_actor_reject = calc_masked_log_probs(
|
||||
actor_reject_logits, reject_input_ids, reject_loss_mask[:, 1:]
|
||||
)
|
||||
|
||||
odds_ratio_loss, log_odds_ratio = self.odds_ratio_loss_fn(logprob_actor_chosen, logprob_actor_reject)
|
||||
|
||||
loss = chosen_nll - odds_ratio_loss * self.lam
|
||||
|
||||
chosen_rewards = torch.mean(logprob_actor_chosen).item()
|
||||
rejected_rewards = torch.mean(logprob_actor_reject).item()
|
||||
reward_accuracies = (log_odds_ratio > 0).float().mean().item()
|
||||
|
||||
# sync
|
||||
loss_mean = all_reduce_mean(tensor=loss)
|
||||
chosen_rewards_mean = all_reduce_mean(tensor=chosen_rewards)
|
||||
rejected_rewards_mean = all_reduce_mean(tensor=rejected_rewards)
|
||||
reward_accuracies_mean = all_reduce_mean(tensor=reward_accuracies)
|
||||
self.accumulative_meter.add("chosen_rewards", chosen_rewards_mean.to(torch.float16).mean().item())
|
||||
self.accumulative_meter.add("rejected_rewards", rejected_rewards_mean.to(torch.float16).mean().item())
|
||||
self.accumulative_meter.add("loss", loss_mean.to(torch.float16).item())
|
||||
self.accumulative_meter.add("log_odds_ratio", log_odds_ratio.to(torch.float16).mean().item())
|
||||
self.accumulative_meter.add("accuracy", reward_accuracies_mean.to(torch.float16).item())
|
||||
|
||||
# logging
|
||||
if self.writer and is_rank_0():
|
||||
self.writer.add_scalar("eval/loss", self.accumulative_meter.get("loss"), self.num_train_step)
|
||||
self.writer.add_scalar("train/lr", self.optimizer.param_groups[0]["lr"], self.num_train_step)
|
||||
self.writer.add_scalar(
|
||||
"train/chosen_rewards", self.accumulative_meter.get("chosen_rewards"), self.num_train_step
|
||||
)
|
||||
self.writer.add_scalar(
|
||||
"train/rejected_rewards",
|
||||
self.accumulative_meter.get("rejected_rewards"),
|
||||
self.num_train_step,
|
||||
)
|
||||
self.writer.add_scalar(
|
||||
"train/log",
|
||||
self.accumulative_meter.get("chosen_rewards") - self.accumulative_meter.get("rejected_rewards"),
|
||||
self.num_train_step,
|
||||
)
|
||||
self.writer.add_scalar(
|
||||
"train/accuracy",
|
||||
self.accumulative_meter.get("accuracy"),
|
||||
self.num_train_step,
|
||||
)
|
||||
self.writer.add_scalar(
|
||||
"train/log_odds_ratio",
|
||||
self.accumulative_meter.get("log_odds_ratio"),
|
||||
self.num_train_step,
|
||||
)
|
||||
self.step_bar.update()
|
||||
|
||||
msg = "Evaluation Result:\n"
|
||||
for tag in ["loss", "chosen_rewards", "rejected_rewards", "log_odds_ratio", "accuracy"]:
|
||||
msg = msg + f"{tag}: {self.accumulative_meter.get(tag)}\n"
|
||||
self.coordinator.print_on_master(msg)
|
||||
step_bar.close()
|
|
@ -102,6 +102,8 @@ class SFTTrainer(SLTrainer):
|
|||
batch_size = batch["input_ids"].size(0)
|
||||
outputs = self.model(batch["input_ids"], attention_mask=batch["attention_mask"], labels=batch["labels"])
|
||||
loss = outputs.loss
|
||||
step_bar.set_description(f"Epoch {epoch + 1}/{self.max_epochs} Loss: {loss.detach().cpu().item():.4f}")
|
||||
|
||||
self.booster.backward(loss=loss, optimizer=self.optimizer)
|
||||
|
||||
loss_mean = all_reduce_mean(tensor=loss)
|
||||
|
|
|
@ -29,6 +29,7 @@
|
|||
- [Alternative Option For RLHF: Direct Preference Optimization](#alternative-option-for-rlhf-direct-preference-optimization)
|
||||
- [DPO Stage 1: Supervised Instruction Tuning](#dpo-training-stage1---supervised-instructs-tuning)
|
||||
- [DPO Stage 2: DPO Training](#dpo-training-stage2---dpo-training)
|
||||
- [Alternative Option For RLHF: Simple Preference Optimization](#alternative-option-for-rlhf-simple-preference-optimization)
|
||||
- [List of Supported Models](#list-of-supported-models)
|
||||
- [Hardware Requirements](#hardware-requirements)
|
||||
- [Inference example](#inference-example)
|
||||
|
@ -717,17 +718,53 @@ For DPO training, you only need the preference dataset. Please follow the instru
|
|||
|
||||
|
||||
#### Step 2: Training
|
||||
You can run the [train_dpo.sh](./examples/training_scripts/train_dpo.sh) to start DPO training. Please refer to the [training configuration](#training-configuration) section for details regarding supported training options.
|
||||
You can run the [train_dpo.sh](./examples/training_scripts/train_dpo.sh) to start DPO training. Please refer to the [training configuration](#training-configuration) section for details regarding supported training options. Following the trend of recent research on DPO-like alignment methods, we added option for the user to choose from, including whether to do length normalization , reward shaping and whether to use a reference model in calculating implicit reward. Here are those options,
|
||||
|
||||
```
|
||||
--beta 0.1 \ # the temperature in DPO loss, Default to 0.1
|
||||
--gamma 0.0 \ # the reward target margin in the SimPO paper, Default to 0.
|
||||
--disable_reference_model \ # whether to disable the reference model, if set, the implicit reward will be calculated solely from the actor. Default to enable reference model in DPO
|
||||
--length_normalization \ # whether to apply length normalization, Default to not use
|
||||
```
|
||||
|
||||
#### DPO Result
|
||||
<p align="center">
|
||||
<img width="1000" alt="image" src="https://raw.githubusercontent.com/hpcaitech/public_assets/main/applications/chat/DPO.png">
|
||||
</p>
|
||||
|
||||
### Alternative Option For RLHF: Simple Preference Optimization
|
||||
|
||||
We support the method introduced in the paper [SimPO: Simple Preference Optimization
|
||||
with a Reference-Free Reward](https://arxiv.org/pdf/2405.14734) (SimPO). Which is a reference model free aligment method that add length normalization and reward shaping to the DPO loss to enhance training stability and efficiency. As the method doesn't deviate too much from DPO, we add support for length normalization and SimPO reward shaping in our DPO implementation. To use SimPO in alignment, use the [train_dpo.sh](./examples/training_scripts/train_dpo.sh) script, set the `loss_type` to `simpo_loss`, you can also set the value for temperature (`beta`) and reward target margin (`gamma`) but it is optional.
|
||||
|
||||
#### SimPO Result
|
||||
<p align="center">
|
||||
<img width="1000" alt="image" src="https://raw.githubusercontent.com/hpcaitech/public_assets/main/applications/chat/SimPO_margin.png">
|
||||
</p>
|
||||
|
||||
|
||||
### Alternative Option For RLHF: Odds Ratio Preference Optimization
|
||||
We support the method introduced in the paper [ORPO: Monolithic Preference Optimization without Reference Model](https://arxiv.org/abs/2403.07691) (ORPO). Which is a reference model free aligment method that mixes the SFT loss with a reinforcement learning loss that uses odds ratio as the implicit reward to enhance training stability and efficiency. Simply set the flag to disable the use of the reference model, set the reward target margin and enable length normalization in the DPO training script. To use ORPO in alignment, use the [train_orpo.sh](./examples/training_scripts/train_orpo.sh) script, You can set the value for `lambda` (which determine how strongly the reinforcement learning loss affect the training) but it is optional.
|
||||
|
||||
#### ORPO Result
|
||||
<p align="center">
|
||||
<img width="1000" alt="image" src="https://raw.githubusercontent.com/hpcaitech/public_assets/main/applications/chat/ORPO_margin.png">
|
||||
</p>
|
||||
|
||||
## Hardware Requirements
|
||||
For PPO, we suggest using Tensor Parallelism. The following table shows the VRAM consumption of training a 7B model on a dummy dataset with 2048 sequence length and 512 layout length with different tp_size (equal to the number of GPUs). In this experiment, we use an H800 GPU with 80GB VRAM.
|
||||
|
||||
For SFT, we recommend using zero2 or zero2-cpu for 7B model and tp is your model is extra large. We tested the VRAM consumption on a dummy dataset with a sequence length of 2048. In all experiments, we use H800 GPUs with 80GB VRAM and enable gradient checkpointing and flash attention.
|
||||
- 2 H800 GPU
|
||||
- zero2-cpu, micro batch size=4, VRAM Usage=22457.98 MB
|
||||
- zero2, micro batch size=4, VRAM Usage=72390.95 MB
|
||||
- 4 H800 GPUs
|
||||
- zero2_cpu, micro batch size=8, VRAM Usage=19412.77 MB
|
||||
- zero2, micro batch size=8, VRAM Usage=43446.31 MB
|
||||
- zero2, micro batch size=16, VRAM Usage=58082.30 MB
|
||||
- zero2, micro batch size=8, lora_rank=8, VRAM Usage=21167.73 MB
|
||||
- zero2, micro batch size=8, lora_rank=32, VRAM Usage=21344.17 MB
|
||||
|
||||
For PPO, we suggest using Tensor Parallelism. The following table shows the VRAM consumption of training a 7B model (llama2-7B-hf) on a dummy dataset with a sequence length of 2048 and a layout length of 512 with different tp_size (equal to the number of GPUs).
|
||||
| PPO | tp=8 | tp=4 |
|
||||
|-------|---------------|---------------|
|
||||
| bs=1 | 18485.19 MB | 42934.45 MB |
|
||||
|
@ -738,12 +775,31 @@ For PPO, we suggest using Tensor Parallelism. The following table shows the VRAM
|
|||
|
||||
For DPO, we recommend using zero2 or zero2-cpu. We tested the VRAM consumption on a dummy dataset with 2048 sequence length.
|
||||
|
||||
|
||||
- 1 H800 GPU
|
||||
- zero2-cpu, batch size=2, VRAM Usage=49873.90 MB
|
||||
- zero2-cpu, batch size=4, VRAM Usage=60998.22 MB
|
||||
- 2 H800 GPU
|
||||
- zero2-cpu, micro batch size=2, VRAM Usage=36989.37 MB
|
||||
- zero2-cpu, micro batch size=4, VRAM Usage=48081.67 MB
|
||||
- 4 H800 GPUs
|
||||
- zero2, batch size=4, VRAM Usage=67544.47 MB
|
||||
- zero2, micro batch size=4, VRAM Usage=67483.44 MB
|
||||
|
||||
For SimPO, we recommend using zero2 or zero2-cpu. We tested the VRAM consumption on a dummy dataset with 2048 sequence length.
|
||||
|
||||
- 2 H800 GPU
|
||||
- zero2-cpu, micro batch size=4, VRAM 25705.26 MB
|
||||
- zero2, micro batch size=4, VRAM Usage=73375.04 MB
|
||||
- 4 H800 GPUs
|
||||
- zero2_cpu, micro batch size=8, VRAM Usage=36709.36 MB
|
||||
- zero2, micro batch size=4, VRAM Usage=44330.90 MB
|
||||
- zero2, micro batch size=8, VRAM Usage=56086.12 MB
|
||||
|
||||
For ORPO, we recommend using zero2 or zero2-cpu. We tested the VRAM consumption on a dummy dataset with 2048 sequence length.
|
||||
|
||||
- 2 H800 GPU
|
||||
- zero2-cpu, micro batch size=4, VRAM 26693.38 MB
|
||||
- zero2, micro batch size=4, VRAM Usage=74332.65 MB
|
||||
- 4 H800 GPUs
|
||||
- zero2_cpu, micro batch size=8, VRAM Usage=38709.73 MB
|
||||
- zero2, micro batch size=4, VRAM Usage=45309.52 MB
|
||||
- zero2, micro batch size=8, VRAM Usage=58086.37 MB
|
||||
|
||||
## List of Supported Models
|
||||
|
||||
|
|
|
@ -5,7 +5,7 @@ rm -rf $SAVE_DIR/jsonl
|
|||
rm -rf $SAVE_DIR/arrow
|
||||
|
||||
python prepare_dataset.py --type preference \
|
||||
--data_input_dirs "PATH/TO/PREFERENCE/DATA" \
|
||||
--data_input_dirs /PATH/TO/PREFERENCE/DATASET \
|
||||
--conversation_template_config /PATH/TO/CHAT/TEMPLATE/CONFIG.json \
|
||||
--tokenizer_dir "" \
|
||||
--data_cache_dir $SAVE_DIR/cache \
|
||||
|
|
|
@ -5,7 +5,7 @@ rm -rf $SAVE_DIR/jsonl
|
|||
rm -rf $SAVE_DIR/arrow
|
||||
|
||||
python prepare_dataset.py --type sft \
|
||||
--data_input_dirs "PATH/TO/SFT/DATA" \
|
||||
--data_input_dirs /PATH/TO/SFT/DATASET \
|
||||
--conversation_template_config /PATH/TO/CHAT/TEMPLATE/CONFIG.json \
|
||||
--tokenizer_dir "" \
|
||||
--data_cache_dir $SAVE_DIR/cache \
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
pandas>=1.4.1
|
||||
sentencepiece
|
||||
colossalai
|
||||
colossalai==0.4.0
|
||||
prompt_toolkit
|
||||
|
|
|
@ -1,5 +1 @@
|
|||
XXX.XX.XXX.XXX # Your master IP
|
||||
XXX.XX.XXX.XXX # Your slave IPs
|
||||
XXX.XX.XXX.XXX # Your slave IPs
|
||||
XXX.XX.XXX.XXX # Your slave IPs
|
||||
XXX.XX.XXX.XXX # Your slave IPs
|
||||
localhost
|
||||
|
|
|
@ -116,7 +116,7 @@ def train(args):
|
|||
else:
|
||||
model = AutoModelForCausalLM.from_pretrained(args.pretrain)
|
||||
disable_dropout(model)
|
||||
if args.enable_reference_model:
|
||||
if not args.disable_reference_model:
|
||||
if args.use_flash_attn:
|
||||
ref_model = AutoModelForCausalLM.from_pretrained(
|
||||
args.pretrain,
|
||||
|
@ -128,15 +128,13 @@ def train(args):
|
|||
disable_dropout(ref_model)
|
||||
else:
|
||||
ref_model = None
|
||||
|
||||
if args.lora_rank > 0:
|
||||
model = convert_to_lora_module(model, args.lora_rank, lora_train_bias=args.lora_train_bias)
|
||||
|
||||
if args.grad_checkpoint and args.lora_rank == 0:
|
||||
if args.grad_checkpoint:
|
||||
# Note, for some models, lora may not be compatible with gradient checkpointing
|
||||
model.gradient_checkpointing_enable()
|
||||
coordinator.print_on_master(msg="Gradient checkpointing enabled successfully")
|
||||
elif args.lora_rank > 0:
|
||||
coordinator.print_on_master(msg="Gradient checkpointing will be disabled when LoRA is enabled")
|
||||
|
||||
# configure tokenizer
|
||||
tokenizer_dir = args.tokenizer_dir if args.tokenizer_dir is not None else args.pretrain
|
||||
|
@ -255,6 +253,9 @@ def train(args):
|
|||
save_interval=args.save_interval,
|
||||
save_dir=args.save_dir,
|
||||
coordinator=coordinator,
|
||||
beta=args.beta,
|
||||
gamma=args.gamma,
|
||||
length_normalization=args.length_normalization,
|
||||
)
|
||||
|
||||
trainer.fit(
|
||||
|
@ -296,6 +297,10 @@ if __name__ == "__main__":
|
|||
parser.add_argument("--tp", type=int, default=1)
|
||||
parser.add_argument("--pp", type=int, default=1)
|
||||
parser.add_argument("--sp", type=int, default=1)
|
||||
parser.add_argument("--loss_type", type=str, default="dpo_loss", help="dpo_loss or simpo_loss")
|
||||
parser.add_argument("--beta", type=float, default=0.1, help="beta in DPO loss")
|
||||
parser.add_argument("--gamma", type=float, default=0.0, help="gamma in SimPO loss")
|
||||
parser.add_argument("--length_normalization", default=False, action="store_true")
|
||||
parser.add_argument("--enable_sequence_parallelism", default=False, action="store_true")
|
||||
parser.add_argument("--zero_stage", type=int, default=0, help="Zero stage", choices=[0, 1, 2])
|
||||
parser.add_argument("--zero_cpu_offload", default=False, action="store_true")
|
||||
|
@ -312,7 +317,12 @@ if __name__ == "__main__":
|
|||
parser.add_argument("--max_length", type=int, default=2048, help="Model max length")
|
||||
parser.add_argument("--max_epochs", type=int, default=3)
|
||||
parser.add_argument("--batch_size", type=int, default=4)
|
||||
parser.add_argument("--enable_reference_model", type=bool, default=True)
|
||||
parser.add_argument(
|
||||
"--disable_reference_model",
|
||||
action="store_true",
|
||||
default=False,
|
||||
help="Disable the reference model (enabled by default)",
|
||||
)
|
||||
parser.add_argument("--mixed_precision", type=str, default="fp16", choices=["fp16", "bf16"], help="Mixed precision")
|
||||
parser.add_argument("--lora_rank", type=int, default=0, help="low-rank adaptation matrices rank")
|
||||
parser.add_argument(
|
||||
|
@ -330,6 +340,12 @@ if __name__ == "__main__":
|
|||
parser.add_argument("--grad_checkpoint", default=False, action="store_true")
|
||||
parser.add_argument("--use_flash_attn", default=False, action="store_true")
|
||||
args = parser.parse_args()
|
||||
|
||||
# fool proof hyperparameter setup
|
||||
if args.loss_type == "simpo_loss":
|
||||
args.length_normalization = True
|
||||
args.gamma = args.gamma if args.gamma > 0 else 1.4
|
||||
|
||||
os.makedirs(os.path.dirname(args.config_file), exist_ok=True)
|
||||
with open(args.config_file, "w") as f:
|
||||
json.dump(args.__dict__, f, indent=4)
|
||||
|
|
|
@ -13,8 +13,7 @@ set_n_least_used_CUDA_VISIBLE_DEVICES() {
|
|||
echo "Now CUDA_VISIBLE_DEVICES is set to:"
|
||||
echo "CUDA_VISIBLE_DEVICES=$CUDA_VISIBLE_DEVICES"
|
||||
}
|
||||
set_n_least_used_CUDA_VISIBLE_DEVICES 8
|
||||
# export CUDA_VISIBLE_DEVICES=6
|
||||
set_n_least_used_CUDA_VISIBLE_DEVICES 4
|
||||
|
||||
PROJECT_NAME="dpo"
|
||||
PARENT_SAVE_DIR="" # Path to a folder to save checkpoints
|
||||
|
@ -24,16 +23,16 @@ PRETRAINED_MODEL_PATH="" # huggingface or local model path
|
|||
PRETRAINED_TOKENIZER_PATH="" # huggingface or local tokenizer path
|
||||
|
||||
declare -a dataset=(
|
||||
YOUR/DATA/DIR/arrow/part-00000
|
||||
YOUR/DATA/DIR/arrow/part-00001
|
||||
YOUR/DATA/DIR/arrow/part-00002
|
||||
YOUR/DATA/DIR/arrow/part-00003
|
||||
YOUR/DATA/DIR/arrow/part-00004
|
||||
YOUR/DATA/DIR/arrow/part-00005
|
||||
YOUR/DATA/DIR/arrow/part-00006
|
||||
YOUR/DATA/DIR/arrow/part-00007
|
||||
YOUR/DATA/DIR/arrow/part-00008
|
||||
YOUR/DATA/DIR/arrow/part-00009
|
||||
/Your/Preference/Data/arrow/part-00000
|
||||
/Your/Preference/Data/arrow/part-00001
|
||||
/Your/Preference/Data/arrow/part-00002
|
||||
/Your/Preference/Data/arrow/part-00003
|
||||
/Your/Preference/Data/arrow/part-00004
|
||||
/Your/Preference/Data/arrow/part-00005
|
||||
/Your/Preference/Data/arrow/part-00006
|
||||
/Your/Preference/Data/arrow/part-00007
|
||||
/Your/Preference/Data/arrow/part-00008
|
||||
/Your/Preference/Data/arrow/part-00009
|
||||
)
|
||||
|
||||
TIMESTAMP=$(date +%Y-%m-%d-%H-%M-%S)
|
||||
|
@ -41,7 +40,7 @@ FULL_PROJECT_NAME="${PROJECT_NAME}-${TIMESTAMP}"
|
|||
SAVE_DIR="${PARENT_SAVE_DIR}${FULL_PROJECT_NAME}"
|
||||
CONFIG_FILE="${PARENT_CONFIG_FILE}-${FULL_PROJECT_NAME}.json"
|
||||
|
||||
colossalai run --nproc_per_node 8 --hostfile hostfile --master_port 31312 train_dpo.py \
|
||||
colossalai run --nproc_per_node 4 --hostfile hostfile --master_port 31313 train_dpo.py \
|
||||
--pretrain $PRETRAINED_MODEL_PATH \
|
||||
--checkpoint_path $PRETRAINED_MODEL_PATH \
|
||||
--tokenizer_dir $PRETRAINED_TOKENIZER_PATH \
|
||||
|
@ -51,12 +50,14 @@ colossalai run --nproc_per_node 8 --hostfile hostfile --master_port 31312 train_
|
|||
--save_dir $SAVE_DIR \
|
||||
--config_file $CONFIG_FILE \
|
||||
--max_epochs 1 \
|
||||
--accumulation_steps 4 \
|
||||
--batch_size 2 \
|
||||
--accumulation_steps 2 \
|
||||
--batch_size 16 \
|
||||
--lr 1e-6 \
|
||||
--beta 0.1 \
|
||||
--mixed_precision "bf16" \
|
||||
--grad_clip 1.0 \
|
||||
--max_length 4096 \
|
||||
--weight_decay 0.01 \
|
||||
--warmup_steps 100 \
|
||||
--warmup_steps 60 \
|
||||
--grad_checkpoint \
|
||||
--use_wandb
|
||||
|
|
|
@ -0,0 +1,325 @@
|
|||
import argparse
|
||||
import json
|
||||
import os
|
||||
import resource
|
||||
from contextlib import nullcontext
|
||||
|
||||
import torch
|
||||
from coati.dataset import DataCollatorForPreferenceDataset, StatefulDistributedSampler, load_tokenized_dataset
|
||||
from coati.models import convert_to_lora_module, disable_dropout
|
||||
from coati.trainer import ORPOTrainer
|
||||
from coati.utils import load_checkpoint
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
|
||||
import colossalai
|
||||
from colossalai.booster import Booster
|
||||
from colossalai.booster.plugin import GeminiPlugin, HybridParallelPlugin, LowLevelZeroPlugin
|
||||
from colossalai.cluster import DistCoordinator
|
||||
from colossalai.logging import get_dist_logger
|
||||
from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR
|
||||
from colossalai.nn.optimizer import HybridAdam
|
||||
|
||||
logger = get_dist_logger()
|
||||
|
||||
|
||||
def train(args):
|
||||
# check lora compatibility
|
||||
if "gemini" in args.plugin and args.lora_rank > 0:
|
||||
raise ValueError("LoRA is not supported in GeminiPlugin. Please use other plugin")
|
||||
if args.plugin == "gemini_auto" and args.accumulation_steps > 1:
|
||||
raise ValueError("Gradient accumulation is not supported in GeminiPlugin. Please use other plugin")
|
||||
|
||||
# ==============================
|
||||
# Initialize Distributed Training
|
||||
# ==============================
|
||||
colossalai.launch_from_torch()
|
||||
coordinator = DistCoordinator()
|
||||
|
||||
# ==============================
|
||||
# Initialize Booster
|
||||
# ==============================
|
||||
if args.plugin == "ddp":
|
||||
"""
|
||||
Default torch ddp plugin without any acceleration, for
|
||||
debugging purpose acceleration, for debugging purpose
|
||||
"""
|
||||
plugin = TorchDDPPlugin(find_unused_parameters=True)
|
||||
elif args.plugin == "gemini":
|
||||
plugin = GeminiPlugin(
|
||||
precision=args.mixed_precision,
|
||||
placement_policy="static",
|
||||
initial_scale=2**16,
|
||||
max_norm=args.grad_clip,
|
||||
enable_gradient_accumulation=True,
|
||||
enable_flash_attention=args.use_flash_attn,
|
||||
)
|
||||
elif args.plugin == "gemini_auto":
|
||||
plugin = GeminiPlugin(
|
||||
precision=args.mixed_precision,
|
||||
placement_policy="auto",
|
||||
initial_scale=2**16,
|
||||
max_norm=args.grad_clip,
|
||||
enable_flash_attention=args.use_flash_attn,
|
||||
)
|
||||
elif args.plugin == "zero2":
|
||||
plugin = LowLevelZeroPlugin(
|
||||
stage=2,
|
||||
precision=args.mixed_precision,
|
||||
initial_scale=2**16,
|
||||
max_norm=args.grad_clip,
|
||||
)
|
||||
elif args.plugin == "zero2_cpu":
|
||||
plugin = LowLevelZeroPlugin(
|
||||
stage=2,
|
||||
precision=args.mixed_precision,
|
||||
initial_scale=2**16,
|
||||
cpu_offload=True,
|
||||
max_norm=args.grad_clip,
|
||||
)
|
||||
elif args.plugin == "3d":
|
||||
plugin = HybridParallelPlugin(
|
||||
tp_size=args.tp,
|
||||
pp_size=args.pp,
|
||||
sp_size=args.sp,
|
||||
sequence_parallelism_mode=args.sp_mode,
|
||||
zero_stage=args.zero_stage,
|
||||
enable_flash_attention=args.use_flash_attn,
|
||||
enable_sequence_parallelism=args.enable_sequence_parallelism,
|
||||
cpu_offload=True if args.zero_stage >= 1 and args.zero_cpu_offload else False,
|
||||
parallel_output=False,
|
||||
max_norm=args.grad_clip,
|
||||
precision=args.mixed_precision,
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unknown plugin {args.plugin}")
|
||||
|
||||
booster = Booster(plugin=plugin)
|
||||
|
||||
# ======================================================
|
||||
# Initialize Model, Objective, Optimizer and LR Scheduler
|
||||
# ======================================================
|
||||
# Temp Fix: Disable lazy init due to version conflict
|
||||
# init_ctx = (
|
||||
# LazyInitContext(default_device=get_current_device()) if isinstance(plugin, (GeminiPlugin,)) else nullcontext()
|
||||
# )
|
||||
|
||||
init_ctx = nullcontext()
|
||||
with init_ctx:
|
||||
if args.use_flash_attn:
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
args.pretrain,
|
||||
torch_dtype=torch.bfloat16 if args.mixed_precision == "bf16" else torch.float16,
|
||||
use_flash_attention_2=True,
|
||||
)
|
||||
coordinator.print_on_master(msg="Flash-attention enabled successfully")
|
||||
else:
|
||||
model = AutoModelForCausalLM.from_pretrained(args.pretrain)
|
||||
disable_dropout(model)
|
||||
if args.lora_rank > 0:
|
||||
model = convert_to_lora_module(model, args.lora_rank, lora_train_bias=args.lora_train_bias)
|
||||
|
||||
if args.grad_checkpoint:
|
||||
# Note, for some models, lora may not be compatible with gradient checkpointing
|
||||
model.gradient_checkpointing_enable()
|
||||
coordinator.print_on_master(msg="Gradient checkpointing enabled successfully")
|
||||
|
||||
# configure tokenizer
|
||||
tokenizer_dir = args.tokenizer_dir if args.tokenizer_dir is not None else args.pretrain
|
||||
tokenizer = AutoTokenizer.from_pretrained(tokenizer_dir, use_fast=False, trust_remote_code=True)
|
||||
if hasattr(tokenizer, "pad_token") and hasattr(tokenizer, "eos_token") and tokenizer.eos_token is not None:
|
||||
try:
|
||||
# Some tokenizers doesn't allow to set pad_token mannually e.g., Qwen
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
except AttributeError as e:
|
||||
logger.warning(f"Unable to set pad token to eos token, {str(e)}")
|
||||
if not hasattr(tokenizer, "pad_token") or tokenizer.pad_token is None:
|
||||
logger.warning(
|
||||
"The tokenizer does not have a pad token which is required. May lead to unintended behavior in training, Please consider manually set them."
|
||||
)
|
||||
|
||||
tokenizer.add_bos_token = False
|
||||
tokenizer.add_eos_token = False
|
||||
|
||||
# configure optimizer
|
||||
optim = HybridAdam(
|
||||
model_params=model.parameters(),
|
||||
lr=args.lr,
|
||||
betas=(0.9, 0.95),
|
||||
weight_decay=args.weight_decay,
|
||||
adamw_mode=True,
|
||||
)
|
||||
|
||||
# configure dataset
|
||||
coordinator.print_on_master(f"Load dataset: {args.dataset}")
|
||||
mode_map = {"train": "train", "valid": "validation", "test": "test"}
|
||||
train_dataset = load_tokenized_dataset(dataset_paths=args.dataset, mode="train", mode_map=mode_map)
|
||||
data_collator = DataCollatorForPreferenceDataset(tokenizer=tokenizer, max_length=args.max_length)
|
||||
|
||||
train_dataloader = plugin.prepare_dataloader(
|
||||
dataset=train_dataset,
|
||||
batch_size=args.batch_size,
|
||||
shuffle=True,
|
||||
drop_last=True,
|
||||
collate_fn=data_collator,
|
||||
distributed_sampler_cls=StatefulDistributedSampler,
|
||||
)
|
||||
|
||||
num_update_steps_per_epoch = len(train_dataloader) // args.accumulation_steps
|
||||
if args.warmup_steps is None:
|
||||
args.warmup_steps = int(args.max_epochs * 0.025 * (len(train_dataloader) // args.accumulation_steps))
|
||||
coordinator.print_on_master(f"Warmup steps is set to {args.warmup_steps}")
|
||||
|
||||
lr_scheduler = CosineAnnealingWarmupLR(
|
||||
optimizer=optim,
|
||||
total_steps=args.max_epochs * num_update_steps_per_epoch,
|
||||
warmup_steps=args.warmup_steps,
|
||||
eta_min=0.1 * args.lr,
|
||||
)
|
||||
|
||||
default_dtype = torch.float16 if args.mixed_precision == "fp16" else torch.bfloat16
|
||||
torch.set_default_dtype(default_dtype)
|
||||
model, optim, _, train_dataloader, lr_scheduler = booster.boost(
|
||||
model=model,
|
||||
optimizer=optim,
|
||||
lr_scheduler=lr_scheduler,
|
||||
dataloader=train_dataloader,
|
||||
)
|
||||
torch.set_default_dtype(torch.float)
|
||||
|
||||
coordinator.print_on_master(f"Booster init max CUDA memory: {torch.cuda.max_memory_allocated() / 1024 ** 2:.2f} MB")
|
||||
coordinator.print_on_master(
|
||||
f"Booster init max CPU memory: {resource.getrusage(resource.RUSAGE_SELF).ru_maxrss / 1024:.2f} MB"
|
||||
)
|
||||
|
||||
start_epoch = 0
|
||||
sampler_start_idx = 0
|
||||
start_step = 0
|
||||
if args.checkpoint_path is not None:
|
||||
if "modeling" in args.checkpoint_path:
|
||||
coordinator.print_on_master(f"Continued pretrain from checkpoint {args.checkpoint_path}")
|
||||
booster.load_model(model, args.checkpoint_path)
|
||||
else:
|
||||
coordinator.print_on_master(f"Load model checkpoint from {args.checkpoint_path}")
|
||||
start_epoch, start_step, sampler_start_idx = load_checkpoint(
|
||||
load_dir=args.checkpoint_path,
|
||||
booster=booster,
|
||||
model=model,
|
||||
optimizer=optim,
|
||||
lr_scheduler=lr_scheduler,
|
||||
)
|
||||
assert isinstance(train_dataloader.sampler, StatefulDistributedSampler)
|
||||
train_dataloader.sampler.set_start_index(start_index=sampler_start_idx)
|
||||
|
||||
coordinator.print_on_master(
|
||||
f"Loaded checkpoint {args.checkpoint_path} at epoch {start_epoch} step {start_step}"
|
||||
)
|
||||
coordinator.print_on_master(f"Loaded sample at index {sampler_start_idx}")
|
||||
|
||||
coordinator.print_on_master(
|
||||
f"Checkpoint loaded max CUDA memory: {torch.cuda.max_memory_allocated() / 1024 ** 2:.2f} MB"
|
||||
)
|
||||
coordinator.print_on_master(
|
||||
f"Checkpoint loaded CUDA memory: {torch.cuda.memory_allocated() / 1024 ** 2:.2f} MB"
|
||||
)
|
||||
coordinator.print_on_master(
|
||||
f"Checkpoint loaded max CPU memory: {resource.getrusage(resource.RUSAGE_SELF).ru_maxrss / 1024:.2f} MB"
|
||||
)
|
||||
|
||||
trainer = ORPOTrainer(
|
||||
actor=model,
|
||||
booster=booster,
|
||||
actor_optim=optim,
|
||||
actor_lr_scheduler=lr_scheduler,
|
||||
tokenizer=tokenizer,
|
||||
max_epochs=args.max_epochs,
|
||||
accumulation_steps=args.accumulation_steps,
|
||||
start_epoch=start_epoch,
|
||||
save_interval=args.save_interval,
|
||||
save_dir=args.save_dir,
|
||||
coordinator=coordinator,
|
||||
lam=args.lam,
|
||||
)
|
||||
|
||||
trainer.fit(
|
||||
train_preference_dataloader=train_dataloader,
|
||||
eval_preference_dataloader=None,
|
||||
log_dir=args.log_dir,
|
||||
use_wandb=args.use_wandb,
|
||||
)
|
||||
|
||||
if args.lora_rank > 0 and args.merge_lora_weights:
|
||||
from coati.models.lora import LORA_MANAGER
|
||||
|
||||
# NOTE: set model to eval to merge LoRA weights
|
||||
LORA_MANAGER.merge_weights = True
|
||||
model.eval()
|
||||
# save model checkpoint after fitting on only rank0
|
||||
coordinator.print_on_master("Start saving final model checkpoint")
|
||||
booster.save_model(model, os.path.join(args.save_dir, "modeling"), shard=True)
|
||||
coordinator.print_on_master(f"Saved final model checkpoint at epoch {args.max_epochs} at folder {args.save_dir}")
|
||||
|
||||
coordinator.print_on_master(f"Max CUDA memory usage: {torch.cuda.max_memory_allocated()/1024**2:.2f} MB")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# ==============================
|
||||
# Parse Arguments
|
||||
# ==============================
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--plugin",
|
||||
type=str,
|
||||
default="gemini",
|
||||
choices=["gemini", "gemini_auto", "zero2", "zero2_cpu", "3d"],
|
||||
help="Choose which plugin to use",
|
||||
)
|
||||
parser.add_argument("--grad_clip", type=float, default=1.0, help="Gradient clipping value")
|
||||
parser.add_argument("--weight_decay", type=float, default=0.1, help="Weight decay")
|
||||
parser.add_argument("--warmup_steps", type=int, default=None, help="Warmup steps")
|
||||
parser.add_argument("--tp", type=int, default=1)
|
||||
parser.add_argument("--pp", type=int, default=1)
|
||||
parser.add_argument("--sp", type=int, default=1)
|
||||
parser.add_argument("--lam", type=float, default=0.1, help="lambda in ORPO loss")
|
||||
parser.add_argument("--enable_sequence_parallelism", default=False, action="store_true")
|
||||
parser.add_argument("--zero_stage", type=int, default=0, help="Zero stage", choices=[0, 1, 2])
|
||||
parser.add_argument("--zero_cpu_offload", default=False, action="store_true")
|
||||
parser.add_argument("--sp_mode", type=str, default="split_gather", choices=["split_gather", "ring", "all_to_all"])
|
||||
parser.add_argument("--pretrain", type=str, default=None)
|
||||
parser.add_argument("--model_type", type=str, default=None)
|
||||
parser.add_argument("--tokenizer_dir", type=str, default=None)
|
||||
parser.add_argument("--dataset", nargs="+", default=[])
|
||||
parser.add_argument(
|
||||
"--checkpoint_path", type=str, default=None, help="Checkpoint path if need to resume training form a checkpoint"
|
||||
)
|
||||
parser.add_argument("--config_file", type=str, default="config_file", help="Config file")
|
||||
parser.add_argument("--save_dir", type=str, default="output")
|
||||
parser.add_argument("--max_length", type=int, default=2048, help="Model max length")
|
||||
parser.add_argument("--max_epochs", type=int, default=3)
|
||||
parser.add_argument("--batch_size", type=int, default=4)
|
||||
parser.add_argument(
|
||||
"--disable_reference_model",
|
||||
action="store_true",
|
||||
default=False,
|
||||
help="Disable the reference model (enabled by default)",
|
||||
)
|
||||
parser.add_argument("--mixed_precision", type=str, default="fp16", choices=["fp16", "bf16"], help="Mixed precision")
|
||||
parser.add_argument("--lora_rank", type=int, default=0, help="low-rank adaptation matrices rank")
|
||||
parser.add_argument(
|
||||
"--lora_train_bias",
|
||||
type=str,
|
||||
default="none",
|
||||
help="'none' means it doesn't train biases. 'all' means it trains all biases. 'lora_only' means it only trains biases of LoRA layers",
|
||||
)
|
||||
parser.add_argument("--save_interval", type=int, default=1000, help="number of step between two checkpoints")
|
||||
parser.add_argument("--merge_lora_weights", type=bool, default=True)
|
||||
parser.add_argument("--lr", type=float, default=5e-6)
|
||||
parser.add_argument("--accumulation_steps", type=int, default=8)
|
||||
parser.add_argument("--log_dir", default="logs", type=str)
|
||||
parser.add_argument("--use_wandb", default=False, action="store_true")
|
||||
parser.add_argument("--grad_checkpoint", default=False, action="store_true")
|
||||
parser.add_argument("--use_flash_attn", default=False, action="store_true")
|
||||
args = parser.parse_args()
|
||||
os.makedirs(os.path.dirname(args.config_file), exist_ok=True)
|
||||
with open(args.config_file, "w") as f:
|
||||
json.dump(args.__dict__, f, indent=4)
|
||||
train(args)
|
|
@ -0,0 +1,63 @@
|
|||
#!/bin/bash
|
||||
set_n_least_used_CUDA_VISIBLE_DEVICES() {
|
||||
local n=${1:-"9999"}
|
||||
echo "GPU Memory Usage:"
|
||||
local FIRST_N_GPU_IDS=$(nvidia-smi --query-gpu=memory.used --format=csv |
|
||||
tail -n +2 |
|
||||
nl -v 0 |
|
||||
tee /dev/tty |
|
||||
sort -g -k 2 |
|
||||
awk '{print $1}' |
|
||||
head -n $n)
|
||||
export CUDA_VISIBLE_DEVICES=$(echo $FIRST_N_GPU_IDS | sed 's/ /,/g')
|
||||
echo "Now CUDA_VISIBLE_DEVICES is set to:"
|
||||
echo "CUDA_VISIBLE_DEVICES=$CUDA_VISIBLE_DEVICES"
|
||||
}
|
||||
set_n_least_used_CUDA_VISIBLE_DEVICES 8
|
||||
|
||||
PROJECT_NAME="dpo"
|
||||
PARENT_SAVE_DIR="" # Path to a folder to save checkpoints
|
||||
PARENT_TENSORBOARD_DIR="" # Path to a folder to save logs
|
||||
PARENT_CONFIG_FILE="" # Path to a folder to save training config logs
|
||||
PRETRAINED_MODEL_PATH="" # huggingface or local model path
|
||||
PRETRAINED_TOKENIZER_PATH="" # huggingface or local tokenizer path
|
||||
|
||||
declare -a dataset=(
|
||||
/Your/Preference/Data/arrow/part-00000
|
||||
/Your/Preference/Data/arrow/part-00001
|
||||
/Your/Preference/Data/arrow/part-00002
|
||||
/Your/Preference/Data/arrow/part-00003
|
||||
/Your/Preference/Data/arrow/part-00004
|
||||
/Your/Preference/Data/arrow/part-00005
|
||||
/Your/Preference/Data/arrow/part-00006
|
||||
/Your/Preference/Data/arrow/part-00007
|
||||
/Your/Preference/Data/arrow/part-00008
|
||||
/Your/Preference/Data/arrow/part-00009
|
||||
)
|
||||
|
||||
TIMESTAMP=$(date +%Y-%m-%d-%H-%M-%S)
|
||||
FULL_PROJECT_NAME="${PROJECT_NAME}-${TIMESTAMP}"
|
||||
SAVE_DIR="${PARENT_SAVE_DIR}${FULL_PROJECT_NAME}"
|
||||
CONFIG_FILE="${PARENT_CONFIG_FILE}-${FULL_PROJECT_NAME}.json"
|
||||
|
||||
colossalai run --nproc_per_node 8 --hostfile hostfile --master_port 31313 train_orpo.py \
|
||||
--pretrain $PRETRAINED_MODEL_PATH \
|
||||
--checkpoint_path $PRETRAINED_MODEL_PATH \
|
||||
--tokenizer_dir $PRETRAINED_TOKENIZER_PATH \
|
||||
--dataset ${dataset[@]} \
|
||||
--plugin "zero2" \
|
||||
--save_interval 1000 \
|
||||
--save_dir $SAVE_DIR \
|
||||
--config_file $CONFIG_FILE \
|
||||
--max_epochs 3 \
|
||||
--accumulation_steps 1 \
|
||||
--batch_size 16 \
|
||||
--lr 8e-6 \
|
||||
--lam 0.5 \
|
||||
--mixed_precision "bf16" \
|
||||
--grad_clip 1.0 \
|
||||
--max_length 1024 \
|
||||
--weight_decay 0.01 \
|
||||
--warmup_steps 60 \
|
||||
--grad_checkpoint \
|
||||
--use_wandb
|
|
@ -122,12 +122,10 @@ def train(args):
|
|||
# LazyInitContext(default_device=get_current_device()) if isinstance(plugin, (GeminiPlugin,)) else nullcontext()
|
||||
# )
|
||||
|
||||
if args.grad_checkpoint and args.lora_rank == 0:
|
||||
# lora layers are not supported by gradient checkpointing
|
||||
if args.grad_checkpoint:
|
||||
# Note, for some models, lora may not be compatible with gradient checkpointing
|
||||
model.gradient_checkpointing_enable()
|
||||
coordinator.print_on_master(msg="Gradient checkpointing enabled successfully")
|
||||
elif args.lora_rank > 0:
|
||||
coordinator.print_on_master(msg="Gradient checkpointing will be disabled when LoRA is enabled")
|
||||
|
||||
# configure tokenizer
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
|
@ -271,7 +269,7 @@ def train(args):
|
|||
# save model checkpoint after fitting on only rank0
|
||||
coordinator.print_on_master("Start saving final model checkpoint")
|
||||
|
||||
# booster.save_model(model, os.path.join(args.save_path, "modeling"), shard=True)
|
||||
booster.save_model(model, os.path.join(args.save_path, "modeling"), shard=True)
|
||||
coordinator.print_on_master(f"Saved final model checkpoint at epoch {args.max_epochs} at folder {args.save_path}")
|
||||
|
||||
coordinator.print_on_master(f"Max CUDA memory usage: {torch.cuda.max_memory_allocated()/1024**2:.2f} MB")
|
||||
|
|
|
@ -13,9 +13,7 @@ set_n_least_used_CUDA_VISIBLE_DEVICES() {
|
|||
echo "CUDA_VISIBLE_DEVICES=$CUDA_VISIBLE_DEVICES"
|
||||
}
|
||||
|
||||
|
||||
# export CUDA_VISIBLE_DEVICES=4,5,6
|
||||
set_n_least_used_CUDA_VISIBLE_DEVICES 2
|
||||
set_n_least_used_CUDA_VISIBLE_DEVICES 4
|
||||
PROJECT_NAME="sft"
|
||||
PARENT_SAVE_DIR="" # Path to a folder to save checkpoints
|
||||
PARENT_TENSORBOARD_DIR="" # Path to a folder to save logs
|
||||
|
@ -23,16 +21,16 @@ PARENT_CONFIG_FILE="" # Path to a folder to save training config logs
|
|||
PRETRAINED_MODEL_PATH="" # huggingface or local model path
|
||||
PRETRAINED_TOKENIZER_PATH="" # huggingface or local tokenizer path
|
||||
declare -a dataset=(
|
||||
YOUR/SFT/DATA/DIR/arrow/part-00000
|
||||
YOUR/SFT/DATA/DIR/arrow/part-00001
|
||||
YOUR/SFT/DATA/DIR/arrow/part-00002
|
||||
YOUR/SFT/DATA/DIR/arrow/part-00003
|
||||
YOUR/SFT/DATA/DIR/arrow/part-00004
|
||||
YOUR/SFT/DATA/DIR/arrow/part-00005
|
||||
YOUR/SFT/DATA/DIR/arrow/part-00006
|
||||
YOUR/SFT/DATA/DIR/arrow/part-00007
|
||||
YOUR/SFT/DATA/DIR/arrow/part-00008
|
||||
YOUR/SFT/DATA/DIR/arrow/part-00009
|
||||
/Your/SFT/Data/arrow/part-00000
|
||||
/Your/SFT/Data/arrow/part-00001
|
||||
/Your/SFT/Data/arrow/part-00002
|
||||
/Your/SFT/Data/arrow/part-00003
|
||||
/Your/SFT/Data/arrow/part-00004
|
||||
/Your/SFT/Data/arrow/part-00005
|
||||
/Your/SFT/Data/arrow/part-00006
|
||||
/Your/SFT/Data/arrow/part-00007
|
||||
/Your/SFT/Data/arrow/part-00008
|
||||
/Your/SFT/Data/arrow/part-00009
|
||||
)
|
||||
|
||||
TIMESTAMP=$(date +%Y-%m-%d-%H-%M-%S)
|
||||
|
@ -43,7 +41,7 @@ CONFIG_FILE="${PARENT_CONFIG_FILE}-${FULL_PROJECT_NAME}.json"
|
|||
echo $(which colossalai)
|
||||
echo $(which python)
|
||||
# the real batch size for gradient descent is number_of_node_in_hostfile * nproc_per_node * train_batch_size
|
||||
colossalai run --nproc_per_node 2 --master_port 31312 --hostfile ./hostfile train_sft.py \
|
||||
colossalai run --nproc_per_node 4 --master_port 31312 --hostfile ./hostfile train_sft.py \
|
||||
--pretrain $PRETRAINED_MODEL_PATH \
|
||||
--tokenizer_dir $PRETRAINED_TOKENIZER_PATH \
|
||||
--save_interval 4000 \
|
||||
|
@ -51,15 +49,12 @@ colossalai run --nproc_per_node 2 --master_port 31312 --hostfile ./hostfile trai
|
|||
--save_path $SAVE_DIR \
|
||||
--config_file $CONFIG_FILE \
|
||||
--lora_rank 0 \
|
||||
--plugin 3d \
|
||||
--tp 2 \
|
||||
--pp 1 \
|
||||
--zero_stage 0 \
|
||||
--batch_size 2 \
|
||||
--max_epochs 3 \
|
||||
--accumulation_steps 1 \
|
||||
--plugin zero2 \
|
||||
--batch_size 4 \
|
||||
--max_epochs 1 \
|
||||
--accumulation_steps 4 \
|
||||
--lr 5e-5 \
|
||||
--max_len 400 \
|
||||
--max_len 4096 \
|
||||
--grad_checkpoint \
|
||||
--use_wandb \
|
||||
--use_flash_attn
|
||||
|
|
|
@ -1,9 +1,9 @@
|
|||
transformers>=4.36.2
|
||||
transformers==4.39.3
|
||||
tqdm
|
||||
datasets==2.14.7
|
||||
loralib
|
||||
colossalai>=0.3.7
|
||||
torch>=1.12.1
|
||||
colossalai==0.4.0
|
||||
torch>=2.1.0
|
||||
langchain
|
||||
tokenizers
|
||||
fastapi
|
||||
|
|
Loading…
Reference in New Issue