2024-07-18 07:54:11 +00:00
import argparse
import json
import os
import resource
from contextlib import nullcontext
import torch
from coati . dataset import DataCollatorForKTODataset , StatefulDistributedSampler , load_tokenized_dataset
2024-07-31 06:10:17 +00:00
from coati . models import LoraConfig , convert_to_lora_module , disable_dropout
2024-07-18 07:54:11 +00:00
from coati . trainer import KTOTrainer
from coati . utils import load_checkpoint
from transformers import AutoModelForCausalLM , AutoTokenizer
import colossalai
from colossalai . booster import Booster
from colossalai . booster . plugin import GeminiPlugin , HybridParallelPlugin , LowLevelZeroPlugin
from colossalai . cluster import DistCoordinator
from colossalai . logging import get_dist_logger
from colossalai . nn . lr_scheduler import CosineAnnealingWarmupLR
from colossalai . nn . optimizer import HybridAdam
logger = get_dist_logger ( )
def train ( args ) :
2024-07-31 06:10:17 +00:00
lora_config = None
if args . lora_config is not None :
lora_config = LoraConfig . from_file ( args . lora_config )
2024-07-18 07:54:11 +00:00
# check lora compatibility
2024-07-31 06:10:17 +00:00
if " gemini " in args . plugin and lora_config is not None and lora_config . r > 0 :
2024-07-18 07:54:11 +00:00
raise ValueError ( " LoRA is not supported in GeminiPlugin. Please use other plugin " )
if args . plugin == " gemini_auto " and args . accumulation_steps > 1 :
raise ValueError ( " Gradient accumulation is not supported in GeminiPlugin. Please use other plugin " )
# ==============================
# Initialize Distributed Training
# ==============================
colossalai . launch_from_torch ( )
coordinator = DistCoordinator ( )
# ==============================
# Initialize Booster
# ==============================
if args . plugin == " ddp " :
"""
Default torch ddp plugin without any acceleration , for
debugging purpose acceleration , for debugging purpose
"""
plugin = TorchDDPPlugin ( find_unused_parameters = True )
elif args . plugin == " gemini " :
plugin = GeminiPlugin (
precision = args . mixed_precision ,
placement_policy = " static " ,
initial_scale = 2 * * 16 ,
max_norm = args . grad_clip ,
enable_gradient_accumulation = True ,
enable_flash_attention = args . use_flash_attn ,
)
elif args . plugin == " gemini_auto " :
plugin = GeminiPlugin (
precision = args . mixed_precision ,
placement_policy = " auto " ,
initial_scale = 2 * * 16 ,
max_norm = args . grad_clip ,
enable_flash_attention = args . use_flash_attn ,
)
elif args . plugin == " zero2 " :
plugin = LowLevelZeroPlugin (
stage = 2 ,
precision = args . mixed_precision ,
initial_scale = 2 * * 16 ,
max_norm = args . grad_clip ,
)
elif args . plugin == " zero2_cpu " :
plugin = LowLevelZeroPlugin (
stage = 2 ,
precision = args . mixed_precision ,
initial_scale = 2 * * 16 ,
cpu_offload = True ,
max_norm = args . grad_clip ,
)
elif args . plugin == " 3d " :
plugin = HybridParallelPlugin (
tp_size = args . tp ,
pp_size = args . pp ,
sp_size = args . sp ,
sequence_parallelism_mode = args . sp_mode ,
zero_stage = args . zero_stage ,
enable_flash_attention = args . use_flash_attn ,
enable_sequence_parallelism = args . enable_sequence_parallelism ,
cpu_offload = True if args . zero_stage > = 1 and args . zero_cpu_offload else False ,
parallel_output = False ,
max_norm = args . grad_clip ,
precision = args . mixed_precision ,
)
else :
raise ValueError ( f " Unknown plugin { args . plugin } " )
booster = Booster ( plugin = plugin )
ref_booster = Booster ( plugin = plugin )
# ======================================================
# Initialize Model, Objective, Optimizer and LR Scheduler
# ======================================================
# Temp Fix: Disable lazy init due to version conflict
# init_ctx = (
# LazyInitContext(default_device=get_current_device()) if isinstance(plugin, (GeminiPlugin,)) else nullcontext()
# )
init_ctx = nullcontext ( )
with init_ctx :
if args . use_flash_attn :
model = AutoModelForCausalLM . from_pretrained (
args . pretrain ,
torch_dtype = torch . bfloat16 if args . mixed_precision == " bf16 " else torch . float16 ,
use_flash_attention_2 = True ,
)
coordinator . print_on_master ( msg = " Flash-attention enabled successfully " )
else :
model = AutoModelForCausalLM . from_pretrained ( args . pretrain )
2024-07-31 06:10:17 +00:00
2024-07-18 07:54:11 +00:00
if args . use_flash_attn :
ref_model = AutoModelForCausalLM . from_pretrained (
args . pretrain ,
torch_dtype = torch . bfloat16 if args . mixed_precision == " bf16 " else torch . float16 ,
use_flash_attention_2 = True ,
)
else :
ref_model = AutoModelForCausalLM . from_pretrained ( args . pretrain )
2024-07-31 06:10:17 +00:00
if args . lora_config is not None :
model = convert_to_lora_module ( model , lora_config = lora_config )
for name , module in model . named_modules ( ) :
if " norm " in name or " gate " in name :
module = module . to ( torch . float32 )
2024-07-18 07:54:11 +00:00
disable_dropout ( ref_model )
2024-07-31 06:10:17 +00:00
disable_dropout ( model )
2024-07-18 07:54:11 +00:00
if args . grad_checkpoint :
# Note, for some models, lora may not be compatible with gradient checkpointing
2024-07-31 06:10:17 +00:00
model . gradient_checkpointing_enable ( gradient_checkpointing_kwargs = { " use_reentrant " : False } )
2024-07-18 07:54:11 +00:00
coordinator . print_on_master ( msg = " Gradient checkpointing enabled successfully " )
# configure tokenizer
tokenizer_dir = args . tokenizer_dir if args . tokenizer_dir is not None else args . pretrain
tokenizer = AutoTokenizer . from_pretrained ( tokenizer_dir , use_fast = False , trust_remote_code = True )
if hasattr ( tokenizer , " pad_token " ) and hasattr ( tokenizer , " eos_token " ) and tokenizer . eos_token is not None :
try :
# Some tokenizers doesn't allow to set pad_token mannually e.g., Qwen
tokenizer . pad_token = tokenizer . eos_token
except AttributeError as e :
logger . warning ( f " Unable to set pad token to eos token, { str ( e ) } " )
if not hasattr ( tokenizer , " pad_token " ) or tokenizer . pad_token is None :
logger . warning (
" The tokenizer does not have a pad token which is required. May lead to unintended behavior in training, Please consider manually set them. "
)
tokenizer . add_bos_token = False
tokenizer . add_eos_token = False
# configure optimizer
optim = HybridAdam (
model_params = model . parameters ( ) ,
lr = args . lr ,
betas = ( 0.9 , 0.95 ) ,
weight_decay = args . weight_decay ,
adamw_mode = True ,
)
# configure dataset
coordinator . print_on_master ( f " Load dataset: { args . dataset } " )
mode_map = { " train " : " train " , " valid " : " validation " , " test " : " test " }
train_dataset = load_tokenized_dataset ( dataset_paths = args . dataset , mode = " train " , mode_map = mode_map )
num_desirable = 0
num_undesirable = 0
for i in range ( len ( train_dataset ) ) :
if train_dataset [ i ] [ " label " ] :
num_desirable + = 1
else :
num_undesirable + = 1
logger . info ( f " Dataset Statistics: \n Desirable: { num_desirable } \n Undesirable: { num_undesirable } " )
# Check if the user specified weights fit into the theoratical lower and upper bounds from Eq. (8) of https://arxiv.org/abs/2402.01306
actual_ratio = ( args . desirable_weight * num_desirable ) / ( args . undesirable_weight * num_undesirable )
2024-07-22 05:57:39 +00:00
if actual_ratio < 1 or actual_ratio > 4 / 3 :
if not args . auto_weight :
raise AssertionError (
f " Desirable weight and undesirable weight are not within the theoratical bounds, [1, 4/3]. Actual ratio: { actual_ratio } , please increase/decrease desirable weight or decrease/increase undesirable weight. "
)
else :
args . desirable_weight = args . desirable_weight / actual_ratio
coordinator . print_on_master (
f " Desirable weight and undesirable weight are not within the theoratical bounds, [1, 4/3]. Actual ratio: { actual_ratio } , auto weight is enabled, set desirable weight to { args . desirable_weight } and undesirable weight to { args . undesirable_weight } "
)
2024-07-18 07:54:11 +00:00
data_collator = DataCollatorForKTODataset ( tokenizer = tokenizer , max_length = args . max_length )
train_dataloader = plugin . prepare_dataloader (
dataset = train_dataset ,
batch_size = args . batch_size ,
shuffle = True ,
drop_last = True ,
collate_fn = data_collator ,
distributed_sampler_cls = StatefulDistributedSampler ,
)
eval_dataloader = None
if args . eval_dataset :
eval_dataset = load_tokenized_dataset ( dataset_paths = args . eval_dataset , mode = " dev " )
eval_data_collator = DataCollatorForKTODataset ( tokenizer = tokenizer , max_length = args . max_length )
eval_dataloader = plugin . prepare_dataloader (
dataset = eval_dataset ,
batch_size = args . batch_size ,
shuffle = True ,
drop_last = True ,
collate_fn = eval_data_collator ,
distributed_sampler_cls = StatefulDistributedSampler ,
)
else :
logger . warning ( " No evaluation dataset is provided, skip evaluation " )
num_update_steps_per_epoch = len ( train_dataloader ) / / args . accumulation_steps
if args . warmup_steps is None :
args . warmup_steps = int ( args . max_epochs * 0.025 * ( len ( train_dataloader ) / / args . accumulation_steps ) )
coordinator . print_on_master ( f " Warmup steps is set to { args . warmup_steps } " )
lr_scheduler = CosineAnnealingWarmupLR (
optimizer = optim ,
total_steps = args . max_epochs * num_update_steps_per_epoch ,
warmup_steps = args . warmup_steps ,
eta_min = 0.1 * args . lr ,
)
default_dtype = torch . float16 if args . mixed_precision == " fp16 " else torch . bfloat16
torch . set_default_dtype ( default_dtype )
model , optim , _ , train_dataloader , lr_scheduler = booster . boost (
model = model ,
optimizer = optim ,
lr_scheduler = lr_scheduler ,
dataloader = train_dataloader ,
)
if ref_model is not None :
ref_model , _ , _ , _ , _ = ref_booster . boost ( model = ref_model , dataloader = train_dataloader )
torch . set_default_dtype ( torch . float )
coordinator . print_on_master ( f " Booster init max CUDA memory: { torch . cuda . max_memory_allocated ( ) / 1024 * * 2 : .2f } MB " )
coordinator . print_on_master (
f " Booster init max CPU memory: { resource . getrusage ( resource . RUSAGE_SELF ) . ru_maxrss / 1024 : .2f } MB "
)
start_epoch = 0
sampler_start_idx = 0
start_step = 0
if args . checkpoint_path is not None :
if " modeling " in args . checkpoint_path :
coordinator . print_on_master ( f " Continued pretrain from checkpoint { args . checkpoint_path } " )
booster . load_model ( model , args . checkpoint_path )
else :
coordinator . print_on_master ( f " Load model checkpoint from { args . checkpoint_path } " )
start_epoch , start_step , sampler_start_idx = load_checkpoint (
load_dir = args . checkpoint_path ,
booster = booster ,
model = model ,
optimizer = optim ,
lr_scheduler = lr_scheduler ,
)
assert isinstance ( train_dataloader . sampler , StatefulDistributedSampler )
train_dataloader . sampler . set_start_index ( start_index = sampler_start_idx )
coordinator . print_on_master (
f " Loaded checkpoint { args . checkpoint_path } at epoch { start_epoch } step { start_step } "
)
coordinator . print_on_master ( f " Loaded sample at index { sampler_start_idx } " )
coordinator . print_on_master (
f " Checkpoint loaded max CUDA memory: { torch . cuda . max_memory_allocated ( ) / 1024 * * 2 : .2f } MB "
)
coordinator . print_on_master (
f " Checkpoint loaded CUDA memory: { torch . cuda . memory_allocated ( ) / 1024 * * 2 : .2f } MB "
)
coordinator . print_on_master (
f " Checkpoint loaded max CPU memory: { resource . getrusage ( resource . RUSAGE_SELF ) . ru_maxrss / 1024 : .2f } MB "
)
trainer = KTOTrainer (
actor = model ,
ref_model = ref_model ,
booster = booster ,
actor_optim = optim ,
[ColossalChat] Add PP support (#6001)
* support pp training
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* update rm
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* refactor
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* update test case
* fix
* change to 4
* fix eval
* test
* add pp
* hotfix
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* support pp training
* update rm
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* refactor
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* update test case
* fix
* change to 4
* fix eval
* test
* add pp
* hotfix
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* update
* skip pp eval
* update all reduce
* update sft
* update ignore
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* update no cache
* add eval
* remove fi
* remove debug
* remove parentheses to avoid warning
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* Revert "add eval"
This reverts commit 3ab2f6fa329b6d12959fb3c668d278b4b225c5f0.
* add all reduce
---------
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
2024-08-21 02:47:39 +00:00
plugin = plugin ,
2024-07-18 07:54:11 +00:00
actor_lr_scheduler = lr_scheduler ,
tokenizer = tokenizer ,
max_epochs = args . max_epochs ,
accumulation_steps = args . accumulation_steps ,
start_epoch = start_epoch ,
save_interval = args . save_interval ,
save_dir = args . save_dir ,
coordinator = coordinator ,
beta = args . beta ,
desirable_weight = args . desirable_weight ,
undesirable_weight = args . undesirable_weight ,
2024-08-02 06:51:38 +00:00
apply_loss_mask = not args . disable_loss_mask ,
2024-07-18 07:54:11 +00:00
)
trainer . fit (
train_preference_dataloader = train_dataloader ,
eval_preference_dataloader = eval_dataloader ,
log_dir = args . log_dir ,
use_wandb = args . use_wandb ,
)
2024-07-31 06:10:17 +00:00
if lora_config is not None and lora_config . r > 0 :
2024-07-18 07:54:11 +00:00
# NOTE: set model to eval to merge LoRA weights
model . eval ( )
# save model checkpoint after fitting on only rank0
2024-07-22 05:57:39 +00:00
if args . save_dir is not None :
coordinator . print_on_master ( " Start saving final model checkpoint " )
booster . save_model ( model , os . path . join ( args . save_dir , " modeling " ) , shard = True )
coordinator . print_on_master (
f " Saved final model checkpoint at epoch { args . max_epochs } at folder { args . save_dir } "
)
2024-07-18 07:54:11 +00:00
coordinator . print_on_master ( f " Max CUDA memory usage: { torch . cuda . max_memory_allocated ( ) / 1024 * * 2 : .2f } MB " )
if __name__ == " __main__ " :
# ==============================
# Parse Arguments
# ==============================
parser = argparse . ArgumentParser ( )
parser . add_argument (
" --plugin " ,
type = str ,
default = " gemini " ,
choices = [ " gemini " , " gemini_auto " , " zero2 " , " zero2_cpu " , " 3d " ] ,
help = " Choose which plugin to use " ,
)
parser . add_argument ( " --grad_clip " , type = float , default = 1.0 , help = " Gradient clipping value " )
parser . add_argument ( " --weight_decay " , type = float , default = 0.1 , help = " Weight decay " )
parser . add_argument ( " --warmup_steps " , type = int , default = None , help = " Warmup steps " )
parser . add_argument ( " --tp " , type = int , default = 1 )
parser . add_argument ( " --pp " , type = int , default = 1 )
parser . add_argument ( " --sp " , type = int , default = 1 )
parser . add_argument ( " --beta " , type = float , default = 0.1 , help = " beta in KTO loss " )
parser . add_argument ( " --desirable_weight " , type = float , default = 1.0 , help = " desirable_weight in KTO loss " )
parser . add_argument ( " --undesirable_weight " , type = float , default = 1.0 , help = " undesirable_weight in KTO loss " )
2024-08-02 06:51:38 +00:00
parser . add_argument ( " --disable_loss_mask " , default = False , action = " store_true " )
2024-07-18 07:54:11 +00:00
parser . add_argument ( " --enable_sequence_parallelism " , default = False , action = " store_true " )
parser . add_argument ( " --zero_stage " , type = int , default = 0 , help = " Zero stage " , choices = [ 0 , 1 , 2 ] )
parser . add_argument ( " --zero_cpu_offload " , default = False , action = " store_true " )
parser . add_argument ( " --sp_mode " , type = str , default = " split_gather " , choices = [ " split_gather " , " ring " , " all_to_all " ] )
parser . add_argument ( " --pretrain " , type = str , default = None )
parser . add_argument ( " --tokenizer_dir " , type = str , default = None )
parser . add_argument ( " --dataset " , nargs = " + " , default = [ ] )
parser . add_argument ( " --eval_dataset " , nargs = " + " , default = [ ] )
parser . add_argument (
" --checkpoint_path " , type = str , default = None , help = " Checkpoint path if need to resume training form a checkpoint "
)
2024-07-22 05:57:39 +00:00
parser . add_argument ( " --config_file " , type = str , default = None , help = " Config file " )
parser . add_argument ( " --save_dir " , type = str , default = None )
2024-07-18 07:54:11 +00:00
parser . add_argument ( " --max_length " , type = int , default = 2048 , help = " Model max length " )
parser . add_argument ( " --max_epochs " , type = int , default = 3 )
parser . add_argument ( " --batch_size " , type = int , default = 4 )
parser . add_argument ( " --mixed_precision " , type = str , default = " fp16 " , choices = [ " fp16 " , " bf16 " ] , help = " Mixed precision " )
2024-07-31 06:10:17 +00:00
parser . add_argument ( " --lora_config " , type = str , default = None , help = " low-rank adaptation config file path " )
2024-07-18 07:54:11 +00:00
parser . add_argument ( " --save_interval " , type = int , default = 1000 , help = " number of step between two checkpoints " )
2024-07-22 05:57:39 +00:00
parser . add_argument ( " --auto_weight " , default = False , action = " store_true " )
2024-07-18 07:54:11 +00:00
parser . add_argument ( " --lr " , type = float , default = 5e-6 )
parser . add_argument ( " --accumulation_steps " , type = int , default = 8 )
2024-07-22 05:57:39 +00:00
parser . add_argument ( " --log_dir " , default = None , type = str )
2024-07-18 07:54:11 +00:00
parser . add_argument ( " --use_wandb " , default = False , action = " store_true " )
parser . add_argument ( " --grad_checkpoint " , default = False , action = " store_true " )
parser . add_argument ( " --use_flash_attn " , default = False , action = " store_true " )
args = parser . parse_args ( )
2024-07-22 05:57:39 +00:00
if args . config_file is not None :
os . makedirs ( os . path . dirname ( args . config_file ) , exist_ok = True )
with open ( args . config_file , " w " ) as f :
json . dump ( args . __dict__ , f , indent = 4 )
2024-07-18 07:54:11 +00:00
train ( args )