mirror of https://github.com/hpcaitech/ColossalAI
fix orpo cross entropy loss
parent
115c4cc5a4
commit
b3594d4d68
|
@ -10,7 +10,6 @@ from coati.models.loss import OddsRatioLoss
|
|||
from coati.models.utils import calc_masked_log_probs
|
||||
from coati.trainer.utils import all_reduce_mean
|
||||
from coati.utils import AccumulativeMeanMeter, save_checkpoint
|
||||
from torch.nn import CrossEntropyLoss
|
||||
from torch.optim import Optimizer
|
||||
from torch.optim.lr_scheduler import _LRScheduler
|
||||
from torch.utils.data import DataLoader
|
||||
|
@ -63,7 +62,6 @@ class ORPOTrainer(SLTrainer):
|
|||
self.actor_scheduler = actor_lr_scheduler
|
||||
self.tokenizer = tokenizer
|
||||
self.odds_ratio_loss_fn = OddsRatioLoss()
|
||||
self.sft_loss_fn = CrossEntropyLoss()
|
||||
self.save_interval = save_interval
|
||||
self.coordinator = coordinator
|
||||
self.save_dir = save_dir
|
||||
|
@ -136,6 +134,9 @@ class ORPOTrainer(SLTrainer):
|
|||
actor_out = self.model(
|
||||
input_ids=torch.cat([chosen_input_ids, reject_input_ids]),
|
||||
attention_mask=torch.cat([chosen_attention_mask, reject_attention_mask]),
|
||||
labels=torch.cat(
|
||||
[chosen_input_ids, torch.ones_like(reject_input_ids, dtype=reject_input_ids.dtype) * -100]
|
||||
),
|
||||
)
|
||||
torch.autograd.set_detect_anomaly(True)
|
||||
actor_all_logits = actor_out["logits"].to(torch.float32)
|
||||
|
@ -144,13 +145,8 @@ class ORPOTrainer(SLTrainer):
|
|||
logprob_actor_chosen = calc_masked_log_probs(actor_chosen_logits, chosen_input_ids, chosen_loss_mask[:, 1:])
|
||||
|
||||
logprob_actor_reject = calc_masked_log_probs(actor_reject_logits, reject_input_ids, reject_loss_mask[:, 1:])
|
||||
chosen_logits = actor_chosen_logits[:, :-1, :].contiguous().view(-1, actor_chosen_logits.size(-1))
|
||||
label_chosen = chosen_input_ids[:, 1:].contiguous()
|
||||
label_chosen_masked = (
|
||||
label_chosen.masked_fill(chosen_loss_mask[:, 1:] == 0, -100).view(-1).contiguous().detach()
|
||||
)
|
||||
# label_chosen[chosen_loss_mask[:, 1:] == 0] = -100
|
||||
chosen_nll = self.sft_loss_fn(chosen_logits, label_chosen_masked).to(dtype=torch.bfloat16)
|
||||
chosen_nll = actor_out["loss"]
|
||||
odds_ratio_loss, log_odds_ratio = self.odds_ratio_loss_fn(
|
||||
logprob_actor_chosen, logprob_actor_reject, chosen_loss_mask[:, 1:], reject_loss_mask[:, 1:]
|
||||
)
|
||||
|
@ -271,6 +267,9 @@ class ORPOTrainer(SLTrainer):
|
|||
actor_out = self.model(
|
||||
input_ids=torch.cat([chosen_input_ids, reject_input_ids]),
|
||||
attention_mask=torch.cat([chosen_attention_mask, reject_attention_mask]),
|
||||
labels=torch.cat(
|
||||
[chosen_input_ids, torch.ones_like(reject_input_ids, dtype=reject_input_ids.dtype) * -100]
|
||||
),
|
||||
)
|
||||
torch.autograd.set_detect_anomaly(True)
|
||||
actor_all_logits = actor_out["logits"].to(torch.float32)
|
||||
|
@ -283,13 +282,7 @@ class ORPOTrainer(SLTrainer):
|
|||
logprob_actor_reject = calc_masked_log_probs(
|
||||
actor_reject_logits, reject_input_ids, reject_loss_mask[:, 1:]
|
||||
)
|
||||
chosen_logits = actor_chosen_logits[:, :-1, :].contiguous().view(-1, actor_chosen_logits.size(-1))
|
||||
label_chosen = chosen_input_ids[:, 1:].contiguous()
|
||||
label_chosen_masked = (
|
||||
label_chosen.masked_fill(chosen_loss_mask[:, 1:] == 0, -100).view(-1).contiguous().detach()
|
||||
)
|
||||
# label_chosen[chosen_loss_mask[:, 1:] == 0] = -100
|
||||
chosen_nll = self.sft_loss_fn(chosen_logits, label_chosen_masked).to(dtype=torch.bfloat16)
|
||||
chosen_nll = actor_out["loss"]
|
||||
odds_ratio_loss, log_odds_ratio = self.odds_ratio_loss_fn(
|
||||
logprob_actor_chosen, logprob_actor_reject, chosen_loss_mask[:, 1:], reject_loss_mask[:, 1:]
|
||||
)
|
||||
|
|
|
@ -189,6 +189,8 @@ def train(args):
|
|||
collate_fn=eval_data_collator,
|
||||
distributed_sampler_cls=StatefulDistributedSampler,
|
||||
)
|
||||
else:
|
||||
logger.warning("No evaluation dataset is provided, skip evaluation")
|
||||
|
||||
num_update_steps_per_epoch = len(train_dataloader) // args.accumulation_steps
|
||||
if args.warmup_steps is None:
|
||||
|
|
|
@ -176,6 +176,8 @@ def train(args):
|
|||
collate_fn=eval_data_collator,
|
||||
distributed_sampler_cls=StatefulDistributedSampler,
|
||||
)
|
||||
else:
|
||||
logger.warning("No evaluation dataset is provided, skip evaluation")
|
||||
|
||||
num_update_steps_per_epoch = len(train_dataloader) // args.accumulation_steps
|
||||
if args.warmup_steps is None:
|
||||
|
|
|
@ -16,10 +16,13 @@ import colossalai
|
|||
from colossalai.booster import Booster
|
||||
from colossalai.booster.plugin import GeminiPlugin, HybridParallelPlugin, LowLevelZeroPlugin, TorchDDPPlugin
|
||||
from colossalai.cluster import DistCoordinator
|
||||
from colossalai.logging import get_dist_logger
|
||||
from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR
|
||||
from colossalai.nn.optimizer import HybridAdam
|
||||
from colossalai.shardformer.policies.auto_policy import get_autopolicy
|
||||
|
||||
logger = get_dist_logger()
|
||||
|
||||
|
||||
def train(args):
|
||||
# check lora compatibility
|
||||
|
@ -186,6 +189,8 @@ def train(args):
|
|||
collate_fn=eval_data_collator,
|
||||
distributed_sampler_cls=StatefulDistributedSampler,
|
||||
)
|
||||
else:
|
||||
logger.warning("No evaluation dataset is provided, skip evaluation")
|
||||
|
||||
num_update_steps_per_epoch = len(train_dataloader) // args.accumulation_steps
|
||||
math.ceil(args.max_epochs * num_update_steps_per_epoch)
|
||||
|
|
|
@ -187,6 +187,8 @@ def train(args):
|
|||
collate_fn=eval_data_collator,
|
||||
distributed_sampler_cls=StatefulDistributedSampler,
|
||||
)
|
||||
else:
|
||||
logger.warning("No evaluation dataset is provided, skip evaluation")
|
||||
|
||||
coordinator.print_on_master(
|
||||
f"Max CUDA memory after data loader: {torch.cuda.max_memory_allocated() / 1024 ** 2:.2f} MB"
|
||||
|
|
Loading…
Reference in New Issue