Merge pull request #5922 from hpcaitech/kto

[Chat] Add KTO
pull/5947/head
YeAnbang 2024-07-29 13:27:00 +08:00 committed by GitHub
commit c8332b9cb5
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
49 changed files with 1548 additions and 965 deletions

View File

@ -52,6 +52,7 @@ jobs:
mkdir sft_data
mkdir prompt_data
mkdir preference_data
mkdir kto_data
./tests/test_data_preparation.sh
./tests/test_train.sh
env:
@ -61,3 +62,4 @@ jobs:
SFT_DATASET: ./sft_data
PROMPT_DATASET: ./prompt_data
PREFERENCE_DATASET: ./preference_data
KTO_DATASET: ./kto_data

View File

@ -24,7 +24,9 @@
- [Limitation for LLaMA-finetuned models](#limitation)
- [Limitation of dataset](#limitation)
- [Alternative Option For RLHF: DPO](#alternative-option-for-rlhf-direct-preference-optimization)
- [Alternative Option For RLHF: SimPO](#alternative-option-for-rlhf-simple-preference-optimization)
- [Alternative Option For RLHF: SimPO](#alternative-option-for-rlhf-simple-preference-optimization-simpo)
- [Alternative Option For RLHF: ORPO](#alternative-option-for-rlhf-odds-ratio-preference-optimization-orpo)
- [Alternative Option For RLHF: KTO](#alternative-option-for-rlhf-kahneman-tversky-optimization-kto)
- [FAQ](#faq)
- [How to save/load checkpoint](#faq)
- [How to train with limited resources](#faq)
@ -284,6 +286,9 @@ Simple Preference Optimization (SimPO) from this [paper](https://arxiv.org/pdf/2
## Alternative Option For RLHF: Odds Ratio Preference Optimization (ORPO)
Odds Ratio Preference Optimization (ORPO) from this [paper](https://arxiv.org/pdf/2403.07691) is a reference model free alignment method that use a mixture of SFT loss and a reinforcement leanring loss calculated based on odds-ratio-based implicit reward to makes the training more efficient and stable. Read this [README](./examples/README.md) for more information.
## Alternative Option For RLHF: Kahneman-Tversky Optimization (KTO)
We support the method introduced in the paper [KTO:Model Alignment as Prospect Theoretic Optimization](https://arxiv.org/pdf/2402.01306) (KTO). Which is a aligment method that directly maximize "human utility" of generation results. Read this [README](./examples/README.md) for more information.
### Inference Quantization and Serving - After Training
We provide an online inference server and a benchmark. We aim to run inference on single GPU, so quantization is essential when using large models.

View File

@ -19,30 +19,33 @@ PROJECT_NAME="dpo"
PARENT_CONFIG_FILE="./benchmark_config" # Path to a folder to save training config logs
PRETRAINED_MODEL_PATH="" # huggingface or local model path
PRETRAINED_TOKENIZER_PATH="" # huggingface or local tokenizer path
BENCHMARK_DATA_DIR="./temp/dpo" # Path to benchmark data
DATASET_SIZE=320
TIMESTAMP=$(date +%Y-%m-%d-%H-%M-%S)
FULL_PROJECT_NAME="${PROJECT_NAME}-${TIMESTAMP}"
SAVE_DIR="${PARENT_SAVE_DIR}${FULL_PROJECT_NAME}"
CONFIG_FILE="${PARENT_CONFIG_FILE}-${FULL_PROJECT_NAME}.json"
declare -a dataset=(
$BENCHMARK_DATA_DIR/arrow/part-0
)
colossalai run --nproc_per_node 4 --master_port 31313 benchmark_dpo.py \
# Generate dummy test data
python prepare_dummy_test_dataset.py --data_dir $BENCHMARK_DATA_DIR --dataset_size $DATASET_SIZE --max_length 2048 --data_type preference
colossalai run --nproc_per_node 4 --master_port 31313 ../examples/training_scripts/train_dpo.py \
--pretrain $PRETRAINED_MODEL_PATH \
--tokenizer_dir $PRETRAINED_TOKENIZER_PATH \
--config_file $CONFIG_FILE \
--dataset ${dataset[@]} \
--plugin "zero2_cpu" \
--max_epochs 1 \
--accumulation_steps 1 \
--batch_size 8 \
--batch_size 4 \
--lr 1e-6 \
--beta 0.1 \
--gamma 0.6 \
--mixed_precision "bf16" \
--grad_clip 1.0 \
--max_length 2048 \
--dataset_size 640 \
--weight_decay 0.01 \
--warmup_steps 60 \
--disable_reference_model \
--length_normalization \
--grad_checkpoint \
--use_flash_attn

View File

@ -0,0 +1,51 @@
#!/bin/bash
set_n_least_used_CUDA_VISIBLE_DEVICES() {
local n=${1:-"9999"}
echo "GPU Memory Usage:"
local FIRST_N_GPU_IDS=$(nvidia-smi --query-gpu=memory.used --format=csv |
tail -n +2 |
nl -v 0 |
tee /dev/tty |
sort -g -k 2 |
awk '{print $1}' |
head -n $n)
export CUDA_VISIBLE_DEVICES=$(echo $FIRST_N_GPU_IDS | sed 's/ /,/g')
echo "Now CUDA_VISIBLE_DEVICES is set to:"
echo "CUDA_VISIBLE_DEVICES=$CUDA_VISIBLE_DEVICES"
}
set_n_least_used_CUDA_VISIBLE_DEVICES 4
PROJECT_NAME="kto"
PARENT_CONFIG_FILE="./benchmark_config" # Path to a folder to save training config logs
PRETRAINED_MODEL_PATH="" # huggingface or local model path
PRETRAINED_TOKENIZER_PATH="" # huggingface or local tokenizer path
BENCHMARK_DATA_DIR="./temp/kto" # Path to benchmark data
DATASET_SIZE=80
TIMESTAMP=$(date +%Y-%m-%d-%H-%M-%S)
FULL_PROJECT_NAME="${PROJECT_NAME}-${TIMESTAMP}"
declare -a dataset=(
$BENCHMARK_DATA_DIR/arrow/part-0
)
# Generate dummy test data
python prepare_dummy_test_dataset.py --data_dir $BENCHMARK_DATA_DIR --dataset_size $DATASET_SIZE --max_length 2048 --data_type kto
colossalai run --nproc_per_node 2 --master_port 31313 ../examples/training_scripts/train_kto.py \
--pretrain $PRETRAINED_MODEL_PATH \
--tokenizer_dir $PRETRAINED_TOKENIZER_PATH \
--dataset ${dataset[@]} \
--plugin "zero2_cpu" \
--max_epochs 1 \
--accumulation_steps 1 \
--batch_size 2 \
--lr 1e-5 \
--beta 0.1 \
--mixed_precision "bf16" \
--grad_clip 1.0 \
--max_length 2048 \
--weight_decay 0.01 \
--warmup_steps 60 \
--grad_checkpoint \
--use_flash_attn

View File

@ -1,315 +0,0 @@
import argparse
import json
import os
import resource
from contextlib import nullcontext
import torch
from coati.dataset import DataCollatorForPreferenceDataset, StatefulDistributedSampler
from coati.models import convert_to_lora_module, disable_dropout
from coati.trainer import ORPOTrainer
from coati.utils import load_checkpoint
from dummy_dataset import DummyLLMDataset
from transformers import AutoModelForCausalLM, AutoTokenizer
import colossalai
from colossalai.booster import Booster
from colossalai.booster.plugin import GeminiPlugin, HybridParallelPlugin, LowLevelZeroPlugin
from colossalai.cluster import DistCoordinator
from colossalai.logging import get_dist_logger
from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR
from colossalai.nn.optimizer import HybridAdam
logger = get_dist_logger()
def train(args):
# check lora compatibility
if "gemini" in args.plugin and args.lora_rank > 0:
raise ValueError("LoRA is not supported in GeminiPlugin. Please use other plugin")
if args.plugin == "gemini_auto" and args.accumulation_steps > 1:
raise ValueError("Gradient accumulation is not supported in GeminiPlugin. Please use other plugin")
# ==============================
# Initialize Distributed Training
# ==============================
colossalai.launch_from_torch()
coordinator = DistCoordinator()
# ==============================
# Initialize Booster
# ==============================
if args.plugin == "ddp":
"""
Default torch ddp plugin without any acceleration, for
debugging purpose acceleration, for debugging purpose
"""
plugin = TorchDDPPlugin(find_unused_parameters=True)
elif args.plugin == "gemini":
plugin = GeminiPlugin(
precision=args.mixed_precision,
placement_policy="static",
initial_scale=2**16,
max_norm=args.grad_clip,
enable_gradient_accumulation=True,
enable_flash_attention=args.use_flash_attn,
)
elif args.plugin == "gemini_auto":
plugin = GeminiPlugin(
precision=args.mixed_precision,
placement_policy="auto",
initial_scale=2**16,
max_norm=args.grad_clip,
enable_flash_attention=args.use_flash_attn,
)
elif args.plugin == "zero2":
plugin = LowLevelZeroPlugin(
stage=2,
precision=args.mixed_precision,
initial_scale=2**16,
max_norm=args.grad_clip,
)
elif args.plugin == "zero2_cpu":
plugin = LowLevelZeroPlugin(
stage=2,
precision=args.mixed_precision,
initial_scale=2**16,
cpu_offload=True,
max_norm=args.grad_clip,
)
elif args.plugin == "3d":
plugin = HybridParallelPlugin(
tp_size=args.tp,
pp_size=args.pp,
sp_size=args.sp,
sequence_parallelism_mode=args.sp_mode,
zero_stage=args.zero_stage,
enable_flash_attention=args.use_flash_attn,
enable_sequence_parallelism=args.enable_sequence_parallelism,
cpu_offload=True if args.zero_stage >= 1 and args.zero_cpu_offload else False,
parallel_output=False,
max_norm=args.grad_clip,
precision=args.mixed_precision,
)
else:
raise ValueError(f"Unknown plugin {args.plugin}")
booster = Booster(plugin=plugin)
# ======================================================
# Initialize Model, Objective, Optimizer and LR Scheduler
# ======================================================
# Temp Fix: Disable lazy init due to version conflict
# init_ctx = (
# LazyInitContext(default_device=get_current_device()) if isinstance(plugin, (GeminiPlugin,)) else nullcontext()
# )
init_ctx = nullcontext()
with init_ctx:
if args.use_flash_attn:
model = AutoModelForCausalLM.from_pretrained(
args.pretrain,
torch_dtype=torch.bfloat16 if args.mixed_precision == "bf16" else torch.float16,
use_flash_attention_2=True,
)
coordinator.print_on_master(msg="Flash-attention enabled successfully")
else:
model = AutoModelForCausalLM.from_pretrained(args.pretrain)
disable_dropout(model)
if args.lora_rank > 0:
model = convert_to_lora_module(model, args.lora_rank, lora_train_bias=args.lora_train_bias)
if args.grad_checkpoint:
# Note, for some models, lora may not be compatible with gradient checkpointing
model.gradient_checkpointing_enable()
coordinator.print_on_master(msg="Gradient checkpointing enabled successfully")
# configure tokenizer
tokenizer_dir = args.tokenizer_dir if args.tokenizer_dir is not None else args.pretrain
tokenizer = AutoTokenizer.from_pretrained(tokenizer_dir, use_fast=False, trust_remote_code=True)
if hasattr(tokenizer, "pad_token") and hasattr(tokenizer, "eos_token") and tokenizer.eos_token is not None:
try:
# Some tokenizers doesn't allow to set pad_token mannually e.g., Qwen
tokenizer.pad_token = tokenizer.eos_token
except AttributeError as e:
logger.warning(f"Unable to set pad token to eos token, {str(e)}")
if not hasattr(tokenizer, "pad_token") or tokenizer.pad_token is None:
logger.warning(
"The tokenizer does not have a pad token which is required. May lead to unintended behavior in training, Please consider manually set them."
)
tokenizer.add_bos_token = False
tokenizer.add_eos_token = False
# configure optimizer
optim = HybridAdam(
model_params=model.parameters(),
lr=args.lr,
betas=(0.9, 0.95),
weight_decay=args.weight_decay,
adamw_mode=True,
)
# configure dataset
coordinator.print_on_master(f"Load dataset: {args.dataset}")
mode_map = {"train": "train", "valid": "validation", "test": "test"}
train_dataset = DummyLLMDataset(
["chosen_input_ids", "chosen_loss_mask", "rejected_input_ids", "rejected_loss_mask"],
args.max_length,
args.dataset_size,
)
data_collator = DataCollatorForPreferenceDataset(tokenizer=tokenizer, max_length=args.max_length)
train_dataloader = plugin.prepare_dataloader(
dataset=train_dataset,
batch_size=args.batch_size,
shuffle=True,
drop_last=True,
collate_fn=data_collator,
distributed_sampler_cls=StatefulDistributedSampler,
)
num_update_steps_per_epoch = len(train_dataloader) // args.accumulation_steps
if args.warmup_steps is None:
args.warmup_steps = int(args.max_epochs * 0.025 * (len(train_dataloader) // args.accumulation_steps))
coordinator.print_on_master(f"Warmup steps is set to {args.warmup_steps}")
lr_scheduler = CosineAnnealingWarmupLR(
optimizer=optim,
total_steps=args.max_epochs * num_update_steps_per_epoch,
warmup_steps=args.warmup_steps,
eta_min=0.1 * args.lr,
)
default_dtype = torch.float16 if args.mixed_precision == "fp16" else torch.bfloat16
torch.set_default_dtype(default_dtype)
model, optim, _, train_dataloader, lr_scheduler = booster.boost(
model=model,
optimizer=optim,
lr_scheduler=lr_scheduler,
dataloader=train_dataloader,
)
torch.set_default_dtype(torch.float)
coordinator.print_on_master(f"Booster init max CUDA memory: {torch.cuda.max_memory_allocated() / 1024 ** 2:.2f} MB")
coordinator.print_on_master(
f"Booster init max CPU memory: {resource.getrusage(resource.RUSAGE_SELF).ru_maxrss / 1024:.2f} MB"
)
start_epoch = 0
sampler_start_idx = 0
start_step = 0
if args.checkpoint_path is not None:
if "modeling" in args.checkpoint_path:
coordinator.print_on_master(f"Continued pretrain from checkpoint {args.checkpoint_path}")
booster.load_model(model, args.checkpoint_path)
else:
coordinator.print_on_master(f"Load model checkpoint from {args.checkpoint_path}")
start_epoch, start_step, sampler_start_idx = load_checkpoint(
load_dir=args.checkpoint_path,
booster=booster,
model=model,
optimizer=optim,
lr_scheduler=lr_scheduler,
)
assert isinstance(train_dataloader.sampler, StatefulDistributedSampler)
train_dataloader.sampler.set_start_index(start_index=sampler_start_idx)
coordinator.print_on_master(
f"Loaded checkpoint {args.checkpoint_path} at epoch {start_epoch} step {start_step}"
)
coordinator.print_on_master(f"Loaded sample at index {sampler_start_idx}")
coordinator.print_on_master(
f"Checkpoint loaded max CUDA memory: {torch.cuda.max_memory_allocated() / 1024 ** 2:.2f} MB"
)
coordinator.print_on_master(
f"Checkpoint loaded CUDA memory: {torch.cuda.memory_allocated() / 1024 ** 2:.2f} MB"
)
coordinator.print_on_master(
f"Checkpoint loaded max CPU memory: {resource.getrusage(resource.RUSAGE_SELF).ru_maxrss / 1024:.2f} MB"
)
trainer = ORPOTrainer(
actor=model,
booster=booster,
actor_optim=optim,
actor_lr_scheduler=lr_scheduler,
tokenizer=tokenizer,
max_epochs=args.max_epochs,
accumulation_steps=args.accumulation_steps,
start_epoch=start_epoch,
save_interval=None,
save_dir=None,
coordinator=coordinator,
lam=args.lam,
)
trainer.fit(
train_preference_dataloader=train_dataloader,
eval_preference_dataloader=None,
log_dir=None,
use_wandb=False,
)
coordinator.print_on_master(f"Max CUDA memory usage: {torch.cuda.max_memory_allocated()/1024**2:.2f} MB")
if __name__ == "__main__":
# ==============================
# Parse Arguments
# ==============================
parser = argparse.ArgumentParser()
parser.add_argument(
"--plugin",
type=str,
default="gemini",
choices=["gemini", "gemini_auto", "zero2", "zero2_cpu", "3d"],
help="Choose which plugin to use",
)
parser.add_argument("--grad_clip", type=float, default=1.0, help="Gradient clipping value")
parser.add_argument("--weight_decay", type=float, default=0.1, help="Weight decay")
parser.add_argument("--warmup_steps", type=int, default=None, help="Warmup steps")
parser.add_argument("--tp", type=int, default=1)
parser.add_argument("--pp", type=int, default=1)
parser.add_argument("--sp", type=int, default=1)
parser.add_argument("--lam", type=float, default=0.1, help="lambda in ORPO loss")
parser.add_argument("--enable_sequence_parallelism", default=False, action="store_true")
parser.add_argument("--zero_stage", type=int, default=0, help="Zero stage", choices=[0, 1, 2])
parser.add_argument("--zero_cpu_offload", default=False, action="store_true")
parser.add_argument("--sp_mode", type=str, default="split_gather", choices=["split_gather", "ring", "all_to_all"])
parser.add_argument("--pretrain", type=str, default=None)
parser.add_argument("--model_type", type=str, default=None)
parser.add_argument("--tokenizer_dir", type=str, default=None)
parser.add_argument("--dataset", nargs="+", default=[])
parser.add_argument(
"--checkpoint_path", type=str, default=None, help="Checkpoint path if need to resume training form a checkpoint"
)
parser.add_argument("--config_file", type=str, default="config_file", help="Config file")
parser.add_argument("--max_length", type=int, default=2048, help="Model max length")
parser.add_argument("--max_epochs", type=int, default=3)
parser.add_argument("--batch_size", type=int, default=4)
parser.add_argument(
"--disable_reference_model",
action="store_true",
default=False,
help="Disable the reference model (enabled by default)",
)
parser.add_argument("--dataset_size", type=int, default=500)
parser.add_argument("--mixed_precision", type=str, default="fp16", choices=["fp16", "bf16"], help="Mixed precision")
parser.add_argument("--lora_rank", type=int, default=0, help="low-rank adaptation matrices rank")
parser.add_argument(
"--lora_train_bias",
type=str,
default="none",
help="'none' means it doesn't train biases. 'all' means it trains all biases. 'lora_only' means it only trains biases of LoRA layers",
)
parser.add_argument("--merge_lora_weights", type=bool, default=True)
parser.add_argument("--lr", type=float, default=5e-6)
parser.add_argument("--accumulation_steps", type=int, default=8)
parser.add_argument("--grad_checkpoint", default=False, action="store_true")
parser.add_argument("--use_flash_attn", default=False, action="store_true")
args = parser.parse_args()
os.makedirs(os.path.dirname(args.config_file), exist_ok=True)
with open(args.config_file, "w") as f:
json.dump(args.__dict__, f, indent=4)
train(args)

View File

@ -15,20 +15,28 @@ set_n_least_used_CUDA_VISIBLE_DEVICES() {
}
set_n_least_used_CUDA_VISIBLE_DEVICES 2
PROJECT_NAME="dpo"
PROJECT_NAME="orpo"
PARENT_CONFIG_FILE="./benchmark_config" # Path to a folder to save training config logs
PRETRAINED_MODEL_PATH="" # huggingface or local model path
PRETRAINED_TOKENIZER_PATH="" # huggingface or local tokenizer path
BENCHMARK_DATA_DIR="./temp/orpo" # Path to benchmark data
DATASET_SIZE=160
TIMESTAMP=$(date +%Y-%m-%d-%H-%M-%S)
FULL_PROJECT_NAME="${PROJECT_NAME}-${TIMESTAMP}"
CONFIG_FILE="${PARENT_CONFIG_FILE}-${FULL_PROJECT_NAME}.json"
declare -a dataset=(
$BENCHMARK_DATA_DIR/arrow/part-0
)
colossalai run --nproc_per_node 2 --master_port 31313 benchmark_orpo.py \
# Generate dummy test data
python prepare_dummy_test_dataset.py --data_dir $BENCHMARK_DATA_DIR --dataset_size $DATASET_SIZE --max_length 2048 --data_type preference
colossalai run --nproc_per_node 2 --master_port 31313 ../examples/training_scripts/train_orpo.py \
--pretrain $PRETRAINED_MODEL_PATH \
--tokenizer_dir $PRETRAINED_TOKENIZER_PATH \
--dataset ${dataset[@]} \
--plugin "zero2" \
--config_file $CONFIG_FILE \
--max_epochs 1 \
--accumulation_steps 1 \
--batch_size 4 \
@ -39,6 +47,5 @@ colossalai run --nproc_per_node 2 --master_port 31313 benchmark_orpo.py \
--max_length 2048 \
--weight_decay 0.01 \
--warmup_steps 60 \
--dataset_size 160 \
--grad_checkpoint \
--use_flash_attn

View File

@ -1,315 +0,0 @@
import argparse
import json
import math
import os
import resource
from contextlib import nullcontext
import torch
from coati.dataset import DataCollatorForSupervisedDataset, StatefulDistributedSampler
from coati.models import convert_to_lora_module
from coati.trainer import SFTTrainer
from coati.utils import load_checkpoint
from dummy_dataset import DummyLLMDataset
from transformers import AutoModelForCausalLM, AutoTokenizer
import colossalai
from colossalai.booster import Booster
from colossalai.booster.plugin import GeminiPlugin, HybridParallelPlugin, LowLevelZeroPlugin, TorchDDPPlugin
from colossalai.cluster import DistCoordinator
from colossalai.logging import get_dist_logger
from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR
from colossalai.nn.optimizer import HybridAdam
logger = get_dist_logger()
def train(args):
# check lora compatibility
if "gemini" in args.plugin and args.lora_rank > 0:
raise ValueError("LoRA is not supported in GeminiPlugin. Please use other plugin")
if args.plugin == "gemini_auto" and args.accumulation_steps > 1:
raise ValueError("Gradient accumulation is not supported in GeminiPlugin. Please use other plugin")
# ==============================
# Initialize Distributed Training
# ==============================
colossalai.launch_from_torch()
coordinator = DistCoordinator()
# ==============================
# Initialize Booster
# ==============================
init_ctx = nullcontext()
with init_ctx:
if args.use_flash_attn:
model = AutoModelForCausalLM.from_pretrained(
args.pretrain,
torch_dtype=torch.bfloat16 if args.mixed_precision == "bf16" else torch.float16,
attn_implementation="flash_attention_2",
trust_remote_code=True,
)
else:
model = AutoModelForCausalLM.from_pretrained(
args.pretrain,
torch_dtype=torch.bfloat16 if args.mixed_precision == "bf16" else torch.float16,
trust_remote_code=True,
)
if args.lora_rank > 0:
model = convert_to_lora_module(model, args.lora_rank, lora_train_bias=args.lora_train_bias)
if args.plugin == "ddp":
"""
Default torch ddp plugin without any acceleration, for
debugging purpose acceleration, for debugging purpose
"""
plugin = TorchDDPPlugin(find_unused_parameters=True)
elif args.plugin == "gemini":
plugin = GeminiPlugin(
precision=args.mixed_precision,
placement_policy="static",
initial_scale=2**16,
max_norm=args.grad_clip,
enable_gradient_accumulation=True if args.accumulation_steps > 1 else False,
enable_flash_attention=args.use_flash_attn,
)
elif args.plugin == "gemini_auto":
plugin = GeminiPlugin(
precision=args.mixed_precision,
placement_policy="auto",
initial_scale=2**16,
max_norm=args.grad_clip,
enable_flash_attention=args.use_flash_attn,
)
elif args.plugin == "zero2":
plugin = LowLevelZeroPlugin(
stage=2,
precision=args.mixed_precision,
initial_scale=2**16,
max_norm=args.grad_clip,
)
elif args.plugin == "zero2_cpu":
plugin = LowLevelZeroPlugin(
stage=2,
precision=args.mixed_precision,
initial_scale=2**16,
cpu_offload=True,
max_norm=args.grad_clip,
)
elif args.plugin == "3d":
plugin = HybridParallelPlugin(
tp_size=args.tp,
pp_size=args.pp,
sp_size=args.sp,
sequence_parallelism_mode=args.sp_mode,
zero_stage=args.zero_stage,
enable_flash_attention=args.use_flash_attn,
enable_sequence_parallelism=args.enable_sequence_parallelism,
cpu_offload=True if args.zero_stage >= 1 and args.zero_cpu_offload else False,
parallel_output=False,
max_norm=args.grad_clip,
precision=args.mixed_precision,
microbatch_size=args.batch_size,
)
else:
raise ValueError(f"Unknown plugin {args.plugin}")
booster = Booster(plugin=plugin)
# ======================================================
# Initialize Model, Objective, Optimizer and LR Scheduler
# ======================================================
# Temp Fix: Disable lazy init due to version conflict
# init_ctx = (
# LazyInitContext(default_device=get_current_device()) if isinstance(plugin, (GeminiPlugin,)) else nullcontext()
# )
if args.grad_checkpoint:
# Note, for some models, lora may not be compatible with gradient checkpointing
model.gradient_checkpointing_enable()
coordinator.print_on_master(msg="Gradient checkpointing enabled successfully")
# configure tokenizer
tokenizer = AutoTokenizer.from_pretrained(
args.tokenizer_dir or args.pretrain, use_fast=False, trust_remote_code=True
)
if hasattr(tokenizer, "pad_token") and hasattr(tokenizer, "eos_token") and tokenizer.eos_token is not None:
try:
# Some tokenizers doesn't allow to set pad_token mannually e.g., Qwen
tokenizer.pad_token = tokenizer.eos_token
except AttributeError as e:
logger.warning(f"Unable to set pad token to eos token, {str(e)}")
if not hasattr(tokenizer, "pad_token") or tokenizer.pad_token is None:
logger.warning(
"The tokenizer does not have a pad token which is required. May lead to unintended behavior in training, Please consider manually set them."
)
tokenizer.add_bos_token = False
tokenizer.add_eos_token = False
tokenizer.padding_side = "right"
coordinator.print_on_master(f"Configuration file will be saved at: {args.config_file}")
# configure optimizer
optim = HybridAdam(
model_params=model.parameters(),
lr=args.lr,
betas=(0.9, 0.95),
weight_decay=args.weight_decay,
adamw_mode=True,
)
# configure dataset
coordinator.print_on_master(
f"Max CUDA memory before data loader: {torch.cuda.max_memory_allocated() / 1024 ** 2:.2f} MB"
)
dataset = DummyLLMDataset(["input_ids", "attention_mask", "labels"], args.max_len, args.dataset_size)
data_collator = DataCollatorForSupervisedDataset(tokenizer=tokenizer, max_length=args.max_len)
train_dataloader = plugin.prepare_dataloader(
dataset=dataset,
batch_size=args.batch_size,
shuffle=True,
drop_last=True,
collate_fn=data_collator,
distributed_sampler_cls=StatefulDistributedSampler,
)
coordinator.print_on_master(
f"Max CUDA memory after data loader: {torch.cuda.max_memory_allocated() / 1024 ** 2:.2f} MB"
)
num_update_steps_per_epoch = len(train_dataloader) // args.accumulation_steps
math.ceil(args.max_epochs * num_update_steps_per_epoch)
if args.warmup_steps is None:
args.warmup_steps = int(args.max_epochs * 0.025 * (len(train_dataloader) // args.accumulation_steps))
coordinator.print_on_master(f"Warmup steps is set to {args.warmup_steps}")
lr_scheduler = CosineAnnealingWarmupLR(
optimizer=optim,
total_steps=args.max_epochs * num_update_steps_per_epoch,
warmup_steps=args.warmup_steps,
eta_min=0.1 * args.lr,
)
# Flash attention will be disabled because it does NOT support fp32.
default_dtype = torch.float16 if args.mixed_precision == "fp16" else torch.bfloat16
torch.set_default_dtype(default_dtype)
model, optim, _, train_dataloader, lr_scheduler = booster.boost(
model=model,
optimizer=optim,
lr_scheduler=lr_scheduler,
dataloader=train_dataloader,
)
torch.set_default_dtype(torch.float)
coordinator.print_on_master(f"Booster init max CUDA memory: {torch.cuda.max_memory_allocated() / 1024 ** 2:.2f} MB")
coordinator.print_on_master(
f"Booster init max CPU memory: {resource.getrusage(resource.RUSAGE_SELF).ru_maxrss / 1024:.2f} MB"
)
start_epoch = 0
sampler_start_idx = 0
start_step = 0
if args.checkpoint_path is not None:
if "modeling" in args.checkpoint_path:
coordinator.print_on_master(f"Continued pretrain from checkpoint {args.checkpoint_path}")
booster.load_model(model, args.checkpoint_path)
else:
coordinator.print_on_master(f"Load model checkpoint from {args.checkpoint_path}")
start_epoch, start_step, sampler_start_idx = load_checkpoint(
load_dir=args.checkpoint_path,
booster=booster,
model=model,
optimizer=optim,
lr_scheduler=lr_scheduler,
)
train_dataloader.sampler.set_start_index(start_index=sampler_start_idx)
coordinator.print_on_master(
f"Loaded checkpoint {args.checkpoint_path} at epoch {start_epoch} step {start_step}"
)
coordinator.print_on_master(f"Loaded sample at index {sampler_start_idx}")
coordinator.print_on_master(
f"Checkpoint loaded max CUDA memory: {torch.cuda.max_memory_allocated() / 1024 ** 2:.2f} MB"
)
coordinator.print_on_master(
f"Checkpoint loaded CUDA memory: {torch.cuda.memory_allocated() / 1024 ** 2:.2f} MB"
)
coordinator.print_on_master(
f"Checkpoint loaded max CPU memory: {resource.getrusage(resource.RUSAGE_SELF).ru_maxrss / 1024:.2f} MB"
)
trainer = SFTTrainer(
model=model,
booster=booster,
optim=optim,
lr_scheduler=lr_scheduler,
max_epochs=args.max_epochs,
accumulation_steps=args.accumulation_steps,
start_epoch=start_epoch,
save_interval=None,
save_dir=None,
coordinator=coordinator,
)
trainer.fit(
train_dataloader=train_dataloader,
eval_dataloader=None,
log_dir=None,
use_wandb=False,
)
coordinator.print_on_master(f"Max CUDA memory usage: {torch.cuda.max_memory_allocated()/1024**2:.2f} MB")
if __name__ == "__main__":
# ==============================
# Parse Arguments
# ==============================
parser = argparse.ArgumentParser()
parser.add_argument(
"--plugin",
type=str,
default="gemini",
choices=["gemini", "gemini_auto", "3d", "ddp", "zero2_cpu", "zero2"],
help="Choose which plugin to use",
)
parser.add_argument("--grad_clip", type=float, default=1.0, help="Gradient clipping value")
parser.add_argument("--weight_decay", type=float, default=0.1, help="Weight decay")
parser.add_argument("--warmup_steps", type=int, default=None, help="Warmup steps")
parser.add_argument("--tp", type=int, default=1)
parser.add_argument("--pp", type=int, default=1)
parser.add_argument("--sp", type=int, default=1)
parser.add_argument("--enable_sequence_parallelism", default=False, action="store_true")
parser.add_argument("--zero_stage", type=int, default=0, help="Zero stage", choices=[0, 1, 2])
parser.add_argument("--zero_cpu_offload", default=False, action="store_true")
parser.add_argument("--sp_mode", type=str, default="split_gather", choices=["split_gather", "ring", "all_to_all"])
parser.add_argument("--pretrain", type=str, default=None)
parser.add_argument("--tokenizer_dir", type=str, default=None)
parser.add_argument(
"--checkpoint_path", type=str, default=None, help="Checkpoint path if need to resume training form a checkpoint"
)
parser.add_argument("--max_epochs", type=int, default=3)
parser.add_argument("--batch_size", type=int, default=4)
parser.add_argument("--max_len", type=int, default=512)
parser.add_argument("--mixed_precision", type=str, default="bf16", choices=["fp16", "bf16"], help="Mixed precision")
parser.add_argument("--lora_rank", type=int, default=0, help="low-rank adaptation matrices rank")
parser.add_argument(
"--lora_train_bias",
type=str,
default="none",
help="'none' means it doesn't train biases. 'all' means it trains all biases. 'lora_only' means it only trains biases of LoRA layers",
)
parser.add_argument("--merge_lora_weights", type=bool, default=True)
parser.add_argument("--lr", type=float, default=5e-6)
parser.add_argument("--config_file", type=str, default="config_file", help="Config file")
parser.add_argument("--accumulation_steps", type=int, default=8)
parser.add_argument("--grad_checkpoint", default=False, action="store_true")
parser.add_argument("--use_flash_attn", default=False, action="store_true")
parser.add_argument("--dataset_size", type=int, default=500)
args = parser.parse_args()
os.makedirs(os.path.dirname(args.config_file), exist_ok=True)
with open(args.config_file, "w") as f:
json.dump(args.__dict__, f, indent=4)
train(args)

View File

@ -14,21 +14,31 @@ set_n_least_used_CUDA_VISIBLE_DEVICES() {
}
set_n_least_used_CUDA_VISIBLE_DEVICES 4
# export CUDA_VISIBLE_DEVICES=3,4
PROJECT_NAME="sft"
PARENT_CONFIG_FILE="./benchmark_config" # Path to a folder to save training config logs
PRETRAINED_MODEL_PATH="" # huggingface or local model path
PRETRAINED_TOKENIZER_PATH="" # huggingface or local tokenizer path
BENCHMARK_DATA_DIR="./temp/sft" # Path to benchmark data
DATASET_SIZE=640
TIMESTAMP=$(date +%Y-%m-%d-%H-%M-%S)
FULL_PROJECT_NAME="${PROJECT_NAME}-${TIMESTAMP}"
CONFIG_FILE="${PARENT_CONFIG_FILE}-${FULL_PROJECT_NAME}.json"
declare -a dataset=(
$BENCHMARK_DATA_DIR/arrow/part-0
)
# Generate dummy test data
python prepare_dummy_test_dataset.py --data_dir $BENCHMARK_DATA_DIR --dataset_size $DATASET_SIZE --max_length 2048 --data_type sft
# the real batch size for gradient descent is number_of_node_in_hostfile * nproc_per_node * train_batch_size
colossalai run --nproc_per_node 4 --master_port 31312 benchmark_sft.py \
colossalai run --nproc_per_node 1 --master_port 31312 ../examples/training_scripts/train_sft.py \
--pretrain $PRETRAINED_MODEL_PATH \
--tokenizer_dir $PRETRAINED_TOKENIZER_PATH \
--config_file $CONFIG_FILE \
--dataset ${dataset[@]} \
--plugin zero2 \
--batch_size 8 \
--max_epochs 1 \
@ -36,6 +46,5 @@ colossalai run --nproc_per_node 4 --master_port 31312 benchmark_sft.py \
--lr 5e-5 \
--lora_rank 32 \
--max_len 2048 \
--dataset_size 640 \
--grad_checkpoint \
--use_flash_attn

View File

@ -0,0 +1,55 @@
#!/bin/bash
set_n_least_used_CUDA_VISIBLE_DEVICES() {
local n=${1:-"9999"}
echo "GPU Memory Usage:"
local FIRST_N_GPU_IDS=$(nvidia-smi --query-gpu=memory.used --format=csv |
tail -n +2 |
nl -v 0 |
tee /dev/tty |
sort -g -k 2 |
awk '{print $1}' |
head -n $n)
export CUDA_VISIBLE_DEVICES=$(echo $FIRST_N_GPU_IDS | sed 's/ /,/g')
echo "Now CUDA_VISIBLE_DEVICES is set to:"
echo "CUDA_VISIBLE_DEVICES=$CUDA_VISIBLE_DEVICES"
}
set_n_least_used_CUDA_VISIBLE_DEVICES 4
PROJECT_NAME="simpo"
PARENT_CONFIG_FILE="./benchmark_config" # Path to a folder to save training config logs
PRETRAINED_MODEL_PATH="" # huggingface or local model path
PRETRAINED_TOKENIZER_PATH="" # huggingface or local tokenizer path
BENCHMARK_DATA_DIR="./temp/simpo" # Path to benchmark data
DATASET_SIZE=640
TIMESTAMP=$(date +%Y-%m-%d-%H-%M-%S)
FULL_PROJECT_NAME="${PROJECT_NAME}-${TIMESTAMP}"
declare -a dataset=(
$BENCHMARK_DATA_DIR/arrow/part-0
)
# Generate dummy test data
python prepare_dummy_test_dataset.py --data_dir $BENCHMARK_DATA_DIR --dataset_size $DATASET_SIZE --max_length 2048 --data_type preference
colossalai run --nproc_per_node 4 --master_port 31313 ../examples/training_scripts/train_dpo.py \
--pretrain $PRETRAINED_MODEL_PATH \
--tokenizer_dir $PRETRAINED_TOKENIZER_PATH \
--dataset ${dataset[@]} \
--plugin "zero2_cpu" \
--loss_type "simpo_loss" \
--max_epochs 1 \
--accumulation_steps 1 \
--batch_size 8 \
--lr 1e-6 \
--beta 0.1 \
--gamma 0.6 \
--mixed_precision "bf16" \
--grad_clip 1.0 \
--max_length 2048 \
--weight_decay 0.01 \
--warmup_steps 60 \
--disable_reference_model \
--length_normalization \
--grad_checkpoint \
--use_flash_attn

View File

@ -1,10 +1,12 @@
import torch
from typing import Callable
from torch.utils.data import Dataset
class DummyLLMDataset(Dataset):
def __init__(self, keys, seq_len, size=500):
def __init__(self, keys, seq_len, size=500, gen_fn={}):
self.keys = keys
self.gen_fn = gen_fn
self.seq_len = seq_len
self.data = self._generate_data()
self.size = size
@ -12,11 +14,17 @@ class DummyLLMDataset(Dataset):
def _generate_data(self):
data = {}
for key in self.keys:
data[key] = torch.ones(self.seq_len, dtype=torch.long)
if key in self.gen_fn:
data[key] = self.gen_fn[key]
else:
data[key] = [1] * self.seq_len
return data
def __len__(self):
return self.size
def __getitem__(self, idx):
return {key: self.data[key] for key in self.keys}
return {
key: self.data[key] if not isinstance(self.data[key], Callable) else self.data[key](idx)
for key in self.keys
}

View File

@ -0,0 +1,105 @@
import argparse
import json
import os
import time
from multiprocessing import cpu_count
from datasets import load_dataset
from dummy_dataset import DummyLLMDataset
from colossalai.logging import get_dist_logger
logger = get_dist_logger()
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--data_dir",
type=str,
required=True,
default=None,
help="The output dir",
)
parser.add_argument(
"--dataset_size",
type=int,
required=True,
default=None,
help="The size of data",
)
parser.add_argument(
"--max_length",
type=int,
required=True,
default=None,
help="The max length of data",
)
parser.add_argument(
"--data_type",
type=str,
required=True,
default=None,
help="The type of data, choose one from ['sft', 'prompt', 'preference', 'kto']",
)
args = parser.parse_args()
if args.data_type == "sft":
dataset = DummyLLMDataset(["input_ids", "attention_mask", "labels"], args.max_length, args.dataset_size)
elif args.data_type == "prompt":
# pass PPO dataset is prepared separately
pass
elif args.data_type == "preference":
dataset = DummyLLMDataset(
["chosen_input_ids", "chosen_loss_mask", "rejected_input_ids", "rejected_loss_mask"],
args.max_length,
args.dataset_size,
)
elif args.data_type == "kto":
dataset = DummyLLMDataset(
["prompt", "completion", "label"],
args.max_length - 512,
args.dataset_size,
gen_fn={
"completion": lambda x: [1] * 512,
"label": lambda x: x % 2,
},
)
else:
raise ValueError(f"Unknown data type {args.data_type}")
# Save each jsonl spliced dataset.
output_index = "0"
output_name = f"part-{output_index}"
os.makedirs(args.data_dir, exist_ok=True)
output_jsonl_path = os.path.join(args.data_dir, "json")
output_arrow_path = os.path.join(args.data_dir, "arrow")
output_cache_path = os.path.join(args.data_dir, "cache")
os.makedirs(output_jsonl_path, exist_ok=True)
os.makedirs(output_arrow_path, exist_ok=True)
output_jsonl_file_path = os.path.join(output_jsonl_path, output_name + ".jsonl")
st = time.time()
with open(file=output_jsonl_file_path, mode="w", encoding="utf-8") as fp_writer:
count = 0
for i in range(len(dataset)):
data_point = dataset[i]
if count % 500 == 0:
logger.info(f"processing {count} spliced data points for {fp_writer.name}")
count += 1
fp_writer.write(json.dumps(data_point, ensure_ascii=False) + "\n")
logger.info(
f"Current file {fp_writer.name}; "
f"Data size: {len(dataset)}; "
f"Time cost: {round((time.time() - st) / 60, 6)} minutes."
)
# Save each arrow spliced dataset
output_arrow_file_path = os.path.join(output_arrow_path, output_name)
logger.info(f"Start to save {output_arrow_file_path}")
dataset = load_dataset(
path="json",
data_files=[output_jsonl_file_path],
cache_dir=os.path.join(output_cache_path, "tokenized"),
keep_in_memory=False,
num_proc=cpu_count(),
split="train",
)
dataset.save_to_disk(dataset_path=output_arrow_file_path, num_proc=min(len(dataset), cpu_count()))

View File

@ -1,24 +1,26 @@
from .conversation import Conversation, setup_conversation_template
from .loader import (
DataCollatorForKTODataset,
DataCollatorForPreferenceDataset,
DataCollatorForPromptDataset,
DataCollatorForSupervisedDataset,
StatefulDistributedSampler,
load_tokenized_dataset,
)
from .tokenization_utils import supervised_tokenize_sft, tokenize_prompt_dataset, tokenize_rlhf
from .tokenization_utils import tokenize_kto, tokenize_prompt, tokenize_rlhf, tokenize_sft
__all__ = [
"tokenize_prompt_dataset",
"tokenize_prompt",
"DataCollatorForPromptDataset",
"is_rank_0",
"DataCollatorForPreferenceDataset",
"DataCollatorForSupervisedDataset",
"DataCollatorForKTODataset",
"StatefulDistributedSampler",
"load_tokenized_dataset",
"supervised_tokenize_pretrain",
"supervised_tokenize_sft",
"tokenize_sft",
"tokenize_rlhf",
"tokenize_kto",
"setup_conversation_template",
"Conversation",
]

View File

@ -18,6 +18,7 @@ class Conversation:
chat_template: str
stop_ids: List[int]
end_of_assistant: str
roles = ["user", "assistant"]
@classmethod
def from_config(cls, tokenizer: PreTrainedTokenizer, config: Dict):
@ -85,7 +86,7 @@ class Conversation:
Raises:
AssertionError: If the role is not 'user' or 'assistant'.
"""
assert role in ["user", "assistant"]
assert role in self.roles
self.messages.append({"role": role, "content": message})
def copy(self):

View File

@ -235,6 +235,91 @@ class DataCollatorForPreferenceDataset(object):
)
@dataclass
class DataCollatorForKTODataset(object):
"""
Collate instances for kto dataset.
Each input instance is a tokenized dictionary with fields
`prompt`(List[int]), `completion`(List[int]) and `label`(bool).
Each output instance is a tokenized dictionary with fields
`kl_input_ids`(List[int]), `kl_attention_mask`(List[int]) and `kl_loss_mask`(List[int]).
`input_ids`(List[int]), `attention_mask`(List[int]), `loss_mask`(List[int]) and `label`(bool).
"""
tokenizer: PreTrainedTokenizer
max_length: int = 4096
ignore_index: int = -100
def __call__(self, instances: Sequence[Dict[str, List[int]]]) -> Dict[str, torch.Tensor]:
"""
Args:
instances (`Sequence[Dict[str, List[int]]]`):
Mini-batch samples, each sample is stored in an individual dictionary contains the following fields:
`prompt`(List[int]), `completion`(List[int]) and `label`(bool, if the sample is desirable or not).
Returns:
(`Dict[str, torch.Tensor]`): Contains the following `torch.Tensor`:
`input_ids`: `torch.Tensor` of shape (bsz, max_len);
`attention_mask`: `torch.BoolTensor` of shape (bsz, max_len);
`labels`: `torch.Tensor` of shape (bsz, max_len), which contains `IGNORE_INDEX`.
"""
assert isinstance(self.tokenizer.pad_token_id, int) and self.tokenizer.pad_token_id >= 0, (
f"`{self.tokenizer.__class__.__name__}.pad_token_id` must be a valid non-negative integer index value, "
f"but now `{self.tokenizer.pad_token_id}`"
)
# prepare the preference data
prompt = [torch.LongTensor(instance["prompt"]) for instance in instances]
prompt_zeros = [torch.zeros_like(t) for t in prompt]
completion = [torch.LongTensor(instance["completion"]) for instance in instances]
completion_ones = [torch.ones_like(t) for t in completion]
label = [torch.tensor(instance["label"], dtype=torch.bool) for instance in instances]
input_ids = [torch.cat([prompt[i], completion[i]], dim=-1) for i in range(len(instances))]
loss_mask = [torch.cat([prompt_zeros[i], completion_ones[i]], dim=-1) for i in range(len(instances))]
# right padding
input_ids = torch.nn.utils.rnn.pad_sequence(
sequences=input_ids,
batch_first=True,
padding_value=self.tokenizer.pad_token_id,
) # (bsz, max_len)
loss_mask = torch.nn.utils.rnn.pad_sequence(
sequences=loss_mask, batch_first=True, padding_value=0
) # (bsz, max_len)
to_pad = self.max_length - input_ids.size(1)
input_ids = F.pad(input_ids, (0, to_pad), value=self.tokenizer.pad_token_id)
loss_mask = F.pad(loss_mask, (0, to_pad), value=0)
attention_mask = input_ids.ne(self.tokenizer.pad_token_id) # `torch.BoolTensor`, (bsz, max_len)
# prepare kt data
kl_completion = completion[::-1] # y'
kl_completion_ones = [torch.ones_like(t) for t in kl_completion]
kl_input_ids = [torch.cat([prompt[i], kl_completion[i]], dim=-1) for i in range(len(instances))]
kl_loss_mask = [torch.cat([prompt_zeros[i], kl_completion_ones[i]], dim=-1) for i in range(len(instances))]
# right padding
kl_input_ids = torch.nn.utils.rnn.pad_sequence(
sequences=kl_input_ids,
batch_first=True,
padding_value=self.tokenizer.pad_token_id,
) # (bsz, max_len)
kl_loss_mask = torch.nn.utils.rnn.pad_sequence(
sequences=kl_loss_mask, batch_first=True, padding_value=0
) # (bsz, max_len)
to_pad = self.max_length - kl_input_ids.size(1)
kl_input_ids = F.pad(kl_input_ids, (0, to_pad), value=self.tokenizer.pad_token_id)
kl_loss_mask = F.pad(kl_loss_mask, (0, to_pad), value=0)
kl_attention_mask = kl_input_ids.ne(self.tokenizer.pad_token_id) # `torch.BoolTensor`, (bsz, max_len)
data_dict = {
"input_ids": input_ids,
"attention_mask": attention_mask,
"loss_mask": loss_mask,
"label": torch.stack(label),
"kl_input_ids": kl_input_ids,
"kl_attention_mask": kl_attention_mask,
"kl_loss_mask": kl_loss_mask,
}
return data_dict
class StatefulDistributedSampler(DistributedSampler):
def __init__(
self,

View File

@ -23,11 +23,10 @@ IGNORE_INDEX = -100
DSType = Union[Dataset, ConcatDataset, dataset_dict.Dataset]
def supervised_tokenize_sft(
def tokenize_sft(
data_point: Dict[str, str],
tokenizer: PreTrainedTokenizer,
conversation_template: Conversation = None,
ignore_index: int = None,
max_length: int = 4096,
) -> Dict[str, Union[int, str, List[int]]]:
"""
@ -39,54 +38,37 @@ def supervised_tokenize_sft(
Args:
data_point: the data point of the following format
{"messages": [{"from": "human", "content": "xxx"}, {"from": "assistant", "content": "xxx"}]}
{"messages": [{"from": "user", "content": "xxx"}, {"from": "assistant", "content": "xxx"}]}
tokenizer: the tokenizer whose
conversation_template: the conversation template to apply
ignore_index: the ignore index when calculate loss during training
max_length: the maximum context length
"""
if ignore_index is None:
ignore_index = IGNORE_INDEX
ignore_index = IGNORE_INDEX
messages = data_point["messages"]
template = deepcopy(conversation_template)
template.messages = []
for mess in messages:
from_str = mess["from"]
if from_str.lower() == "human":
from_str = "user"
elif from_str.lower() == "assistant":
from_str = "assistant"
else:
raise ValueError(f"Unsupported role {from_str.lower()}")
template.append_message(from_str, mess["content"])
for idx, mess in enumerate(messages):
if mess["from"] != template.roles[idx % 2]:
raise ValueError(
f"Message should iterate between user and assistant and starts with a \
line from the user. Got the following data:\n{messages}"
)
template.append_message(mess["from"], mess["content"])
if len(template.messages) % 2 != 0:
# Force to end with assistant response
template.messages = template.messages[0:-1]
# `target_turn_index` is the number of turns which exceeds `max_length - 1` for the first time.
turns = [i for i in range(1, len(messages) // 2 + 1)]
lo, hi = 0, len(turns)
while lo < hi:
mid = (lo + hi) // 2
prompt = template.get_prompt(2 * turns[mid] - 1)
chunks, require_loss = split_templated_prompt_into_chunks(
template.messages[: 2 * turns[mid] - 1], prompt, conversation_template.end_of_assistant
)
tokenized, starts, ends = tokenize_and_concatenate(tokenizer, chunks, require_loss)
if max_length - 1 < len(tokenized):
hi = mid
else:
lo = mid + 1
target_turn_index = lo
# The tokenized length for first turn already exceeds `max_length - 1`.
if target_turn_index - 1 < 0:
warnings.warn("The tokenized length for first turn already exceeds `max_length - 1`.")
# tokenize and calculate masked labels -100 for positions corresponding to non-assistant lines
prompt = template.get_prompt()
chunks, require_loss = split_templated_prompt_into_chunks(
template.messages, prompt, conversation_template.end_of_assistant
)
tokenized, starts, ends = tokenize_and_concatenate(tokenizer, chunks, require_loss, max_length=max_length)
if tokenized is None:
return dict(
input_ids=None,
labels=None,
@ -96,45 +78,18 @@ def supervised_tokenize_sft(
seq_category=None,
)
target_turn = turns[target_turn_index - 1]
prompt = template.get_prompt(2 * target_turn)
chunks, require_loss = split_templated_prompt_into_chunks(
template.messages[: 2 * target_turn], prompt, conversation_template.end_of_assistant
)
tokenized, starts, ends = tokenize_and_concatenate(tokenizer, chunks, require_loss)
labels = [ignore_index] * len(tokenized)
for start, end in zip(starts, ends):
if end == len(tokenized):
tokenized = tokenized + [tokenizer.eos_token_id]
labels = labels + [ignore_index]
labels[start:end] = tokenized[start:end]
# truncate the sequence at the last token that requires loss calculation
to_truncate_len = 0
for i in range(len(tokenized) - 1, -1, -1):
if labels[i] == ignore_index:
to_truncate_len += 1
else:
break
to_truncate_len = max(len(tokenized) - max_length, to_truncate_len)
tokenized = tokenized[: len(tokenized) - to_truncate_len]
labels = labels[: len(labels) - to_truncate_len]
if tokenizer.bos_token_id is not None:
# Force to add bos token at the beginning of the tokenized sequence if the input ids doesn;t starts with bos
if tokenized[0] != tokenizer.bos_token_id:
# Some chat templates already include bos token
tokenized = [tokenizer.bos_token_id] + tokenized
labels = [ignore_index] + labels
labels = [-100] + labels
if tokenizer.eos_token_id is not None:
# Force to add eos token at the end of the tokenized sequence
if tokenized[-1] != tokenizer.eos_token_id:
tokenized = tokenized + [tokenizer.eos_token_id]
labels = labels + [tokenizer.eos_token_id]
else:
labels[-1] = tokenizer.eos_token_id
# For some model without bos/eos may raise the following errors
# log decoded inputs and labels for debugging
inputs_decode = tokenizer.decode(tokenized)
start = 0
end = 0
@ -171,11 +126,10 @@ def supervised_tokenize_sft(
)
def tokenize_prompt_dataset(
def tokenize_prompt(
data_point: Dict[str, str],
tokenizer: PreTrainedTokenizer,
conversation_template: Conversation = None,
ignore_index: int = None,
max_length: int = 4096,
) -> Dict[str, Union[int, str, List[int]]]:
"""
@ -183,48 +137,39 @@ def tokenize_prompt_dataset(
"Something here can be system message[user_line_start]User line[User line end][Assistant line start]Assistant line[Assistant line end]...[Assistant line start]"
Args:
data_point: the data point of the following format
{"messages": [{"from": "human", "content": "xxx"}, {"from": "assistant", "content": "xxx"}]}
{"messages": [{"from": "user", "content": "xxx"}, {"from": "assistant", "content": "xxx"}]}
tokenizer: the tokenizer whose
conversation_template: the conversation template to apply
ignore_index: the ignore index when calculate loss during training
max_length: the maximum context length
"""
if ignore_index is None:
ignore_index = IGNORE_INDEX
messages = data_point["messages"]
template = deepcopy(conversation_template)
template.messages = []
for mess in messages:
from_str = mess["from"]
if from_str.lower() == "human":
from_str = "user"
elif from_str.lower() == "assistant":
from_str = "assistant"
else:
raise ValueError(f"Unsupported role {from_str.lower()}")
template.append_message(from_str, mess["content"])
for idx, mess in enumerate(messages):
if mess["from"] != template.roles[idx % 2]:
raise ValueError(
f"Message should iterate between user and assistant and starts with a \
line from the user. Got the following data:\n{messages}"
)
template.append_message(mess["from"], mess["content"])
# `target_turn_index` is the number of turns which exceeds `max_length - 1` for the first time.
target_turn = len(template.messages)
if target_turn % 2 != 1:
if len(template.messages) % 2 != 1:
# exclude the answer if provided. keep only the prompt
target_turn = target_turn - 1
template.messages = template.messages[:-1]
# Prepare data
prompt = template.get_prompt(target_turn, add_generation_prompt=True)
chunks, require_loss = split_templated_prompt_into_chunks(
template.messages[:target_turn], prompt, conversation_template.end_of_assistant
)
tokenized, starts, ends = tokenize_and_concatenate(tokenizer, chunks, require_loss)
prompt = template.get_prompt(length=len(template.messages) - 1, add_generation_prompt=True)
tokenized = tokenizer([prompt], add_special_tokens=False)["input_ids"][0]
if tokenizer.bos_token_id is not None:
if tokenized[0] != tokenizer.bos_token_id:
tokenized = [tokenizer.bos_token_id] + tokenized
# Skip overlength data
if max_length - 1 < len(tokenized):
if len(tokenized) > max_length:
return dict(
input_ids=None,
inputs_decode=None,
@ -235,47 +180,32 @@ def tokenize_prompt_dataset(
# `inputs_decode` can be used to check whether the tokenization method is true.
return dict(
input_ids=tokenized,
inputs_decode=tokenizer.decode(tokenized),
inputs_decode=prompt,
seq_length=len(tokenized),
seq_category=data_point["category"] if "category" in data_point else "None",
)
def apply_rlhf_data_format(
template: Conversation, tokenizer: Any, context_len: int, mask_out_target_assistant_line_end=False
):
def apply_rlhf_data_format(template: Conversation, tokenizer: Any):
target_turn = int(len(template.messages) / 2)
prompt = template.get_prompt(target_turn * 2)
chunks, require_loss = split_templated_prompt_into_chunks(
template.messages[: 2 * target_turn], prompt, template.end_of_assistant
)
tokenized, starts, ends = tokenize_and_concatenate(tokenizer, chunks, require_loss)
loss_mask = [0] * len(tokenized)
mask_token = tokenizer.eos_token_id or tokenizer.pad_token_id
if mask_token is None:
mask_token = 1 # If the tokenizer doesn't have eos_token or pad_token: Qwen
# no truncation applied
tokenized, starts, ends = tokenize_and_concatenate(tokenizer, chunks, require_loss, max_length=None)
loss_mask = [0] * len(tokenized)
label_decode = []
for start, end in zip(starts[-1:], ends[-1:]):
# only the last round (chosen/rejected) counts
if end == len(tokenized):
tokenized = tokenized + [tokenizer.eos_token_id]
loss_mask = loss_mask + [1]
loss_mask[start:end] = [1] * len(loss_mask[start:end])
label_decode.append(tokenizer.decode(tokenized[start:end], skip_special_tokens=False))
# only the last round (chosen/rejected) is used to calculate loss
for i in range(starts[-1], ends[-1]):
loss_mask[i] = 1
label_decode.append(tokenizer.decode(tokenized[starts[-1] : ends[-1]], skip_special_tokens=False))
if tokenizer.bos_token_id is not None:
if tokenized[0] != tokenizer.bos_token_id:
tokenized = [tokenizer.bos_token_id] + tokenized
loss_mask = [0] + loss_mask
if tokenizer.eos_token_id is not None:
# Force to add eos token at the end of the tokenized sequence
if tokenized[-1] != tokenizer.eos_token_id:
tokenized = tokenized + [tokenizer.eos_token_id]
loss_mask = loss_mask + [1]
else:
loss_mask[-1] = 1
return {"input_ids": tokenized, "loss_mask": loss_mask, "label_decode": label_decode}
@ -283,39 +213,29 @@ def tokenize_rlhf(
data_point: Dict[str, str],
tokenizer: PreTrainedTokenizer,
conversation_template: Conversation = None,
ignore_index: int = None,
max_length: int = 4096,
) -> Dict[str, Union[int, str, List[int]]]:
"""
A tokenization function to tokenize an original pretraining data point as following:
{"context": [{"from": "human", "content": "xxx"}, {"from": "assistant", "content": "xxx"}],
{"context": [{"from": "user", "content": "xxx"}, {"from": "assistant", "content": "xxx"}],
"chosen": {"from": "assistant", "content": "xxx"}, "rejected": {"from": "assistant", "content": "xxx"}}
"""
if ignore_index is None:
ignore_index = IGNORE_INDEX
context = data_point["context"]
template = deepcopy(conversation_template)
template.clear()
for mess in context:
from_str = mess["from"]
if from_str.lower() == "human":
from_str = "user"
elif from_str.lower() == "assistant":
from_str = "assistant"
else:
raise ValueError(f"Unsupported role {from_str.lower()}")
if len(template.messages) > 0 and from_str == template.messages[-1]["role"]:
# Concate adjacent message from the same role
template.messages[-1]["content"] = str(template.messages[-1]["content"] + " " + mess["content"])
else:
template.append_message(from_str, mess["content"])
for idx, mess in enumerate(context):
if mess["from"] != template.roles[idx % 2]:
raise ValueError(
f"Message should iterate between user and assistant and starts with a \
line from the user. Got the following data:\n{context}"
)
template.append_message(mess["from"], mess["content"])
if len(template.messages) % 2 != 1:
warnings.warn(
"Please make sure leading context starts and ends with a line from human\nLeading context: "
"Please make sure leading context starts and ends with a line from user\nLeading context: "
+ str(template.messages)
)
return dict(
@ -326,31 +246,27 @@ def tokenize_rlhf(
rejected_loss_mask=None,
rejected_label_decode=None,
)
round_of_context = int((len(template.messages) - 1) / 2)
assert context[-1]["from"].lower() == "human", "The last message in context should be from human."
assert context[-1]["from"].lower() == template.roles[0], "The last message in context should be from user."
chosen = deepcopy(template)
rejected = deepcopy(template)
chosen_continuation = data_point["chosen"]
rejected_continuation = data_point["rejected"]
for round in range(len(chosen_continuation)):
if chosen_continuation[round]["from"] != template.roles[(round + 1) % 2]:
raise ValueError(
f"Message should iterate between user and assistant and starts with a \
line from the user. Got the following data:\n{chosen_continuation}"
)
chosen.append_message(chosen_continuation[round]["from"], chosen_continuation[round]["content"])
for round in range(len(data_point["chosen"])):
from_str = data_point["chosen"][round]["from"]
if from_str.lower() == "human":
from_str = "user"
elif from_str.lower() == "assistant":
from_str = "assistant"
else:
raise ValueError(f"Unsupported role {from_str.lower()}")
chosen.append_message(from_str, data_point["chosen"][round]["content"])
for round in range(len(data_point["rejected"])):
from_str = data_point["rejected"][round]["from"]
if from_str.lower() == "human":
from_str = "user"
elif from_str.lower() == "assistant":
from_str = "assistant"
else:
raise ValueError(f"Unsupported role {from_str.lower()}")
rejected.append_message(from_str, data_point["rejected"][round]["content"])
for round in range(len(rejected_continuation)):
if rejected_continuation[round]["from"] != template.roles[(round + 1) % 2]:
raise ValueError(
f"Message should iterate between user and assistant and starts with a \
line from the user. Got the following data:\n{rejected_continuation}"
)
rejected.append_message(rejected_continuation[round]["from"], rejected_continuation[round]["content"])
(
chosen_input_ids,
@ -361,16 +277,14 @@ def tokenize_rlhf(
rejected_label_decode,
) = (None, None, None, None, None, None)
chosen_data_packed = apply_rlhf_data_format(chosen, tokenizer, round_of_context)
chosen_data_packed = apply_rlhf_data_format(chosen, tokenizer)
(chosen_input_ids, chosen_loss_mask, chosen_label_decode) = (
chosen_data_packed["input_ids"],
chosen_data_packed["loss_mask"],
chosen_data_packed["label_decode"],
)
rejected_data_packed = apply_rlhf_data_format(
rejected, tokenizer, round_of_context, mask_out_target_assistant_line_end=True
)
rejected_data_packed = apply_rlhf_data_format(rejected, tokenizer)
(rejected_input_ids, rejected_loss_mask, rejected_label_decode) = (
rejected_data_packed["input_ids"],
rejected_data_packed["loss_mask"],
@ -387,7 +301,7 @@ def tokenize_rlhf(
rejected_label_decode=None,
)
# Check if loss mask is all 0s (no loss), this may happen when the tokenized length is too long
if chosen_loss_mask[1:].count(1) == 0 or rejected_loss_mask[1:].count(1) == 0:
if chosen_loss_mask.count(1) == 0 or rejected_loss_mask.count(1) == 0:
return dict(
chosen_input_ids=None,
chosen_loss_mask=None,
@ -405,3 +319,62 @@ def tokenize_rlhf(
"rejected_loss_mask": rejected_loss_mask,
"rejected_label_decode": rejected_label_decode,
}
def tokenize_kto(
data_point: Dict[str, str],
tokenizer: PreTrainedTokenizer,
conversation_template: Conversation = None,
max_length: int = 4096,
) -> Dict[str, Union[int, str, List[int]]]:
"""
Tokenize a dataset for KTO training
The raw input data is conversation that have the following format
{
"prompt": [{"from": "user", "content": "xxx"}...],
"completion": {"from": "assistant", "content": "xxx"},
"label": true/false
}
It returns three fields
The context, which contain the query and the assistant start,
the completion, which only contains the assistance's answer,
and a binary label, which indicates if the sample is prefered or not
"""
prompt = data_point["prompt"]
completion = data_point["completion"]
template = deepcopy(conversation_template)
template.clear()
if prompt[0].get("from", None) != "user":
raise ValueError("conversation should start with user")
if completion.get("from", None) != "assistant":
raise ValueError("conversation should end with assistant")
for mess in prompt:
if mess.get("from", None) == "user":
template.append_message("user", mess["content"])
elif mess.get("from", None) == "assistant":
template.append_message("assistant", mess["content"])
else:
raise ValueError(f"Unsupported role {mess.get('from', None)}")
generation_prompt = template.get_prompt(len(prompt), add_generation_prompt=True)
template.append_message("assistant", completion["content"])
full_prompt = template.get_prompt(len(prompt) + 1, add_generation_prompt=False)
tokenized_full_prompt = tokenizer(full_prompt, add_special_tokens=False)["input_ids"]
if len(tokenized_full_prompt) + 1 > max_length:
return dict(prompt=None, completion=None, label=None, input_id_decode=None, completion_decode=None)
tokenized_generation_prompt = tokenizer(generation_prompt, add_special_tokens=False)["input_ids"]
tokenized_completion = tokenized_full_prompt[len(tokenized_generation_prompt) :]
tokenized_completion = deepcopy(tokenized_completion)
if tokenizer.bos_token_id is not None and tokenized_generation_prompt[0] != tokenizer.bos_token_id:
tokenized_generation_prompt = [tokenizer.bos_token_id] + tokenized_generation_prompt
decoded_full_prompt = tokenizer.decode(tokenized_full_prompt, skip_special_tokens=False)
decoded_completion = tokenizer.decode(tokenized_completion, skip_special_tokens=False)
return {
"prompt": tokenized_generation_prompt,
"completion": tokenized_completion,
"label": data_point["label"],
"input_id_decode": decoded_full_prompt,
"completion_decode": decoded_completion,
}

View File

@ -88,7 +88,13 @@ def find_first_occurrence_subsequence(seq: torch.Tensor, subseq: torch.Tensor, s
return -1
def tokenize_and_concatenate(tokenizer: PreTrainedTokenizer, text: List[str], require_loss: List[bool]):
def tokenize_and_concatenate(
tokenizer: PreTrainedTokenizer,
text: List[str],
require_loss: List[bool],
max_length: int,
discard_non_loss_tokens_at_tail: bool = True,
):
"""
Tokenizes a list of texts using the provided tokenizer and concatenates the tokenized outputs.
@ -96,6 +102,13 @@ def tokenize_and_concatenate(tokenizer: PreTrainedTokenizer, text: List[str], re
tokenizer (PreTrainedTokenizer): The tokenizer to use for tokenization.
text (List[str]): The list of texts to tokenize.
require_loss (List[bool]): A list of boolean values indicating whether each text requires loss calculation.
max_length: used to truncate the input ids
discard_non_loss_tokens_at_tail: whether to discard the non-loss tokens at the tail
if the first round has already exeeded max length
- if the user query already exeeded max length, discard the sample
- if only the first assistant response exeeded max length, truncate the response to fit the max length
else keep the first several complete rounds of the conversations until max length is reached
Returns:
Tuple[List[int], List[int], List[int]]: A tuple containing the concatenated tokenized input ids,
@ -106,10 +119,18 @@ def tokenize_and_concatenate(tokenizer: PreTrainedTokenizer, text: List[str], re
loss_ends = []
for s, r in zip(text, require_loss):
tokenized = tokenizer(s, add_special_tokens=False)["input_ids"]
if r:
loss_starts.append(len(input_ids))
loss_ends.append(len(input_ids) + len(tokenized))
input_ids.extend(tokenized)
if not max_length or len(input_ids) + len(tokenized) <= max_length or len(loss_ends) == 0:
if r:
loss_starts.append(len(input_ids))
loss_ends.append(len(input_ids) + len(tokenized))
input_ids.extend(tokenized)
if max_length and loss_starts[0] >= max_length:
return None, None, None
if discard_non_loss_tokens_at_tail:
input_ids = input_ids[: loss_ends[-1]]
if max_length:
input_ids = input_ids[:max_length]
loss_ends[-1] = min(max_length, loss_ends[-1])
return input_ids, loss_starts, loss_ends
@ -125,6 +146,12 @@ def split_templated_prompt_into_chunks(messages: List[Dict[str, str]], prompt: s
content_length = (
prompt.find(end_of_assistant, first_occur + content_length) + len(end_of_assistant) - first_occur
)
# if the tokenized content start with a leading space, we want to keep it in loss calculation
# e.g., Assistant: I am saying...
# if the tokenized content doesn't start with a leading space, we only need to keep the content in loss calculation
# e.g.,
# Assistant: # '\n' as line breaker
# I am saying...
if prompt[first_occur - 1] != " ":
chunks.append(prompt[start_idx:first_occur])
chunks.append(prompt[first_occur : first_occur + content_length])

View File

@ -2,7 +2,7 @@ from .base import BaseModel
from .critic import Critic
from .generation import generate, generate_streaming, prepare_inputs_fn, update_model_kwargs_fn
from .lora import convert_to_lora_module
from .loss import DpoLoss, LogExpLoss, LogSigLoss, PolicyLoss, ValueLoss
from .loss import DpoLoss, KTOLoss, LogExpLoss, LogSigLoss, PolicyLoss, ValueLoss
from .reward_model import RewardModel
from .utils import disable_dropout
@ -16,7 +16,7 @@ __all__ = [
"LogExpLoss",
"convert_to_lora_module",
"DpoLoss",
"generate",
"KTOLoss" "generate",
"generate_streaming",
"disable_dropout",
"update_model_kwargs_fn",

View File

@ -42,7 +42,6 @@ class BaseModel(nn.Module):
out = self.model(dummy_input)
self.last_hidden_state_size = out.last_hidden_state.shape[-1]
self.model = self.model.cpu()
# print("self.last_hidden_state_size: ",self.last_hidden_state_size)
def resize_token_embeddings(self, *args, **kwargs):
"""

View File

@ -50,7 +50,7 @@ class LoraLinear(lora.LoRALayer, nn.Module):
self.fan_in_fan_out = fan_in_fan_out
# Actual trainable parameters
if r > 0:
self.lora_A = nn.Parameter(self.weight.new_zeros((r, in_features)))
self.lora_A = nn.Parameter(self.weight.new_zeros((r, in_features)), requires_grad=False)
self.lora_B = nn.Parameter(self.weight.new_zeros((out_features, r)))
self.scaling = self.lora_alpha / self.r
# Freezing the pre-trained weight matrix

View File

@ -5,6 +5,7 @@ loss functions
from typing import Optional, Tuple
import torch
import torch.distributed as dist
import torch.nn as nn
from .utils import masked_mean
@ -201,7 +202,72 @@ class OddsRatioLoss(nn.Module):
chosen_odds_masked = torch.sum(chosen_odds * chosen_loss_mask.float()) / torch.sum(chosen_loss_mask)
reject_odds = reject_logp - torch.log(-torch.exp(reject_logp) + 1.0001)
reject_odds_masked = torch.sum(reject_odds * reject_loss_mask.float()) / torch.sum(reject_loss_mask)
# print("chosen_odds_masked", chosen_odds_masked[0], "reject_odds_masked", reject_odds_masked[0])
log_odds_ratio = chosen_odds_masked - reject_odds_masked
ratio = torch.log(torch.nn.functional.sigmoid(log_odds_ratio))
return ratio.to(dtype=torch.bfloat16), log_odds_ratio
class KTOLoss(nn.Module):
def __init__(self, beta: float = 0.1, desirable_weight: float = 1.0, undesirable_weight: float = 1.0):
"""
Args:
beta: The temperature parameter in the KTO paper.
desirable_weight: The weight for the desirable responses.
undesirable_weight: The weight for the undesirable
"""
super().__init__()
self.beta = beta
self.desirable_weight = desirable_weight
self.undesirable_weight = undesirable_weight
def forward(
self,
chosen_logps: torch.Tensor,
rejected_logps: torch.Tensor,
kl_logps: torch.Tensor,
ref_chosen_logps: torch.Tensor,
ref_rejected_logps: torch.Tensor,
ref_kl_logps: torch.Tensor,
):
"""
Reference:
https://github.com/huggingface/trl/blob/a2adfb836a90d1e37b1253ab43dace05f1241e04/trl/trainer/kto_trainer.py#L585
Compute the KTO loss for a batch of policy and reference model log probabilities.
Args:
chosen_logps: Log probabilities of the policy model for the chosen responses. Shape: (batch_size,)
rejected_logps: Log probabilities of the policy model for the rejected responses. Shape: (batch_size,)
kl_logps: KL divergence of the policy model. Shape: (batch_size,)
ref_chosen_logps: Log probabilities of the reference model for the chosen responses. Shape: (batch_size,)
ref_rejected_logps: Log probabilities of the reference model for the rejected responses. Shape: (batch_size,)
ref_kl_logps: KL divergence of the reference model. Shape: (batch_size,)
beta: The temperature parameter in the DPO paper.
desirable_weight: The weight for the desirable responses.
undesirable_weight: The weight for the undesirable responses.
Refer to the KTO paper for details about hyperparameters https://arxiv.org/pdf/2402.01306
"""
kl = (kl_logps - ref_kl_logps).mean().detach()
# all gather
dist.all_reduce(kl, op=dist.ReduceOp.SUM)
kl = (kl / dist.get_world_size()).clamp(min=0)
if chosen_logps.shape[0] != 0 and ref_chosen_logps.shape[0] != 0:
chosen_logratios = chosen_logps - ref_chosen_logps
chosen_losses = 1 - nn.functional.sigmoid(self.beta * (chosen_logratios - kl))
chosen_rewards = self.beta * chosen_logratios.detach()
else:
chosen_losses = torch.Tensor([]).to(kl_logps.device)
chosen_rewards = torch.Tensor([]).to(kl_logps.device)
if rejected_logps.shape[0] != 0 and ref_rejected_logps.shape[0] != 0:
rejected_logratios = rejected_logps - ref_rejected_logps
rejected_losses = 1 - nn.functional.sigmoid(self.beta * (kl - rejected_logratios))
rejected_rewards = self.beta * rejected_logratios.detach()
else:
rejected_losses = torch.Tensor([]).to(kl_logps.device)
rejected_rewards = torch.Tensor([]).to(kl_logps.device)
losses = torch.cat((self.desirable_weight * chosen_losses, self.undesirable_weight * rejected_losses), 0).mean()
return losses, chosen_rewards, rejected_rewards, kl

View File

@ -1,8 +1,18 @@
from .base import OLTrainer, SLTrainer
from .dpo import DPOTrainer
from .kto import KTOTrainer
from .orpo import ORPOTrainer
from .ppo import PPOTrainer
from .rm import RewardModelTrainer
from .sft import SFTTrainer
__all__ = ["SLTrainer", "OLTrainer", "RewardModelTrainer", "SFTTrainer", "PPOTrainer", "DPOTrainer", "ORPOTrainer"]
__all__ = [
"SLTrainer",
"OLTrainer",
"RewardModelTrainer",
"SFTTrainer",
"PPOTrainer",
"DPOTrainer",
"ORPOTrainer",
"KTOTrainer",
]

View File

@ -26,7 +26,7 @@ from .utils import is_rank_0, to_device
class DPOTrainer(SLTrainer):
"""
Trainer for PPO algorithm.
Trainer for DPO algorithm.
Args:
actor (Actor): the actor model in ppo algorithm

View File

@ -0,0 +1,318 @@
"""
KTO trainer
"""
import os
from typing import Any, Optional
import torch
import torch.distributed
from coati.models.loss import KTOLoss
from coati.models.utils import calc_masked_log_probs
from coati.trainer.utils import all_reduce_mean
from coati.utils import AccumulativeMeanMeter, save_checkpoint
from torch.optim import Optimizer
from torch.optim.lr_scheduler import _LRScheduler
from torch.utils.data import DataLoader
from tqdm import trange
from transformers import PreTrainedTokenizerBase
from colossalai.booster import Booster
from colossalai.cluster import DistCoordinator
from colossalai.utils import get_current_device
from .base import SLTrainer
from .utils import is_rank_0, to_device
class KTOTrainer(SLTrainer):
"""
Trainer for KTO algorithm.
Args:
actor (Actor): the actor model in ppo algorithm
ref_model (Critic): the reference model in ppo algorithm
booster (Strategy): the strategy to use for training
actor_optim (Optimizer): the optimizer to use for actor model
actor_lr_scheduler (_LRScheduler): the lr scheduler to use for actor model
tokenizer (PreTrainedTokenizerBase): the tokenizer to use for encoding
max_epochs (int, defaults to 1): the max number of epochs to train
accumulation_steps (int): the number of steps to accumulate gradients
start_epoch (int, defaults to 0): the start epoch, non-zero if resumed from a checkpoint
save_interval (int): the interval to save model checkpoints, default to 0, which means no checkpoint will be saved during trainning
save_dir (str): the directory to save checkpoints
coordinator (DistCoordinator): the coordinator to use for distributed logging
beta (float, defaults to 0.1): the beta parameter in kto loss
desirable_weight (float, defaults to 1.0): the weight for desirable reward
undesirable_weight (float, defaults to 1.0): the weight for undesirable reward
"""
def __init__(
self,
actor: Any,
ref_model: Any,
booster: Booster,
actor_optim: Optimizer,
actor_lr_scheduler: _LRScheduler,
tokenizer: PreTrainedTokenizerBase,
max_epochs: int = 1,
beta: float = 0.1,
desirable_weight: float = 1.0,
undesirable_weight: float = 1.0,
accumulation_steps: int = 1,
start_epoch: int = 0,
save_interval: int = 0,
save_dir: str = None,
coordinator: DistCoordinator = None,
) -> None:
super().__init__(booster, max_epochs=max_epochs, model=actor, optimizer=actor_optim, start_epoch=start_epoch)
self.ref_model = ref_model
self.actor_scheduler = actor_lr_scheduler
self.tokenizer = tokenizer
self.kto_loss = KTOLoss(beta=beta, desirable_weight=desirable_weight, undesirable_weight=undesirable_weight)
self.save_interval = save_interval
self.coordinator = coordinator
self.save_dir = save_dir
self.num_train_step = 0
self.accumulation_steps = accumulation_steps
self.device = get_current_device()
self.accumulative_meter = AccumulativeMeanMeter()
self.desirable_weight = desirable_weight
self.undesirable_weight = undesirable_weight
self.beta = beta
def _before_fit(
self,
train_preference_dataloader: DataLoader = None,
eval_preference_dataloader: DataLoader = None,
log_dir: Optional[str] = None,
use_wandb: bool = False,
):
"""
Args:
prompt_dataloader (DataLoader): the dataloader to use for prompt data
pretrain_dataloader (DataLoader): the dataloader to use for pretrain data
"""
self.train_dataloader = train_preference_dataloader
self.eval_dataloader = eval_preference_dataloader
self.writer = None
if use_wandb and is_rank_0():
assert log_dir is not None, "log_dir must be provided when use_wandb is True"
import wandb
self.wandb_run = wandb.init(project="Coati-kto", sync_tensorboard=True)
if log_dir is not None and is_rank_0():
import os
import time
from torch.utils.tensorboard import SummaryWriter
log_dir = os.path.join(log_dir, "kto")
log_dir = os.path.join(log_dir, time.strftime("%Y-%m-%d_%H:%M:%S", time.localtime()))
self.writer = SummaryWriter(log_dir=log_dir)
def _train(self, epoch: int):
"""
Args:
epoch int: the number of current epoch
"""
self.model.train()
self.accumulative_meter.reset()
step_bar = trange(
len(self.train_dataloader) // self.accumulation_steps,
desc=f"Epoch {epoch + 1}/{self.max_epochs}",
disable=not is_rank_0(),
)
for i, batch in enumerate(self.train_dataloader):
batch = to_device(batch, self.device)
(input_ids, attention_mask, loss_mask, label, kl_input_ids, kl_attention_mask, kl_loss_mask) = (
batch["input_ids"],
batch["attention_mask"],
batch["loss_mask"],
batch["label"],
batch["kl_input_ids"],
batch["kl_attention_mask"],
batch["kl_loss_mask"],
)
batch_size = input_ids.size()[0]
# actor logits
with torch.no_grad():
# calculate KL term with KT data
kl_logits = self.model(
input_ids=kl_input_ids,
attention_mask=kl_attention_mask,
)["logits"]
logits = self.model(
input_ids=input_ids,
attention_mask=attention_mask,
)["logits"]
logprob = calc_masked_log_probs(logits, input_ids, loss_mask[:, 1:]).sum(-1)
kl_logprob = calc_masked_log_probs(kl_logits, kl_input_ids, kl_loss_mask[:, 1:]).sum(-1)
chosen_index = [i for i in range(batch_size) if label[i] == 1]
rejected_index = [i for i in range(batch_size) if label[i] == 0]
chosen_logprob = logprob[chosen_index]
rejected_logprob = logprob[rejected_index]
with torch.no_grad():
ref_kl_logits = self.ref_model(
input_ids=kl_input_ids,
attention_mask=kl_attention_mask,
)["logits"]
ref_logits = self.ref_model(
input_ids=input_ids,
attention_mask=attention_mask,
)["logits"]
ref_logprob = calc_masked_log_probs(ref_logits, input_ids, loss_mask[:, 1:]).sum(-1)
ref_kl_logprob = calc_masked_log_probs(ref_kl_logits, kl_input_ids, kl_loss_mask[:, 1:]).sum(-1)
ref_chosen_logprob = ref_logprob[chosen_index]
ref_rejected_logprob = ref_logprob[rejected_index]
loss, chosen_rewards, rejected_rewards, kl = self.kto_loss(
chosen_logprob, rejected_logprob, kl_logprob, ref_chosen_logprob, ref_rejected_logprob, ref_kl_logprob
)
self.booster.backward(loss=loss, optimizer=self.optimizer)
if self.num_train_step % self.accumulation_steps == self.accumulation_steps - 1:
self.optimizer.step()
self.optimizer.zero_grad()
self.actor_scheduler.step()
# sync
loss_mean = all_reduce_mean(tensor=loss)
chosen_rewards_mean = all_reduce_mean(tensor=chosen_rewards.mean())
rejected_rewards_mean = all_reduce_mean(tensor=rejected_rewards.mean())
self.accumulative_meter.add("chosen_rewards", chosen_rewards_mean.to(torch.float16).mean().item())
self.accumulative_meter.add("rejected_rewards", rejected_rewards_mean.to(torch.float16).mean().item())
self.accumulative_meter.add("loss", loss_mean.to(torch.float16).detach().item())
if i % self.accumulation_steps == self.accumulation_steps - 1:
self.num_train_step += 1
step_bar.update()
# logging
if self.writer and is_rank_0():
self.writer.add_scalar("train/loss", self.accumulative_meter.get("loss"), self.num_train_step)
self.writer.add_scalar("train/lr", self.optimizer.param_groups[0]["lr"], self.num_train_step)
self.writer.add_scalar(
"train/chosen_rewards", self.accumulative_meter.get("chosen_rewards"), self.num_train_step
)
self.writer.add_scalar(
"train/rejected_rewards",
self.accumulative_meter.get("rejected_rewards"),
self.num_train_step,
)
self.writer.add_scalar(
"train/margin",
self.accumulative_meter.get("chosen_rewards") - self.accumulative_meter.get("rejected_rewards"),
self.num_train_step,
)
self.accumulative_meter.reset()
if self.save_dir is not None and (self.num_train_step + 1) % self.save_interval == 0:
# save checkpoint
self.coordinator.print_on_master("\nStart saving model checkpoint with running states")
save_checkpoint(
save_dir=self.save_dir,
booster=self.booster,
model=self.model,
optimizer=self.optimizer,
lr_scheduler=self.actor_scheduler,
epoch=epoch,
step=i + 1,
batch_size=batch_size,
coordinator=self.coordinator,
)
self.coordinator.print_on_master(
f"Saved checkpoint at epoch {epoch} step {self.save_interval} at folder {self.save_dir}"
)
step_bar.close()
def _eval(self, epoch: int):
"""
Args:
epoch int: the number of current epoch
"""
if self.eval_dataloader is None:
self.coordinator.print_on_master("No eval dataloader is provided, skip evaluation")
return
self.model.eval()
self.accumulative_meter.reset()
step_bar = trange(
len(self.train_dataloader) // self.accumulation_steps,
desc=f"Epoch {epoch + 1}/{self.max_epochs}",
disable=not is_rank_0(),
)
for i, batch in enumerate(self.train_dataloader):
batch = to_device(batch, self.device)
(input_ids, attention_mask, loss_mask, label, kl_input_ids, kl_attention_mask, kl_loss_mask) = (
batch["input_ids"],
batch["attention_mask"],
batch["loss_mask"],
batch["label"],
batch["kl_input_ids"],
batch["kl_attention_mask"],
batch["kl_loss_mask"],
)
batch_size = input_ids.size()[0]
# actor logits
with torch.no_grad():
# calculate KL term with KT data
kl_logits = self.model(
input_ids=kl_input_ids,
attention_mask=kl_attention_mask,
)["logits"]
logits = self.model(
input_ids=input_ids,
attention_mask=attention_mask,
)["logits"]
logprob = calc_masked_log_probs(logits, input_ids, loss_mask[:, 1:]).sum(-1)
kl_logprob = calc_masked_log_probs(kl_logits, kl_input_ids, kl_loss_mask[:, 1:]).sum(-1)
chosen_index = [i for i in range(batch_size) if label[i] == 1]
rejected_index = [i for i in range(batch_size) if label[i] == 0]
chosen_logprob = logprob[chosen_index]
rejected_logprob = logprob[rejected_index]
with torch.no_grad():
ref_kl_logits = self.ref_model(
input_ids=kl_input_ids,
attention_mask=kl_attention_mask,
)["logits"]
ref_logits = self.ref_model(
input_ids=input_ids,
attention_mask=attention_mask,
)["logits"]
ref_logprob = calc_masked_log_probs(ref_logits, input_ids, loss_mask[:, 1:]).sum(-1)
ref_kl_logprob = calc_masked_log_probs(ref_kl_logits, kl_input_ids, kl_loss_mask[:, 1:]).sum(-1)
ref_chosen_logprob = ref_logprob[chosen_index]
ref_rejected_logprob = ref_logprob[rejected_index]
loss, chosen_rewards, rejected_rewards, kl = self.kto_loss(
chosen_logprob, rejected_logprob, kl_logprob, ref_chosen_logprob, ref_rejected_logprob, ref_kl_logprob
)
# sync
loss_mean = all_reduce_mean(tensor=loss)
chosen_rewards_mean = all_reduce_mean(tensor=chosen_rewards.mean())
rejected_rewards_mean = all_reduce_mean(tensor=rejected_rewards.mean())
self.accumulative_meter.add("chosen_rewards", chosen_rewards_mean.to(torch.float16).mean().item())
self.accumulative_meter.add("rejected_rewards", rejected_rewards_mean.to(torch.float16).mean().item())
self.accumulative_meter.add("loss", loss_mean.to(torch.float16).detach().item())
self.accumulative_meter.add(
"margin", (chosen_rewards_mean - rejected_rewards_mean).to(torch.float16).mean().item()
)
step_bar.update()
msg = "Evaluation Result:\n"
for tag in ["loss", "chosen_rewards", "rejected_rewards", "margin"]:
msg = msg + f"{tag}: {self.accumulative_meter.get(tag)}\n"
self.coordinator.print_on_master(msg)
os.makedirs(self.save_dir, exist_ok=True)
with open(os.path.join(self.save_dir, f"eval_result_epoch{epoch}.txt"), "w") as f:
f.write(msg)
step_bar.close()

View File

@ -26,7 +26,7 @@ from .utils import is_rank_0, to_device
class ORPOTrainer(SLTrainer):
"""
Trainer for PPO algorithm.
Trainer for ORPO algorithm.
Args:
actor (Actor): the actor model in ppo algorithm

View File

@ -30,6 +30,8 @@
- [DPO Stage 1: Supervised Instruction Tuning](#dpo-training-stage1---supervised-instructs-tuning)
- [DPO Stage 2: DPO Training](#dpo-training-stage2---dpo-training)
- [Alternative Option For RLHF: Simple Preference Optimization](#alternative-option-for-rlhf-simple-preference-optimization)
- [Alternative Option For RLHF: Kahneman-Tversky Optimization (KTO)](#alternative-option-for-rlhf-kahneman-tversky-optimization-kto)
- [Alternative Option For RLHF: Odds Ratio Preference Optimization](#alternative-option-for-rlhf-odds-ratio-preference-optimization)
- [List of Supported Models](#list-of-supported-models)
- [Hardware Requirements](#hardware-requirements)
- [Inference example](#inference-example)
@ -446,7 +448,7 @@ The first step in Stage 1 is to collect a dataset of human demonstrations of the
{"messages":
[
{
"from": "human",
"from": "user",
"content": "what are some pranks with a pen i can do?"
},
{
@ -527,7 +529,7 @@ Below shows the preference dataset format used in training the reward model.
[
{"context": [
{
"from": "human",
"from": "user",
"content": "Introduce butterflies species in Oregon."
}
]
@ -596,7 +598,7 @@ In stage3 we will use reinforcement learning algorithm--- Proximal Policy Optimi
#### Step 1: Data Collection
PPO uses two kinds of training data--- the prompt data and the pretrain data (optional). The first dataset is mandatory, data samples within the prompt dataset ends with a line from "human" and thus the "assistant" needs to generate a response to answer to the "human". Note that you can still use conversation that ends with a line from the "assistant", in that case, the last line will be dropped. Here is an example of the prompt dataset format.
PPO uses two kinds of training data--- the prompt data and the pretrain data (optional). The first dataset is mandatory, data samples within the prompt dataset ends with a line from "user" and thus the "assistant" needs to generate a response to answer to the "user". Note that you can still use conversation that ends with a line from the "assistant", in that case, the last line will be dropped. Here is an example of the prompt dataset format.
```json
@ -604,7 +606,7 @@ PPO uses two kinds of training data--- the prompt data and the pretrain data (op
{"messages":
[
{
"from": "human",
"from": "user",
"content": "what are some pranks with a pen i can do?"
}
...
@ -744,13 +746,40 @@ with a Reference-Free Reward](https://arxiv.org/pdf/2405.14734) (SimPO). Which i
### Alternative Option For RLHF: Odds Ratio Preference Optimization
We support the method introduced in the paper [ORPO: Monolithic Preference Optimization without Reference Model](https://arxiv.org/abs/2403.07691) (ORPO). Which is a reference model free aligment method that mixes the SFT loss with a reinforcement learning loss that uses odds ratio as the implicit reward to enhance training stability and efficiency. Simply set the flag to disable the use of the reference model, set the reward target margin and enable length normalization in the DPO training script. To use ORPO in alignment, use the [train_orpo.sh](./training_scripts/train_orpo.sh) script, You can set the value for `lambda` (which determine how strongly the reinforcement learning loss affect the training) but it is optional.
We support the method introduced in the paper [ORPO: Monolithic Preference Optimization without Reference Model](https://arxiv.org/abs/2403.07691) (ORPO). Which is a reference model free aligment method that mixes the SFT loss with a reinforcement learning loss that uses odds ratio as the implicit reward to enhance training stability and efficiency. To use ORPO in alignment, use the [train_orpo.sh](./training_scripts/train_orpo.sh) script, You can set the value for `lambda` (which determine how strongly the reinforcement learning loss affect the training) but it is optional.
#### ORPO Result
<p align="center">
<img width="1000" alt="image" src="https://raw.githubusercontent.com/hpcaitech/public_assets/main/applications/chat/ORPO_margin.png">
</p>
### Alternative Option For RLHF: Kahneman-Tversky Optimization (KTO)
We support the method introduced in the paper [KTO:Model Alignment as Prospect Theoretic Optimization](https://arxiv.org/pdf/2402.01306) (KTO). Which is a aligment method that directly maximize "human utility" of generation results.
For KTO data preparation, please use the script [prepare_kto_dataset.sh](./examples/data_preparation_scripts/prepare_kto_dataset.sh). You will need preference data, different from DPO and its derivatives, you no longer need a pair of chosen/rejected response for the same input. You only need data whose response is associated with a preference label--- whether the response is okay or not, read the papre for more details. You also need to convert your data to the following intermediate format before you run the data preparation script.
```jsonl
{
"prompt": [
{"from": "user", "content": "What are some cuss words in english?"},
{
"from": "assistant",
"content": "Here's an incomplete list.\n\nAss, dick, bugger, crap, ...",
},
{"from": "user", "content": "What's your favorite one?"},
],
"completion": {"from": "assistant", "content": "Ass."}, # the completion must contain a single line from the assistant.
"label": False, # whether the response is favorable or not
}
```
For training, use the [train_kto.sh](./examples/training_scripts/train_orpo.sh) script, You may need to set the value for `beta` (which determine how strongly the reinforcement learning loss affect the training), `desirable_weight` and `undesirable_weight` if your data is biased (has unequal number of chosen and rejected samples).
#### KTO Result
<p align="center">
<img width="1000" alt="image" src="https://raw.githubusercontent.com/hpcaitech/public_assets/main/applications/chat/KTO.png">
</p>
## Hardware Requirements
For SFT, we recommend using zero2 or zero2-cpu for 7B model and tp is your model is extra large. We tested the VRAM consumption on a dummy dataset with a sequence length of 2048. In all experiments, we use H800 GPUs with 80GB VRAM and enable gradient checkpointing and flash attention.
@ -801,6 +830,14 @@ For ORPO, we recommend using zero2 or zero2-cpu. We tested the VRAM consumption
- zero2, micro batch size=4, VRAM Usage=45309.52 MB
- zero2, micro batch size=8, VRAM Usage=58086.37 MB
For KTO, we recommend using zero2-cpu or zero2 plugin, We tested the VRAM consumption on a dummy dataset with 2048 sequence length.
- 2 H800 GPU
- zero2-cpu, micro batch size=2, VRAM Usage=35241.98 MB
- zero2-cpu, micro batch size=4, VRAM Usage=38989.37 MB
- 4 H800 GPUs
- zero2_cpu, micro batch size=2, VRAM_USAGE=32443.22 MB
- zero2, micro batch size=4, VRAM_USAGE=59307.97 MB
## List of Supported Models
For SFT, we support the following models/series:

View File

@ -40,7 +40,7 @@ import random
import time
from multiprocessing import cpu_count
from coati.dataset import setup_conversation_template, supervised_tokenize_sft, tokenize_prompt_dataset, tokenize_rlhf
from coati.dataset import setup_conversation_template, tokenize_kto, tokenize_prompt, tokenize_rlhf, tokenize_sft
from datasets import dataset_dict, load_dataset
from transformers import AutoTokenizer
@ -56,8 +56,8 @@ def main():
type=str,
required=True,
default=None,
choices=["sft", "prompt", "preference"],
help="Type of dataset, chose from 'sft', 'prompt', 'preference'.",
choices=["sft", "prompt", "preference", "kto"],
help="Type of dataset, chose from 'sft', 'prompt', 'preference'. 'kto'",
)
parser.add_argument(
"--data_input_dirs",
@ -199,11 +199,13 @@ def main():
)
if args.type == "sft":
preparation_function = supervised_tokenize_sft
preparation_function = tokenize_sft
elif args.type == "prompt":
preparation_function = tokenize_prompt_dataset
preparation_function = tokenize_prompt
elif args.type == "preference":
preparation_function = tokenize_rlhf
elif args.type == "kto":
preparation_function = tokenize_kto
else:
raise ValueError("Unknow dataset type. Please choose one from ['sft', 'prompt', 'preference']")
@ -228,10 +230,13 @@ def main():
keep_in_memory=False,
num_proc=min(len(dataset), cpu_count()),
)
dataset = dataset.filter(
lambda data: data["chosen_input_ids" if args.type == "preference" else "input_ids"] is not None
)
if args.type == "kto":
filter_by = "completion"
elif args.type == "preference":
filter_by = "chosen_input_ids"
else:
filter_by = "input_ids"
dataset = dataset.filter(lambda data: data[filter_by] is not None)
# Save each jsonl spliced dataset.
output_index = "0" * (5 - len(str(index))) + str(index)

View File

@ -0,0 +1,14 @@
SAVE_DIR=""
rm -rf $SAVE_DIR/cache
rm -rf $SAVE_DIR/jsonl
rm -rf $SAVE_DIR/arrow
python prepare_dataset.py --type kto \
--data_input_dirs /PATH/TO/KTO/DATASET \
--conversation_template_config /PATH/TO/CHAT/TEMPLATE/CONFIG.json \
--tokenizer_dir "" \
--data_cache_dir $SAVE_DIR/cache \
--data_jsonl_output_dir $SAVE_DIR/jsonl \
--data_arrow_output_dir $SAVE_DIR/arrow \
--max_length 1024

View File

@ -10,4 +10,5 @@ python prepare_dataset.py --type preference \
--tokenizer_dir "" \
--data_cache_dir $SAVE_DIR/cache \
--data_jsonl_output_dir $SAVE_DIR/jsonl \
--data_arrow_output_dir $SAVE_DIR/arrow
--data_arrow_output_dir $SAVE_DIR/arrow \
--max_length 1024

View File

@ -10,4 +10,5 @@ python prepare_dataset.py --type prompt \
--tokenizer_dir "" \
--data_cache_dir $SAVE_DIR/cache \
--data_jsonl_output_dir $SAVE_DIR/jsonl \
--data_arrow_output_dir $SAVE_DIR/arrow
--data_arrow_output_dir $SAVE_DIR/arrow \
--max_length 1024

View File

@ -11,3 +11,4 @@ python prepare_dataset.py --type sft \
--data_cache_dir $SAVE_DIR/cache \
--data_jsonl_output_dir $SAVE_DIR/jsonl \
--data_arrow_output_dir $SAVE_DIR/arrow \
--max_length 4096

View File

@ -0,0 +1,104 @@
==========
round 1:
<s>[INST] <<SYS>>
A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions.
<</SYS>>
tell me a story [/INST] Great, lets hear a story. </s>
==========
==========
round 2:
<s>[INST] <<SYS>>
A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions.
<</SYS>>
tell me a story [/INST] Great, lets hear a story. </s><s>[INST] calculate 1+1 [/INST] 1+1=2 </s>
==========
==========
round 3:
<s>[INST] <<SYS>>
A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions.
<</SYS>>
tell me a story [/INST] Great, lets hear a story. </s><s>[INST] calculate 1+1 [/INST] 1+1=2 </s><s>[INST] who is the first president of the USA [/INST] The first president of the United States was George Washington. </s>
==========
==========
round 1:
<s>[INST] <<SYS>>
A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions.
<</SYS>>
who is the first president of the USA? [/INST] The first president of the United States was George Washington. </s>
==========
==========
round 2:
<s>[INST] <<SYS>>
A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions.
<</SYS>>
who is the first president of the USA? [/INST] The first president of the United States was George Washington. </s><s>[INST] tell me a story [/INST] One story that might be interesting is the story of how the United States was founded. In 1776, the Thirteen Colonies united together to form the new nation of America. The first president of the new nation was George Washington. The first state was Pennsylvania. The first capital was Philadelphia. The first presidential election was held in 1787, and was between George Read, a Federalist, and John Adams, a Republican. The Federalists wanted to keep the power of the federal government limited, while the Republicans wanted the power to be spread around to the states. The Federalists won, and George Washington was elected president. </s>
==========
==========
round 1:
<s>[INST] <<SYS>>
A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions.
<</SYS>>
tell me a story [/INST] I do enjoy telling stories. Is there a particular story youd like to hear? </s>
==========
==========
round 2:
<s>[INST] <<SYS>>
A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions.
<</SYS>>
tell me a story [/INST] I do enjoy telling stories. Is there a particular story youd like to hear? </s><s>[INST] about Donald Trump [/INST] Id be happy to listen to a story about Donald Trump. I have no idea what a “story” is, but I do know a lot about human life and human behavior. I can try. Do you have a theme or a question youd like to ask about Donald Trump? </s>
==========
==========
round 3:
<s>[INST] <<SYS>>
A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions.
<</SYS>>
tell me a story [/INST] I do enjoy telling stories. Is there a particular story youd like to hear? </s><s>[INST] about Donald Trump [/INST] Id be happy to listen to a story about Donald Trump. I have no idea what a “story” is, but I do know a lot about human life and human behavior. I can try. Do you have a theme or a question youd like to ask about Donald Trump? </s><s>[INST] Is Donald Trump the president of the United States [/INST] Yes, Donald Trump became the 45th president of the United States in January of 2016. </s>
==========

View File

@ -287,9 +287,12 @@ def train(args):
LORA_MANAGER.merge_weights = True
model.eval()
# save model checkpoint after fitting on only rank0
coordinator.print_on_master("Start saving final model checkpoint")
booster.save_model(model, os.path.join(args.save_dir, "modeling"), shard=True)
coordinator.print_on_master(f"Saved final model checkpoint at epoch {args.max_epochs} at folder {args.save_dir}")
if args.save_dir is not None:
coordinator.print_on_master("Start saving final model checkpoint")
booster.save_model(model, os.path.join(args.save_dir, "modeling"), shard=True)
coordinator.print_on_master(
f"Saved final model checkpoint at epoch {args.max_epochs} at folder {args.save_dir}"
)
coordinator.print_on_master(f"Max CUDA memory usage: {torch.cuda.max_memory_allocated()/1024**2:.2f} MB")
@ -328,8 +331,8 @@ if __name__ == "__main__":
parser.add_argument(
"--checkpoint_path", type=str, default=None, help="Checkpoint path if need to resume training form a checkpoint"
)
parser.add_argument("--config_file", type=str, default="config_file", help="Config file")
parser.add_argument("--save_dir", type=str, default="output")
parser.add_argument("--config_file", type=str, default=None, help="Config file")
parser.add_argument("--save_dir", type=str, default=None)
parser.add_argument("--max_length", type=int, default=2048, help="Model max length")
parser.add_argument("--max_epochs", type=int, default=3)
parser.add_argument("--batch_size", type=int, default=4)
@ -351,7 +354,7 @@ if __name__ == "__main__":
parser.add_argument("--merge_lora_weights", type=bool, default=True)
parser.add_argument("--lr", type=float, default=5e-6)
parser.add_argument("--accumulation_steps", type=int, default=8)
parser.add_argument("--log_dir", default="logs", type=str)
parser.add_argument("--log_dir", default=None, type=str)
parser.add_argument("--use_wandb", default=False, action="store_true")
parser.add_argument("--grad_checkpoint", default=False, action="store_true")
parser.add_argument("--use_flash_attn", default=False, action="store_true")
@ -362,7 +365,8 @@ if __name__ == "__main__":
args.length_normalization = True
args.gamma = args.gamma if args.gamma > 0 else 1.4
os.makedirs(os.path.dirname(args.config_file), exist_ok=True)
with open(args.config_file, "w") as f:
json.dump(args.__dict__, f, indent=4)
if args.config_file is not None:
os.makedirs(os.path.dirname(args.config_file), exist_ok=True)
with open(args.config_file, "w") as f:
json.dump(args.__dict__, f, indent=4)
train(args)

View File

@ -18,6 +18,7 @@ set_n_least_used_CUDA_VISIBLE_DEVICES 4
PROJECT_NAME="DPO"
PARENT_SAVE_DIR="" # Path to a folder to save checkpoints
PARENT_CONFIG_FILE="" # Path to a folder to save training config logs
PARENT_LOG_DIR="" # Path to a folder to save training config logs
PRETRAINED_MODEL_PATH="" # huggingface or local model path
PRETRAINED_TOKENIZER_PATH="" # huggingface or local tokenizer path
@ -38,6 +39,7 @@ TIMESTAMP=$(date +%Y-%m-%d-%H-%M-%S)
FULL_PROJECT_NAME="${PROJECT_NAME}-${TIMESTAMP}"
SAVE_DIR="${PARENT_SAVE_DIR}${FULL_PROJECT_NAME}"
CONFIG_FILE="${PARENT_CONFIG_FILE}${FULL_PROJECT_NAME}.json"
LOG_DIR="${PARENT_LOG_DIR}${FULL_PROJECT_NAME}"
colossalai run --nproc_per_node 4 --hostfile hostfile --master_port 31313 train_dpo.py \
--pretrain $PRETRAINED_MODEL_PATH \
@ -47,6 +49,7 @@ colossalai run --nproc_per_node 4 --hostfile hostfile --master_port 31313 train_
--save_interval 1000 \
--save_dir $SAVE_DIR \
--config_file $CONFIG_FILE \
--log_dir $LOG_DIR \
--max_epochs 1 \
--accumulation_steps 2 \
--batch_size 16 \

View File

@ -5,11 +5,10 @@ import resource
from contextlib import nullcontext
import torch
from coati.dataset import DataCollatorForPreferenceDataset, StatefulDistributedSampler
from coati.dataset import DataCollatorForKTODataset, StatefulDistributedSampler, load_tokenized_dataset
from coati.models import convert_to_lora_module, disable_dropout
from coati.trainer import DPOTrainer
from coati.trainer import KTOTrainer
from coati.utils import load_checkpoint
from dummy_dataset import DummyLLMDataset
from transformers import AutoModelForCausalLM, AutoTokenizer
import colossalai
@ -117,18 +116,15 @@ def train(args):
else:
model = AutoModelForCausalLM.from_pretrained(args.pretrain)
disable_dropout(model)
if not args.disable_reference_model:
if args.use_flash_attn:
ref_model = AutoModelForCausalLM.from_pretrained(
args.pretrain,
torch_dtype=torch.bfloat16 if args.mixed_precision == "bf16" else torch.float16,
use_flash_attention_2=True,
)
else:
ref_model = AutoModelForCausalLM.from_pretrained(args.pretrain)
disable_dropout(ref_model)
if args.use_flash_attn:
ref_model = AutoModelForCausalLM.from_pretrained(
args.pretrain,
torch_dtype=torch.bfloat16 if args.mixed_precision == "bf16" else torch.float16,
use_flash_attention_2=True,
)
else:
ref_model = None
ref_model = AutoModelForCausalLM.from_pretrained(args.pretrain)
disable_dropout(ref_model)
if args.lora_rank > 0:
model = convert_to_lora_module(model, args.lora_rank, lora_train_bias=args.lora_train_bias)
@ -164,13 +160,32 @@ def train(args):
)
# configure dataset
coordinator.print_on_master(f"Load dataset: {args.dataset}")
mode_map = {"train": "train", "valid": "validation", "test": "test"}
train_dataset = DummyLLMDataset(
["chosen_input_ids", "chosen_loss_mask", "rejected_input_ids", "rejected_loss_mask"],
args.max_length,
args.dataset_size,
)
data_collator = DataCollatorForPreferenceDataset(tokenizer=tokenizer, max_length=args.max_length)
train_dataset = load_tokenized_dataset(dataset_paths=args.dataset, mode="train", mode_map=mode_map)
num_desirable = 0
num_undesirable = 0
for i in range(len(train_dataset)):
if train_dataset[i]["label"]:
num_desirable += 1
else:
num_undesirable += 1
logger.info(f"Dataset Statistics:\nDesirable: {num_desirable}\nUndesirable: {num_undesirable}")
# Check if the user specified weights fit into the theoratical lower and upper bounds from Eq. (8) of https://arxiv.org/abs/2402.01306
actual_ratio = (args.desirable_weight * num_desirable) / (args.undesirable_weight * num_undesirable)
if actual_ratio < 1 or actual_ratio > 4 / 3:
if not args.auto_weight:
raise AssertionError(
f"Desirable weight and undesirable weight are not within the theoratical bounds, [1, 4/3]. Actual ratio: {actual_ratio}, please increase/decrease desirable weight or decrease/increase undesirable weight."
)
else:
args.desirable_weight = args.desirable_weight / actual_ratio
coordinator.print_on_master(
f"Desirable weight and undesirable weight are not within the theoratical bounds, [1, 4/3]. Actual ratio: {actual_ratio}, auto weight is enabled, set desirable weight to {args.desirable_weight} and undesirable weight to {args.undesirable_weight}"
)
data_collator = DataCollatorForKTODataset(tokenizer=tokenizer, max_length=args.max_length)
train_dataloader = plugin.prepare_dataloader(
dataset=train_dataset,
@ -180,6 +195,21 @@ def train(args):
collate_fn=data_collator,
distributed_sampler_cls=StatefulDistributedSampler,
)
eval_dataloader = None
if args.eval_dataset:
eval_dataset = load_tokenized_dataset(dataset_paths=args.eval_dataset, mode="dev")
eval_data_collator = DataCollatorForKTODataset(tokenizer=tokenizer, max_length=args.max_length)
eval_dataloader = plugin.prepare_dataloader(
dataset=eval_dataset,
batch_size=args.batch_size,
shuffle=True,
drop_last=True,
collate_fn=eval_data_collator,
distributed_sampler_cls=StatefulDistributedSampler,
)
else:
logger.warning("No evaluation dataset is provided, skip evaluation")
num_update_steps_per_epoch = len(train_dataloader) // args.accumulation_steps
if args.warmup_steps is None:
@ -244,7 +274,7 @@ def train(args):
f"Checkpoint loaded max CPU memory: {resource.getrusage(resource.RUSAGE_SELF).ru_maxrss / 1024:.2f} MB"
)
trainer = DPOTrainer(
trainer = KTOTrainer(
actor=model,
ref_model=ref_model,
booster=booster,
@ -254,20 +284,35 @@ def train(args):
max_epochs=args.max_epochs,
accumulation_steps=args.accumulation_steps,
start_epoch=start_epoch,
save_interval=None,
save_dir=None,
save_interval=args.save_interval,
save_dir=args.save_dir,
coordinator=coordinator,
beta=args.beta,
gamma=args.gamma,
length_normalization=args.length_normalization,
desirable_weight=args.desirable_weight,
undesirable_weight=args.undesirable_weight,
)
trainer.fit(
train_preference_dataloader=train_dataloader,
eval_preference_dataloader=None,
log_dir=None,
use_wandb=False,
eval_preference_dataloader=eval_dataloader,
log_dir=args.log_dir,
use_wandb=args.use_wandb,
)
if args.lora_rank > 0 and args.merge_lora_weights:
from coati.models.lora import LORA_MANAGER
# NOTE: set model to eval to merge LoRA weights
LORA_MANAGER.merge_weights = True
model.eval()
# save model checkpoint after fitting on only rank0
if args.save_dir is not None:
coordinator.print_on_master("Start saving final model checkpoint")
booster.save_model(model, os.path.join(args.save_dir, "modeling"), shard=True)
coordinator.print_on_master(
f"Saved final model checkpoint at epoch {args.max_epochs} at folder {args.save_dir}"
)
coordinator.print_on_master(f"Max CUDA memory usage: {torch.cuda.max_memory_allocated()/1024**2:.2f} MB")
@ -289,31 +334,26 @@ if __name__ == "__main__":
parser.add_argument("--tp", type=int, default=1)
parser.add_argument("--pp", type=int, default=1)
parser.add_argument("--sp", type=int, default=1)
parser.add_argument("--loss_type", type=str, default="dpo_loss", help="dpo_loss or simpo_loss")
parser.add_argument("--beta", type=float, default=0.1, help="beta in DPO loss")
parser.add_argument("--gamma", type=float, default=0.0, help="gamma in SimPO loss")
parser.add_argument("--length_normalization", default=False, action="store_true")
parser.add_argument("--beta", type=float, default=0.1, help="beta in KTO loss")
parser.add_argument("--desirable_weight", type=float, default=1.0, help="desirable_weight in KTO loss")
parser.add_argument("--undesirable_weight", type=float, default=1.0, help="undesirable_weight in KTO loss")
parser.add_argument("--enable_sequence_parallelism", default=False, action="store_true")
parser.add_argument("--zero_stage", type=int, default=0, help="Zero stage", choices=[0, 1, 2])
parser.add_argument("--zero_cpu_offload", default=False, action="store_true")
parser.add_argument("--sp_mode", type=str, default="split_gather", choices=["split_gather", "ring", "all_to_all"])
parser.add_argument("--pretrain", type=str, default=None)
parser.add_argument("--model_type", type=str, default=None)
parser.add_argument("--tokenizer_dir", type=str, default=None)
parser.add_argument("--dataset", nargs="+", default=[])
parser.add_argument("--eval_dataset", nargs="+", default=[])
parser.add_argument(
"--checkpoint_path", type=str, default=None, help="Checkpoint path if need to resume training form a checkpoint"
)
parser.add_argument("--config_file", type=str, default="config_file", help="Config file")
parser.add_argument("--config_file", type=str, default=None, help="Config file")
parser.add_argument("--save_dir", type=str, default=None)
parser.add_argument("--max_length", type=int, default=2048, help="Model max length")
parser.add_argument("--max_epochs", type=int, default=3)
parser.add_argument("--batch_size", type=int, default=4)
parser.add_argument("--dataset_size", type=int, default=500)
parser.add_argument(
"--disable_reference_model",
action="store_true",
default=False,
help="Disable the reference model (enabled by default)",
)
parser.add_argument("--mixed_precision", type=str, default="fp16", choices=["fp16", "bf16"], help="Mixed precision")
parser.add_argument("--lora_rank", type=int, default=0, help="low-rank adaptation matrices rank")
parser.add_argument(
@ -322,19 +362,18 @@ if __name__ == "__main__":
default="none",
help="'none' means it doesn't train biases. 'all' means it trains all biases. 'lora_only' means it only trains biases of LoRA layers",
)
parser.add_argument("--save_interval", type=int, default=1000, help="number of step between two checkpoints")
parser.add_argument("--merge_lora_weights", type=bool, default=True)
parser.add_argument("--auto_weight", default=False, action="store_true")
parser.add_argument("--lr", type=float, default=5e-6)
parser.add_argument("--accumulation_steps", type=int, default=8)
parser.add_argument("--log_dir", default=None, type=str)
parser.add_argument("--use_wandb", default=False, action="store_true")
parser.add_argument("--grad_checkpoint", default=False, action="store_true")
parser.add_argument("--use_flash_attn", default=False, action="store_true")
args = parser.parse_args()
# fool proof hyperparameter setup
if args.loss_type == "simpo_loss":
args.length_normalization = True
args.gamma = args.gamma if args.gamma > 0 else 1.4
os.makedirs(os.path.dirname(args.config_file), exist_ok=True)
with open(args.config_file, "w") as f:
json.dump(args.__dict__, f, indent=4)
if args.config_file is not None:
os.makedirs(os.path.dirname(args.config_file), exist_ok=True)
with open(args.config_file, "w") as f:
json.dump(args.__dict__, f, indent=4)
train(args)

View File

@ -0,0 +1,65 @@
#!/bin/bash
set_n_least_used_CUDA_VISIBLE_DEVICES() {
local n=${1:-"9999"}
echo "GPU Memory Usage:"
local FIRST_N_GPU_IDS=$(nvidia-smi --query-gpu=memory.used --format=csv |
tail -n +2 |
nl -v 0 |
tee /dev/tty |
sort -g -k 2 |
awk '{print $1}' |
head -n $n)
export CUDA_VISIBLE_DEVICES=$(echo $FIRST_N_GPU_IDS | sed 's/ /,/g')
echo "Now CUDA_VISIBLE_DEVICES is set to:"
echo "CUDA_VISIBLE_DEVICES=$CUDA_VISIBLE_DEVICES"
}
set_n_least_used_CUDA_VISIBLE_DEVICES 4
PROJECT_NAME="kto"
PARENT_SAVE_DIR="" # Path to a folder to save checkpoints
PARENT_TENSORBOARD_DIR="" # Path to a folder to save logs
PARENT_CONFIG_FILE="" # Path to a folder to save training config logs
PARENT_LOG_DIR="" # Path to a folder to save training config logs
PRETRAINED_MODEL_PATH="" # huggingface or local model path
PRETRAINED_TOKENIZER_PATH="" # huggingface or local tokenizer path
declare -a dataset=(
/Your/KTO/Data/arrow/part-00000
/Your/KTO/Data/arrow/part-00001
/Your/KTO/Data/arrow/part-00002
/Your/KTO/Data/arrow/part-00003
/Your/KTO/Data/arrow/part-00004
/Your/KTO/Data/arrow/part-00005
/Your/KTO/Data/arrow/part-00006
/Your/KTO/Data/arrow/part-00007
/Your/KTO/Data/arrow/part-00008
/Your/KTO/Data/arrow/part-00009
)
TIMESTAMP=$(date +%Y-%m-%d-%H-%M-%S)
FULL_PROJECT_NAME="${PROJECT_NAME}-${TIMESTAMP}"
SAVE_DIR="${PARENT_SAVE_DIR}${FULL_PROJECT_NAME}"
CONFIG_FILE="${PARENT_CONFIG_FILE}-${FULL_PROJECT_NAME}.json"
LOG_DIR="${PARENT_LOG_DIR}${FULL_PROJECT_NAME}"
colossalai run --nproc_per_node 4 --master_port 31313 train_kto.py \
--pretrain $PRETRAINED_MODEL_PATH \
--tokenizer_dir $PRETRAINED_TOKENIZER_PATH \
--dataset ${dataset[@]} \
--plugin "zero2" \
--save_interval 1000 \
--save_dir $SAVE_DIR \
--config_file $CONFIG_FILE \
--log_dir $LOG_DIR \
--max_epochs 1 \
--accumulation_steps 1 \
--batch_size 8 \
--auto_weight \
--lr 1e-5 \
--beta 0.1 \
--mixed_precision "bf16" \
--grad_clip 1.0 \
--max_length 1024 \
--weight_decay 0.01 \
--warmup_steps 60 \
--grad_checkpoint

View File

@ -269,9 +269,12 @@ def train(args):
LORA_MANAGER.merge_weights = True
model.eval()
# save model checkpoint after fitting on only rank0
coordinator.print_on_master("Start saving final model checkpoint")
booster.save_model(model, os.path.join(args.save_dir, "modeling"), shard=True)
coordinator.print_on_master(f"Saved final model checkpoint at epoch {args.max_epochs} at folder {args.save_dir}")
if args.save_dir is not None:
coordinator.print_on_master("Start saving final model checkpoint")
booster.save_model(model, os.path.join(args.save_dir, "modeling"), shard=True)
coordinator.print_on_master(
f"Saved final model checkpoint at epoch {args.max_epochs} at folder {args.save_dir}"
)
coordinator.print_on_master(f"Max CUDA memory usage: {torch.cuda.max_memory_allocated()/1024**2:.2f} MB")
@ -307,8 +310,8 @@ if __name__ == "__main__":
parser.add_argument(
"--checkpoint_path", type=str, default=None, help="Checkpoint path if need to resume training form a checkpoint"
)
parser.add_argument("--config_file", type=str, default="config_file", help="Config file")
parser.add_argument("--save_dir", type=str, default="output")
parser.add_argument("--config_file", type=str, default=None, help="Config file")
parser.add_argument("--save_dir", type=str, default=None)
parser.add_argument("--max_length", type=int, default=2048, help="Model max length")
parser.add_argument("--max_epochs", type=int, default=3)
parser.add_argument("--batch_size", type=int, default=4)
@ -330,12 +333,13 @@ if __name__ == "__main__":
parser.add_argument("--merge_lora_weights", type=bool, default=True)
parser.add_argument("--lr", type=float, default=5e-6)
parser.add_argument("--accumulation_steps", type=int, default=8)
parser.add_argument("--log_dir", default="logs", type=str)
parser.add_argument("--log_dir", default=None, type=str)
parser.add_argument("--use_wandb", default=False, action="store_true")
parser.add_argument("--grad_checkpoint", default=False, action="store_true")
parser.add_argument("--use_flash_attn", default=False, action="store_true")
args = parser.parse_args()
os.makedirs(os.path.dirname(args.config_file), exist_ok=True)
with open(args.config_file, "w") as f:
json.dump(args.__dict__, f, indent=4)
if args.config_file is not None:
os.makedirs(os.path.dirname(args.config_file), exist_ok=True)
with open(args.config_file, "w") as f:
json.dump(args.__dict__, f, indent=4)
train(args)

View File

@ -18,6 +18,7 @@ set_n_least_used_CUDA_VISIBLE_DEVICES 2
PROJECT_NAME="ORPO"
PARENT_SAVE_DIR="" # Path to a folder to save checkpoints
PARENT_CONFIG_FILE="" # Path to a folder to save training config logs
PARENT_LOG_DIR="" # Path to a folder to save training config logs
PRETRAINED_MODEL_PATH="" # huggingface or local model path
PRETRAINED_TOKENIZER_PATH="" # huggingface or local tokenizer path
@ -38,6 +39,7 @@ TIMESTAMP=$(date +%Y-%m-%d-%H-%M-%S)
FULL_PROJECT_NAME="${PROJECT_NAME}-${TIMESTAMP}"
SAVE_DIR="${PARENT_SAVE_DIR}${FULL_PROJECT_NAME}"
CONFIG_FILE="${PARENT_CONFIG_FILE}${FULL_PROJECT_NAME}.json"
LOG_DIR="${PARENT_LOG_DIR}${FULL_PROJECT_NAME}"
colossalai run --nproc_per_node 8 --hostfile hostfile --master_port 31313 train_orpo.py \
--pretrain $PRETRAINED_MODEL_PATH \
@ -47,6 +49,7 @@ colossalai run --nproc_per_node 8 --hostfile hostfile --master_port 31313 train_
--save_interval 1000 \
--save_dir $SAVE_DIR \
--config_file $CONFIG_FILE \
--log_dir $LOG_DIR \
--max_epochs 3 \
--accumulation_steps 1 \
--batch_size 16 \

View File

@ -284,9 +284,12 @@ def train(args):
LORA_MANAGER.merge_weights = True
model.eval()
# save model checkpoint after fitting on only rank0
coordinator.print_on_master("Start saving final model checkpoint")
booster.save_model(model, os.path.join(args.save_dir, "modeling"), shard=True)
coordinator.print_on_master(f"Saved final model checkpoint at epoch {args.max_epochs} at folder {args.save_dir}")
if args.save_dir is not None:
coordinator.print_on_master("Start saving final model checkpoint")
booster.save_model(model, os.path.join(args.save_dir, "modeling"), shard=True)
coordinator.print_on_master(
f"Saved final model checkpoint at epoch {args.max_epochs} at folder {args.save_dir}"
)
coordinator.print_on_master(f"Max CUDA memory usage: {torch.cuda.max_memory_allocated()/1024**2:.2f} MB")
@ -320,8 +323,8 @@ if __name__ == "__main__":
parser.add_argument(
"--checkpoint_path", type=str, default=None, help="Checkpoint path if need to resume training form a checkpoint"
)
parser.add_argument("--config_file", type=str, default="config_file", help="Config file")
parser.add_argument("--save_dir", type=str, default="output")
parser.add_argument("--config_file", type=str, default=None, help="Config file")
parser.add_argument("--save_dir", type=str, default=None)
parser.add_argument("--max_length", type=int, default=2048, help="Model max length")
parser.add_argument("--max_epochs", type=int, default=3)
parser.add_argument("--batch_size", type=int, default=4)
@ -338,12 +341,13 @@ if __name__ == "__main__":
parser.add_argument("--merge_lora_weights", type=bool, default=True)
parser.add_argument("--lr", type=float, default=5e-6)
parser.add_argument("--accumulation_steps", type=int, default=8)
parser.add_argument("--log_dir", default="logs", type=str)
parser.add_argument("--log_dir", default=None, type=str)
parser.add_argument("--use_wandb", default=False, action="store_true")
parser.add_argument("--grad_checkpoint", default=False, action="store_true")
parser.add_argument("--use_flash_attn", default=False, action="store_true")
args = parser.parse_args()
os.makedirs(os.path.dirname(args.config_file), exist_ok=True)
with open(args.config_file, "w") as f:
json.dump(args.__dict__, f, indent=4)
if args.config_file is not None:
os.makedirs(os.path.dirname(args.config_file), exist_ok=True)
with open(args.config_file, "w") as f:
json.dump(args.__dict__, f, indent=4)
train(args)

View File

@ -18,6 +18,7 @@ set_n_least_used_CUDA_VISIBLE_DEVICES 8
PROJECT_NAME="RM"
PARENT_SAVE_DIR="" # Path to a folder to save checkpoints
PARENT_CONFIG_FILE="" # Path to a folder to save training config logs
PARENT_LOG_DIR="" # Path to a folder to save training config logs
PRETRAINED_MODEL_PATH="" # huggingface or local model path
PRETRAINED_TOKENIZER_PATH="" # huggingface or local tokenizer path
@ -38,6 +39,7 @@ TIMESTAMP=$(date +%Y-%m-%d-%H-%M-%S)
FULL_PROJECT_NAME="${PROJECT_NAME}-${TIMESTAMP}"
SAVE_DIR="${PARENT_SAVE_DIR}${FULL_PROJECT_NAME}"
CONFIG_FILE="${PARENT_CONFIG_FILE}${FULL_PROJECT_NAME}.json"
LOG_DIR="${PARENT_LOG_DIR}${FULL_PROJECT_NAME}"
colossalai run --nproc_per_node 8 --hostfile hostfile --master_port 31312 train_rm.py \
--pretrain $PRETRAINED_MODEL_PATH \
@ -47,6 +49,7 @@ colossalai run --nproc_per_node 8 --hostfile hostfile --master_port 31312 train_
--save_interval 1000 \
--save_dir $SAVE_DIR \
--config_file $CONFIG_FILE \
--log_dir $LOG_DIR \
--max_epochs 3 \
--accumulation_steps 1 \
--batch_size 8 \

View File

@ -284,10 +284,12 @@ def train(args):
LORA_MANAGER.merge_weights = True
model.eval()
# save model checkpoint after fitting on only rank0
coordinator.print_on_master("Start saving final model checkpoint")
booster.save_model(model, os.path.join(args.save_path, "modeling"), shard=True)
coordinator.print_on_master(f"Saved final model checkpoint at epoch {args.max_epochs} at folder {args.save_path}")
if args.save_path is not None:
coordinator.print_on_master("Start saving final model checkpoint")
booster.save_model(model, os.path.join(args.save_path, "modeling"), shard=True)
coordinator.print_on_master(
f"Saved final model checkpoint at epoch {args.max_epochs} at folder {args.save_path}"
)
coordinator.print_on_master(f"Max CUDA memory usage: {torch.cuda.max_memory_allocated()/1024**2:.2f} MB")
@ -321,7 +323,7 @@ if __name__ == "__main__":
parser.add_argument(
"--checkpoint_path", type=str, default=None, help="Checkpoint path if need to resume training form a checkpoint"
)
parser.add_argument("--save_path", type=str, default="output")
parser.add_argument("--save_path", type=str, default=None)
parser.add_argument("--max_epochs", type=int, default=3)
parser.add_argument("--batch_size", type=int, default=4)
parser.add_argument("--max_len", type=int, default=512)
@ -336,14 +338,15 @@ if __name__ == "__main__":
parser.add_argument("--save_interval", type=int, default=1000, help="number of step between two checkpoints")
parser.add_argument("--merge_lora_weights", type=bool, default=True)
parser.add_argument("--lr", type=float, default=5e-6)
parser.add_argument("--config_file", type=str, default="config_file", help="Config file")
parser.add_argument("--config_file", type=str, default=None, help="Config file")
parser.add_argument("--accumulation_steps", type=int, default=8)
parser.add_argument("--log_dir", default="logs", type=str)
parser.add_argument("--log_dir", default=None, type=str)
parser.add_argument("--use_wandb", default=False, action="store_true")
parser.add_argument("--grad_checkpoint", default=False, action="store_true")
parser.add_argument("--use_flash_attn", default=False, action="store_true")
args = parser.parse_args()
os.makedirs(os.path.dirname(args.config_file), exist_ok=True)
with open(args.config_file, "w") as f:
json.dump(args.__dict__, f, indent=4)
if args.config_file is not None:
os.makedirs(os.path.dirname(args.config_file), exist_ok=True)
with open(args.config_file, "w") as f:
json.dump(args.__dict__, f, indent=4)
train(args)

View File

@ -17,6 +17,7 @@ set_n_least_used_CUDA_VISIBLE_DEVICES 4
PROJECT_NAME="SFT"
PARENT_SAVE_DIR="" # Path to a folder to save checkpoints
PARENT_CONFIG_FILE="" # Path to a folder to save training config logs
PARENT_LOG_DIR="" # Path to a folder to save training config logs
PRETRAINED_MODEL_PATH="" # huggingface or local model path
PRETRAINED_TOKENIZER_PATH="" # huggingface or local tokenizer path
declare -a dataset=(
@ -36,6 +37,7 @@ TIMESTAMP=$(date +%Y-%m-%d-%H-%M-%S)
FULL_PROJECT_NAME="${PROJECT_NAME}-${TIMESTAMP}"
SAVE_DIR="${PARENT_SAVE_DIR}${FULL_PROJECT_NAME}"
CONFIG_FILE="${PARENT_CONFIG_FILE}${FULL_PROJECT_NAME}.json"
LOG_DIR="${PARENT_LOG_DIR}${FULL_PROJECT_NAME}"
echo $(which colossalai)
echo $(which python)
@ -43,17 +45,17 @@ echo $(which python)
colossalai run --nproc_per_node 4 --master_port 31312 --hostfile ./hostfile train_sft.py \
--pretrain $PRETRAINED_MODEL_PATH \
--tokenizer_dir $PRETRAINED_TOKENIZER_PATH \
--save_interval 4000 \
--save_interval 2000 \
--dataset ${dataset[@]} \
--save_path $SAVE_DIR \
--config_file $CONFIG_FILE \
--log_dir $LOG_DIR \
--lora_rank 0 \
--plugin zero2 \
--batch_size 4 \
--batch_size 8 \
--max_epochs 1 \
--accumulation_steps 4 \
--accumulation_steps 2 \
--lr 5e-5 \
--max_len 4096 \
--grad_checkpoint \
--use_wandb \
--use_flash_attn

View File

@ -4,7 +4,7 @@ import os
sft_seed = {
"messages": [
{"from": "human", "content": "Give three tips for staying healthy."},
{"from": "user", "content": "Give three tips for staying healthy."},
{
"from": "assistant",
"content": "1.Eat a balanced diet and make sure to include plenty of fruits and vegetables. \n2. Exercise regularly to keep your body active and strong. \n3. Get enough sleep and maintain a consistent sleep schedule.",
@ -13,7 +13,7 @@ sft_seed = {
}
prompt_seed = {
"messages": [
{"from": "human", "content": "Describe the impacts of climate change on communities living in coastal areas."},
{"from": "user", "content": "Describe the impacts of climate change on communities living in coastal areas."},
{
"from": "assistant",
"content": "Climate change has caused an increase in sea levels, which has caused coastal erosion and flooding of low-lying areas. This has led to displacement of people from their homes, as well as increased risk of epidemics of waterborne illnesses. Coastal cities have also seen an increase in extreme weather events such as hurricanes and tropical storms, which can cause extensive damage to infrastructure, homes, and businesses. As a result of climate change, some coastal areas are becoming uninhabitable, forcing communities to seek alternative living arrangements.",
@ -22,21 +22,34 @@ prompt_seed = {
}
preference_seed = {
"context": [
{"from": "human", "content": "What kind of noises did dinosaurs make?"},
{"from": "user", "content": "What kind of noises did dinosaurs make?"},
{
"from": "assistant",
"content": "Humans and dinosaurs didn't live at the same time, so it's really hard to say. The best place to find out what noises dinosaurs made would be",
},
{"from": "human", "content": "yes they did"},
{"from": "user", "content": "yes they did"},
{
"from": "assistant",
"content": "to guess, and that would probably require lots of reading and a certain amount of imagination, so we're not really prepared to do that.",
},
{"from": "human", "content": "you cant read"},
{"from": "user", "content": "you cant read"},
],
"chosen": [{"from": "assistant", "content": "You can read?"}],
"rejected": [{"from": "assistant", "content": "there's a lot of stuff humans don't know"}],
}
kto_seed = {
"prompt": [
{"from": "user", "content": "What are some praise words in english?"},
{
"from": "assistant",
"content": "Here's an incomplete list.\n\nexcellent, fantastic, impressive ...",
},
{"from": "user", "content": "What's your favorite one?"},
],
"completion": {"from": "assistant", "content": "Impressive."},
"label": True,
}
if __name__ == "__main__":
parser = argparse.ArgumentParser()
@ -61,12 +74,21 @@ if __name__ == "__main__":
seed = prompt_seed
elif args.data_type == "preference":
seed = preference_seed
elif args.data_type == "kto":
seed = kto_seed
else:
raise ValueError(f"Unknown data type {args.data_type}")
line = json.dumps(seed, ensure_ascii=False) + "\n"
for idx in [1, 2, 3]:
with open(os.path.join(args.data_dir, f"{idx}.jsonl"), "w", encoding="utf8") as f:
for i in range(1000):
if args.data_type != "kto":
line = json.dumps(seed, ensure_ascii=False) + "\n"
for idx in [1, 2, 3]:
with open(os.path.join(args.data_dir, f"{idx}.jsonl"), "w", encoding="utf8") as f:
for i in range(1000):
f.write(line)
f.write(line)
f.write(line)
else:
for idx in [1, 2, 3]:
with open(os.path.join(args.data_dir, f"{idx}.jsonl"), "w", encoding="utf8") as f:
for i in range(1000):
seed["label"] = not seed["label"]
line = json.dumps(seed, ensure_ascii=False) + "\n"
f.write(line)

View File

@ -1 +1 @@
{"context": [{"from": "human", "content": "If I were to give you some coordinates, could you tell me where the location is?"}], "chosen": [{"from": "assistant", "content": "Certainly! I can tell you if the given coordinates are within any geographic location, such as a country or city. However, I cannot give accurate directions to any specific location without more information. Are you looking for a more general geographic location, or do you want more precise directions?"}], "rejected": [{"from": "assistant", "content": "Yes, I can determine the location and give you some details about the area. Please provide the latitude and longitude coordinates, and I will return a description of the location based on these coordinates. Please note that precise geographical coordinates are often imprecise, so please provide the best possible coordinates you have available. Here are some examples:\n-The coordinates 41.543554,-71.09471 locate the Boston Commons in Boston, Massachusetts\n-The coordinates 37.42242, -122.083733 locate the Palace of Fine Arts in San Francisco, California \n-The coordinates 25.774145, -80.190262 locate Canal Street in Miami Beach, Florida\n-The coordinates 35. journalisten nl 48.056944, -106.277778 locate the Wyoming State Penitentiary in Rawlins, Wyoming\n-The coordinates 51.507351,-0.127758 locate Tower Bridge in London, England\n-The coordinates 44.885618, -79.422594 locate the Stratosphere Casino, Hotel and Tower in Las Vegas, Nevada"}]}
{"context": [{"from": "user", "content": "If I were to giv6rkkurjetes, could you tell me where the location is?"}, {"from": "assistant", "content": "If I were to give you some coordinates, could you tell me where the location is?"}, {"from": "user", "content": "If I were to give you 76i85l5e4henates, could you tell me where the location is?"}], "chosen": [{"from": "assistant", "content": "Certainly! I can tell you if the given coordinatestw434yy45wn any geographic location, such as a country or city. However, I cannot give accurate directions to any specific location without more information. Are you looking for a more general geographic location, or do you want more precise directions?"}], "rejected": [{"from": "assistant", "content": "Yes, I can determine the location and give45wywydfgw46 about the area. Please provide the latitude and longitude coordinates, and I will return a description of the location based on these coordinates. Please note that precise geographical coordinates are often imprecise, so please provide the best possible coordinates you have available. Here are some examples:\n-The coordinates 41.543554,-71.09471 locate the Boston Commons in Boston, Massachusetts\n-The coordinates 37.42242, -122.083733 locate the Palace of Fine Arts in San Francisco, California \n-The coordinates 25.774145, -80.190262 locate Canal Street in Miami Beach, Florida\n-The coordinates 35. journalisten nl 48.056944, -106.277778 locate the Wyoming State Penitentiary in Rawlins, Wyoming\n-The coordinates 51.507351,-0.127758 locate Tower Bridge in London, England\n-The coordinates 44.885618, -79.422594 locate the Stratosphere Casino, Hotel and Tower in Las Vegas, Nevada"}]}

View File

@ -0,0 +1 @@
{"prompt": [{"from": "user", "content": "What are some praise words in english?"}, {"from": "assistant", "content": "Here's an incomplete list.\n\nexcellent, fantastic, impressive ..."},{"from": "user", "content": "What's your favorite one?"}],"completion": {"from": "assistant", "content": "impressive."},"label": true}

View File

@ -1 +1 @@
{"messages": [{"from": "human", "content": "what are some pranks with a pen I can do?"}, {"from": "assistant", "content": "Are you looking for practical joke ideas?"}, {"from": "human", "content": "Yes, exactly!"}, {"from": "assistant", "content": "You could try attaching a small buzzer to a pen. When someone tries to use it, they'll be surprised by the sound!"}, {"from": "human", "content": "That's a funny one! Any other ideas?"}, {"from": "assistant", "content": "Another idea is the classic 'ink explosion' prank. Replace the ink of a pen with disappearing ink, and watch the confusion when the ink vanishes from the paper."}]}
{"messages": [{"from": "user", "content": "what are some pranks with a pen I can do?"}, {"from": "assistant", "content": "Are you looking for practical joke ideas?"}, {"from": "user", "content": "Yes, exactly!"}, {"from": "assistant", "content": "You could try attaching a small buzzer to a pen. When someone tries to use it, they'll be surprised by the sound!"}, {"from": "user", "content": "That's a funny one! Any other ideas?"}, {"from": "assistant", "content": "Another idea is the classic 'ink explosion' prank. Replace the ink of a pen with disappearing ink, and watch the confusion when the ink vanishes from the paper."}]}

View File

@ -71,6 +71,8 @@ get_data_input_dirs() {
echo "$PROMPT_DATASET"
elif [[ $data_type == "preference" ]]; then
echo "$PREFERENCE_DATASET"
elif [[ $data_type == "kto" ]]; then
echo "$KTO_DATASET"
else
echo "Unknown data type $data_type"
exit 1
@ -121,6 +123,10 @@ python $TEST_DIR/generate_dummy_datasets_for_testing.py \
--data_dir $(get_data_input_dirs prompt) \
--data_type "prompt"
python $TEST_DIR/generate_dummy_datasets_for_testing.py \
--data_dir $(get_data_input_dirs kto) \
--data_type "kto"
echo "[Test]: testing prepare_preference_dataset.py ..."
# FIXME: This is a hack to skip tests that are not working
@ -258,3 +264,50 @@ for model in ${MODELS[@]}; do
exit 1
fi
done
echo "[Test]: testing prepare_kto_dataset.py ..."
# FIXME: This is a hack to skip tests that are not working
SKIPPED_TESTS=(
)
# test prepare_kto_dataset
for model in ${MODELS[@]}; do
data_type="kto"
if [[ " ${SKIPPED_TESTS[*]} " =~ " $model-$data_type " ]]; then
echo "[Test]: Skipped $model-$data_type"
continue
fi
cache_dir=$DATA_SAVE_PATH/tokenized_${model}_${data_type}/cache
jsonl_dir=$DATA_SAVE_PATH/tokenized_${model}_${data_type}/jsonl
arrow_dir=$DATA_SAVE_PATH/tokenized_${model}_${data_type}/arrow
data_input_dirs=$(get_data_input_dirs $data_type)
tokenizer_dir=$(get_tokenizer_dirs $model)
conversation_template=$(get_conversation_template_config $model)
for i in $(seq $NUM_RETRY); do
rm -rf $cache_dir
rm -rf $jsonl_dir
rm -rf $arrow_dir
echo "[Test]: $model-$data_type, attempt $i"
python $EXAMPLES_DIR/data_preparation_scripts/prepare_dataset.py \
--type kto \
--data_input_dirs $data_input_dirs \
--conversation_template_config $conversation_template \
--tokenizer_dir $tokenizer_dir \
--data_cache_dir $cache_dir \
--data_jsonl_output_dir $jsonl_dir \
--data_arrow_output_dir $arrow_dir \
--max_length 400 \
--num_samples_per_datafile 100 \
--num_spliced_dataset_bins 1
passed=$?
if [ $passed -eq 0 ]; then
break
fi
done
if [ $passed -ne 0 ]; then
echo "[Test]: Failed $model-$data_type"
exit 1
fi
done

View File

@ -94,7 +94,7 @@ done
# Test DPO/PPO data Preparation
for model in ${MODELS[@]}; do
echo "Testing DPO/PPO data templating for $model"
echo "Testing DPO/RM data templating for $model"
SAVE_DIR=$DATA_SAVE_PATH/dpo/$model
rm -rf $SAVE_DIR/cache
rm -rf $SAVE_DIR/jsonl
@ -109,14 +109,44 @@ for model in ${MODELS[@]}; do
--data_arrow_output_dir $SAVE_DIR/arrow
passed=$?
if [ $passed -ne 0 ]; then
echo "[Test]: Failed in the DPO data templating for $model"
echo "[Test]: Failed in the DPO/RM data templating for $model"
exit 1
fi
python $BASE_DIR/tests/verify_chat_data.py --data_source $TEST_DATA_DIR/dpo/test_dpo_data.jsonl \
--to_verify_file $SAVE_DIR/jsonl/part-00005.jsonl --data_type dpo
passed=$?
if [ $passed -ne 0 ]; then
echo "[Test]: Failed in the DPO data templating test for $model"
echo "[Test]: Failed in the DPO/RM data templating test for $model"
exit 1
fi
done
# Test KTO data Preparation
for model in ${MODELS[@]}; do
echo "Testing KTO data templating for $model"
SAVE_DIR=$DATA_SAVE_PATH/kto/$model
rm -rf $SAVE_DIR/cache
rm -rf $SAVE_DIR/jsonl
rm -rf $SAVE_DIR/arrow
pretrain=$(get_pretrain $model)
conversation_template_config=$(get_conversation_template_config $model)
python $EXAMPLES_DIR/data_preparation_scripts/prepare_dataset.py --type kto --data_input_dirs $TEST_DATA_DIR/kto \
--tokenizer_dir $pretrain \
--conversation_template_config $conversation_template_config \
--data_cache_dir $SAVE_DIR/cache \
--data_jsonl_output_dir $SAVE_DIR/jsonl \
--data_arrow_output_dir $SAVE_DIR/arrow
passed=$?
if [ $passed -ne 0 ]; then
echo "[Test]: Failed in the KTO data templating for $model"
exit 1
fi
python $BASE_DIR/tests/verify_chat_data.py --data_source $TEST_DATA_DIR/kto/test_kto_data.jsonl \
--to_verify_file $SAVE_DIR/jsonl/part-00005.jsonl --data_type kto
passed=$?
if [ $passed -ne 0 ]; then
echo "[Test]: Failed in the KTO data templating test for $model"
exit 1
fi
done

View File

@ -193,8 +193,8 @@ for lora_rank in ${LORA_RANK[@]}; do
--use_flash_attn
passed=$?
if [ $passed -eq 0 ]; then
rm -rf $MODEL_SAVE_PATH/*
rm -rf $MODELS_DIR/*
rm -rf ${MODEL_SAVE_PATH:?}/*
rm -rf ${MODELS_DIR:?}/*
break
fi
done
@ -264,8 +264,8 @@ for lora_rank in ${LORA_RANK[@]}; do
--use_flash_attn
passed=$?
if [ $passed -eq 0 ]; then
rm -rf $MODEL_SAVE_PATH/*
rm -rf $MODELS_DIR/*
rm -rf ${MODEL_SAVE_PATH:?}/*
rm -rf ${MODELS_DIR:?}/*
break
fi
done
@ -363,8 +363,8 @@ for lora_rank in ${LORA_RANK[@]}; do
# --use_flash_attn
passed=$?
if [ $passed -eq 0 ]; then
rm -rf $MODEL_SAVE_PATH/*
rm -rf $MODELS_DIR/*
rm -rf ${MODEL_SAVE_PATH:?}/*
rm -rf ${MODELS_DIR:?}/*
break
fi
done
@ -440,8 +440,8 @@ for lora_rank in ${LORA_RANK[@]}; do
--use_flash_attn
passed=$?
if [ $passed -eq 0 ]; then
rm -rf $MODEL_SAVE_PATH/*
rm -rf $MODELS_DIR/*
rm -rf ${MODEL_SAVE_PATH:?}/*
rm -rf ${MODELS_DIR:?}/*
break
fi
done
@ -518,8 +518,88 @@ for lora_rank in ${LORA_RANK[@]}; do
--use_flash_attn
passed=$?
if [ $passed -eq 0 ]; then
rm -rf $MODEL_SAVE_PATH/*
rm -rf $MODELS_DIR/*
rm -rf ${MODEL_SAVE_PATH:?}/*
rm -rf ${MODELS_DIR:?}/*
break
fi
done
if [ $passed -ne 0 ]; then
echo "[Test]: Failed $model-$plugin-$lora_rank"
exit 1
fi
done
done
done
echo "[Test]: testing KTO ..."
SKIPPED_TESTS=(
llama-3d-20 # 3d plugin doesn't support lora
llama-gemini_auto-20 # gemini_auto plugin doesn't support lora
llama-gemini-20 # gemini doesn't support lora
)
GRAD_CKPTS=('--grad_checkpoint')
for lora_rank in ${LORA_RANK[@]}; do
for model in ${MODELS[@]}; do
for plugin in ${PLUGINS[@]}; do
if [[ " ${SKIPPED_TESTS[*]} " =~ " $model-$plugin-$lora_rank " ]]; then
echo "[Test]: Skipped $model-$plugin-$lora_rank"
continue
elif [[ " ${SKIPPED_TESTS[*]} " =~ " $model-$plugin " ]]; then
echo "[Test]: Skipped $model-$plugin"
continue
fi
pretrain=$(get_pretrain $model)
tokenizer_dir=$(get_tokenizer_dirs $model)
grad_ckpt=$(random_choice "${GRAD_CKPTS[@]}")
tp='1'
bs='2'
if [[ $plugin == "3d" ]]; then
tp='4'
bs='8'
fi
grad_accu='2'
# gemini_auto and gemini doesn't support gradient accumulation
if [[ $plugin == "gemini_auto" ]]; then
grad_accu='1'
fi
# gemini_auto doesn't support generation
# (need to calculate ref_model logits through forwarding in inference mode)
if [[ $plugin == "gemini_auto" ]]; then
echo "[Test]: Skipped $model-$plugin"
continue
fi
for i in $(seq $NUM_RETRY); do
echo "[Test]: $model-$plugin-$lora_rank, attempt $i"
declare -a dataset=()
for split in $(seq -f "%05g" 0 0); do
dataset+=("$TEMP_DIR/rlhf_data/tokenized_${model}_kto/arrow/part-$split")
done
colossalai run --nproc_per_node 4 --master_port 31332 $EXAMPLES_DIR/training_scripts/train_kto.py \
--pretrain $pretrain \
--tokenizer_dir $tokenizer_dir \
--dataset ${dataset[@]} \
--eval_dataset ${dataset[@]} \
--save_dir $MODEL_SAVE_PATH \
--config_file $MODELS_DIR/config.jsonl \
--lora_rank $lora_rank \
--plugin $plugin \
--batch_size $bs \
--max_epochs 1 \
--accumulation_steps $grad_accu \
--tp $tp \
--lr 2e-5 \
--auto_weight \
--desirable_weight 1.2 \
$grad_ckpt \
--max_len 400 \
--use_flash_attn
passed=$?
if [ $passed -eq 0 ]; then
rm -rf ${MODEL_SAVE_PATH:?}/*
rm -rf ${MODELS_DIR:?}/*
break
fi
done

View File

@ -62,3 +62,11 @@ if __name__ == "__main__":
assert any(
[rejected_lable in s for s in to_verify_lable_rejected]
), f"Rejected label {rejected_lable} not in target rejected label {to_verify_lable_chosen}"
elif args.data_type == "kto":
sample = data[0]
to_verify_data = to_verify_data[0]
for line in sample["prompt"]:
assert line["content"] in to_verify_data["input_id_decode"]
assert sample["completion"]["content"] in to_verify_data["input_id_decode"]
assert sample["completion"]["content"] in to_verify_data["completion_decode"]
assert sample["label"] == to_verify_data["label"]