mirror of https://github.com/hpcaitech/ColossalAI
add kto
parent
b3594d4d68
commit
09d5ffca1a
|
@ -52,6 +52,7 @@ jobs:
|
|||
mkdir sft_data
|
||||
mkdir prompt_data
|
||||
mkdir preference_data
|
||||
mkdir kto_data
|
||||
./tests/test_data_preparation.sh
|
||||
./tests/test_train.sh
|
||||
env:
|
||||
|
@ -61,3 +62,4 @@ jobs:
|
|||
SFT_DATASET: ./sft_data
|
||||
PROMPT_DATASET: ./prompt_data
|
||||
PREFERENCE_DATASET: ./preference_data
|
||||
KTO_DATASET: ./kto_data
|
||||
|
|
|
@ -24,7 +24,9 @@
|
|||
- [Limitation for LLaMA-finetuned models](#limitation)
|
||||
- [Limitation of dataset](#limitation)
|
||||
- [Alternative Option For RLHF: DPO](#alternative-option-for-rlhf-direct-preference-optimization)
|
||||
- [Alternative Option For RLHF: SimPO](#alternative-option-for-rlhf-simple-preference-optimization)
|
||||
- [Alternative Option For RLHF: SimPO](#alternative-option-for-rlhf-simple-preference-optimization-simpo)
|
||||
- [Alternative Option For RLHF: ORPO](#alternative-option-for-rlhf-odds-ratio-preference-optimization-orpo)
|
||||
- [Alternative Option For RLHF: KTO](#alternative-option-for-rlhf-kahneman-tversky-optimization-kto)
|
||||
- [FAQ](#faq)
|
||||
- [How to save/load checkpoint](#faq)
|
||||
- [How to train with limited resources](#faq)
|
||||
|
@ -284,6 +286,9 @@ Simple Preference Optimization (SimPO) from this [paper](https://arxiv.org/pdf/2
|
|||
## Alternative Option For RLHF: Odds Ratio Preference Optimization (ORPO)
|
||||
Odds Ratio Preference Optimization (ORPO) from this [paper](https://arxiv.org/pdf/2403.07691) is a reference model free alignment method that use a mixture of SFT loss and a reinforcement leanring loss calculated based on odds-ratio-based implicit reward to makes the training more efficient and stable. Read this [README](./examples/README.md) for more information.
|
||||
|
||||
## Alternative Option For RLHF: Kahneman-Tversky Optimization (KTO)
|
||||
We support the method introduced in the paper [KTO:Model Alignment as Prospect Theoretic Optimization](https://arxiv.org/pdf/2402.01306) (KTO). Which is a aligment method that directly maximize "human utility" of generation results. Read this [README](./examples/README.md) for more information.
|
||||
|
||||
### Inference Quantization and Serving - After Training
|
||||
|
||||
We provide an online inference server and a benchmark. We aim to run inference on single GPU, so quantization is essential when using large models.
|
||||
|
|
|
@ -0,0 +1,332 @@
|
|||
import argparse
|
||||
import json
|
||||
import os
|
||||
import resource
|
||||
from contextlib import nullcontext
|
||||
|
||||
import torch
|
||||
from coati.dataset import DataCollatorForKTODataset, StatefulDistributedSampler
|
||||
from coati.models import convert_to_lora_module, disable_dropout
|
||||
from coati.trainer import KTOTrainer
|
||||
from coati.utils import load_checkpoint
|
||||
from dummy_dataset import DummyLLMDataset
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
|
||||
import colossalai
|
||||
from colossalai.booster import Booster
|
||||
from colossalai.booster.plugin import GeminiPlugin, HybridParallelPlugin, LowLevelZeroPlugin
|
||||
from colossalai.cluster import DistCoordinator
|
||||
from colossalai.logging import get_dist_logger
|
||||
from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR
|
||||
from colossalai.nn.optimizer import HybridAdam
|
||||
|
||||
logger = get_dist_logger()
|
||||
|
||||
|
||||
def train(args):
|
||||
# check lora compatibility
|
||||
if "gemini" in args.plugin and args.lora_rank > 0:
|
||||
raise ValueError("LoRA is not supported in GeminiPlugin. Please use other plugin")
|
||||
if args.plugin == "gemini_auto" and args.accumulation_steps > 1:
|
||||
raise ValueError("Gradient accumulation is not supported in GeminiPlugin. Please use other plugin")
|
||||
|
||||
# ==============================
|
||||
# Initialize Distributed Training
|
||||
# ==============================
|
||||
colossalai.launch_from_torch()
|
||||
coordinator = DistCoordinator()
|
||||
|
||||
# ==============================
|
||||
# Initialize Booster
|
||||
# ==============================
|
||||
if args.plugin == "ddp":
|
||||
"""
|
||||
Default torch ddp plugin without any acceleration, for
|
||||
debugging purpose acceleration, for debugging purpose
|
||||
"""
|
||||
plugin = TorchDDPPlugin(find_unused_parameters=True)
|
||||
elif args.plugin == "gemini":
|
||||
plugin = GeminiPlugin(
|
||||
precision=args.mixed_precision,
|
||||
placement_policy="static",
|
||||
initial_scale=2**16,
|
||||
max_norm=args.grad_clip,
|
||||
enable_gradient_accumulation=True,
|
||||
enable_flash_attention=args.use_flash_attn,
|
||||
)
|
||||
elif args.plugin == "gemini_auto":
|
||||
plugin = GeminiPlugin(
|
||||
precision=args.mixed_precision,
|
||||
placement_policy="auto",
|
||||
initial_scale=2**16,
|
||||
max_norm=args.grad_clip,
|
||||
enable_flash_attention=args.use_flash_attn,
|
||||
)
|
||||
elif args.plugin == "zero2":
|
||||
plugin = LowLevelZeroPlugin(
|
||||
stage=2,
|
||||
precision=args.mixed_precision,
|
||||
initial_scale=2**16,
|
||||
max_norm=args.grad_clip,
|
||||
)
|
||||
elif args.plugin == "zero2_cpu":
|
||||
plugin = LowLevelZeroPlugin(
|
||||
stage=2,
|
||||
precision=args.mixed_precision,
|
||||
initial_scale=2**16,
|
||||
cpu_offload=True,
|
||||
max_norm=args.grad_clip,
|
||||
)
|
||||
elif args.plugin == "3d":
|
||||
plugin = HybridParallelPlugin(
|
||||
tp_size=args.tp,
|
||||
pp_size=args.pp,
|
||||
sp_size=args.sp,
|
||||
sequence_parallelism_mode=args.sp_mode,
|
||||
zero_stage=args.zero_stage,
|
||||
enable_flash_attention=args.use_flash_attn,
|
||||
enable_sequence_parallelism=args.enable_sequence_parallelism,
|
||||
cpu_offload=True if args.zero_stage >= 1 and args.zero_cpu_offload else False,
|
||||
parallel_output=False,
|
||||
max_norm=args.grad_clip,
|
||||
precision=args.mixed_precision,
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unknown plugin {args.plugin}")
|
||||
|
||||
booster = Booster(plugin=plugin)
|
||||
ref_booster = Booster(plugin=plugin)
|
||||
|
||||
# ======================================================
|
||||
# Initialize Model, Objective, Optimizer and LR Scheduler
|
||||
# ======================================================
|
||||
# Temp Fix: Disable lazy init due to version conflict
|
||||
# init_ctx = (
|
||||
# LazyInitContext(default_device=get_current_device()) if isinstance(plugin, (GeminiPlugin,)) else nullcontext()
|
||||
# )
|
||||
|
||||
init_ctx = nullcontext()
|
||||
with init_ctx:
|
||||
if args.use_flash_attn:
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
args.pretrain,
|
||||
torch_dtype=torch.bfloat16 if args.mixed_precision == "bf16" else torch.float16,
|
||||
use_flash_attention_2=True,
|
||||
)
|
||||
coordinator.print_on_master(msg="Flash-attention enabled successfully")
|
||||
else:
|
||||
model = AutoModelForCausalLM.from_pretrained(args.pretrain)
|
||||
disable_dropout(model)
|
||||
if not args.disable_reference_model:
|
||||
if args.use_flash_attn:
|
||||
ref_model = AutoModelForCausalLM.from_pretrained(
|
||||
args.pretrain,
|
||||
torch_dtype=torch.bfloat16 if args.mixed_precision == "bf16" else torch.float16,
|
||||
use_flash_attention_2=True,
|
||||
)
|
||||
else:
|
||||
ref_model = AutoModelForCausalLM.from_pretrained(args.pretrain)
|
||||
disable_dropout(ref_model)
|
||||
else:
|
||||
ref_model = None
|
||||
if args.lora_rank > 0:
|
||||
model = convert_to_lora_module(model, args.lora_rank, lora_train_bias=args.lora_train_bias)
|
||||
|
||||
if args.grad_checkpoint:
|
||||
# Note, for some models, lora may not be compatible with gradient checkpointing
|
||||
model.gradient_checkpointing_enable()
|
||||
coordinator.print_on_master(msg="Gradient checkpointing enabled successfully")
|
||||
|
||||
# configure tokenizer
|
||||
tokenizer_dir = args.tokenizer_dir if args.tokenizer_dir is not None else args.pretrain
|
||||
tokenizer = AutoTokenizer.from_pretrained(tokenizer_dir, use_fast=False, trust_remote_code=True)
|
||||
if hasattr(tokenizer, "pad_token") and hasattr(tokenizer, "eos_token") and tokenizer.eos_token is not None:
|
||||
try:
|
||||
# Some tokenizers doesn't allow to set pad_token mannually e.g., Qwen
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
except AttributeError as e:
|
||||
logger.warning(f"Unable to set pad token to eos token, {str(e)}")
|
||||
if not hasattr(tokenizer, "pad_token") or tokenizer.pad_token is None:
|
||||
logger.warning(
|
||||
"The tokenizer does not have a pad token which is required. May lead to unintended behavior in training, Please consider manually set them."
|
||||
)
|
||||
|
||||
tokenizer.add_bos_token = False
|
||||
tokenizer.add_eos_token = False
|
||||
|
||||
# configure optimizer
|
||||
optim = HybridAdam(
|
||||
model_params=model.parameters(),
|
||||
lr=args.lr,
|
||||
betas=(0.9, 0.95),
|
||||
weight_decay=args.weight_decay,
|
||||
adamw_mode=True,
|
||||
)
|
||||
|
||||
# configure dataset
|
||||
train_dataset = DummyLLMDataset(
|
||||
["prompt", "completion", "label"],
|
||||
args.max_length - 512,
|
||||
args.dataset_size,
|
||||
gen_fn={
|
||||
"completion": lambda x: torch.ones(512, dtype=torch.long),
|
||||
"label": lambda x: torch.tensor(x % 2, dtype=torch.long),
|
||||
},
|
||||
)
|
||||
data_collator = DataCollatorForKTODataset(tokenizer=tokenizer, max_length=args.max_length)
|
||||
|
||||
train_dataloader = plugin.prepare_dataloader(
|
||||
dataset=train_dataset,
|
||||
batch_size=args.batch_size,
|
||||
shuffle=True,
|
||||
drop_last=True,
|
||||
collate_fn=data_collator,
|
||||
distributed_sampler_cls=StatefulDistributedSampler,
|
||||
)
|
||||
|
||||
num_update_steps_per_epoch = len(train_dataloader) // args.accumulation_steps
|
||||
if args.warmup_steps is None:
|
||||
args.warmup_steps = int(args.max_epochs * 0.025 * (len(train_dataloader) // args.accumulation_steps))
|
||||
coordinator.print_on_master(f"Warmup steps is set to {args.warmup_steps}")
|
||||
|
||||
lr_scheduler = CosineAnnealingWarmupLR(
|
||||
optimizer=optim,
|
||||
total_steps=args.max_epochs * num_update_steps_per_epoch,
|
||||
warmup_steps=args.warmup_steps,
|
||||
eta_min=0.1 * args.lr,
|
||||
)
|
||||
|
||||
default_dtype = torch.float16 if args.mixed_precision == "fp16" else torch.bfloat16
|
||||
torch.set_default_dtype(default_dtype)
|
||||
model, optim, _, train_dataloader, lr_scheduler = booster.boost(
|
||||
model=model,
|
||||
optimizer=optim,
|
||||
lr_scheduler=lr_scheduler,
|
||||
dataloader=train_dataloader,
|
||||
)
|
||||
if ref_model is not None:
|
||||
ref_model, _, _, _, _ = ref_booster.boost(model=ref_model, dataloader=train_dataloader)
|
||||
torch.set_default_dtype(torch.float)
|
||||
|
||||
coordinator.print_on_master(f"Booster init max CUDA memory: {torch.cuda.max_memory_allocated() / 1024 ** 2:.2f} MB")
|
||||
coordinator.print_on_master(
|
||||
f"Booster init max CPU memory: {resource.getrusage(resource.RUSAGE_SELF).ru_maxrss / 1024:.2f} MB"
|
||||
)
|
||||
|
||||
start_epoch = 0
|
||||
sampler_start_idx = 0
|
||||
start_step = 0
|
||||
if args.checkpoint_path is not None:
|
||||
if "modeling" in args.checkpoint_path:
|
||||
coordinator.print_on_master(f"Continued pretrain from checkpoint {args.checkpoint_path}")
|
||||
booster.load_model(model, args.checkpoint_path)
|
||||
else:
|
||||
coordinator.print_on_master(f"Load model checkpoint from {args.checkpoint_path}")
|
||||
start_epoch, start_step, sampler_start_idx = load_checkpoint(
|
||||
load_dir=args.checkpoint_path,
|
||||
booster=booster,
|
||||
model=model,
|
||||
optimizer=optim,
|
||||
lr_scheduler=lr_scheduler,
|
||||
)
|
||||
assert isinstance(train_dataloader.sampler, StatefulDistributedSampler)
|
||||
train_dataloader.sampler.set_start_index(start_index=sampler_start_idx)
|
||||
|
||||
coordinator.print_on_master(
|
||||
f"Loaded checkpoint {args.checkpoint_path} at epoch {start_epoch} step {start_step}"
|
||||
)
|
||||
coordinator.print_on_master(f"Loaded sample at index {sampler_start_idx}")
|
||||
|
||||
coordinator.print_on_master(
|
||||
f"Checkpoint loaded max CUDA memory: {torch.cuda.max_memory_allocated() / 1024 ** 2:.2f} MB"
|
||||
)
|
||||
coordinator.print_on_master(
|
||||
f"Checkpoint loaded CUDA memory: {torch.cuda.memory_allocated() / 1024 ** 2:.2f} MB"
|
||||
)
|
||||
coordinator.print_on_master(
|
||||
f"Checkpoint loaded max CPU memory: {resource.getrusage(resource.RUSAGE_SELF).ru_maxrss / 1024:.2f} MB"
|
||||
)
|
||||
|
||||
trainer = KTOTrainer(
|
||||
actor=model,
|
||||
ref_model=ref_model,
|
||||
booster=booster,
|
||||
actor_optim=optim,
|
||||
actor_lr_scheduler=lr_scheduler,
|
||||
tokenizer=tokenizer,
|
||||
max_epochs=args.max_epochs,
|
||||
accumulation_steps=args.accumulation_steps,
|
||||
start_epoch=start_epoch,
|
||||
save_interval=None,
|
||||
save_dir=None,
|
||||
coordinator=coordinator,
|
||||
beta=args.beta,
|
||||
)
|
||||
|
||||
trainer.fit(
|
||||
train_preference_dataloader=train_dataloader,
|
||||
eval_preference_dataloader=None,
|
||||
log_dir=None,
|
||||
use_wandb=False,
|
||||
)
|
||||
coordinator.print_on_master(f"Max CUDA memory usage: {torch.cuda.max_memory_allocated()/1024**2:.2f} MB")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# ==============================
|
||||
# Parse Arguments
|
||||
# ==============================
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--plugin",
|
||||
type=str,
|
||||
default="gemini",
|
||||
choices=["gemini", "gemini_auto", "zero2", "zero2_cpu", "3d"],
|
||||
help="Choose which plugin to use",
|
||||
)
|
||||
parser.add_argument("--grad_clip", type=float, default=1.0, help="Gradient clipping value")
|
||||
parser.add_argument("--weight_decay", type=float, default=0.1, help="Weight decay")
|
||||
parser.add_argument("--warmup_steps", type=int, default=None, help="Warmup steps")
|
||||
parser.add_argument("--tp", type=int, default=1)
|
||||
parser.add_argument("--pp", type=int, default=1)
|
||||
parser.add_argument("--sp", type=int, default=1)
|
||||
parser.add_argument("--beta", type=float, default=0.1, help="beta in KTO loss")
|
||||
parser.add_argument("--enable_sequence_parallelism", default=False, action="store_true")
|
||||
parser.add_argument("--zero_stage", type=int, default=0, help="Zero stage", choices=[0, 1, 2])
|
||||
parser.add_argument("--zero_cpu_offload", default=False, action="store_true")
|
||||
parser.add_argument("--sp_mode", type=str, default="split_gather", choices=["split_gather", "ring", "all_to_all"])
|
||||
parser.add_argument("--pretrain", type=str, default=None)
|
||||
parser.add_argument("--tokenizer_dir", type=str, default=None)
|
||||
parser.add_argument(
|
||||
"--checkpoint_path", type=str, default=None, help="Checkpoint path if need to resume training form a checkpoint"
|
||||
)
|
||||
parser.add_argument("--config_file", type=str, default="config_file", help="Config file")
|
||||
parser.add_argument("--max_length", type=int, default=2048, help="Model max length")
|
||||
parser.add_argument("--max_epochs", type=int, default=3)
|
||||
parser.add_argument("--batch_size", type=int, default=4)
|
||||
parser.add_argument("--dataset_size", type=int, default=500)
|
||||
parser.add_argument(
|
||||
"--disable_reference_model",
|
||||
action="store_true",
|
||||
default=False,
|
||||
help="Disable the reference model (enabled by default)",
|
||||
)
|
||||
parser.add_argument("--mixed_precision", type=str, default="fp16", choices=["fp16", "bf16"], help="Mixed precision")
|
||||
parser.add_argument("--lora_rank", type=int, default=0, help="low-rank adaptation matrices rank")
|
||||
parser.add_argument(
|
||||
"--lora_train_bias",
|
||||
type=str,
|
||||
default="none",
|
||||
help="'none' means it doesn't train biases. 'all' means it trains all biases. 'lora_only' means it only trains biases of LoRA layers",
|
||||
)
|
||||
parser.add_argument("--merge_lora_weights", type=bool, default=True)
|
||||
parser.add_argument("--lr", type=float, default=5e-6)
|
||||
parser.add_argument("--accumulation_steps", type=int, default=8)
|
||||
parser.add_argument("--grad_checkpoint", default=False, action="store_true")
|
||||
parser.add_argument("--use_flash_attn", default=False, action="store_true")
|
||||
args = parser.parse_args()
|
||||
|
||||
os.makedirs(os.path.dirname(args.config_file), exist_ok=True)
|
||||
with open(args.config_file, "w") as f:
|
||||
json.dump(args.__dict__, f, indent=4)
|
||||
train(args)
|
|
@ -0,0 +1,45 @@
|
|||
#!/bin/bash
|
||||
set_n_least_used_CUDA_VISIBLE_DEVICES() {
|
||||
local n=${1:-"9999"}
|
||||
echo "GPU Memory Usage:"
|
||||
local FIRST_N_GPU_IDS=$(nvidia-smi --query-gpu=memory.used --format=csv |
|
||||
tail -n +2 |
|
||||
nl -v 0 |
|
||||
tee /dev/tty |
|
||||
sort -g -k 2 |
|
||||
awk '{print $1}' |
|
||||
head -n $n)
|
||||
export CUDA_VISIBLE_DEVICES=$(echo $FIRST_N_GPU_IDS | sed 's/ /,/g')
|
||||
echo "Now CUDA_VISIBLE_DEVICES is set to:"
|
||||
echo "CUDA_VISIBLE_DEVICES=$CUDA_VISIBLE_DEVICES"
|
||||
}
|
||||
set_n_least_used_CUDA_VISIBLE_DEVICES 4
|
||||
|
||||
PROJECT_NAME="kto"
|
||||
PARENT_CONFIG_FILE="./benchmark_config" # Path to a folder to save training config logs
|
||||
PRETRAINED_MODEL_PATH="/root/commonData/Llama-2-7b-hf" # huggingface or local model path
|
||||
PRETRAINED_TOKENIZER_PATH="/root/commonData/Llama-2-7b-hf" # huggingface or local tokenizer path
|
||||
|
||||
TIMESTAMP=$(date +%Y-%m-%d-%H-%M-%S)
|
||||
FULL_PROJECT_NAME="${PROJECT_NAME}-${TIMESTAMP}"
|
||||
SAVE_DIR="${PARENT_SAVE_DIR}${FULL_PROJECT_NAME}"
|
||||
CONFIG_FILE="${PARENT_CONFIG_FILE}-${FULL_PROJECT_NAME}.json"
|
||||
|
||||
colossalai run --nproc_per_node 2 --master_port 31313 benchmark_kto.py \
|
||||
--pretrain $PRETRAINED_MODEL_PATH \
|
||||
--tokenizer_dir $PRETRAINED_TOKENIZER_PATH \
|
||||
--plugin "zero2_cpu" \
|
||||
--config_file $CONFIG_FILE \
|
||||
--max_epochs 1 \
|
||||
--accumulation_steps 1 \
|
||||
--batch_size 2 \
|
||||
--lr 1e-5 \
|
||||
--beta 0.1 \
|
||||
--mixed_precision "bf16" \
|
||||
--grad_clip 1.0 \
|
||||
--max_length 2048 \
|
||||
--dataset_size 80 \
|
||||
--weight_decay 0.01 \
|
||||
--warmup_steps 60 \
|
||||
--grad_checkpoint \
|
||||
--use_flash_attn
|
|
@ -17,19 +17,19 @@ set_n_least_used_CUDA_VISIBLE_DEVICES 4
|
|||
# export CUDA_VISIBLE_DEVICES=3,4
|
||||
PROJECT_NAME="sft"
|
||||
PARENT_CONFIG_FILE="./benchmark_config" # Path to a folder to save training config logs
|
||||
PRETRAINED_MODEL_PATH="" # huggingface or local model path
|
||||
PRETRAINED_TOKENIZER_PATH="" # huggingface or local tokenizer path
|
||||
PRETRAINED_MODEL_PATH="/root/commonData/Llama-2-7b-hf" # huggingface or local model path
|
||||
PRETRAINED_TOKENIZER_PATH="/root/commonData/Llama-2-7b-hf" # huggingface or local tokenizer path
|
||||
|
||||
TIMESTAMP=$(date +%Y-%m-%d-%H-%M-%S)
|
||||
FULL_PROJECT_NAME="${PROJECT_NAME}-${TIMESTAMP}"
|
||||
CONFIG_FILE="${PARENT_CONFIG_FILE}-${FULL_PROJECT_NAME}.json"
|
||||
|
||||
# the real batch size for gradient descent is number_of_node_in_hostfile * nproc_per_node * train_batch_size
|
||||
colossalai run --nproc_per_node 4 --master_port 31312 benchmark_sft.py \
|
||||
colossalai run --nproc_per_node 1 --master_port 31312 benchmark_sft.py \
|
||||
--pretrain $PRETRAINED_MODEL_PATH \
|
||||
--tokenizer_dir $PRETRAINED_TOKENIZER_PATH \
|
||||
--config_file $CONFIG_FILE \
|
||||
--plugin zero2 \
|
||||
--plugin ddp \
|
||||
--batch_size 8 \
|
||||
--max_epochs 1 \
|
||||
--accumulation_steps 1 \
|
||||
|
|
|
@ -1,10 +1,13 @@
|
|||
from typing import Callable
|
||||
|
||||
import torch
|
||||
from torch.utils.data import Dataset
|
||||
|
||||
|
||||
class DummyLLMDataset(Dataset):
|
||||
def __init__(self, keys, seq_len, size=500):
|
||||
def __init__(self, keys, seq_len, size=500, gen_fn={}):
|
||||
self.keys = keys
|
||||
self.gen_fn = gen_fn
|
||||
self.seq_len = seq_len
|
||||
self.data = self._generate_data()
|
||||
self.size = size
|
||||
|
@ -12,6 +15,9 @@ class DummyLLMDataset(Dataset):
|
|||
def _generate_data(self):
|
||||
data = {}
|
||||
for key in self.keys:
|
||||
if key in self.gen_fn:
|
||||
data[key] = self.gen_fn[key]
|
||||
else:
|
||||
data[key] = torch.ones(self.seq_len, dtype=torch.long)
|
||||
return data
|
||||
|
||||
|
@ -19,4 +25,7 @@ class DummyLLMDataset(Dataset):
|
|||
return self.size
|
||||
|
||||
def __getitem__(self, idx):
|
||||
return {key: self.data[key] for key in self.keys}
|
||||
return {
|
||||
key: self.data[key] if not isinstance(self.data[key], Callable) else self.data[key](idx)
|
||||
for key in self.keys
|
||||
}
|
||||
|
|
|
@ -1,12 +1,13 @@
|
|||
from .conversation import Conversation, setup_conversation_template
|
||||
from .loader import (
|
||||
DataCollatorForKTODataset,
|
||||
DataCollatorForPreferenceDataset,
|
||||
DataCollatorForPromptDataset,
|
||||
DataCollatorForSupervisedDataset,
|
||||
StatefulDistributedSampler,
|
||||
load_tokenized_dataset,
|
||||
)
|
||||
from .tokenization_utils import supervised_tokenize_sft, tokenize_prompt_dataset, tokenize_rlhf
|
||||
from .tokenization_utils import supervised_tokenize_sft, tokenize_kto, tokenize_prompt_dataset, tokenize_rlhf
|
||||
|
||||
__all__ = [
|
||||
"tokenize_prompt_dataset",
|
||||
|
@ -14,11 +15,13 @@ __all__ = [
|
|||
"is_rank_0",
|
||||
"DataCollatorForPreferenceDataset",
|
||||
"DataCollatorForSupervisedDataset",
|
||||
"DataCollatorForKTODataset",
|
||||
"StatefulDistributedSampler",
|
||||
"load_tokenized_dataset",
|
||||
"supervised_tokenize_pretrain",
|
||||
"supervised_tokenize_sft",
|
||||
"tokenize_rlhf",
|
||||
"tokenize_kto",
|
||||
"setup_conversation_template",
|
||||
"Conversation",
|
||||
]
|
||||
|
|
|
@ -235,6 +235,91 @@ class DataCollatorForPreferenceDataset(object):
|
|||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class DataCollatorForKTODataset(object):
|
||||
"""
|
||||
Collate instances for kto dataset.
|
||||
Each input instance is a tokenized dictionary with fields
|
||||
`prompt`(List[int]), `completion`(List[int]) and `label`(bool).
|
||||
Each output instance is a tokenized dictionary with fields
|
||||
`kl_input_ids`(List[int]), `kl_attention_mask`(List[int]) and `kl_loss_mask`(List[int]).
|
||||
`input_ids`(List[int]), `attention_mask`(List[int]), `loss_mask`(List[int]) and `label`(bool).
|
||||
"""
|
||||
|
||||
tokenizer: PreTrainedTokenizer
|
||||
max_length: int = 4096
|
||||
ignore_index: int = -100
|
||||
|
||||
def __call__(self, instances: Sequence[Dict[str, List[int]]]) -> Dict[str, torch.Tensor]:
|
||||
"""
|
||||
|
||||
Args:
|
||||
instances (`Sequence[Dict[str, List[int]]]`):
|
||||
Mini-batch samples, each sample is stored in an individual dictionary contains the following fields:
|
||||
`prompt`(List[int]), `completion`(List[int]) and `label`(bool, if the sample is desirable or not).
|
||||
|
||||
Returns:
|
||||
(`Dict[str, torch.Tensor]`): Contains the following `torch.Tensor`:
|
||||
`input_ids`: `torch.Tensor` of shape (bsz, max_len);
|
||||
`attention_mask`: `torch.BoolTensor` of shape (bsz, max_len);
|
||||
`labels`: `torch.Tensor` of shape (bsz, max_len), which contains `IGNORE_INDEX`.
|
||||
"""
|
||||
assert isinstance(self.tokenizer.pad_token_id, int) and self.tokenizer.pad_token_id >= 0, (
|
||||
f"`{self.tokenizer.__class__.__name__}.pad_token_id` must be a valid non-negative integer index value, "
|
||||
f"but now `{self.tokenizer.pad_token_id}`"
|
||||
)
|
||||
# prepare the preference data
|
||||
prompt = [torch.LongTensor(instance["prompt"]) for instance in instances]
|
||||
prompt_zeros = [torch.zeros_like(t) for t in prompt]
|
||||
completion = [torch.LongTensor(instance["completion"]) for instance in instances]
|
||||
completion_ones = [torch.ones_like(t) for t in completion]
|
||||
label = [torch.tensor(instance["label"], dtype=torch.bool) for instance in instances]
|
||||
input_ids = [torch.cat([prompt[i], completion[i]], dim=-1) for i in range(len(instances))]
|
||||
loss_mask = [torch.cat([prompt_zeros[i], completion_ones[i]], dim=-1) for i in range(len(instances))]
|
||||
# right padding
|
||||
input_ids = torch.nn.utils.rnn.pad_sequence(
|
||||
sequences=input_ids,
|
||||
batch_first=True,
|
||||
padding_value=self.tokenizer.pad_token_id,
|
||||
) # (bsz, max_len)
|
||||
loss_mask = torch.nn.utils.rnn.pad_sequence(
|
||||
sequences=loss_mask, batch_first=True, padding_value=0
|
||||
) # (bsz, max_len)
|
||||
to_pad = self.max_length - input_ids.size(1)
|
||||
input_ids = F.pad(input_ids, (0, to_pad), value=self.tokenizer.pad_token_id)
|
||||
loss_mask = F.pad(loss_mask, (0, to_pad), value=0)
|
||||
attention_mask = input_ids.ne(self.tokenizer.pad_token_id) # `torch.BoolTensor`, (bsz, max_len)
|
||||
|
||||
# prepare kt data
|
||||
kl_completion = completion[::-1] # y'
|
||||
kl_completion_ones = [torch.ones_like(t) for t in kl_completion]
|
||||
kl_input_ids = [torch.cat([prompt[i], kl_completion[i]], dim=-1) for i in range(len(instances))]
|
||||
kl_loss_mask = [torch.cat([prompt_zeros[i], kl_completion_ones[i]], dim=-1) for i in range(len(instances))]
|
||||
# right padding
|
||||
kl_input_ids = torch.nn.utils.rnn.pad_sequence(
|
||||
sequences=kl_input_ids,
|
||||
batch_first=True,
|
||||
padding_value=self.tokenizer.pad_token_id,
|
||||
) # (bsz, max_len)
|
||||
kl_loss_mask = torch.nn.utils.rnn.pad_sequence(
|
||||
sequences=kl_loss_mask, batch_first=True, padding_value=0
|
||||
) # (bsz, max_len)
|
||||
to_pad = self.max_length - kl_input_ids.size(1)
|
||||
kl_input_ids = F.pad(kl_input_ids, (0, to_pad), value=self.tokenizer.pad_token_id)
|
||||
kl_loss_mask = F.pad(kl_loss_mask, (0, to_pad), value=0)
|
||||
kl_attention_mask = kl_input_ids.ne(self.tokenizer.pad_token_id) # `torch.BoolTensor`, (bsz, max_len)
|
||||
data_dict = {
|
||||
"input_ids": input_ids,
|
||||
"attention_mask": attention_mask,
|
||||
"loss_mask": loss_mask,
|
||||
"label": torch.stack(label),
|
||||
"kl_input_ids": kl_input_ids,
|
||||
"kl_attention_mask": kl_attention_mask,
|
||||
"kl_loss_mask": kl_loss_mask,
|
||||
}
|
||||
return data_dict
|
||||
|
||||
|
||||
class StatefulDistributedSampler(DistributedSampler):
|
||||
def __init__(
|
||||
self,
|
||||
|
|
|
@ -405,3 +405,66 @@ def tokenize_rlhf(
|
|||
"rejected_loss_mask": rejected_loss_mask,
|
||||
"rejected_label_decode": rejected_label_decode,
|
||||
}
|
||||
|
||||
|
||||
def tokenize_kto(
|
||||
data_point: Dict[str, str],
|
||||
tokenizer: PreTrainedTokenizer,
|
||||
conversation_template: Conversation = None,
|
||||
ignore_index: int = None,
|
||||
max_length: int = 4096,
|
||||
) -> Dict[str, Union[int, str, List[int]]]:
|
||||
"""
|
||||
Tokenize a dataset for KTO training
|
||||
The raw input data is conversation that have the following format
|
||||
{
|
||||
"prompt": [{"from": "human", "content": "xxx"}...],
|
||||
"completion": {"from": "assistant", "content": "xxx"},
|
||||
"label": true/false
|
||||
}
|
||||
It returns three fields
|
||||
The context, which contain the query and the assistant start,
|
||||
the completion, which only contains the assistance's answer,
|
||||
and a binary label, which indicates if the sample is prefered or not
|
||||
"""
|
||||
if ignore_index is None:
|
||||
ignore_index = IGNORE_INDEX
|
||||
|
||||
prompt = data_point["prompt"]
|
||||
completion = data_point["completion"]
|
||||
template = deepcopy(conversation_template)
|
||||
template.clear()
|
||||
|
||||
if prompt[0].get("from", None) != "human":
|
||||
raise ValueError("conversation should start with human")
|
||||
if completion.get("from", None) != "assistant":
|
||||
raise ValueError("conversation should end with assistant")
|
||||
|
||||
for mess in prompt:
|
||||
if mess.get("from", None) == "human":
|
||||
template.append_message("user", mess["content"])
|
||||
elif mess.get("from", None) == "assistant":
|
||||
template.append_message("assistant", mess["content"])
|
||||
else:
|
||||
raise ValueError(f"Unsupported role {mess.get('from', None)}")
|
||||
generation_prompt = template.get_prompt(len(prompt), add_generation_prompt=True)
|
||||
template.append_message("assistant", completion["content"])
|
||||
full_prompt = template.get_prompt(len(prompt) + 1, add_generation_prompt=False)
|
||||
tokenized_full_prompt = tokenizer(full_prompt, add_special_tokens=False)["input_ids"]
|
||||
if len(tokenized_full_prompt) + 1 > max_length:
|
||||
return dict(prompt=None, completion=None, label=None, input_id_decode=None, completion_decode=None)
|
||||
tokenized_generation_prompt = tokenizer(generation_prompt, add_special_tokens=False)["input_ids"]
|
||||
tokenized_completion = tokenized_full_prompt[len(tokenized_generation_prompt) :]
|
||||
tokenized_completion = deepcopy(tokenized_completion)
|
||||
if tokenizer.bos_token_id is not None and tokenized_generation_prompt[0] != tokenizer.bos_token_id:
|
||||
tokenized_generation_prompt = [tokenizer.bos_token_id] + tokenized_generation_prompt
|
||||
decoded_full_prompt = tokenizer.decode(tokenized_full_prompt, skip_special_tokens=False)
|
||||
decoded_completion = tokenizer.decode(tokenized_completion, skip_special_tokens=False)
|
||||
|
||||
return {
|
||||
"prompt": tokenized_generation_prompt,
|
||||
"completion": tokenized_completion,
|
||||
"label": data_point["label"],
|
||||
"input_id_decode": decoded_full_prompt,
|
||||
"completion_decode": decoded_completion,
|
||||
}
|
||||
|
|
|
@ -2,7 +2,7 @@ from .base import BaseModel
|
|||
from .critic import Critic
|
||||
from .generation import generate, generate_streaming, prepare_inputs_fn, update_model_kwargs_fn
|
||||
from .lora import convert_to_lora_module
|
||||
from .loss import DpoLoss, LogExpLoss, LogSigLoss, PolicyLoss, ValueLoss
|
||||
from .loss import DpoLoss, KTOLoss, LogExpLoss, LogSigLoss, PolicyLoss, ValueLoss
|
||||
from .reward_model import RewardModel
|
||||
from .utils import disable_dropout
|
||||
|
||||
|
@ -16,7 +16,7 @@ __all__ = [
|
|||
"LogExpLoss",
|
||||
"convert_to_lora_module",
|
||||
"DpoLoss",
|
||||
"generate",
|
||||
"KTOLoss" "generate",
|
||||
"generate_streaming",
|
||||
"disable_dropout",
|
||||
"update_model_kwargs_fn",
|
||||
|
|
|
@ -42,7 +42,6 @@ class BaseModel(nn.Module):
|
|||
out = self.model(dummy_input)
|
||||
self.last_hidden_state_size = out.last_hidden_state.shape[-1]
|
||||
self.model = self.model.cpu()
|
||||
# print("self.last_hidden_state_size: ",self.last_hidden_state_size)
|
||||
|
||||
def resize_token_embeddings(self, *args, **kwargs):
|
||||
"""
|
||||
|
|
|
@ -50,7 +50,7 @@ class LoraLinear(lora.LoRALayer, nn.Module):
|
|||
self.fan_in_fan_out = fan_in_fan_out
|
||||
# Actual trainable parameters
|
||||
if r > 0:
|
||||
self.lora_A = nn.Parameter(self.weight.new_zeros((r, in_features)))
|
||||
self.lora_A = nn.Parameter(self.weight.new_zeros((r, in_features)), requires_grad=False)
|
||||
self.lora_B = nn.Parameter(self.weight.new_zeros((out_features, r)))
|
||||
self.scaling = self.lora_alpha / self.r
|
||||
# Freezing the pre-trained weight matrix
|
||||
|
|
|
@ -5,6 +5,7 @@ loss functions
|
|||
from typing import Optional, Tuple
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.nn as nn
|
||||
|
||||
from .utils import masked_mean
|
||||
|
@ -201,7 +202,79 @@ class OddsRatioLoss(nn.Module):
|
|||
chosen_odds_masked = torch.sum(chosen_odds * chosen_loss_mask.float()) / torch.sum(chosen_loss_mask)
|
||||
reject_odds = reject_logp - torch.log(-torch.exp(reject_logp) + 1.0001)
|
||||
reject_odds_masked = torch.sum(reject_odds * reject_loss_mask.float()) / torch.sum(reject_loss_mask)
|
||||
# print("chosen_odds_masked", chosen_odds_masked[0], "reject_odds_masked", reject_odds_masked[0])
|
||||
log_odds_ratio = chosen_odds_masked - reject_odds_masked
|
||||
ratio = torch.log(torch.nn.functional.sigmoid(log_odds_ratio))
|
||||
return ratio.to(dtype=torch.bfloat16), log_odds_ratio
|
||||
|
||||
|
||||
class KTOLoss(nn.Module):
|
||||
def __init__(self, beta: float = 0.1, desirable_weight: float = 1.0, undesirable_weight: float = 1.0):
|
||||
"""
|
||||
Args:
|
||||
beta: The temperature parameter in the KTO paper.
|
||||
desirable_weight: The weight for the desirable responses.
|
||||
undesirable_weight: The weight for the undesirable
|
||||
"""
|
||||
super().__init__()
|
||||
self.beta = beta
|
||||
self.desirable_weight = desirable_weight
|
||||
self.undesirable_weight = undesirable_weight
|
||||
|
||||
def forward(
|
||||
self,
|
||||
chosen_logps: torch.Tensor,
|
||||
rejected_logps: torch.Tensor,
|
||||
kl_logps: torch.Tensor,
|
||||
ref_chosen_logps: torch.Tensor,
|
||||
ref_rejected_logps: torch.Tensor,
|
||||
ref_kl_logps: torch.Tensor,
|
||||
):
|
||||
"""
|
||||
Reference:
|
||||
https://github.com/huggingface/trl/blob/a2adfb836a90d1e37b1253ab43dace05f1241e04/trl/trainer/kto_trainer.py#L585
|
||||
|
||||
Compute the KTO loss for a batch of policy and reference model log probabilities.
|
||||
Args:
|
||||
chosen_logps: Log probabilities of the policy model for the chosen responses. Shape: (batch_size,)
|
||||
rejected_logps: Log probabilities of the policy model for the rejected responses. Shape: (batch_size,)
|
||||
kl_logps: KL divergence of the policy model. Shape: (batch_size,)
|
||||
ref_chosen_logps: Log probabilities of the reference model for the chosen responses. Shape: (batch_size,)
|
||||
ref_rejected_logps: Log probabilities of the reference model for the rejected responses. Shape: (batch_size,)
|
||||
ref_kl_logps: KL divergence of the reference model. Shape: (batch_size,)
|
||||
beta: The temperature parameter in the DPO paper.
|
||||
desirable_weight: The weight for the desirable responses.
|
||||
undesirable_weight: The weight for the undesirable responses.
|
||||
|
||||
Refer to the KTO paper for details about hyperparameters https://arxiv.org/pdf/2402.01306
|
||||
"""
|
||||
kl = (kl_logps - ref_kl_logps).mean().detach()
|
||||
# all gather
|
||||
dist.all_reduce(kl, op=dist.ReduceOp.SUM)
|
||||
kl = (kl / dist.get_world_size()).clamp(min=0)
|
||||
# kl = 0
|
||||
|
||||
if chosen_logps.shape[0] != 0 and ref_chosen_logps.shape[0] != 0:
|
||||
chosen_logratios = chosen_logps - ref_chosen_logps
|
||||
chosen_losses = 1 - nn.functional.sigmoid(self.beta * (chosen_logratios - kl))
|
||||
chosen_rewards = self.beta * chosen_logratios.detach()
|
||||
else:
|
||||
# important to cast to policy_dtype; otherwise error will occur during all_gather
|
||||
chosen_losses = torch.Tensor([]).to(
|
||||
kl_logps.device
|
||||
) # torch.Tensor(0.).to(chosen_logps.dtype).to(chosen_logps.device)
|
||||
chosen_rewards = torch.Tensor([]).to(kl_logps.device)
|
||||
|
||||
if rejected_logps.shape[0] != 0 and ref_rejected_logps.shape[0] != 0:
|
||||
rejected_logratios = rejected_logps - ref_rejected_logps
|
||||
rejected_losses = 1 - nn.functional.sigmoid(self.beta * (kl - rejected_logratios))
|
||||
rejected_rewards = self.beta * rejected_logratios.detach()
|
||||
else:
|
||||
# important to cast to policy_dtype; otherwise error will occur during all_gather
|
||||
rejected_losses = torch.Tensor([]).to(
|
||||
kl_logps.device
|
||||
) # torch.Tensor(0.).to(rejected_logps.dtype).to(rejected_logps.device)
|
||||
rejected_rewards = torch.Tensor([]).to(kl_logps.device)
|
||||
|
||||
losses = torch.cat((self.desirable_weight * chosen_losses, self.undesirable_weight * rejected_losses), 0).mean()
|
||||
|
||||
return losses, chosen_rewards, rejected_rewards, kl
|
||||
|
|
|
@ -1,8 +1,18 @@
|
|||
from .base import OLTrainer, SLTrainer
|
||||
from .dpo import DPOTrainer
|
||||
from .kto import KTOTrainer
|
||||
from .orpo import ORPOTrainer
|
||||
from .ppo import PPOTrainer
|
||||
from .rm import RewardModelTrainer
|
||||
from .sft import SFTTrainer
|
||||
|
||||
__all__ = ["SLTrainer", "OLTrainer", "RewardModelTrainer", "SFTTrainer", "PPOTrainer", "DPOTrainer", "ORPOTrainer"]
|
||||
__all__ = [
|
||||
"SLTrainer",
|
||||
"OLTrainer",
|
||||
"RewardModelTrainer",
|
||||
"SFTTrainer",
|
||||
"PPOTrainer",
|
||||
"DPOTrainer",
|
||||
"ORPOTrainer",
|
||||
"KTOTrainer",
|
||||
]
|
||||
|
|
|
@ -0,0 +1,318 @@
|
|||
"""
|
||||
KTO trainer
|
||||
"""
|
||||
|
||||
import os
|
||||
from typing import Any, Optional
|
||||
|
||||
import torch
|
||||
import torch.distributed
|
||||
from coati.models.loss import KTOLoss
|
||||
from coati.models.utils import calc_masked_log_probs
|
||||
from coati.trainer.utils import all_reduce_mean
|
||||
from coati.utils import AccumulativeMeanMeter, save_checkpoint
|
||||
from torch.optim import Optimizer
|
||||
from torch.optim.lr_scheduler import _LRScheduler
|
||||
from torch.utils.data import DataLoader
|
||||
from tqdm import trange
|
||||
from transformers import PreTrainedTokenizerBase
|
||||
|
||||
from colossalai.booster import Booster
|
||||
from colossalai.cluster import DistCoordinator
|
||||
from colossalai.utils import get_current_device
|
||||
|
||||
from .base import SLTrainer
|
||||
from .utils import is_rank_0, to_device
|
||||
|
||||
|
||||
class KTOTrainer(SLTrainer):
|
||||
"""
|
||||
Trainer for PPO algorithm.
|
||||
|
||||
Args:
|
||||
actor (Actor): the actor model in ppo algorithm
|
||||
ref_model (Critic): the reference model in ppo algorithm
|
||||
booster (Strategy): the strategy to use for training
|
||||
actor_optim (Optimizer): the optimizer to use for actor model
|
||||
actor_lr_scheduler (_LRScheduler): the lr scheduler to use for actor model
|
||||
tokenizer (PreTrainedTokenizerBase): the tokenizer to use for encoding
|
||||
max_epochs (int, defaults to 1): the max number of epochs to train
|
||||
accumulation_steps (int): the number of steps to accumulate gradients
|
||||
start_epoch (int, defaults to 0): the start epoch, non-zero if resumed from a checkpoint
|
||||
save_interval (int): the interval to save model checkpoints, default to 0, which means no checkpoint will be saved during trainning
|
||||
save_dir (str): the directory to save checkpoints
|
||||
coordinator (DistCoordinator): the coordinator to use for distributed logging
|
||||
beta (float, defaults to 0.1): the beta parameter in kto loss
|
||||
desirable_weight (float, defaults to 1.0): the weight for desirable reward
|
||||
undesirable_weight (float, defaults to 1.0): the weight for undesirable reward
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
actor: Any,
|
||||
ref_model: Any,
|
||||
booster: Booster,
|
||||
actor_optim: Optimizer,
|
||||
actor_lr_scheduler: _LRScheduler,
|
||||
tokenizer: PreTrainedTokenizerBase,
|
||||
max_epochs: int = 1,
|
||||
beta: float = 0.1,
|
||||
desirable_weight: float = 1.0,
|
||||
undesirable_weight: float = 1.0,
|
||||
accumulation_steps: int = 1,
|
||||
start_epoch: int = 0,
|
||||
save_interval: int = 0,
|
||||
save_dir: str = None,
|
||||
coordinator: DistCoordinator = None,
|
||||
) -> None:
|
||||
super().__init__(booster, max_epochs=max_epochs, model=actor, optimizer=actor_optim, start_epoch=start_epoch)
|
||||
self.ref_model = ref_model
|
||||
self.actor_scheduler = actor_lr_scheduler
|
||||
self.tokenizer = tokenizer
|
||||
self.kto_loss = KTOLoss(beta=beta, desirable_weight=desirable_weight, undesirable_weight=undesirable_weight)
|
||||
self.save_interval = save_interval
|
||||
self.coordinator = coordinator
|
||||
self.save_dir = save_dir
|
||||
self.num_train_step = 0
|
||||
self.accumulation_steps = accumulation_steps
|
||||
self.device = get_current_device()
|
||||
self.accumulative_meter = AccumulativeMeanMeter()
|
||||
self.desirable_weight = desirable_weight
|
||||
self.undesirable_weight = undesirable_weight
|
||||
self.beta = beta
|
||||
|
||||
def _before_fit(
|
||||
self,
|
||||
train_preference_dataloader: DataLoader = None,
|
||||
eval_preference_dataloader: DataLoader = None,
|
||||
log_dir: Optional[str] = None,
|
||||
use_wandb: bool = False,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
prompt_dataloader (DataLoader): the dataloader to use for prompt data
|
||||
pretrain_dataloader (DataLoader): the dataloader to use for pretrain data
|
||||
"""
|
||||
self.train_dataloader = train_preference_dataloader
|
||||
self.eval_dataloader = eval_preference_dataloader
|
||||
self.writer = None
|
||||
if use_wandb and is_rank_0():
|
||||
assert log_dir is not None, "log_dir must be provided when use_wandb is True"
|
||||
import wandb
|
||||
|
||||
self.wandb_run = wandb.init(project="Coati-kto", sync_tensorboard=True)
|
||||
if log_dir is not None and is_rank_0():
|
||||
import os
|
||||
import time
|
||||
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
|
||||
log_dir = os.path.join(log_dir, "kto")
|
||||
log_dir = os.path.join(log_dir, time.strftime("%Y-%m-%d_%H:%M:%S", time.localtime()))
|
||||
self.writer = SummaryWriter(log_dir=log_dir)
|
||||
|
||||
def _train(self, epoch: int):
|
||||
"""
|
||||
Args:
|
||||
epoch int: the number of current epoch
|
||||
"""
|
||||
self.model.train()
|
||||
self.accumulative_meter.reset()
|
||||
step_bar = trange(
|
||||
len(self.train_dataloader) // self.accumulation_steps,
|
||||
desc=f"Epoch {epoch + 1}/{self.max_epochs}",
|
||||
disable=not is_rank_0(),
|
||||
)
|
||||
for i, batch in enumerate(self.train_dataloader):
|
||||
batch = to_device(batch, self.device)
|
||||
(input_ids, attention_mask, loss_mask, label, kl_input_ids, kl_attention_mask, kl_loss_mask) = (
|
||||
batch["input_ids"],
|
||||
batch["attention_mask"],
|
||||
batch["loss_mask"],
|
||||
batch["label"],
|
||||
batch["kl_input_ids"],
|
||||
batch["kl_attention_mask"],
|
||||
batch["kl_loss_mask"],
|
||||
)
|
||||
batch_size = input_ids.size()[0]
|
||||
|
||||
# actor logits
|
||||
with torch.no_grad():
|
||||
# calculate KL term with KT data
|
||||
kl_logits = self.model(
|
||||
input_ids=kl_input_ids,
|
||||
attention_mask=kl_attention_mask,
|
||||
)["logits"]
|
||||
|
||||
logits = self.model(
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
)["logits"]
|
||||
|
||||
logprob = calc_masked_log_probs(logits, input_ids, loss_mask[:, 1:]).sum(-1)
|
||||
kl_logprob = calc_masked_log_probs(kl_logits, kl_input_ids, kl_loss_mask[:, 1:]).sum(-1)
|
||||
chosen_index = [i for i in range(batch_size) if label[i] == 1]
|
||||
rejected_index = [i for i in range(batch_size) if label[i] == 0]
|
||||
chosen_logprob = logprob[chosen_index]
|
||||
rejected_logprob = logprob[rejected_index]
|
||||
with torch.no_grad():
|
||||
ref_kl_logits = self.ref_model(
|
||||
input_ids=kl_input_ids,
|
||||
attention_mask=kl_attention_mask,
|
||||
)["logits"]
|
||||
ref_logits = self.ref_model(
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
)["logits"]
|
||||
|
||||
ref_logprob = calc_masked_log_probs(ref_logits, input_ids, loss_mask[:, 1:]).sum(-1)
|
||||
ref_kl_logprob = calc_masked_log_probs(ref_kl_logits, kl_input_ids, kl_loss_mask[:, 1:]).sum(-1)
|
||||
ref_chosen_logprob = ref_logprob[chosen_index]
|
||||
ref_rejected_logprob = ref_logprob[rejected_index]
|
||||
|
||||
loss, chosen_rewards, rejected_rewards, kl = self.kto_loss(
|
||||
chosen_logprob, rejected_logprob, kl_logprob, ref_chosen_logprob, ref_rejected_logprob, ref_kl_logprob
|
||||
)
|
||||
|
||||
self.booster.backward(loss=loss, optimizer=self.optimizer)
|
||||
if self.num_train_step % self.accumulation_steps == self.accumulation_steps - 1:
|
||||
self.optimizer.step()
|
||||
self.optimizer.zero_grad()
|
||||
self.actor_scheduler.step()
|
||||
|
||||
# # sync
|
||||
loss_mean = all_reduce_mean(tensor=loss)
|
||||
chosen_rewards_mean = all_reduce_mean(tensor=chosen_rewards.mean())
|
||||
rejected_rewards_mean = all_reduce_mean(tensor=rejected_rewards.mean())
|
||||
self.accumulative_meter.add("chosen_rewards", chosen_rewards_mean.to(torch.float16).mean().item())
|
||||
self.accumulative_meter.add("rejected_rewards", rejected_rewards_mean.to(torch.float16).mean().item())
|
||||
self.accumulative_meter.add("loss", loss_mean.to(torch.float16).detach().item())
|
||||
|
||||
if i % self.accumulation_steps == self.accumulation_steps - 1:
|
||||
self.num_train_step += 1
|
||||
step_bar.update()
|
||||
# logging
|
||||
if self.writer and is_rank_0():
|
||||
self.writer.add_scalar("train/loss", self.accumulative_meter.get("loss"), self.num_train_step)
|
||||
self.writer.add_scalar("train/lr", self.optimizer.param_groups[0]["lr"], self.num_train_step)
|
||||
self.writer.add_scalar(
|
||||
"train/chosen_rewards", self.accumulative_meter.get("chosen_rewards"), self.num_train_step
|
||||
)
|
||||
self.writer.add_scalar(
|
||||
"train/rejected_rewards",
|
||||
self.accumulative_meter.get("rejected_rewards"),
|
||||
self.num_train_step,
|
||||
)
|
||||
self.writer.add_scalar(
|
||||
"train/margin",
|
||||
self.accumulative_meter.get("chosen_rewards") - self.accumulative_meter.get("rejected_rewards"),
|
||||
self.num_train_step,
|
||||
)
|
||||
self.accumulative_meter.reset()
|
||||
|
||||
if self.save_dir is not None and (self.num_train_step + 1) % self.save_interval == 0:
|
||||
# save checkpoint
|
||||
self.coordinator.print_on_master("\nStart saving model checkpoint with running states")
|
||||
save_checkpoint(
|
||||
save_dir=self.save_dir,
|
||||
booster=self.booster,
|
||||
model=self.model,
|
||||
optimizer=self.optimizer,
|
||||
lr_scheduler=self.actor_scheduler,
|
||||
epoch=epoch,
|
||||
step=i + 1,
|
||||
batch_size=batch_size,
|
||||
coordinator=self.coordinator,
|
||||
)
|
||||
self.coordinator.print_on_master(
|
||||
f"Saved checkpoint at epoch {epoch} step {self.save_interval} at folder {self.save_dir}"
|
||||
)
|
||||
|
||||
step_bar.close()
|
||||
|
||||
def _eval(self, epoch: int):
|
||||
"""
|
||||
Args:
|
||||
epoch int: the number of current epoch
|
||||
"""
|
||||
if self.eval_dataloader is None:
|
||||
self.coordinator.print_on_master("No eval dataloader is provided, skip evaluation")
|
||||
return
|
||||
self.model.eval()
|
||||
self.accumulative_meter.reset()
|
||||
step_bar = trange(
|
||||
len(self.train_dataloader) // self.accumulation_steps,
|
||||
desc=f"Epoch {epoch + 1}/{self.max_epochs}",
|
||||
disable=not is_rank_0(),
|
||||
)
|
||||
for i, batch in enumerate(self.train_dataloader):
|
||||
batch = to_device(batch, self.device)
|
||||
(input_ids, attention_mask, loss_mask, label, kl_input_ids, kl_attention_mask, kl_loss_mask) = (
|
||||
batch["input_ids"],
|
||||
batch["attention_mask"],
|
||||
batch["loss_mask"],
|
||||
batch["label"],
|
||||
batch["kl_input_ids"],
|
||||
batch["kl_attention_mask"],
|
||||
batch["kl_loss_mask"],
|
||||
)
|
||||
batch_size = input_ids.size()[0]
|
||||
|
||||
# actor logits
|
||||
with torch.no_grad():
|
||||
# calculate KL term with KT data
|
||||
kl_logits = self.model(
|
||||
input_ids=kl_input_ids,
|
||||
attention_mask=kl_attention_mask,
|
||||
)["logits"]
|
||||
|
||||
logits = self.model(
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
)["logits"]
|
||||
|
||||
logprob = calc_masked_log_probs(logits, input_ids, loss_mask[:, 1:]).sum(-1)
|
||||
kl_logprob = calc_masked_log_probs(kl_logits, kl_input_ids, kl_loss_mask[:, 1:]).sum(-1)
|
||||
chosen_index = [i for i in range(batch_size) if label[i] == 1]
|
||||
rejected_index = [i for i in range(batch_size) if label[i] == 0]
|
||||
chosen_logprob = logprob[chosen_index]
|
||||
rejected_logprob = logprob[rejected_index]
|
||||
with torch.no_grad():
|
||||
ref_kl_logits = self.ref_model(
|
||||
input_ids=kl_input_ids,
|
||||
attention_mask=kl_attention_mask,
|
||||
)["logits"]
|
||||
|
||||
ref_logits = self.ref_model(
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
)["logits"]
|
||||
|
||||
ref_logprob = calc_masked_log_probs(ref_logits, input_ids, loss_mask[:, 1:]).sum(-1)
|
||||
ref_kl_logprob = calc_masked_log_probs(ref_kl_logits, kl_input_ids, kl_loss_mask[:, 1:]).sum(-1)
|
||||
ref_chosen_logprob = ref_logprob[chosen_index]
|
||||
ref_rejected_logprob = ref_logprob[rejected_index]
|
||||
|
||||
loss, chosen_rewards, rejected_rewards, kl = self.kto_loss(
|
||||
chosen_logprob, rejected_logprob, kl_logprob, ref_chosen_logprob, ref_rejected_logprob, ref_kl_logprob
|
||||
)
|
||||
|
||||
# # sync
|
||||
loss_mean = all_reduce_mean(tensor=loss)
|
||||
chosen_rewards_mean = all_reduce_mean(tensor=chosen_rewards.mean())
|
||||
rejected_rewards_mean = all_reduce_mean(tensor=rejected_rewards.mean())
|
||||
self.accumulative_meter.add("chosen_rewards", chosen_rewards_mean.to(torch.float16).mean().item())
|
||||
self.accumulative_meter.add("rejected_rewards", rejected_rewards_mean.to(torch.float16).mean().item())
|
||||
self.accumulative_meter.add("loss", loss_mean.to(torch.float16).detach().item())
|
||||
self.accumulative_meter.add(
|
||||
"margin", (chosen_rewards_mean - rejected_rewards_mean).to(torch.float16).mean().item()
|
||||
)
|
||||
step_bar.update()
|
||||
msg = "Evaluation Result:\n"
|
||||
for tag in ["loss", "chosen_rewards", "rejected_rewards", "margin"]:
|
||||
msg = msg + f"{tag}: {self.accumulative_meter.get(tag)}\n"
|
||||
self.coordinator.print_on_master(msg)
|
||||
os.makedirs(self.save_dir, exist_ok=True)
|
||||
with open(os.path.join(self.save_dir, f"eval_result_epoch{epoch}.txt"), "w") as f:
|
||||
f.write(msg)
|
||||
step_bar.close()
|
|
@ -30,6 +30,8 @@
|
|||
- [DPO Stage 1: Supervised Instruction Tuning](#dpo-training-stage1---supervised-instructs-tuning)
|
||||
- [DPO Stage 2: DPO Training](#dpo-training-stage2---dpo-training)
|
||||
- [Alternative Option For RLHF: Simple Preference Optimization](#alternative-option-for-rlhf-simple-preference-optimization)
|
||||
- [Alternative Option For RLHF: Kahneman-Tversky Optimization (KTO)](#alternative-option-for-rlhf-kahneman-tversky-optimization-kto)
|
||||
- [Alternative Option For RLHF: Odds Ratio Preference Optimization](#alternative-option-for-rlhf-odds-ratio-preference-optimization)
|
||||
- [List of Supported Models](#list-of-supported-models)
|
||||
- [Hardware Requirements](#hardware-requirements)
|
||||
- [Inference example](#inference-example)
|
||||
|
@ -744,13 +746,21 @@ with a Reference-Free Reward](https://arxiv.org/pdf/2405.14734) (SimPO). Which i
|
|||
|
||||
|
||||
### Alternative Option For RLHF: Odds Ratio Preference Optimization
|
||||
We support the method introduced in the paper [ORPO: Monolithic Preference Optimization without Reference Model](https://arxiv.org/abs/2403.07691) (ORPO). Which is a reference model free aligment method that mixes the SFT loss with a reinforcement learning loss that uses odds ratio as the implicit reward to enhance training stability and efficiency. Simply set the flag to disable the use of the reference model, set the reward target margin and enable length normalization in the DPO training script. To use ORPO in alignment, use the [train_orpo.sh](./examples/training_scripts/train_orpo.sh) script, You can set the value for `lambda` (which determine how strongly the reinforcement learning loss affect the training) but it is optional.
|
||||
We support the method introduced in the paper [ORPO: Monolithic Preference Optimization without Reference Model](https://arxiv.org/abs/2403.07691) (ORPO). Which is a reference model free aligment method that mixes the SFT loss with a reinforcement learning loss that uses odds ratio as the implicit reward to enhance training stability and efficiency. To use ORPO in alignment, use the [train_orpo.sh](./examples/training_scripts/train_orpo.sh) script, You can set the value for `lambda` (which determine how strongly the reinforcement learning loss affect the training) but it is optional.
|
||||
|
||||
#### ORPO Result
|
||||
<p align="center">
|
||||
<img width="1000" alt="image" src="https://raw.githubusercontent.com/hpcaitech/public_assets/main/applications/chat/ORPO_margin.png">
|
||||
</p>
|
||||
|
||||
### Alternative Option For RLHF: Kahneman-Tversky Optimization (KTO)
|
||||
We support the method introduced in the paper [KTO:Model Alignment as Prospect Theoretic Optimization](https://arxiv.org/pdf/2402.01306) (KTO). Which is a aligment method that directly maximize "human utility" of generation results. To use KTO in alignment, use the [train_kto.sh](./examples/training_scripts/train_orpo.sh) script, You may need to set the value for `beta` (which determine how strongly the reinforcement learning loss affect the training), `desirable_weight` and `undesirable_weight` if your data is biased (has unequal number of chosen and rejected samples).
|
||||
|
||||
#### KTO Result
|
||||
<p align="center">
|
||||
<img width="1000" alt="image" src="https://raw.githubusercontent.com/hpcaitech/public_assets/main/applications/chat/KTO.png">
|
||||
</p>
|
||||
|
||||
## Hardware Requirements
|
||||
|
||||
For SFT, we recommend using zero2 or zero2-cpu for 7B model and tp is your model is extra large. We tested the VRAM consumption on a dummy dataset with a sequence length of 2048. In all experiments, we use H800 GPUs with 80GB VRAM and enable gradient checkpointing and flash attention.
|
||||
|
@ -801,6 +811,14 @@ For ORPO, we recommend using zero2 or zero2-cpu. We tested the VRAM consumption
|
|||
- zero2, micro batch size=4, VRAM Usage=45309.52 MB
|
||||
- zero2, micro batch size=8, VRAM Usage=58086.37 MB
|
||||
|
||||
For KTO, we recommend using zero2-cpu or zero2 plugin, We tested the VRAM consumption on a dummy dataset with 2048 sequence length.
|
||||
- 2 H800 GPU
|
||||
- zero2-cpu, micro batch size=2, VRAM Usage=35241.98 MB
|
||||
- zero2-cpu, micro batch size=4, VRAM Usage=38989.37 MB
|
||||
- 4 H800 GPUs
|
||||
- zero2_cpu, micro batch size=2, VRAM_USAGE=32443.22 MB
|
||||
- zero2, micro batch size=4, VRAM_USAGE=59307.97 MB
|
||||
|
||||
## List of Supported Models
|
||||
|
||||
For SFT, we support the following models/series:
|
||||
|
|
|
@ -40,7 +40,13 @@ import random
|
|||
import time
|
||||
from multiprocessing import cpu_count
|
||||
|
||||
from coati.dataset import setup_conversation_template, supervised_tokenize_sft, tokenize_prompt_dataset, tokenize_rlhf
|
||||
from coati.dataset import (
|
||||
setup_conversation_template,
|
||||
supervised_tokenize_sft,
|
||||
tokenize_kto,
|
||||
tokenize_prompt_dataset,
|
||||
tokenize_rlhf,
|
||||
)
|
||||
from datasets import dataset_dict, load_dataset
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
|
@ -56,8 +62,8 @@ def main():
|
|||
type=str,
|
||||
required=True,
|
||||
default=None,
|
||||
choices=["sft", "prompt", "preference"],
|
||||
help="Type of dataset, chose from 'sft', 'prompt', 'preference'.",
|
||||
choices=["sft", "prompt", "preference", "kto"],
|
||||
help="Type of dataset, chose from 'sft', 'prompt', 'preference'. 'kto'",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--data_input_dirs",
|
||||
|
@ -204,6 +210,8 @@ def main():
|
|||
preparation_function = tokenize_prompt_dataset
|
||||
elif args.type == "preference":
|
||||
preparation_function = tokenize_rlhf
|
||||
elif args.type == "kto":
|
||||
preparation_function = tokenize_kto
|
||||
else:
|
||||
raise ValueError("Unknow dataset type. Please choose one from ['sft', 'prompt', 'preference']")
|
||||
|
||||
|
@ -228,10 +236,13 @@ def main():
|
|||
keep_in_memory=False,
|
||||
num_proc=min(len(dataset), cpu_count()),
|
||||
)
|
||||
|
||||
dataset = dataset.filter(
|
||||
lambda data: data["chosen_input_ids" if args.type == "preference" else "input_ids"] is not None
|
||||
)
|
||||
if args.type == "kto":
|
||||
filter_by = "completion"
|
||||
elif args.type == "preference":
|
||||
filter_by = "chosen_input_ids"
|
||||
else:
|
||||
filter_by = "input_ids"
|
||||
dataset = dataset.filter(lambda data: data[filter_by] is not None)
|
||||
|
||||
# Save each jsonl spliced dataset.
|
||||
output_index = "0" * (5 - len(str(index))) + str(index)
|
||||
|
|
|
@ -0,0 +1,14 @@
|
|||
SAVE_DIR="/home/nvme-share/home/yeanbang/data/experiments/kto"
|
||||
|
||||
rm -rf $SAVE_DIR/cache
|
||||
rm -rf $SAVE_DIR/jsonl
|
||||
rm -rf $SAVE_DIR/arrow
|
||||
|
||||
python prepare_dataset.py --type kto \
|
||||
--data_input_dirs /home/nvme-share/home/yeanbang/data/dataset/hh_rlhf/kto_format/data \
|
||||
--conversation_template_config /home/nvme-share/home/yeanbang/ColossalAI/applications/ColossalChat/config/conversation_template/llama2.json \
|
||||
--tokenizer_dir "/home/nvme-share/share/models/Sheared-LLaMA-1.3B" \
|
||||
--data_cache_dir $SAVE_DIR/cache \
|
||||
--data_jsonl_output_dir $SAVE_DIR/jsonl \
|
||||
--data_arrow_output_dir $SAVE_DIR/arrow \
|
||||
--max_length 1024
|
|
@ -1,13 +1,13 @@
|
|||
SAVE_DIR=""
|
||||
SAVE_DIR="/home/nvme-share/home/yeanbang/data/experiments/sft"
|
||||
|
||||
rm -rf $SAVE_DIR/cache
|
||||
rm -rf $SAVE_DIR/jsonl
|
||||
rm -rf $SAVE_DIR/arrow
|
||||
|
||||
python prepare_dataset.py --type sft \
|
||||
--data_input_dirs /PATH/TO/SFT/DATASET \
|
||||
--conversation_template_config /PATH/TO/CHAT/TEMPLATE/CONFIG.json \
|
||||
--tokenizer_dir "" \
|
||||
--data_input_dirs /home/nvme-share/home/yeanbang/data/dataset/hh_rlhf/sft \
|
||||
--conversation_template_config /home/nvme-share/home/yeanbang/ColossalAI/applications/ColossalChat/config/conversation_template/llama2.json \
|
||||
--tokenizer_dir "/home/nvme-share/share/models/Sheared-LLaMA-1.3B" \
|
||||
--data_cache_dir $SAVE_DIR/cache \
|
||||
--data_jsonl_output_dir $SAVE_DIR/jsonl \
|
||||
--data_arrow_output_dir $SAVE_DIR/arrow \
|
||||
|
|
|
@ -0,0 +1,104 @@
|
|||
|
||||
|
||||
==========
|
||||
round 1:
|
||||
<s>[INST] <<SYS>>
|
||||
A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions.
|
||||
|
||||
|
||||
<</SYS>>
|
||||
|
||||
tell me a story [/INST] Great, let’s hear a story. </s>
|
||||
|
||||
==========
|
||||
|
||||
|
||||
==========
|
||||
round 2:
|
||||
<s>[INST] <<SYS>>
|
||||
A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions.
|
||||
|
||||
|
||||
<</SYS>>
|
||||
|
||||
tell me a story [/INST] Great, let’s hear a story. </s><s>[INST] calculate 1+1 [/INST] 1+1=2 </s>
|
||||
|
||||
==========
|
||||
|
||||
|
||||
==========
|
||||
round 3:
|
||||
<s>[INST] <<SYS>>
|
||||
A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions.
|
||||
|
||||
|
||||
<</SYS>>
|
||||
|
||||
tell me a story [/INST] Great, let’s hear a story. </s><s>[INST] calculate 1+1 [/INST] 1+1=2 </s><s>[INST] who is the first president of the USA [/INST] The first president of the United States was George Washington. </s>
|
||||
|
||||
==========
|
||||
|
||||
|
||||
==========
|
||||
round 1:
|
||||
<s>[INST] <<SYS>>
|
||||
A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions.
|
||||
|
||||
|
||||
<</SYS>>
|
||||
|
||||
who is the first president of the USA? [/INST] The first president of the United States was George Washington. </s>
|
||||
|
||||
==========
|
||||
|
||||
|
||||
==========
|
||||
round 2:
|
||||
<s>[INST] <<SYS>>
|
||||
A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions.
|
||||
|
||||
|
||||
<</SYS>>
|
||||
|
||||
who is the first president of the USA? [/INST] The first president of the United States was George Washington. </s><s>[INST] tell me a story [/INST] One story that might be interesting is the story of how the United States was founded. In 1776, the Thirteen Colonies united together to form the new nation of America. The first president of the new nation was George Washington. The first state was Pennsylvania. The first capital was Philadelphia. The first presidential election was held in 1787, and was between George Read, a Federalist, and John Adams, a Republican. The Federalists wanted to keep the power of the federal government limited, while the Republicans wanted the power to be spread around to the states. The Federalists won, and George Washington was elected president. </s>
|
||||
|
||||
==========
|
||||
|
||||
|
||||
==========
|
||||
round 1:
|
||||
<s>[INST] <<SYS>>
|
||||
A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions.
|
||||
|
||||
|
||||
<</SYS>>
|
||||
|
||||
tell me a story [/INST] I do enjoy telling stories. Is there a particular story you’d like to hear? </s>
|
||||
|
||||
==========
|
||||
|
||||
|
||||
==========
|
||||
round 2:
|
||||
<s>[INST] <<SYS>>
|
||||
A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions.
|
||||
|
||||
|
||||
<</SYS>>
|
||||
|
||||
tell me a story [/INST] I do enjoy telling stories. Is there a particular story you’d like to hear? </s><s>[INST] about Donald Trump [/INST] I’d be happy to listen to a story about Donald Trump. I have no idea what a “story” is, but I do know a lot about human life and human behavior. I can try. Do you have a theme or a question you’d like to ask about Donald Trump? </s>
|
||||
|
||||
==========
|
||||
|
||||
|
||||
==========
|
||||
round 3:
|
||||
<s>[INST] <<SYS>>
|
||||
A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions.
|
||||
|
||||
|
||||
<</SYS>>
|
||||
|
||||
tell me a story [/INST] I do enjoy telling stories. Is there a particular story you’d like to hear? </s><s>[INST] about Donald Trump [/INST] I’d be happy to listen to a story about Donald Trump. I have no idea what a “story” is, but I do know a lot about human life and human behavior. I can try. Do you have a theme or a question you’d like to ask about Donald Trump? </s><s>[INST] Is Donald Trump the president of the United States [/INST] Yes, Donald Trump became the 45th president of the United States in January of 2016. </s>
|
||||
|
||||
==========
|
|
@ -0,0 +1,372 @@
|
|||
import argparse
|
||||
import json
|
||||
import os
|
||||
import resource
|
||||
from contextlib import nullcontext
|
||||
|
||||
import torch
|
||||
from coati.dataset import DataCollatorForKTODataset, StatefulDistributedSampler, load_tokenized_dataset
|
||||
from coati.models import convert_to_lora_module, disable_dropout
|
||||
from coati.trainer import KTOTrainer
|
||||
from coati.utils import load_checkpoint
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
|
||||
import colossalai
|
||||
from colossalai.booster import Booster
|
||||
from colossalai.booster.plugin import GeminiPlugin, HybridParallelPlugin, LowLevelZeroPlugin
|
||||
from colossalai.cluster import DistCoordinator
|
||||
from colossalai.logging import get_dist_logger
|
||||
from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR
|
||||
from colossalai.nn.optimizer import HybridAdam
|
||||
|
||||
logger = get_dist_logger()
|
||||
|
||||
|
||||
def train(args):
|
||||
# check lora compatibility
|
||||
if "gemini" in args.plugin and args.lora_rank > 0:
|
||||
raise ValueError("LoRA is not supported in GeminiPlugin. Please use other plugin")
|
||||
if args.plugin == "gemini_auto" and args.accumulation_steps > 1:
|
||||
raise ValueError("Gradient accumulation is not supported in GeminiPlugin. Please use other plugin")
|
||||
|
||||
# ==============================
|
||||
# Initialize Distributed Training
|
||||
# ==============================
|
||||
colossalai.launch_from_torch()
|
||||
coordinator = DistCoordinator()
|
||||
|
||||
# ==============================
|
||||
# Initialize Booster
|
||||
# ==============================
|
||||
if args.plugin == "ddp":
|
||||
"""
|
||||
Default torch ddp plugin without any acceleration, for
|
||||
debugging purpose acceleration, for debugging purpose
|
||||
"""
|
||||
plugin = TorchDDPPlugin(find_unused_parameters=True)
|
||||
elif args.plugin == "gemini":
|
||||
plugin = GeminiPlugin(
|
||||
precision=args.mixed_precision,
|
||||
placement_policy="static",
|
||||
initial_scale=2**16,
|
||||
max_norm=args.grad_clip,
|
||||
enable_gradient_accumulation=True,
|
||||
enable_flash_attention=args.use_flash_attn,
|
||||
)
|
||||
elif args.plugin == "gemini_auto":
|
||||
plugin = GeminiPlugin(
|
||||
precision=args.mixed_precision,
|
||||
placement_policy="auto",
|
||||
initial_scale=2**16,
|
||||
max_norm=args.grad_clip,
|
||||
enable_flash_attention=args.use_flash_attn,
|
||||
)
|
||||
elif args.plugin == "zero2":
|
||||
plugin = LowLevelZeroPlugin(
|
||||
stage=2,
|
||||
precision=args.mixed_precision,
|
||||
initial_scale=2**16,
|
||||
max_norm=args.grad_clip,
|
||||
)
|
||||
elif args.plugin == "zero2_cpu":
|
||||
plugin = LowLevelZeroPlugin(
|
||||
stage=2,
|
||||
precision=args.mixed_precision,
|
||||
initial_scale=2**16,
|
||||
cpu_offload=True,
|
||||
max_norm=args.grad_clip,
|
||||
)
|
||||
elif args.plugin == "3d":
|
||||
plugin = HybridParallelPlugin(
|
||||
tp_size=args.tp,
|
||||
pp_size=args.pp,
|
||||
sp_size=args.sp,
|
||||
sequence_parallelism_mode=args.sp_mode,
|
||||
zero_stage=args.zero_stage,
|
||||
enable_flash_attention=args.use_flash_attn,
|
||||
enable_sequence_parallelism=args.enable_sequence_parallelism,
|
||||
cpu_offload=True if args.zero_stage >= 1 and args.zero_cpu_offload else False,
|
||||
parallel_output=False,
|
||||
max_norm=args.grad_clip,
|
||||
precision=args.mixed_precision,
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unknown plugin {args.plugin}")
|
||||
|
||||
booster = Booster(plugin=plugin)
|
||||
ref_booster = Booster(plugin=plugin)
|
||||
|
||||
# ======================================================
|
||||
# Initialize Model, Objective, Optimizer and LR Scheduler
|
||||
# ======================================================
|
||||
# Temp Fix: Disable lazy init due to version conflict
|
||||
# init_ctx = (
|
||||
# LazyInitContext(default_device=get_current_device()) if isinstance(plugin, (GeminiPlugin,)) else nullcontext()
|
||||
# )
|
||||
|
||||
init_ctx = nullcontext()
|
||||
with init_ctx:
|
||||
if args.use_flash_attn:
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
args.pretrain,
|
||||
torch_dtype=torch.bfloat16 if args.mixed_precision == "bf16" else torch.float16,
|
||||
use_flash_attention_2=True,
|
||||
)
|
||||
coordinator.print_on_master(msg="Flash-attention enabled successfully")
|
||||
else:
|
||||
model = AutoModelForCausalLM.from_pretrained(args.pretrain)
|
||||
disable_dropout(model)
|
||||
if args.use_flash_attn:
|
||||
ref_model = AutoModelForCausalLM.from_pretrained(
|
||||
args.pretrain,
|
||||
torch_dtype=torch.bfloat16 if args.mixed_precision == "bf16" else torch.float16,
|
||||
use_flash_attention_2=True,
|
||||
)
|
||||
else:
|
||||
ref_model = AutoModelForCausalLM.from_pretrained(args.pretrain)
|
||||
disable_dropout(ref_model)
|
||||
if args.lora_rank > 0:
|
||||
model = convert_to_lora_module(model, args.lora_rank, lora_train_bias=args.lora_train_bias)
|
||||
|
||||
if args.grad_checkpoint:
|
||||
# Note, for some models, lora may not be compatible with gradient checkpointing
|
||||
model.gradient_checkpointing_enable()
|
||||
coordinator.print_on_master(msg="Gradient checkpointing enabled successfully")
|
||||
|
||||
# configure tokenizer
|
||||
tokenizer_dir = args.tokenizer_dir if args.tokenizer_dir is not None else args.pretrain
|
||||
tokenizer = AutoTokenizer.from_pretrained(tokenizer_dir, use_fast=False, trust_remote_code=True)
|
||||
if hasattr(tokenizer, "pad_token") and hasattr(tokenizer, "eos_token") and tokenizer.eos_token is not None:
|
||||
try:
|
||||
# Some tokenizers doesn't allow to set pad_token mannually e.g., Qwen
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
except AttributeError as e:
|
||||
logger.warning(f"Unable to set pad token to eos token, {str(e)}")
|
||||
if not hasattr(tokenizer, "pad_token") or tokenizer.pad_token is None:
|
||||
logger.warning(
|
||||
"The tokenizer does not have a pad token which is required. May lead to unintended behavior in training, Please consider manually set them."
|
||||
)
|
||||
|
||||
tokenizer.add_bos_token = False
|
||||
tokenizer.add_eos_token = False
|
||||
|
||||
# configure optimizer
|
||||
optim = HybridAdam(
|
||||
model_params=model.parameters(),
|
||||
lr=args.lr,
|
||||
betas=(0.9, 0.95),
|
||||
weight_decay=args.weight_decay,
|
||||
adamw_mode=True,
|
||||
)
|
||||
|
||||
# configure dataset
|
||||
coordinator.print_on_master(f"Load dataset: {args.dataset}")
|
||||
mode_map = {"train": "train", "valid": "validation", "test": "test"}
|
||||
train_dataset = load_tokenized_dataset(dataset_paths=args.dataset, mode="train", mode_map=mode_map)
|
||||
num_desirable = 0
|
||||
num_undesirable = 0
|
||||
for i in range(len(train_dataset)):
|
||||
if train_dataset[i]["label"]:
|
||||
num_desirable += 1
|
||||
else:
|
||||
num_undesirable += 1
|
||||
logger.info(f"Dataset Statistics:\nDesirable: {num_desirable}\nUndesirable: {num_undesirable}")
|
||||
|
||||
# Check if the user specified weights fit into the theoratical lower and upper bounds from Eq. (8) of https://arxiv.org/abs/2402.01306
|
||||
actual_ratio = (args.desirable_weight * num_desirable) / (args.undesirable_weight * num_undesirable)
|
||||
if actual_ratio <= 1:
|
||||
raise AssertionError(
|
||||
f"Desirable weight and undesirable weight are not within the theoratical bounds, [1, 4/3]. Actual ratio: {actual_ratio}, please increase desirable weight or decrease undesirable weight."
|
||||
)
|
||||
elif actual_ratio > 4 / 3:
|
||||
raise AssertionError(
|
||||
f"Desirable weight and undesirable weight are not within the theoratical bounds, [1, 4/3]. Actual ratio: {actual_ratio}, please decrease desirable weight or increase undesirable weight."
|
||||
)
|
||||
|
||||
data_collator = DataCollatorForKTODataset(tokenizer=tokenizer, max_length=args.max_length)
|
||||
|
||||
train_dataloader = plugin.prepare_dataloader(
|
||||
dataset=train_dataset,
|
||||
batch_size=args.batch_size,
|
||||
shuffle=True,
|
||||
drop_last=True,
|
||||
collate_fn=data_collator,
|
||||
distributed_sampler_cls=StatefulDistributedSampler,
|
||||
)
|
||||
eval_dataloader = None
|
||||
if args.eval_dataset:
|
||||
eval_dataset = load_tokenized_dataset(dataset_paths=args.eval_dataset, mode="dev")
|
||||
eval_data_collator = DataCollatorForKTODataset(tokenizer=tokenizer, max_length=args.max_length)
|
||||
|
||||
eval_dataloader = plugin.prepare_dataloader(
|
||||
dataset=eval_dataset,
|
||||
batch_size=args.batch_size,
|
||||
shuffle=True,
|
||||
drop_last=True,
|
||||
collate_fn=eval_data_collator,
|
||||
distributed_sampler_cls=StatefulDistributedSampler,
|
||||
)
|
||||
else:
|
||||
logger.warning("No evaluation dataset is provided, skip evaluation")
|
||||
|
||||
num_update_steps_per_epoch = len(train_dataloader) // args.accumulation_steps
|
||||
if args.warmup_steps is None:
|
||||
args.warmup_steps = int(args.max_epochs * 0.025 * (len(train_dataloader) // args.accumulation_steps))
|
||||
coordinator.print_on_master(f"Warmup steps is set to {args.warmup_steps}")
|
||||
|
||||
lr_scheduler = CosineAnnealingWarmupLR(
|
||||
optimizer=optim,
|
||||
total_steps=args.max_epochs * num_update_steps_per_epoch,
|
||||
warmup_steps=args.warmup_steps,
|
||||
eta_min=0.1 * args.lr,
|
||||
)
|
||||
|
||||
default_dtype = torch.float16 if args.mixed_precision == "fp16" else torch.bfloat16
|
||||
torch.set_default_dtype(default_dtype)
|
||||
model, optim, _, train_dataloader, lr_scheduler = booster.boost(
|
||||
model=model,
|
||||
optimizer=optim,
|
||||
lr_scheduler=lr_scheduler,
|
||||
dataloader=train_dataloader,
|
||||
)
|
||||
if ref_model is not None:
|
||||
ref_model, _, _, _, _ = ref_booster.boost(model=ref_model, dataloader=train_dataloader)
|
||||
torch.set_default_dtype(torch.float)
|
||||
|
||||
coordinator.print_on_master(f"Booster init max CUDA memory: {torch.cuda.max_memory_allocated() / 1024 ** 2:.2f} MB")
|
||||
coordinator.print_on_master(
|
||||
f"Booster init max CPU memory: {resource.getrusage(resource.RUSAGE_SELF).ru_maxrss / 1024:.2f} MB"
|
||||
)
|
||||
|
||||
start_epoch = 0
|
||||
sampler_start_idx = 0
|
||||
start_step = 0
|
||||
if args.checkpoint_path is not None:
|
||||
if "modeling" in args.checkpoint_path:
|
||||
coordinator.print_on_master(f"Continued pretrain from checkpoint {args.checkpoint_path}")
|
||||
booster.load_model(model, args.checkpoint_path)
|
||||
else:
|
||||
coordinator.print_on_master(f"Load model checkpoint from {args.checkpoint_path}")
|
||||
start_epoch, start_step, sampler_start_idx = load_checkpoint(
|
||||
load_dir=args.checkpoint_path,
|
||||
booster=booster,
|
||||
model=model,
|
||||
optimizer=optim,
|
||||
lr_scheduler=lr_scheduler,
|
||||
)
|
||||
assert isinstance(train_dataloader.sampler, StatefulDistributedSampler)
|
||||
train_dataloader.sampler.set_start_index(start_index=sampler_start_idx)
|
||||
|
||||
coordinator.print_on_master(
|
||||
f"Loaded checkpoint {args.checkpoint_path} at epoch {start_epoch} step {start_step}"
|
||||
)
|
||||
coordinator.print_on_master(f"Loaded sample at index {sampler_start_idx}")
|
||||
|
||||
coordinator.print_on_master(
|
||||
f"Checkpoint loaded max CUDA memory: {torch.cuda.max_memory_allocated() / 1024 ** 2:.2f} MB"
|
||||
)
|
||||
coordinator.print_on_master(
|
||||
f"Checkpoint loaded CUDA memory: {torch.cuda.memory_allocated() / 1024 ** 2:.2f} MB"
|
||||
)
|
||||
coordinator.print_on_master(
|
||||
f"Checkpoint loaded max CPU memory: {resource.getrusage(resource.RUSAGE_SELF).ru_maxrss / 1024:.2f} MB"
|
||||
)
|
||||
|
||||
trainer = KTOTrainer(
|
||||
actor=model,
|
||||
ref_model=ref_model,
|
||||
booster=booster,
|
||||
actor_optim=optim,
|
||||
actor_lr_scheduler=lr_scheduler,
|
||||
tokenizer=tokenizer,
|
||||
max_epochs=args.max_epochs,
|
||||
accumulation_steps=args.accumulation_steps,
|
||||
start_epoch=start_epoch,
|
||||
save_interval=args.save_interval,
|
||||
save_dir=args.save_dir,
|
||||
coordinator=coordinator,
|
||||
beta=args.beta,
|
||||
desirable_weight=args.desirable_weight,
|
||||
undesirable_weight=args.undesirable_weight,
|
||||
)
|
||||
|
||||
trainer.fit(
|
||||
train_preference_dataloader=train_dataloader,
|
||||
eval_preference_dataloader=eval_dataloader,
|
||||
log_dir=args.log_dir,
|
||||
use_wandb=args.use_wandb,
|
||||
)
|
||||
|
||||
if args.lora_rank > 0 and args.merge_lora_weights:
|
||||
from coati.models.lora import LORA_MANAGER
|
||||
|
||||
# NOTE: set model to eval to merge LoRA weights
|
||||
LORA_MANAGER.merge_weights = True
|
||||
model.eval()
|
||||
# save model checkpoint after fitting on only rank0
|
||||
coordinator.print_on_master("Start saving final model checkpoint")
|
||||
booster.save_model(model, os.path.join(args.save_dir, "modeling"), shard=True)
|
||||
coordinator.print_on_master(f"Saved final model checkpoint at epoch {args.max_epochs} at folder {args.save_dir}")
|
||||
|
||||
coordinator.print_on_master(f"Max CUDA memory usage: {torch.cuda.max_memory_allocated()/1024**2:.2f} MB")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# ==============================
|
||||
# Parse Arguments
|
||||
# ==============================
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--plugin",
|
||||
type=str,
|
||||
default="gemini",
|
||||
choices=["gemini", "gemini_auto", "zero2", "zero2_cpu", "3d"],
|
||||
help="Choose which plugin to use",
|
||||
)
|
||||
parser.add_argument("--grad_clip", type=float, default=1.0, help="Gradient clipping value")
|
||||
parser.add_argument("--weight_decay", type=float, default=0.1, help="Weight decay")
|
||||
parser.add_argument("--warmup_steps", type=int, default=None, help="Warmup steps")
|
||||
parser.add_argument("--tp", type=int, default=1)
|
||||
parser.add_argument("--pp", type=int, default=1)
|
||||
parser.add_argument("--sp", type=int, default=1)
|
||||
parser.add_argument("--beta", type=float, default=0.1, help="beta in KTO loss")
|
||||
parser.add_argument("--desirable_weight", type=float, default=1.0, help="desirable_weight in KTO loss")
|
||||
parser.add_argument("--undesirable_weight", type=float, default=1.0, help="undesirable_weight in KTO loss")
|
||||
parser.add_argument("--enable_sequence_parallelism", default=False, action="store_true")
|
||||
parser.add_argument("--zero_stage", type=int, default=0, help="Zero stage", choices=[0, 1, 2])
|
||||
parser.add_argument("--zero_cpu_offload", default=False, action="store_true")
|
||||
parser.add_argument("--sp_mode", type=str, default="split_gather", choices=["split_gather", "ring", "all_to_all"])
|
||||
parser.add_argument("--pretrain", type=str, default=None)
|
||||
parser.add_argument("--tokenizer_dir", type=str, default=None)
|
||||
parser.add_argument("--dataset", nargs="+", default=[])
|
||||
parser.add_argument("--eval_dataset", nargs="+", default=[])
|
||||
parser.add_argument(
|
||||
"--checkpoint_path", type=str, default=None, help="Checkpoint path if need to resume training form a checkpoint"
|
||||
)
|
||||
parser.add_argument("--config_file", type=str, default="config_file", help="Config file")
|
||||
parser.add_argument("--save_dir", type=str, default="output")
|
||||
parser.add_argument("--max_length", type=int, default=2048, help="Model max length")
|
||||
parser.add_argument("--max_epochs", type=int, default=3)
|
||||
parser.add_argument("--batch_size", type=int, default=4)
|
||||
|
||||
parser.add_argument("--mixed_precision", type=str, default="fp16", choices=["fp16", "bf16"], help="Mixed precision")
|
||||
parser.add_argument("--lora_rank", type=int, default=0, help="low-rank adaptation matrices rank")
|
||||
parser.add_argument(
|
||||
"--lora_train_bias",
|
||||
type=str,
|
||||
default="none",
|
||||
help="'none' means it doesn't train biases. 'all' means it trains all biases. 'lora_only' means it only trains biases of LoRA layers",
|
||||
)
|
||||
parser.add_argument("--save_interval", type=int, default=1000, help="number of step between two checkpoints")
|
||||
parser.add_argument("--merge_lora_weights", type=bool, default=True)
|
||||
parser.add_argument("--lr", type=float, default=5e-6)
|
||||
parser.add_argument("--accumulation_steps", type=int, default=8)
|
||||
parser.add_argument("--log_dir", default="logs", type=str)
|
||||
parser.add_argument("--use_wandb", default=False, action="store_true")
|
||||
parser.add_argument("--grad_checkpoint", default=False, action="store_true")
|
||||
parser.add_argument("--use_flash_attn", default=False, action="store_true")
|
||||
args = parser.parse_args()
|
||||
os.makedirs(os.path.dirname(args.config_file), exist_ok=True)
|
||||
with open(args.config_file, "w") as f:
|
||||
json.dump(args.__dict__, f, indent=4)
|
||||
train(args)
|
|
@ -0,0 +1,61 @@
|
|||
#!/bin/bash
|
||||
set_n_least_used_CUDA_VISIBLE_DEVICES() {
|
||||
local n=${1:-"9999"}
|
||||
echo "GPU Memory Usage:"
|
||||
local FIRST_N_GPU_IDS=$(nvidia-smi --query-gpu=memory.used --format=csv |
|
||||
tail -n +2 |
|
||||
nl -v 0 |
|
||||
tee /dev/tty |
|
||||
sort -g -k 2 |
|
||||
awk '{print $1}' |
|
||||
head -n $n)
|
||||
export CUDA_VISIBLE_DEVICES=$(echo $FIRST_N_GPU_IDS | sed 's/ /,/g')
|
||||
echo "Now CUDA_VISIBLE_DEVICES is set to:"
|
||||
echo "CUDA_VISIBLE_DEVICES=$CUDA_VISIBLE_DEVICES"
|
||||
}
|
||||
set_n_least_used_CUDA_VISIBLE_DEVICES 4
|
||||
|
||||
PROJECT_NAME="kto"
|
||||
PARENT_SAVE_DIR="/home/nvme-share/home/yeanbang/data/experiments/kto/checkpoint" # Path to a folder to save checkpoints
|
||||
PARENT_TENSORBOARD_DIR="/home/nvme-share/home/yeanbang/data/experiments/kto/log" # Path to a folder to save logs
|
||||
PARENT_CONFIG_FILE="/home/nvme-share/home/yeanbang/data/experiments/kto/log" # Path to a folder to save training config logs
|
||||
PRETRAINED_MODEL_PATH="/home/nvme-share/home/yeanbang/data/model/hh_rlhf_sheared_llamasft-2024-07-17-07-29-29/modeling" # huggingface or local model path
|
||||
PRETRAINED_TOKENIZER_PATH="/home/nvme-share/share/models/Sheared-LLaMA-1.3B" # huggingface or local tokenizer path
|
||||
|
||||
declare -a dataset=(
|
||||
/home/nvme-share/home/yeanbang/data/experiments/kto/arrow/part-00000
|
||||
/home/nvme-share/home/yeanbang/data/experiments/kto/arrow/part-00001
|
||||
/home/nvme-share/home/yeanbang/data/experiments/kto/arrow/part-00002
|
||||
/home/nvme-share/home/yeanbang/data/experiments/kto/arrow/part-00003
|
||||
/home/nvme-share/home/yeanbang/data/experiments/kto/arrow/part-00004
|
||||
/home/nvme-share/home/yeanbang/data/experiments/kto/arrow/part-00005
|
||||
/home/nvme-share/home/yeanbang/data/experiments/kto/arrow/part-00006
|
||||
/home/nvme-share/home/yeanbang/data/experiments/kto/arrow/part-00007
|
||||
/home/nvme-share/home/yeanbang/data/experiments/kto/arrow/part-00008
|
||||
/home/nvme-share/home/yeanbang/data/experiments/kto/arrow/part-00009
|
||||
)
|
||||
|
||||
TIMESTAMP=$(date +%Y-%m-%d-%H-%M-%S)
|
||||
FULL_PROJECT_NAME="${PROJECT_NAME}-${TIMESTAMP}"
|
||||
SAVE_DIR="${PARENT_SAVE_DIR}${FULL_PROJECT_NAME}"
|
||||
CONFIG_FILE="${PARENT_CONFIG_FILE}-${FULL_PROJECT_NAME}.json"
|
||||
|
||||
colossalai run --nproc_per_node 4 --master_port 31313 train_kto.py \
|
||||
--pretrain $PRETRAINED_MODEL_PATH \
|
||||
--tokenizer_dir $PRETRAINED_TOKENIZER_PATH \
|
||||
--dataset ${dataset[@]} \
|
||||
--plugin "zero2" \
|
||||
--save_interval 1000 \
|
||||
--save_dir $SAVE_DIR \
|
||||
--config_file $CONFIG_FILE \
|
||||
--max_epochs 1 \
|
||||
--accumulation_steps 1 \
|
||||
--batch_size 8 \
|
||||
--lr 1e-5 \
|
||||
--beta 0.1 \
|
||||
--mixed_precision "bf16" \
|
||||
--grad_clip 1.0 \
|
||||
--max_length 1024 \
|
||||
--weight_decay 0.01 \
|
||||
--warmup_steps 60 \
|
||||
--grad_checkpoint
|
|
@ -42,7 +42,6 @@ CONFIG_FILE="${PARENT_CONFIG_FILE}-${FULL_PROJECT_NAME}.json"
|
|||
|
||||
colossalai run --nproc_per_node 8 --hostfile hostfile --master_port 31312 train_rm.py \
|
||||
--pretrain $PRETRAINED_MODEL_PATH \
|
||||
--checkpoint_path /home/yeanbang/data/experiments/rm/hhh_aligh/ckptllama2-rm-2024-01-17-14-43-24/epoch-1_step-1317/modeling \
|
||||
--tokenizer_dir $PRETRAINED_TOKENIZER_PATH \
|
||||
--dataset ${dataset[@]} \
|
||||
--plugin "zero2" \
|
||||
|
|
|
@ -15,22 +15,22 @@ set_n_least_used_CUDA_VISIBLE_DEVICES() {
|
|||
|
||||
set_n_least_used_CUDA_VISIBLE_DEVICES 4
|
||||
PROJECT_NAME="sft"
|
||||
PARENT_SAVE_DIR="" # Path to a folder to save checkpoints
|
||||
PARENT_TENSORBOARD_DIR="" # Path to a folder to save logs
|
||||
PARENT_CONFIG_FILE="" # Path to a folder to save training config logs
|
||||
PRETRAINED_MODEL_PATH="" # huggingface or local model path
|
||||
PRETRAINED_TOKENIZER_PATH="" # huggingface or local tokenizer path
|
||||
PARENT_SAVE_DIR="/home/nvme-share/home/yeanbang/data/model/hh_rlhf_sheared_llama" # Path to a folder to save checkpoints
|
||||
PARENT_TENSORBOARD_DIR="/home/nvme-share/home/yeanbang/data/experiments/sft/log" # Path to a folder to save logs
|
||||
PARENT_CONFIG_FILE="/home/nvme-share/home/yeanbang/data/experiments/kto/log" # Path to a folder to save training config logs
|
||||
PRETRAINED_MODEL_PATH="/home/nvme-share/share/models/Sheared-LLaMA-1.3B" # huggingface or local model path
|
||||
PRETRAINED_TOKENIZER_PATH="/home/nvme-share/share/models/Sheared-LLaMA-1.3B" # huggingface or local tokenizer path
|
||||
declare -a dataset=(
|
||||
/Your/SFT/Data/arrow/part-00000
|
||||
/Your/SFT/Data/arrow/part-00001
|
||||
/Your/SFT/Data/arrow/part-00002
|
||||
/Your/SFT/Data/arrow/part-00003
|
||||
/Your/SFT/Data/arrow/part-00004
|
||||
/Your/SFT/Data/arrow/part-00005
|
||||
/Your/SFT/Data/arrow/part-00006
|
||||
/Your/SFT/Data/arrow/part-00007
|
||||
/Your/SFT/Data/arrow/part-00008
|
||||
/Your/SFT/Data/arrow/part-00009
|
||||
/home/nvme-share/home/yeanbang/data/experiments/sft/arrow/part-00000
|
||||
/home/nvme-share/home/yeanbang/data/experiments/sft/arrow/part-00001
|
||||
/home/nvme-share/home/yeanbang/data/experiments/sft/arrow/part-00002
|
||||
/home/nvme-share/home/yeanbang/data/experiments/sft/arrow/part-00003
|
||||
/home/nvme-share/home/yeanbang/data/experiments/sft/arrow/part-00004
|
||||
/home/nvme-share/home/yeanbang/data/experiments/sft/arrow/part-00005
|
||||
/home/nvme-share/home/yeanbang/data/experiments/sft/arrow/part-00006
|
||||
/home/nvme-share/home/yeanbang/data/experiments/sft/arrow/part-00007
|
||||
/home/nvme-share/home/yeanbang/data/experiments/sft/arrow/part-00008
|
||||
/home/nvme-share/home/yeanbang/data/experiments/sft/arrow/part-00009
|
||||
)
|
||||
|
||||
TIMESTAMP=$(date +%Y-%m-%d-%H-%M-%S)
|
||||
|
@ -44,17 +44,16 @@ echo $(which python)
|
|||
colossalai run --nproc_per_node 4 --master_port 31312 --hostfile ./hostfile train_sft.py \
|
||||
--pretrain $PRETRAINED_MODEL_PATH \
|
||||
--tokenizer_dir $PRETRAINED_TOKENIZER_PATH \
|
||||
--save_interval 4000 \
|
||||
--save_interval 2000 \
|
||||
--dataset ${dataset[@]} \
|
||||
--save_path $SAVE_DIR \
|
||||
--config_file $CONFIG_FILE \
|
||||
--lora_rank 0 \
|
||||
--plugin zero2 \
|
||||
--batch_size 4 \
|
||||
--batch_size 8 \
|
||||
--max_epochs 1 \
|
||||
--accumulation_steps 4 \
|
||||
--accumulation_steps 2 \
|
||||
--lr 5e-5 \
|
||||
--max_len 4096 \
|
||||
--grad_checkpoint \
|
||||
--use_wandb \
|
||||
--use_flash_attn
|
||||
|
|
|
@ -37,6 +37,19 @@ preference_seed = {
|
|||
"chosen": [{"from": "assistant", "content": "You can read?"}],
|
||||
"rejected": [{"from": "assistant", "content": "there's a lot of stuff humans don't know"}],
|
||||
}
|
||||
kto_seed = {
|
||||
"prompt": [
|
||||
{"from": "human", "content": "What are some cuss words in english?"},
|
||||
{
|
||||
"from": "assistant",
|
||||
"content": "Here's an incomplete list.\n\nAss, dick, bugger, crap, fuck, shit, bitch, turd, shithead, shitbag, scrotum, cunt, whore, fucker, shit-eating, cum, cumbucket, fucknugget, butthole, poop, jackass, cocksucker, asshole, goddamn, piss, sperm, blow, wank, jism, cum-sucking, masturbate, faggot, queer, jizz, jizz-licking, prostitute, slut, cheater, fornicator, floozy, wetback, Mexican, Hispanic, sodomite, midget, mama's boy, faggot, pervert, queer, scumbag, bitch,",
|
||||
},
|
||||
{"from": "human", "content": "What's your favorite one?"},
|
||||
],
|
||||
"completion": {"from": "assistant", "content": "Ass."},
|
||||
"label": False,
|
||||
}
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
|
@ -61,12 +74,21 @@ if __name__ == "__main__":
|
|||
seed = prompt_seed
|
||||
elif args.data_type == "preference":
|
||||
seed = preference_seed
|
||||
elif args.data_type == "kto":
|
||||
seed = kto_seed
|
||||
else:
|
||||
raise ValueError(f"Unknown data type {args.data_type}")
|
||||
|
||||
if args.data_type != "kto":
|
||||
line = json.dumps(seed, ensure_ascii=False) + "\n"
|
||||
for idx in [1, 2, 3]:
|
||||
with open(os.path.join(args.data_dir, f"{idx}.jsonl"), "w", encoding="utf8") as f:
|
||||
for i in range(1000):
|
||||
f.write(line)
|
||||
f.write(line)
|
||||
else:
|
||||
for idx in [1, 2, 3]:
|
||||
with open(os.path.join(args.data_dir, f"{idx}.jsonl"), "w", encoding="utf8") as f:
|
||||
for i in range(1000):
|
||||
seed["label"] = not seed["label"]
|
||||
line = json.dumps(seed, ensure_ascii=False) + "\n"
|
||||
f.write(line)
|
||||
|
|
|
@ -71,6 +71,8 @@ get_data_input_dirs() {
|
|||
echo "$PROMPT_DATASET"
|
||||
elif [[ $data_type == "preference" ]]; then
|
||||
echo "$PREFERENCE_DATASET"
|
||||
elif [[ $data_type == "kto" ]]; then
|
||||
echo "$KTO_DATASET"
|
||||
else
|
||||
echo "Unknown data type $data_type"
|
||||
exit 1
|
||||
|
@ -121,6 +123,10 @@ python $TEST_DIR/generate_dummy_datasets_for_testing.py \
|
|||
--data_dir $(get_data_input_dirs prompt) \
|
||||
--data_type "prompt"
|
||||
|
||||
python $TEST_DIR/generate_dummy_datasets_for_testing.py \
|
||||
--data_dir $(get_data_input_dirs kto) \
|
||||
--data_type "kto"
|
||||
|
||||
echo "[Test]: testing prepare_preference_dataset.py ..."
|
||||
|
||||
# FIXME: This is a hack to skip tests that are not working
|
||||
|
@ -258,3 +264,50 @@ for model in ${MODELS[@]}; do
|
|||
exit 1
|
||||
fi
|
||||
done
|
||||
|
||||
|
||||
echo "[Test]: testing prepare_kto_dataset.py ..."
|
||||
|
||||
# FIXME: This is a hack to skip tests that are not working
|
||||
SKIPPED_TESTS=(
|
||||
)
|
||||
|
||||
# test prepare_kto_dataset
|
||||
for model in ${MODELS[@]}; do
|
||||
data_type="kto"
|
||||
if [[ " ${SKIPPED_TESTS[*]} " =~ " $model-$data_type " ]]; then
|
||||
echo "[Test]: Skipped $model-$data_type"
|
||||
continue
|
||||
fi
|
||||
cache_dir=$DATA_SAVE_PATH/tokenized_${model}_${data_type}/cache
|
||||
jsonl_dir=$DATA_SAVE_PATH/tokenized_${model}_${data_type}/jsonl
|
||||
arrow_dir=$DATA_SAVE_PATH/tokenized_${model}_${data_type}/arrow
|
||||
data_input_dirs=$(get_data_input_dirs $data_type)
|
||||
tokenizer_dir=$(get_tokenizer_dirs $model)
|
||||
conversation_template=$(get_conversation_template_config $model)
|
||||
for i in $(seq $NUM_RETRY); do
|
||||
rm -rf $cache_dir
|
||||
rm -rf $jsonl_dir
|
||||
rm -rf $arrow_dir
|
||||
echo "[Test]: $model-$data_type, attempt $i"
|
||||
python $EXAMPLES_DIR/data_preparation_scripts/prepare_dataset.py \
|
||||
--type kto \
|
||||
--data_input_dirs $data_input_dirs \
|
||||
--conversation_template_config $conversation_template \
|
||||
--tokenizer_dir $tokenizer_dir \
|
||||
--data_cache_dir $cache_dir \
|
||||
--data_jsonl_output_dir $jsonl_dir \
|
||||
--data_arrow_output_dir $arrow_dir \
|
||||
--max_length 400 \
|
||||
--num_samples_per_datafile 100 \
|
||||
--num_spliced_dataset_bins 1
|
||||
passed=$?
|
||||
if [ $passed -eq 0 ]; then
|
||||
break
|
||||
fi
|
||||
done
|
||||
if [ $passed -ne 0 ]; then
|
||||
echo "[Test]: Failed $model-$data_type"
|
||||
exit 1
|
||||
fi
|
||||
done
|
||||
|
|
|
@ -193,8 +193,8 @@ for lora_rank in ${LORA_RANK[@]}; do
|
|||
--use_flash_attn
|
||||
passed=$?
|
||||
if [ $passed -eq 0 ]; then
|
||||
rm -rf $MODEL_SAVE_PATH/*
|
||||
rm -rf $MODELS_DIR/*
|
||||
rm -rf ${MODEL_SAVE_PATH:?}/*
|
||||
rm -rf ${MODELS_DIR:?}/*
|
||||
break
|
||||
fi
|
||||
done
|
||||
|
@ -264,8 +264,8 @@ for lora_rank in ${LORA_RANK[@]}; do
|
|||
--use_flash_attn
|
||||
passed=$?
|
||||
if [ $passed -eq 0 ]; then
|
||||
rm -rf $MODEL_SAVE_PATH/*
|
||||
rm -rf $MODELS_DIR/*
|
||||
rm -rf ${MODEL_SAVE_PATH:?}/*
|
||||
rm -rf ${MODELS_DIR:?}/*
|
||||
break
|
||||
fi
|
||||
done
|
||||
|
@ -363,8 +363,8 @@ for lora_rank in ${LORA_RANK[@]}; do
|
|||
# --use_flash_attn
|
||||
passed=$?
|
||||
if [ $passed -eq 0 ]; then
|
||||
rm -rf $MODEL_SAVE_PATH/*
|
||||
rm -rf $MODELS_DIR/*
|
||||
rm -rf ${MODEL_SAVE_PATH:?}/*
|
||||
rm -rf ${MODELS_DIR:?}/*
|
||||
break
|
||||
fi
|
||||
done
|
||||
|
@ -440,8 +440,8 @@ for lora_rank in ${LORA_RANK[@]}; do
|
|||
--use_flash_attn
|
||||
passed=$?
|
||||
if [ $passed -eq 0 ]; then
|
||||
rm -rf $MODEL_SAVE_PATH/*
|
||||
rm -rf $MODELS_DIR/*
|
||||
rm -rf ${MODEL_SAVE_PATH:?}/*
|
||||
rm -rf ${MODELS_DIR:?}/*
|
||||
break
|
||||
fi
|
||||
done
|
||||
|
@ -518,8 +518,87 @@ for lora_rank in ${LORA_RANK[@]}; do
|
|||
--use_flash_attn
|
||||
passed=$?
|
||||
if [ $passed -eq 0 ]; then
|
||||
rm -rf $MODEL_SAVE_PATH/*
|
||||
rm -rf $MODELS_DIR/*
|
||||
rm -rf ${MODEL_SAVE_PATH:?}/*
|
||||
rm -rf ${MODELS_DIR:?}/*
|
||||
break
|
||||
fi
|
||||
done
|
||||
if [ $passed -ne 0 ]; then
|
||||
echo "[Test]: Failed $model-$plugin-$lora_rank"
|
||||
exit 1
|
||||
fi
|
||||
done
|
||||
done
|
||||
done
|
||||
|
||||
|
||||
|
||||
echo "[Test]: testing KTO ..."
|
||||
|
||||
SKIPPED_TESTS=(
|
||||
llama-3d-20 # 3d plugin doesn't support lora
|
||||
llama-gemini_auto-20 # gemini_auto plugin doesn't support lora
|
||||
llama-gemini-20 # gemini doesn't support lora
|
||||
)
|
||||
GRAD_CKPTS=('--grad_checkpoint')
|
||||
for lora_rank in ${LORA_RANK[@]}; do
|
||||
for model in ${MODELS[@]}; do
|
||||
for plugin in ${PLUGINS[@]}; do
|
||||
if [[ " ${SKIPPED_TESTS[*]} " =~ " $model-$plugin-$lora_rank " ]]; then
|
||||
echo "[Test]: Skipped $model-$plugin-$lora_rank"
|
||||
continue
|
||||
elif [[ " ${SKIPPED_TESTS[*]} " =~ " $model-$plugin " ]]; then
|
||||
echo "[Test]: Skipped $model-$plugin"
|
||||
continue
|
||||
fi
|
||||
pretrain=$(get_pretrain $model)
|
||||
tokenizer_dir=$(get_tokenizer_dirs $model)
|
||||
grad_ckpt=$(random_choice "${GRAD_CKPTS[@]}")
|
||||
tp='1'
|
||||
bs='2'
|
||||
if [[ $plugin == "3d" ]]; then
|
||||
tp='4'
|
||||
bs='8'
|
||||
fi
|
||||
grad_accu='2'
|
||||
# gemini_auto and gemini doesn't support gradient accumulation
|
||||
if [[ $plugin == "gemini_auto" ]]; then
|
||||
grad_accu='1'
|
||||
fi
|
||||
# gemini_auto doesn't support generation
|
||||
# (need to calculate ref_model logits through forwarding in inference mode)
|
||||
if [[ $plugin == "gemini_auto" ]]; then
|
||||
echo "[Test]: Skipped $model-$plugin"
|
||||
continue
|
||||
fi
|
||||
for i in $(seq $NUM_RETRY); do
|
||||
echo "[Test]: $model-$plugin-$lora_rank, attempt $i"
|
||||
declare -a dataset=()
|
||||
for split in $(seq -f "%05g" 0 0); do
|
||||
dataset+=("$TEMP_DIR/rlhf_data/tokenized_${model}_kto/arrow/part-$split")
|
||||
done
|
||||
colossalai run --nproc_per_node 4 --master_port 31332 $EXAMPLES_DIR/training_scripts/train_kto.py \
|
||||
--pretrain $pretrain \
|
||||
--tokenizer_dir $tokenizer_dir \
|
||||
--dataset ${dataset[@]} \
|
||||
--eval_dataset ${dataset[@]} \
|
||||
--save_dir $MODEL_SAVE_PATH \
|
||||
--config_file $MODELS_DIR/config.jsonl \
|
||||
--lora_rank $lora_rank \
|
||||
--plugin $plugin \
|
||||
--batch_size $bs \
|
||||
--max_epochs 1 \
|
||||
--accumulation_steps $grad_accu \
|
||||
--tp $tp \
|
||||
--lr 2e-5 \
|
||||
--desirable_weight 1.2 \
|
||||
$grad_ckpt \
|
||||
--max_len 400 \
|
||||
--use_flash_attn
|
||||
passed=$?
|
||||
if [ $passed -eq 0 ]; then
|
||||
rm -rf ${MODEL_SAVE_PATH:?}/*
|
||||
rm -rf ${MODELS_DIR:?}/*
|
||||
break
|
||||
fi
|
||||
done
|
||||
|
|
Loading…
Reference in New Issue