2022-11-18 06:04:49 +00:00
|
|
|
import torch
|
|
|
|
|
2023-09-19 06:20:26 +00:00
|
|
|
__all__ = ["LossForPretraining"]
|
2022-11-18 06:04:49 +00:00
|
|
|
|
|
|
|
|
|
|
|
class LossForPretraining(torch.nn.Module):
|
|
|
|
def __init__(self, vocab_size):
|
|
|
|
super(LossForPretraining, self).__init__()
|
|
|
|
self.loss_fn = torch.nn.CrossEntropyLoss(ignore_index=-1)
|
|
|
|
self.vocab_size = vocab_size
|
|
|
|
|
|
|
|
def forward(self, prediction_scores, masked_lm_labels, next_sentence_labels=None):
|
|
|
|
masked_lm_loss = self.loss_fn(prediction_scores.view(-1, self.vocab_size), masked_lm_labels.view(-1))
|
|
|
|
# next_sentence_loss = self.loss_fn(seq_relationship_score.view(-1, 2), next_sentence_labels.view(-1))
|
2023-09-19 06:20:26 +00:00
|
|
|
total_loss = masked_lm_loss # + next_sentence_loss
|
2022-11-18 06:04:49 +00:00
|
|
|
return total_loss
|