mirror of https://github.com/hpcaitech/ColossalAI
add loss
parent
9995119c28
commit
797a81a8e2
|
@ -280,3 +280,24 @@ class KTOLoss(nn.Module):
|
|||
losses = torch.cat((self.desirable_weight * chosen_losses, self.undesirable_weight * rejected_losses), 0).mean()
|
||||
|
||||
return losses, chosen_rewards, rejected_rewards, kl
|
||||
|
||||
|
||||
class PRMLoss(nn.Module):
|
||||
def __init__(self, reward_signal_id: Optional[list[int]] = None):
|
||||
super().__init__()
|
||||
self.IGNORE_INDEX = -100
|
||||
self.loss = nn.CrossEntropyLoss(ignore_index=self.IGNORE_INDEX)
|
||||
self.reward_signal_id = reward_signal_id
|
||||
|
||||
def forward(self, labels: torch.Tensor, logits: torch.Tensor):
|
||||
loss_mask = torch.isin(labels, torch.tensor(self.reward_signal_id).to(labels.device))
|
||||
|
||||
logits = logits[loss_mask]
|
||||
labels = labels[loss_mask]
|
||||
logits = logits[..., self.reward_signal_id]
|
||||
|
||||
label_mapping = {token: i for i, token in enumerate(self.reward_signal_id)}
|
||||
labels = torch.tensor([label_mapping.get(label.item(), label.item()) for label in labels], device=labels.device)
|
||||
loss = self.loss(logits, labels)
|
||||
|
||||
return loss
|
||||
|
|
Loading…
Reference in New Issue