mirror of https://github.com/hpcaitech/ColossalAI
fix orpo cross entropy loss
parent
115c4cc5a4
commit
b3594d4d68
|
@ -10,7 +10,6 @@ from coati.models.loss import OddsRatioLoss
|
||||||
from coati.models.utils import calc_masked_log_probs
|
from coati.models.utils import calc_masked_log_probs
|
||||||
from coati.trainer.utils import all_reduce_mean
|
from coati.trainer.utils import all_reduce_mean
|
||||||
from coati.utils import AccumulativeMeanMeter, save_checkpoint
|
from coati.utils import AccumulativeMeanMeter, save_checkpoint
|
||||||
from torch.nn import CrossEntropyLoss
|
|
||||||
from torch.optim import Optimizer
|
from torch.optim import Optimizer
|
||||||
from torch.optim.lr_scheduler import _LRScheduler
|
from torch.optim.lr_scheduler import _LRScheduler
|
||||||
from torch.utils.data import DataLoader
|
from torch.utils.data import DataLoader
|
||||||
|
@ -63,7 +62,6 @@ class ORPOTrainer(SLTrainer):
|
||||||
self.actor_scheduler = actor_lr_scheduler
|
self.actor_scheduler = actor_lr_scheduler
|
||||||
self.tokenizer = tokenizer
|
self.tokenizer = tokenizer
|
||||||
self.odds_ratio_loss_fn = OddsRatioLoss()
|
self.odds_ratio_loss_fn = OddsRatioLoss()
|
||||||
self.sft_loss_fn = CrossEntropyLoss()
|
|
||||||
self.save_interval = save_interval
|
self.save_interval = save_interval
|
||||||
self.coordinator = coordinator
|
self.coordinator = coordinator
|
||||||
self.save_dir = save_dir
|
self.save_dir = save_dir
|
||||||
|
@ -136,6 +134,9 @@ class ORPOTrainer(SLTrainer):
|
||||||
actor_out = self.model(
|
actor_out = self.model(
|
||||||
input_ids=torch.cat([chosen_input_ids, reject_input_ids]),
|
input_ids=torch.cat([chosen_input_ids, reject_input_ids]),
|
||||||
attention_mask=torch.cat([chosen_attention_mask, reject_attention_mask]),
|
attention_mask=torch.cat([chosen_attention_mask, reject_attention_mask]),
|
||||||
|
labels=torch.cat(
|
||||||
|
[chosen_input_ids, torch.ones_like(reject_input_ids, dtype=reject_input_ids.dtype) * -100]
|
||||||
|
),
|
||||||
)
|
)
|
||||||
torch.autograd.set_detect_anomaly(True)
|
torch.autograd.set_detect_anomaly(True)
|
||||||
actor_all_logits = actor_out["logits"].to(torch.float32)
|
actor_all_logits = actor_out["logits"].to(torch.float32)
|
||||||
|
@ -144,13 +145,8 @@ class ORPOTrainer(SLTrainer):
|
||||||
logprob_actor_chosen = calc_masked_log_probs(actor_chosen_logits, chosen_input_ids, chosen_loss_mask[:, 1:])
|
logprob_actor_chosen = calc_masked_log_probs(actor_chosen_logits, chosen_input_ids, chosen_loss_mask[:, 1:])
|
||||||
|
|
||||||
logprob_actor_reject = calc_masked_log_probs(actor_reject_logits, reject_input_ids, reject_loss_mask[:, 1:])
|
logprob_actor_reject = calc_masked_log_probs(actor_reject_logits, reject_input_ids, reject_loss_mask[:, 1:])
|
||||||
chosen_logits = actor_chosen_logits[:, :-1, :].contiguous().view(-1, actor_chosen_logits.size(-1))
|
|
||||||
label_chosen = chosen_input_ids[:, 1:].contiguous()
|
|
||||||
label_chosen_masked = (
|
|
||||||
label_chosen.masked_fill(chosen_loss_mask[:, 1:] == 0, -100).view(-1).contiguous().detach()
|
|
||||||
)
|
|
||||||
# label_chosen[chosen_loss_mask[:, 1:] == 0] = -100
|
# label_chosen[chosen_loss_mask[:, 1:] == 0] = -100
|
||||||
chosen_nll = self.sft_loss_fn(chosen_logits, label_chosen_masked).to(dtype=torch.bfloat16)
|
chosen_nll = actor_out["loss"]
|
||||||
odds_ratio_loss, log_odds_ratio = self.odds_ratio_loss_fn(
|
odds_ratio_loss, log_odds_ratio = self.odds_ratio_loss_fn(
|
||||||
logprob_actor_chosen, logprob_actor_reject, chosen_loss_mask[:, 1:], reject_loss_mask[:, 1:]
|
logprob_actor_chosen, logprob_actor_reject, chosen_loss_mask[:, 1:], reject_loss_mask[:, 1:]
|
||||||
)
|
)
|
||||||
|
@ -271,6 +267,9 @@ class ORPOTrainer(SLTrainer):
|
||||||
actor_out = self.model(
|
actor_out = self.model(
|
||||||
input_ids=torch.cat([chosen_input_ids, reject_input_ids]),
|
input_ids=torch.cat([chosen_input_ids, reject_input_ids]),
|
||||||
attention_mask=torch.cat([chosen_attention_mask, reject_attention_mask]),
|
attention_mask=torch.cat([chosen_attention_mask, reject_attention_mask]),
|
||||||
|
labels=torch.cat(
|
||||||
|
[chosen_input_ids, torch.ones_like(reject_input_ids, dtype=reject_input_ids.dtype) * -100]
|
||||||
|
),
|
||||||
)
|
)
|
||||||
torch.autograd.set_detect_anomaly(True)
|
torch.autograd.set_detect_anomaly(True)
|
||||||
actor_all_logits = actor_out["logits"].to(torch.float32)
|
actor_all_logits = actor_out["logits"].to(torch.float32)
|
||||||
|
@ -283,13 +282,7 @@ class ORPOTrainer(SLTrainer):
|
||||||
logprob_actor_reject = calc_masked_log_probs(
|
logprob_actor_reject = calc_masked_log_probs(
|
||||||
actor_reject_logits, reject_input_ids, reject_loss_mask[:, 1:]
|
actor_reject_logits, reject_input_ids, reject_loss_mask[:, 1:]
|
||||||
)
|
)
|
||||||
chosen_logits = actor_chosen_logits[:, :-1, :].contiguous().view(-1, actor_chosen_logits.size(-1))
|
chosen_nll = actor_out["loss"]
|
||||||
label_chosen = chosen_input_ids[:, 1:].contiguous()
|
|
||||||
label_chosen_masked = (
|
|
||||||
label_chosen.masked_fill(chosen_loss_mask[:, 1:] == 0, -100).view(-1).contiguous().detach()
|
|
||||||
)
|
|
||||||
# label_chosen[chosen_loss_mask[:, 1:] == 0] = -100
|
|
||||||
chosen_nll = self.sft_loss_fn(chosen_logits, label_chosen_masked).to(dtype=torch.bfloat16)
|
|
||||||
odds_ratio_loss, log_odds_ratio = self.odds_ratio_loss_fn(
|
odds_ratio_loss, log_odds_ratio = self.odds_ratio_loss_fn(
|
||||||
logprob_actor_chosen, logprob_actor_reject, chosen_loss_mask[:, 1:], reject_loss_mask[:, 1:]
|
logprob_actor_chosen, logprob_actor_reject, chosen_loss_mask[:, 1:], reject_loss_mask[:, 1:]
|
||||||
)
|
)
|
||||||
|
|
|
@ -189,6 +189,8 @@ def train(args):
|
||||||
collate_fn=eval_data_collator,
|
collate_fn=eval_data_collator,
|
||||||
distributed_sampler_cls=StatefulDistributedSampler,
|
distributed_sampler_cls=StatefulDistributedSampler,
|
||||||
)
|
)
|
||||||
|
else:
|
||||||
|
logger.warning("No evaluation dataset is provided, skip evaluation")
|
||||||
|
|
||||||
num_update_steps_per_epoch = len(train_dataloader) // args.accumulation_steps
|
num_update_steps_per_epoch = len(train_dataloader) // args.accumulation_steps
|
||||||
if args.warmup_steps is None:
|
if args.warmup_steps is None:
|
||||||
|
|
|
@ -176,6 +176,8 @@ def train(args):
|
||||||
collate_fn=eval_data_collator,
|
collate_fn=eval_data_collator,
|
||||||
distributed_sampler_cls=StatefulDistributedSampler,
|
distributed_sampler_cls=StatefulDistributedSampler,
|
||||||
)
|
)
|
||||||
|
else:
|
||||||
|
logger.warning("No evaluation dataset is provided, skip evaluation")
|
||||||
|
|
||||||
num_update_steps_per_epoch = len(train_dataloader) // args.accumulation_steps
|
num_update_steps_per_epoch = len(train_dataloader) // args.accumulation_steps
|
||||||
if args.warmup_steps is None:
|
if args.warmup_steps is None:
|
||||||
|
|
|
@ -16,10 +16,13 @@ import colossalai
|
||||||
from colossalai.booster import Booster
|
from colossalai.booster import Booster
|
||||||
from colossalai.booster.plugin import GeminiPlugin, HybridParallelPlugin, LowLevelZeroPlugin, TorchDDPPlugin
|
from colossalai.booster.plugin import GeminiPlugin, HybridParallelPlugin, LowLevelZeroPlugin, TorchDDPPlugin
|
||||||
from colossalai.cluster import DistCoordinator
|
from colossalai.cluster import DistCoordinator
|
||||||
|
from colossalai.logging import get_dist_logger
|
||||||
from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR
|
from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR
|
||||||
from colossalai.nn.optimizer import HybridAdam
|
from colossalai.nn.optimizer import HybridAdam
|
||||||
from colossalai.shardformer.policies.auto_policy import get_autopolicy
|
from colossalai.shardformer.policies.auto_policy import get_autopolicy
|
||||||
|
|
||||||
|
logger = get_dist_logger()
|
||||||
|
|
||||||
|
|
||||||
def train(args):
|
def train(args):
|
||||||
# check lora compatibility
|
# check lora compatibility
|
||||||
|
@ -186,6 +189,8 @@ def train(args):
|
||||||
collate_fn=eval_data_collator,
|
collate_fn=eval_data_collator,
|
||||||
distributed_sampler_cls=StatefulDistributedSampler,
|
distributed_sampler_cls=StatefulDistributedSampler,
|
||||||
)
|
)
|
||||||
|
else:
|
||||||
|
logger.warning("No evaluation dataset is provided, skip evaluation")
|
||||||
|
|
||||||
num_update_steps_per_epoch = len(train_dataloader) // args.accumulation_steps
|
num_update_steps_per_epoch = len(train_dataloader) // args.accumulation_steps
|
||||||
math.ceil(args.max_epochs * num_update_steps_per_epoch)
|
math.ceil(args.max_epochs * num_update_steps_per_epoch)
|
||||||
|
|
|
@ -187,6 +187,8 @@ def train(args):
|
||||||
collate_fn=eval_data_collator,
|
collate_fn=eval_data_collator,
|
||||||
distributed_sampler_cls=StatefulDistributedSampler,
|
distributed_sampler_cls=StatefulDistributedSampler,
|
||||||
)
|
)
|
||||||
|
else:
|
||||||
|
logger.warning("No evaluation dataset is provided, skip evaluation")
|
||||||
|
|
||||||
coordinator.print_on_master(
|
coordinator.print_on_master(
|
||||||
f"Max CUDA memory after data loader: {torch.cuda.max_memory_allocated() / 1024 ** 2:.2f} MB"
|
f"Max CUDA memory after data loader: {torch.cuda.max_memory_allocated() / 1024 ** 2:.2f} MB"
|
||||||
|
|
Loading…
Reference in New Issue