fix orpo cross entropy loss

pull/5901/head
YeAnbang 2024-07-15 02:12:05 +00:00
parent 115c4cc5a4
commit b3594d4d68
5 changed files with 19 additions and 15 deletions

View File

@ -10,7 +10,6 @@ from coati.models.loss import OddsRatioLoss
from coati.models.utils import calc_masked_log_probs
from coati.trainer.utils import all_reduce_mean
from coati.utils import AccumulativeMeanMeter, save_checkpoint
from torch.nn import CrossEntropyLoss
from torch.optim import Optimizer
from torch.optim.lr_scheduler import _LRScheduler
from torch.utils.data import DataLoader
@ -63,7 +62,6 @@ class ORPOTrainer(SLTrainer):
self.actor_scheduler = actor_lr_scheduler
self.tokenizer = tokenizer
self.odds_ratio_loss_fn = OddsRatioLoss()
self.sft_loss_fn = CrossEntropyLoss()
self.save_interval = save_interval
self.coordinator = coordinator
self.save_dir = save_dir
@ -136,6 +134,9 @@ class ORPOTrainer(SLTrainer):
actor_out = self.model(
input_ids=torch.cat([chosen_input_ids, reject_input_ids]),
attention_mask=torch.cat([chosen_attention_mask, reject_attention_mask]),
labels=torch.cat(
[chosen_input_ids, torch.ones_like(reject_input_ids, dtype=reject_input_ids.dtype) * -100]
),
)
torch.autograd.set_detect_anomaly(True)
actor_all_logits = actor_out["logits"].to(torch.float32)
@ -144,13 +145,8 @@ class ORPOTrainer(SLTrainer):
logprob_actor_chosen = calc_masked_log_probs(actor_chosen_logits, chosen_input_ids, chosen_loss_mask[:, 1:])
logprob_actor_reject = calc_masked_log_probs(actor_reject_logits, reject_input_ids, reject_loss_mask[:, 1:])
chosen_logits = actor_chosen_logits[:, :-1, :].contiguous().view(-1, actor_chosen_logits.size(-1))
label_chosen = chosen_input_ids[:, 1:].contiguous()
label_chosen_masked = (
label_chosen.masked_fill(chosen_loss_mask[:, 1:] == 0, -100).view(-1).contiguous().detach()
)
# label_chosen[chosen_loss_mask[:, 1:] == 0] = -100
chosen_nll = self.sft_loss_fn(chosen_logits, label_chosen_masked).to(dtype=torch.bfloat16)
chosen_nll = actor_out["loss"]
odds_ratio_loss, log_odds_ratio = self.odds_ratio_loss_fn(
logprob_actor_chosen, logprob_actor_reject, chosen_loss_mask[:, 1:], reject_loss_mask[:, 1:]
)
@ -271,6 +267,9 @@ class ORPOTrainer(SLTrainer):
actor_out = self.model(
input_ids=torch.cat([chosen_input_ids, reject_input_ids]),
attention_mask=torch.cat([chosen_attention_mask, reject_attention_mask]),
labels=torch.cat(
[chosen_input_ids, torch.ones_like(reject_input_ids, dtype=reject_input_ids.dtype) * -100]
),
)
torch.autograd.set_detect_anomaly(True)
actor_all_logits = actor_out["logits"].to(torch.float32)
@ -283,13 +282,7 @@ class ORPOTrainer(SLTrainer):
logprob_actor_reject = calc_masked_log_probs(
actor_reject_logits, reject_input_ids, reject_loss_mask[:, 1:]
)
chosen_logits = actor_chosen_logits[:, :-1, :].contiguous().view(-1, actor_chosen_logits.size(-1))
label_chosen = chosen_input_ids[:, 1:].contiguous()
label_chosen_masked = (
label_chosen.masked_fill(chosen_loss_mask[:, 1:] == 0, -100).view(-1).contiguous().detach()
)
# label_chosen[chosen_loss_mask[:, 1:] == 0] = -100
chosen_nll = self.sft_loss_fn(chosen_logits, label_chosen_masked).to(dtype=torch.bfloat16)
chosen_nll = actor_out["loss"]
odds_ratio_loss, log_odds_ratio = self.odds_ratio_loss_fn(
logprob_actor_chosen, logprob_actor_reject, chosen_loss_mask[:, 1:], reject_loss_mask[:, 1:]
)

View File

@ -189,6 +189,8 @@ def train(args):
collate_fn=eval_data_collator,
distributed_sampler_cls=StatefulDistributedSampler,
)
else:
logger.warning("No evaluation dataset is provided, skip evaluation")
num_update_steps_per_epoch = len(train_dataloader) // args.accumulation_steps
if args.warmup_steps is None:

View File

@ -176,6 +176,8 @@ def train(args):
collate_fn=eval_data_collator,
distributed_sampler_cls=StatefulDistributedSampler,
)
else:
logger.warning("No evaluation dataset is provided, skip evaluation")
num_update_steps_per_epoch = len(train_dataloader) // args.accumulation_steps
if args.warmup_steps is None:

View File

@ -16,10 +16,13 @@ import colossalai
from colossalai.booster import Booster
from colossalai.booster.plugin import GeminiPlugin, HybridParallelPlugin, LowLevelZeroPlugin, TorchDDPPlugin
from colossalai.cluster import DistCoordinator
from colossalai.logging import get_dist_logger
from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR
from colossalai.nn.optimizer import HybridAdam
from colossalai.shardformer.policies.auto_policy import get_autopolicy
logger = get_dist_logger()
def train(args):
# check lora compatibility
@ -186,6 +189,8 @@ def train(args):
collate_fn=eval_data_collator,
distributed_sampler_cls=StatefulDistributedSampler,
)
else:
logger.warning("No evaluation dataset is provided, skip evaluation")
num_update_steps_per_epoch = len(train_dataloader) // args.accumulation_steps
math.ceil(args.max_epochs * num_update_steps_per_epoch)

View File

@ -187,6 +187,8 @@ def train(args):
collate_fn=eval_data_collator,
distributed_sampler_cls=StatefulDistributedSampler,
)
else:
logger.warning("No evaluation dataset is provided, skip evaluation")
coordinator.print_on_master(
f"Max CUDA memory after data loader: {torch.cuda.max_memory_allocated() / 1024 ** 2:.2f} MB"