mirror of https://github.com/InternLM/InternLM
55 lines
2.0 KiB
Python
55 lines
2.0 KiB
Python
#!/usr/bin/env python
|
|
# -*- encoding: utf-8 -*-
|
|
|
|
from flash_attn.losses.cross_entropy import CrossEntropyLoss as FlashCrossEntropyLoss
|
|
from torch import nn
|
|
|
|
from internlm.core.context import ParallelMode
|
|
from internlm.core.context import global_context as gpc
|
|
|
|
|
|
class FlashGPTLMLoss(nn.Module):
|
|
"""
|
|
Loss function for flash GPT Language Model.
|
|
"""
|
|
|
|
def __init__(self, parallel_output=True, label_smoothing=0):
|
|
super().__init__()
|
|
|
|
if label_smoothing is not None:
|
|
if label_smoothing != 0:
|
|
if gpc.is_rank_for_log():
|
|
print(f"use label_smoothing: {label_smoothing}")
|
|
else:
|
|
label_smoothing = 0
|
|
self.label_smoothing = label_smoothing
|
|
|
|
if parallel_output:
|
|
self.loss_fn = FlashCrossEntropyLoss(
|
|
reduction="mean",
|
|
inplace_backward=True,
|
|
process_group=gpc.get_group(ParallelMode.TENSOR),
|
|
label_smoothing=label_smoothing,
|
|
) # The loss in this place is bound to the gather_output initialized by VocabParallelClassifier1D
|
|
else:
|
|
# Here, the output will gather output is set in the model, so use ordinary loss
|
|
self.loss_fn = nn.CrossEntropyLoss(reduction="mean", label_smoothing=label_smoothing)
|
|
|
|
def forward(self, *args):
|
|
if len(args) == 3:
|
|
# residual is to match prenorm
|
|
logits, _, labels = args
|
|
elif len(args) == 2:
|
|
# When using postnorm
|
|
logits, labels = args
|
|
else:
|
|
raise RuntimeError(f"The number of criterion inputs are:{len(args)}")
|
|
shift_logits = logits.contiguous().view(-1, logits.size(-1))
|
|
shift_labels = labels.contiguous().view(-1)
|
|
loss = self.loss_fn(
|
|
shift_logits, shift_labels
|
|
) # There is no need to consider the ignore_index problem here, because the loss calculation will be
|
|
# calculated through the calculation range, and -100 must be outside this range, so there is no problem
|
|
|
|
return loss
|