fix orpo cross entropy loss

pull/5901/head
YeAnbang 2024-07-15 02:12:05 +00:00
parent 115c4cc5a4
commit b3594d4d68
5 changed files with 19 additions and 15 deletions

View File

@ -10,7 +10,6 @@ from coati.models.loss import OddsRatioLoss
from coati.models.utils import calc_masked_log_probs from coati.models.utils import calc_masked_log_probs
from coati.trainer.utils import all_reduce_mean from coati.trainer.utils import all_reduce_mean
from coati.utils import AccumulativeMeanMeter, save_checkpoint from coati.utils import AccumulativeMeanMeter, save_checkpoint
from torch.nn import CrossEntropyLoss
from torch.optim import Optimizer from torch.optim import Optimizer
from torch.optim.lr_scheduler import _LRScheduler from torch.optim.lr_scheduler import _LRScheduler
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
@ -63,7 +62,6 @@ class ORPOTrainer(SLTrainer):
self.actor_scheduler = actor_lr_scheduler self.actor_scheduler = actor_lr_scheduler
self.tokenizer = tokenizer self.tokenizer = tokenizer
self.odds_ratio_loss_fn = OddsRatioLoss() self.odds_ratio_loss_fn = OddsRatioLoss()
self.sft_loss_fn = CrossEntropyLoss()
self.save_interval = save_interval self.save_interval = save_interval
self.coordinator = coordinator self.coordinator = coordinator
self.save_dir = save_dir self.save_dir = save_dir
@ -136,6 +134,9 @@ class ORPOTrainer(SLTrainer):
actor_out = self.model( actor_out = self.model(
input_ids=torch.cat([chosen_input_ids, reject_input_ids]), input_ids=torch.cat([chosen_input_ids, reject_input_ids]),
attention_mask=torch.cat([chosen_attention_mask, reject_attention_mask]), attention_mask=torch.cat([chosen_attention_mask, reject_attention_mask]),
labels=torch.cat(
[chosen_input_ids, torch.ones_like(reject_input_ids, dtype=reject_input_ids.dtype) * -100]
),
) )
torch.autograd.set_detect_anomaly(True) torch.autograd.set_detect_anomaly(True)
actor_all_logits = actor_out["logits"].to(torch.float32) actor_all_logits = actor_out["logits"].to(torch.float32)
@ -144,13 +145,8 @@ class ORPOTrainer(SLTrainer):
logprob_actor_chosen = calc_masked_log_probs(actor_chosen_logits, chosen_input_ids, chosen_loss_mask[:, 1:]) logprob_actor_chosen = calc_masked_log_probs(actor_chosen_logits, chosen_input_ids, chosen_loss_mask[:, 1:])
logprob_actor_reject = calc_masked_log_probs(actor_reject_logits, reject_input_ids, reject_loss_mask[:, 1:]) logprob_actor_reject = calc_masked_log_probs(actor_reject_logits, reject_input_ids, reject_loss_mask[:, 1:])
chosen_logits = actor_chosen_logits[:, :-1, :].contiguous().view(-1, actor_chosen_logits.size(-1))
label_chosen = chosen_input_ids[:, 1:].contiguous()
label_chosen_masked = (
label_chosen.masked_fill(chosen_loss_mask[:, 1:] == 0, -100).view(-1).contiguous().detach()
)
# label_chosen[chosen_loss_mask[:, 1:] == 0] = -100 # label_chosen[chosen_loss_mask[:, 1:] == 0] = -100
chosen_nll = self.sft_loss_fn(chosen_logits, label_chosen_masked).to(dtype=torch.bfloat16) chosen_nll = actor_out["loss"]
odds_ratio_loss, log_odds_ratio = self.odds_ratio_loss_fn( odds_ratio_loss, log_odds_ratio = self.odds_ratio_loss_fn(
logprob_actor_chosen, logprob_actor_reject, chosen_loss_mask[:, 1:], reject_loss_mask[:, 1:] logprob_actor_chosen, logprob_actor_reject, chosen_loss_mask[:, 1:], reject_loss_mask[:, 1:]
) )
@ -271,6 +267,9 @@ class ORPOTrainer(SLTrainer):
actor_out = self.model( actor_out = self.model(
input_ids=torch.cat([chosen_input_ids, reject_input_ids]), input_ids=torch.cat([chosen_input_ids, reject_input_ids]),
attention_mask=torch.cat([chosen_attention_mask, reject_attention_mask]), attention_mask=torch.cat([chosen_attention_mask, reject_attention_mask]),
labels=torch.cat(
[chosen_input_ids, torch.ones_like(reject_input_ids, dtype=reject_input_ids.dtype) * -100]
),
) )
torch.autograd.set_detect_anomaly(True) torch.autograd.set_detect_anomaly(True)
actor_all_logits = actor_out["logits"].to(torch.float32) actor_all_logits = actor_out["logits"].to(torch.float32)
@ -283,13 +282,7 @@ class ORPOTrainer(SLTrainer):
logprob_actor_reject = calc_masked_log_probs( logprob_actor_reject = calc_masked_log_probs(
actor_reject_logits, reject_input_ids, reject_loss_mask[:, 1:] actor_reject_logits, reject_input_ids, reject_loss_mask[:, 1:]
) )
chosen_logits = actor_chosen_logits[:, :-1, :].contiguous().view(-1, actor_chosen_logits.size(-1)) chosen_nll = actor_out["loss"]
label_chosen = chosen_input_ids[:, 1:].contiguous()
label_chosen_masked = (
label_chosen.masked_fill(chosen_loss_mask[:, 1:] == 0, -100).view(-1).contiguous().detach()
)
# label_chosen[chosen_loss_mask[:, 1:] == 0] = -100
chosen_nll = self.sft_loss_fn(chosen_logits, label_chosen_masked).to(dtype=torch.bfloat16)
odds_ratio_loss, log_odds_ratio = self.odds_ratio_loss_fn( odds_ratio_loss, log_odds_ratio = self.odds_ratio_loss_fn(
logprob_actor_chosen, logprob_actor_reject, chosen_loss_mask[:, 1:], reject_loss_mask[:, 1:] logprob_actor_chosen, logprob_actor_reject, chosen_loss_mask[:, 1:], reject_loss_mask[:, 1:]
) )

