@ -6,6 +6,7 @@ import os
from typing import Any , Optional
import torch
import torch . distributed as dist
from coati . models . loss import DpoLoss
from coati . models . utils import calc_masked_log_probs
from coati . trainer . utils import all_reduce_mean
@ -13,10 +14,11 @@ from coati.utils import AccumulativeMeanMeter, save_checkpoint
from torch . optim import Optimizer
from torch . optim . lr_scheduler import _LRScheduler
from torch . utils . data import DataLoader
from tqdm import t range
from tqdm import t qdm, t range
from transformers import PreTrainedTokenizerBase
from colossalai . booster import Booster , Plugin
from colossalai . booster . plugin import HybridParallelPlugin
from colossalai . cluster import DistCoordinator
from colossalai . utils import get_current_device
@ -96,18 +98,25 @@ class DPOTrainer(SLTrainer):
self . train_dataloader = train_preference_dataloader
self . eval_dataloader = eval_preference_dataloader
self . writer = None
if use_wandb and is_rank_0 ( ) :
init_criterion = (
dist . get_rank ( ) == dist . get_world_size ( ) - 1
if isinstance ( self . plugin , HybridParallelPlugin ) and self . plugin . pp_size > 1
else is_rank_0 ( )
)
if use_wandb and init_criterion :
assert log_dir is not None , " log_dir must be provided when use_wandb is True "
import wandb
self . wandb_run = wandb . init ( project = " Coati-dpo " , sync_tensorboard = True )
if log_dir is not None and is_rank_0 ( ) :
if log_dir is not None and i nit_criterion :
import os
import time
from torch . utils . tensorboard import SummaryWriter
log_dir = os . path . join ( log_dir , " dpo " )
log_dir = os . path . join ( log_dir , " DPO " )
log_dir = os . path . join ( log_dir , time . strftime ( " % Y- % m- %d _ % H: % M: % S " , time . localtime ( ) ) )
self . writer = SummaryWriter ( log_dir = log_dir )
@ -117,6 +126,140 @@ class DPOTrainer(SLTrainer):
epoch int : the number of current epoch
"""
self . model . train ( )
if isinstance ( self . plugin , HybridParallelPlugin ) and self . plugin . pp_size > 1 :
step_bar = tqdm (
range ( len ( self . train_dataloader ) ) ,
desc = " Step " ,
disable = not ( dist . get_rank ( ) == dist . get_world_size ( ) - 1 ) ,
)
for i , batch in enumerate ( self . train_dataloader ) :
batch = to_device ( batch , self . device )
(
chosen_input_ids ,
chosen_attention_mask ,
chosen_loss_mask ,
reject_input_ids ,
reject_attention_mask ,
reject_loss_mask ,
) = (
batch [ " chosen_input_ids " ] ,
batch [ " chosen_attention_mask " ] ,
batch [ " chosen_loss_mask " ] ,
batch [ " reject_input_ids " ] ,
batch [ " reject_attention_mask " ] ,
batch [ " reject_loss_mask " ] ,
)
batch_size = chosen_input_ids . size ( ) [ 0 ]
# Calculate logits from reference model.
if self . ref_model is not None :
self . ref_model . eval ( )
with torch . no_grad ( ) :
ref_all_logits = self . ref_model (
input_ids = torch . cat ( [ chosen_input_ids , reject_input_ids ] ) ,
attention_mask = torch . cat ( [ chosen_attention_mask , reject_attention_mask ] ) ,
) [ " logits " ]
ref_chosen_logits = ref_all_logits [ : batch_size ]
ref_reject_logits = ref_all_logits [ batch_size : ]
logprob_ref_chosen = calc_masked_log_probs (
ref_chosen_logits , chosen_input_ids , chosen_loss_mask [ : , 1 : ] , self . length_normalization
)
logprob_ref_reject = calc_masked_log_probs (
ref_reject_logits , reject_input_ids , reject_loss_mask [ : , 1 : ] , self . length_normalization
)
else :
logprob_ref_chosen = None
logprob_ref_reject = None
# Merge chosen and reject
inputs_ids = torch . stack ( [ item for tup in zip ( chosen_input_ids , reject_input_ids ) for item in tup ] )
attention_mask = torch . stack (
[ item for tup in zip ( chosen_attention_mask , reject_attention_mask ) for item in tup ]
)
loss_mask = torch . stack ( [ item for tup in zip ( chosen_loss_mask , reject_loss_mask ) for item in tup ] )
logprob_ref = torch . stack ( [ item for tup in zip ( logprob_ref_chosen , logprob_ref_reject ) for item in tup ] )
data_iter = iter (
[
{
" input_ids " : inputs_ids ,
" attention_mask " : attention_mask ,
" loss_mask " : loss_mask ,
" logprob_ref " : logprob_ref ,
}
]
)
rewards = [ ]
def _criterion ( outputs , inputs ) :
loss , chosen_rewards , rejected_rewards = self . actor_loss_fn (
calc_masked_log_probs (
outputs [ " logits " ] [ 0 : : 2 ] ,
inputs [ " input_ids " ] [ 0 : : 2 ] ,
inputs [ " loss_mask " ] [ 0 : : 2 ] [ : , 1 : ] ,
self . length_normalization ,
) ,
calc_masked_log_probs (
outputs [ " logits " ] [ 1 : : 2 ] ,
inputs [ " input_ids " ] [ 1 : : 2 ] ,
inputs [ " loss_mask " ] [ 1 : : 2 ] [ : , 1 : ] ,
self . length_normalization ,
) ,
inputs [ " logprob_ref " ] [ 0 : : 2 ] if inputs [ " logprob_ref " ] is not None else None ,
inputs [ " logprob_ref " ] [ 1 : : 2 ] if inputs [ " logprob_ref " ] is not None else None ,
inputs [ " loss_mask " ] [ 0 : : 2 ] [ : , 1 : ] ,
inputs [ " loss_mask " ] [ 1 : : 2 ] [ : , 1 : ] ,
)
rewards . append ( chosen_rewards )
rewards . append ( rejected_rewards )
return loss
outputs = self . booster . execute_pipeline (
data_iter ,
self . model ,
criterion = _criterion ,
optimizer = self . optimizer ,
return_loss = True ,
)
loss = outputs [ " loss " ]
if self . booster . plugin . stage_manager . is_last_stage ( ) :
chosen_rewards , rejected_rewards = rewards [ 0 ] , rewards [ 1 ]
global_loss = all_reduce_mean ( loss , self . plugin )
if dist . get_rank ( ) == dist . get_world_size ( ) - 1 :
step_bar . set_postfix (
{
" train/loss " : global_loss . item ( ) ,
" train/lr " : self . actor_scheduler . get_last_lr ( ) [ 0 ] ,
" train/chosen_rewards " : chosen_rewards . to ( torch . float16 ) . mean ( ) . item ( ) ,
" train/rejected_rewards " : rejected_rewards . to ( torch . float16 ) . mean ( ) . item ( ) ,
}
)
step_bar . update ( )
self . accumulative_meter . add ( " loss " , global_loss . item ( ) )
self . accumulative_meter . add ( " chosen_rewards " , chosen_rewards . to ( torch . float16 ) . mean ( ) . item ( ) )
self . accumulative_meter . add (
" rejected_rewards " , rejected_rewards . to ( torch . float16 ) . mean ( ) . item ( )
)
if self . writer is not None :
self . writer . add_scalar ( " train/loss " , self . accumulative_meter . get ( " loss " ) , i )
self . writer . add_scalar (
" train/chosen_rewards " , self . accumulative_meter . get ( " chosen_rewards " ) , i
)
self . writer . add_scalar (
" train/rejected_rewards " ,
self . accumulative_meter . get ( " rejected_rewards " ) ,
i ,
)
self . writer . add_scalar (
" train/margin " ,
self . accumulative_meter . get ( " chosen_rewards " )
- self . accumulative_meter . get ( " rejected_rewards " ) ,
i ,
)
self . optimizer . step ( )
self . optimizer . zero_grad ( )
self . actor_scheduler . step ( )
else :
self . accumulative_meter . reset ( )
step_bar = trange (
len ( self . train_dataloader ) / / self . accumulation_steps ,
@ -179,7 +322,7 @@ class DPOTrainer(SLTrainer):
logprob_ref_chosen = None
logprob_ref_reject = None
los se s, chosen_rewards , rejected_rewards = self . actor_loss_fn (
los s, chosen_rewards , rejected_rewards = self . actor_loss_fn (
logprob_actor_chosen ,
logprob_actor_reject ,
logprob_ref_chosen if logprob_ref_chosen is not None else None ,
@ -189,15 +332,7 @@ class DPOTrainer(SLTrainer):
)
reward_accuracies = ( chosen_rewards > rejected_rewards ) . float ( ) . mean ( )
# DPO Loss
loss = losses . mean ( )
self . booster . backward ( loss = loss , optimizer = self . optimizer )
if self . num_train_step % self . accumulation_steps == self . accumulation_steps - 1 :
self . optimizer . step ( )
self . optimizer . zero_grad ( )
self . actor_scheduler . step ( )
# sync
loss_mean = all_reduce_mean ( tensor = loss )
chosen_rewards_mean = all_reduce_mean ( tensor = chosen_rewards )
@ -208,10 +343,20 @@ class DPOTrainer(SLTrainer):
self . accumulative_meter . add ( " loss " , loss_mean . to ( torch . float16 ) . item ( ) )
self . accumulative_meter . add ( " accuracy " , reward_accuracies_mean . to ( torch . float16 ) . item ( ) )
if i % self . accumulation_steps == self . accumulation_steps - 1 :
self . num_train_step + = 1
if ( i + 1 ) % self . accumulation_steps == 0 :
self . optimizer . step ( )
self . optimizer . zero_grad ( )
self . actor_scheduler . step ( )
step_bar . set_postfix (
{
" train/loss " : self . accumulative_meter . get ( " loss " ) ,
" train/chosen_rewards " : self . accumulative_meter . get ( " chosen_rewards " ) ,
" train/rejected_rewards " : self . accumulative_meter . get ( " rejected_rewards " ) ,
" train/accuracy " : self . accumulative_meter . get ( " accuracy " ) ,
}
)
step_bar . update ( )
# logging
if self . writer and is_rank_0 ( ) :
self . writer . add_scalar ( " train/loss " , self . accumulative_meter . get ( " loss " ) , self . num_train_step )
self . writer . add_scalar ( " train/lr " , self . optimizer . param_groups [ 0 ] [ " lr " ] , self . num_train_step )
@ -225,7 +370,8 @@ class DPOTrainer(SLTrainer):
)
self . writer . add_scalar (
" train/margin " ,
self . accumulative_meter . get ( " chosen_rewards " ) - self . accumulative_meter . get ( " rejected_rewards " ) ,
self . accumulative_meter . get ( " chosen_rewards " )
- self . accumulative_meter . get ( " rejected_rewards " ) ,
self . num_train_step ,
)
self . writer . add_scalar (
@ -233,9 +379,10 @@ class DPOTrainer(SLTrainer):
self . accumulative_meter . get ( " accuracy " ) ,
self . num_train_step ,
)
self . num_train_step + = 1
self . accumulative_meter . reset ( )
if self . save_dir is not None and ( self . num_train_step + 1 ) % self . save_interval == 0 :
if self . save_dir is not None and self . num_train_step > 0 and self . num_train_step % self . save_interval == 0 :
# save checkpoint
self . coordinator . print_on_master ( " \n Start saving model checkpoint with running states " )
save_checkpoint (
@ -245,7 +392,7 @@ class DPOTrainer(SLTrainer):
optimizer = self . optimizer ,
lr_scheduler = self . actor_scheduler ,
epoch = epoch ,
step = i + 1 ,
step = self . num_train_step ,
batch_size = batch_size ,
coordinator = self . coordinator ,
)
@ -265,16 +412,141 @@ class DPOTrainer(SLTrainer):
return
self . model . eval ( )
self . ref_model . eval ( )
self . accumulative_meter . reset ( )
self . coordinator . print_on_master ( " \n Start evaluation... " )
if isinstance ( self . plugin , HybridParallelPlugin ) and self . plugin . pp_size > 1 :
step_bar = tqdm (
range ( len ( self . eval_dataloader ) ) ,
desc = " Step " ,
disable = not ( dist . get_rank ( ) == dist . get_world_size ( ) - 1 ) ,
)
with torch . no_grad ( ) :
for _ , batch in enumerate ( self . eval_dataloader ) :
batch = to_device ( batch , self . device )
(
chosen_input_ids ,
chosen_attention_mask ,
chosen_loss_mask ,
reject_input_ids ,
reject_attention_mask ,
reject_loss_mask ,
) = (
batch [ " chosen_input_ids " ] ,
batch [ " chosen_attention_mask " ] ,
batch [ " chosen_loss_mask " ] ,
batch [ " reject_input_ids " ] ,
batch [ " reject_attention_mask " ] ,
batch [ " reject_loss_mask " ] ,
)
batch_size = chosen_input_ids . size ( ) [ 0 ]
# Calculate logits from reference model.
if self . ref_model is not None :
self . ref_model . eval ( )
with torch . no_grad ( ) :
ref_all_logits = self . ref_model (
input_ids = torch . cat ( [ chosen_input_ids , reject_input_ids ] ) ,
attention_mask = torch . cat ( [ chosen_attention_mask , reject_attention_mask ] ) ,
) [ " logits " ]
ref_chosen_logits = ref_all_logits [ : batch_size ]
ref_reject_logits = ref_all_logits [ batch_size : ]
logprob_ref_chosen = calc_masked_log_probs (
ref_chosen_logits , chosen_input_ids , chosen_loss_mask [ : , 1 : ] , self . length_normalization
)
logprob_ref_reject = calc_masked_log_probs (
ref_reject_logits , reject_input_ids , reject_loss_mask [ : , 1 : ] , self . length_normalization
)
else :
logprob_ref_chosen = None
logprob_ref_reject = None
# Merge chosen and reject
inputs_ids = torch . stack ( [ item for tup in zip ( chosen_input_ids , reject_input_ids ) for item in tup ] )
attention_mask = torch . stack (
[ item for tup in zip ( chosen_attention_mask , reject_attention_mask ) for item in tup ]
)
loss_mask = torch . stack ( [ item for tup in zip ( chosen_loss_mask , reject_loss_mask ) for item in tup ] )
logprob_ref = torch . stack (
[ item for tup in zip ( logprob_ref_chosen , logprob_ref_reject ) for item in tup ]
)
data_iter = iter (
[
{
" input_ids " : inputs_ids ,
" attention_mask " : attention_mask ,
" loss_mask " : loss_mask ,
" logprob_ref " : logprob_ref ,
}
]
)
rewards = [ ]
def _criterion ( outputs , inputs ) :
loss , chosen_rewards , rejected_rewards = self . actor_loss_fn (
calc_masked_log_probs (
outputs [ " logits " ] [ 0 : : 2 ] ,
inputs [ " input_ids " ] [ 0 : : 2 ] ,
inputs [ " loss_mask " ] [ 0 : : 2 ] [ : , 1 : ] ,
self . length_normalization ,
) ,
calc_masked_log_probs (
outputs [ " logits " ] [ 1 : : 2 ] ,
inputs [ " input_ids " ] [ 1 : : 2 ] ,
inputs [ " loss_mask " ] [ 1 : : 2 ] [ : , 1 : ] ,
self . length_normalization ,
) ,
inputs [ " logprob_ref " ] [ 0 : : 2 ] if inputs [ " logprob_ref " ] is not None else None ,
inputs [ " logprob_ref " ] [ 1 : : 2 ] if inputs [ " logprob_ref " ] is not None else None ,
inputs [ " loss_mask " ] [ 0 : : 2 ] [ : , 1 : ] ,
inputs [ " loss_mask " ] [ 1 : : 2 ] [ : , 1 : ] ,
)
rewards . append ( chosen_rewards )
rewards . append ( rejected_rewards )
return loss
outputs = self . booster . execute_pipeline (
data_iter ,
self . model ,
criterion = _criterion ,
optimizer = self . optimizer ,
return_loss = True ,
)
loss = outputs [ " loss " ]
if self . booster . plugin . stage_manager . is_last_stage ( ) :
chosen_rewards , rejected_rewards = rewards [ 0 ] , rewards [ 1 ]
global_loss = all_reduce_mean ( loss , self . plugin )
chosen_rewards_mean = all_reduce_mean ( chosen_rewards , self . plugin )
rejected_rewards_mean = all_reduce_mean ( rejected_rewards , self . plugin )
if dist . get_rank ( ) == dist . get_world_size ( ) - 1 :
step_bar . set_postfix (
{
" eval/loss " : global_loss . item ( ) ,
" eval/lr " : self . actor_scheduler . get_last_lr ( ) [ 0 ] ,
" eval/chosen_rewards " : chosen_rewards . to ( torch . float16 ) . mean ( ) . item ( ) ,
" eval/rejected_rewards " : rejected_rewards . to ( torch . float16 ) . mean ( ) . item ( ) ,
}
)
self . accumulative_meter . add (
" chosen_rewards " , chosen_rewards_mean . to ( torch . float16 ) . mean ( ) . item ( )
)
self . accumulative_meter . add (
" rejected_rewards " , rejected_rewards_mean . to ( torch . float16 ) . mean ( ) . item ( )
)
self . accumulative_meter . add ( " loss " , global_loss . to ( torch . float16 ) . item ( ) )
step_bar . update ( )
if self . booster . plugin . stage_manager . is_last_stage ( ) :
msg = " \n Evaluation Result: \n "
for tag in [ " loss " , " chosen_rewards " , " rejected_rewards " ] :
msg = msg + f " { tag } : { self . accumulative_meter . get ( tag ) } \n "
if dist . get_rank ( ) == dist . get_world_size ( ) - 1 :
print ( msg )
else :
step_bar = trange (
len ( self . eval_dataloader ) ,
desc = f " Epoch { epoch + 1 } / { self . max_epochs } " ,
disable = not is_rank_0 ( ) ,
)
self . accumulative_meter . reset ( )
with torch . no_grad ( ) :
for i , batch in enumerate ( self . eval_dataloader ) :
batch = to_device ( batch , self . device )
@ -313,9 +585,6 @@ class DPOTrainer(SLTrainer):
logprob_actor_reject = calc_masked_log_probs (
actor_reject_logits , reject_input_ids , reject_loss_mask [ : , 1 : ] , self . length_normalization
)
self . ref_model . eval ( )
ref_all_logits = self . ref_model (
torch . cat ( [ chosen_input_ids , reject_input_ids ] ) ,
torch . cat ( [ chosen_attention_mask , reject_attention_mask ] ) ,
@ -344,18 +613,29 @@ class DPOTrainer(SLTrainer):
rejected_rewards_mean = all_reduce_mean ( tensor = rejected_rewards )
reward_accuracies_mean = all_reduce_mean ( tensor = reward_accuracies )
self . accumulative_meter . add ( " chosen_rewards " , chosen_rewards_mean . to ( torch . float16 ) . mean ( ) . item ( ) )
self . accumulative_meter . add ( " rejected_rewards " , rejected_rewards_mean . to ( torch . float16 ) . mean ( ) . item ( ) )
self . accumulative_meter . add (
" rejected_rewards " , rejected_rewards_mean . to ( torch . float16 ) . mean ( ) . item ( )
)
self . accumulative_meter . add ( " loss " , loss_mean . to ( torch . float16 ) . item ( ) )
self . accumulative_meter . add ( " accuracy " , reward_accuracies_mean . to ( torch . float16 ) . item ( ) )
self . accumulative_meter . add (
" margin " , ( chosen_rewards_mean - rejected_rewards_mean ) . to ( torch . float16 ) . mean ( ) . item ( )
)
step_bar . set_postfix (
{
" eval/loss " : self . accumulative_meter . get ( " loss " ) ,
" eval/chosen_rewards " : self . accumulative_meter . get ( " chosen_rewards " ) ,
" eval/rejected_rewards " : self . accumulative_meter . get ( " rejected_rewards " ) ,
" eval/accuracy " : self . accumulative_meter . get ( " accuracy " ) ,
}
)
step_bar . update ( )
msg = " Evaluation Result: \n "
msg = " \n Evaluation Result: \n "
for tag in [ " loss " , " chosen_rewards " , " rejected_rewards " , " accuracy " , " margin " ] :
msg = msg + f " { tag } : { self . accumulative_meter . get ( tag ) } \n "
self . coordinator . print_on_master ( msg )
if self . save_dir is not None :
os . makedirs ( self . save_dir , exist_ok = True )
with open ( os . path . join ( self . save_dir , f " eval_result_epoch { epoch } .txt " ) , " w " ) as f :
f . write ( msg )