pull/6119/head
Tong Li 2024-11-14 08:53:26 +00:00
parent 9995119c28
commit 797a81a8e2
1 changed files with 21 additions and 0 deletions

View File

@ -280,3 +280,24 @@ class KTOLoss(nn.Module):
losses = torch.cat((self.desirable_weight * chosen_losses, self.undesirable_weight * rejected_losses), 0).mean()
return losses, chosen_rewards, rejected_rewards, kl
class PRMLoss(nn.Module):
def __init__(self, reward_signal_id: Optional[list[int]] = None):
super().__init__()
self.IGNORE_INDEX = -100
self.loss = nn.CrossEntropyLoss(ignore_index=self.IGNORE_INDEX)
self.reward_signal_id = reward_signal_id
def forward(self, labels: torch.Tensor, logits: torch.Tensor):
loss_mask = torch.isin(labels, torch.tensor(self.reward_signal_id).to(labels.device))
logits = logits[loss_mask]
labels = labels[loss_mask]
logits = logits[..., self.reward_signal_id]
label_mapping = {token: i for i, token in enumerate(self.reward_signal_id)}
labels = torch.tensor([label_mapping.get(label.item(), label.item()) for label in labels], device=labels.device)
loss = self.loss(logits, labels)
return loss