View File

@ -189,6 +189,8 @@ def train(args):
collate_fn=eval_data_collator, collate_fn=eval_data_collator,
distributed_sampler_cls=StatefulDistributedSampler, distributed_sampler_cls=StatefulDistributedSampler,
) )
else:
logger.warning("No evaluation dataset is provided, skip evaluation")
num_update_steps_per_epoch = len(train_dataloader) // args.accumulation_steps num_update_steps_per_epoch = len(train_dataloader) // args.accumulation_steps
if args.warmup_steps is None: if args.warmup_steps is None:

View File

@ -176,6 +176,8 @@ def train(args):
collate_fn=eval_data_collator, collate_fn=eval_data_collator,
distributed_sampler_cls=StatefulDistributedSampler, distributed_sampler_cls=StatefulDistributedSampler,
) )
else:
logger.warning("No evaluation dataset is provided, skip evaluation")
num_update_steps_per_epoch = len(train_dataloader) // args.accumulation_steps num_update_steps_per_epoch = len(train_dataloader) // args.accumulation_steps
if args.warmup_steps is None: if args.warmup_steps is None:

View File

@ -16,10 +16,13 @@ import colossalai
from colossalai.booster import Booster from colossalai.booster import Booster
from colossalai.booster.plugin import GeminiPlugin, HybridParallelPlugin, LowLevelZeroPlugin, TorchDDPPlugin from colossalai.booster.plugin import GeminiPlugin, HybridParallelPlugin, LowLevelZeroPlugin, TorchDDPPlugin
from colossalai.cluster import DistCoordinator from colossalai.cluster import DistCoordinator
from colossalai.logging import get_dist_logger
from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR
from colossalai.nn.optimizer import HybridAdam from colossalai.nn.optimizer import HybridAdam
from colossalai.shardformer.policies.auto_policy import get_autopolicy from colossalai.shardformer.policies.auto_policy import get_autopolicy
logger = get_dist_logger()
def train(args): def train(args):
# check lora compatibility # check lora compatibility
@ -186,6 +189,8 @@ def train(args):
collate_fn=eval_data_collator, collate_fn=eval_data_collator,
distributed_sampler_cls=StatefulDistributedSampler, distributed_sampler_cls=StatefulDistributedSampler,
) )
else:
logger.warning("No evaluation dataset is provided, skip evaluation")
num_update_steps_per_epoch = len(train_dataloader) // args.accumulation_steps num_update_steps_per_epoch = len(train_dataloader) // args.accumulation_steps
math.ceil(args.max_epochs * num_update_steps_per_epoch) math.ceil(args.max_epochs * num_update_steps_per_epoch)

View File

@ -187,6 +187,8 @@ def train(args):
collate_fn=eval_data_collator, collate_fn=eval_data_collator,
distributed_sampler_cls=StatefulDistributedSampler, distributed_sampler_cls=StatefulDistributedSampler,
) )
else:
logger.warning("No evaluation dataset is provided, skip evaluation")
coordinator.print_on_master( coordinator.print_on_master(
f"Max CUDA memory after data loader: {torch.cuda.max_memory_allocated() / 1024 ** 2:.2f} MB" f"Max CUDA memory after data loader: {torch.cuda.max_memory_allocated() / 1024 ** 2:.2f} MB"