From b3594d4d68458d5d7add9f323ead33c470a1dac1 Mon Sep 17 00:00:00 2001 From: YeAnbang Date: Mon, 15 Jul 2024 02:12:05 +0000 Subject: [PATCH] fix orpo cross entropy loss --- .../ColossalChat/coati/trainer/orpo.py | 23 +++++++------------ .../examples/training_scripts/train_dpo.py | 2 ++ .../examples/training_scripts/train_orpo.py | 2 ++ .../examples/training_scripts/train_rm.py | 5 ++++ .../examples/training_scripts/train_sft.py | 2 ++ 5 files changed, 19 insertions(+), 15 deletions(-) diff --git a/applications/ColossalChat/coati/trainer/orpo.py b/applications/ColossalChat/coati/trainer/orpo.py index 3a751690d..495bb332b 100644 --- a/applications/ColossalChat/coati/trainer/orpo.py +++ b/applications/ColossalChat/coati/trainer/orpo.py @@ -10,7 +10,6 @@ from coati.models.loss import OddsRatioLoss from coati.models.utils import calc_masked_log_probs from coati.trainer.utils import all_reduce_mean from coati.utils import AccumulativeMeanMeter, save_checkpoint -from torch.nn import CrossEntropyLoss from torch.optim import Optimizer from torch.optim.lr_scheduler import _LRScheduler from torch.utils.data import DataLoader @@ -63,7 +62,6 @@ class ORPOTrainer(SLTrainer): self.actor_scheduler = actor_lr_scheduler self.tokenizer = tokenizer self.odds_ratio_loss_fn = OddsRatioLoss() - self.sft_loss_fn = CrossEntropyLoss() self.save_interval = save_interval self.coordinator = coordinator self.save_dir = save_dir @@ -136,6 +134,9 @@ class ORPOTrainer(SLTrainer): actor_out = self.model( input_ids=torch.cat([chosen_input_ids, reject_input_ids]), attention_mask=torch.cat([chosen_attention_mask, reject_attention_mask]), + labels=torch.cat( + [chosen_input_ids, torch.ones_like(reject_input_ids, dtype=reject_input_ids.dtype) * -100] + ), ) torch.autograd.set_detect_anomaly(True) actor_all_logits = actor_out["logits"].to(torch.float32) @@ -144,13 +145,8 @@ class ORPOTrainer(SLTrainer): logprob_actor_chosen = calc_masked_log_probs(actor_chosen_logits, chosen_input_ids, chosen_loss_mask[:, 1:]) logprob_actor_reject = calc_masked_log_probs(actor_reject_logits, reject_input_ids, reject_loss_mask[:, 1:]) - chosen_logits = actor_chosen_logits[:, :-1, :].contiguous().view(-1, actor_chosen_logits.size(-1)) - label_chosen = chosen_input_ids[:, 1:].contiguous() - label_chosen_masked = ( - label_chosen.masked_fill(chosen_loss_mask[:, 1:] == 0, -100).view(-1).contiguous().detach() - ) # label_chosen[chosen_loss_mask[:, 1:] == 0] = -100 - chosen_nll = self.sft_loss_fn(chosen_logits, label_chosen_masked).to(dtype=torch.bfloat16) + chosen_nll = actor_out["loss"] odds_ratio_loss, log_odds_ratio = self.odds_ratio_loss_fn( logprob_actor_chosen, logprob_actor_reject, chosen_loss_mask[:, 1:], reject_loss_mask[:, 1:] ) @@ -271,6 +267,9 @@ class ORPOTrainer(SLTrainer): actor_out = self.model( input_ids=torch.cat([chosen_input_ids, reject_input_ids]), attention_mask=torch.cat([chosen_attention_mask, reject_attention_mask]), + labels=torch.cat( + [chosen_input_ids, torch.ones_like(reject_input_ids, dtype=reject_input_ids.dtype) * -100] + ), ) torch.autograd.set_detect_anomaly(True) actor_all_logits = actor_out["logits"].to(torch.float32) @@ -283,13 +282,7 @@ class ORPOTrainer(SLTrainer): logprob_actor_reject = calc_masked_log_probs( actor_reject_logits, reject_input_ids, reject_loss_mask[:, 1:] ) - chosen_logits = actor_chosen_logits[:, :-1, :].contiguous().view(-1, actor_chosen_logits.size(-1)) - label_chosen = chosen_input_ids[:, 1:].contiguous() - label_chosen_masked = ( - label_chosen.masked_fill(chosen_loss_mask[:, 1:] == 0, -100).view(-1).contiguous().detach() - ) - # label_chosen[chosen_loss_mask[:, 1:] == 0] = -100 - chosen_nll = self.sft_loss_fn(chosen_logits, label_chosen_masked).to(dtype=torch.bfloat16) + chosen_nll = actor_out["loss"] odds_ratio_loss, log_odds_ratio = self.odds_ratio_loss_fn( logprob_actor_chosen, logprob_actor_reject, chosen_loss_mask[:, 1:], reject_loss_mask[:, 1:] ) diff --git a/applications/ColossalChat/examples/training_scripts/train_dpo.py b/applications/ColossalChat/examples/training_scripts/train_dpo.py index 2765a39cb..06d7133ca 100755 --- a/applications/ColossalChat/examples/training_scripts/train_dpo.py +++ b/applications/ColossalChat/examples/training_scripts/train_dpo.py @@ -189,6 +189,8 @@ def train(args): collate_fn=eval_data_collator, distributed_sampler_cls=StatefulDistributedSampler, ) + else: + logger.warning("No evaluation dataset is provided, skip evaluation") num_update_steps_per_epoch = len(train_dataloader) // args.accumulation_steps if args.warmup_steps is None: diff --git a/applications/ColossalChat/examples/training_scripts/train_orpo.py b/applications/ColossalChat/examples/training_scripts/train_orpo.py index 4451e9c60..886aa39dd 100755 --- a/applications/ColossalChat/examples/training_scripts/train_orpo.py +++ b/applications/ColossalChat/examples/training_scripts/train_orpo.py @@ -176,6 +176,8 @@ def train(args): collate_fn=eval_data_collator, distributed_sampler_cls=StatefulDistributedSampler, ) + else: + logger.warning("No evaluation dataset is provided, skip evaluation") num_update_steps_per_epoch = len(train_dataloader) // args.accumulation_steps if args.warmup_steps is None: diff --git a/applications/ColossalChat/examples/training_scripts/train_rm.py b/applications/ColossalChat/examples/training_scripts/train_rm.py index 978b936dc..f8e894e7e 100755 --- a/applications/ColossalChat/examples/training_scripts/train_rm.py +++ b/applications/ColossalChat/examples/training_scripts/train_rm.py @@ -16,10 +16,13 @@ import colossalai from colossalai.booster import Booster from colossalai.booster.plugin import GeminiPlugin, HybridParallelPlugin, LowLevelZeroPlugin, TorchDDPPlugin from colossalai.cluster import DistCoordinator +from colossalai.logging import get_dist_logger from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR from colossalai.nn.optimizer import HybridAdam from colossalai.shardformer.policies.auto_policy import get_autopolicy +logger = get_dist_logger() + def train(args): # check lora compatibility @@ -186,6 +189,8 @@ def train(args): collate_fn=eval_data_collator, distributed_sampler_cls=StatefulDistributedSampler, ) + else: + logger.warning("No evaluation dataset is provided, skip evaluation") num_update_steps_per_epoch = len(train_dataloader) // args.accumulation_steps math.ceil(args.max_epochs * num_update_steps_per_epoch) diff --git a/applications/ColossalChat/examples/training_scripts/train_sft.py b/applications/ColossalChat/examples/training_scripts/train_sft.py index ccd8a5868..fe1506559 100755 --- a/applications/ColossalChat/examples/training_scripts/train_sft.py +++ b/applications/ColossalChat/examples/training_scripts/train_sft.py @@ -187,6 +187,8 @@ def train(args): collate_fn=eval_data_collator, distributed_sampler_cls=StatefulDistributedSampler, ) + else: + logger.warning("No evaluation dataset is provided, skip evaluation") coordinator.print_on_master( f"Max CUDA memory after data loader: {torch.cuda.max_memory_allocated() / 1024 ** 2:.2f} MB"