2023-06-01 08:21:02 +00:00
|
|
|
import torch
|
|
|
|
import torch.distributed as dist
|
|
|
|
from torch.autograd import Function
|
2023-06-16 07:00:26 +00:00
|
|
|
from torch.distributed import ProcessGroup
|
2023-06-01 08:21:02 +00:00
|
|
|
|
2023-09-19 06:20:26 +00:00
|
|
|
__all__ = ["DistCrossEntropy", "cross_entropy_1d"]
|
2023-06-21 06:30:06 +00:00
|
|
|
|
2023-06-01 08:21:02 +00:00
|
|
|
|
|
|
|
class DistCrossEntropy(Function):
|
|
|
|
r"""
|
|
|
|
Overwrite the forward and backward function to calculate the cross entropy loss before gather
|
|
|
|
|
|
|
|
Args:
|
|
|
|
Function (:class:`torch.autograd.Function`): default
|
|
|
|
"""
|
|
|
|
|
|
|
|
@staticmethod
|
2023-06-16 07:00:26 +00:00
|
|
|
def forward(ctx, vocab_logits: torch.Tensor, target: torch.Tensor, ignore_index: int, process_group: ProcessGroup):
|
2023-06-01 08:21:02 +00:00
|
|
|
r"""
|
|
|
|
Calculate the cross entropy loss before gather, the origin loss function is as follows:
|
|
|
|
loss = -log(exp(x[class])/sum(exp(x[i]))
|
|
|
|
and can be rewrite as:
|
|
|
|
loss = log(sum(exp(x[i])) - x[class]
|
|
|
|
|
2023-06-13 06:44:40 +00:00
|
|
|
To avoid the `nan` of log(sum(exp(x[i]))), we minus the max of x[i]
|
2023-06-01 08:21:02 +00:00
|
|
|
|
|
|
|
Args:
|
|
|
|
vocab_logits (:class:`torch.Tensor`): The logits of the vocabulary, shape is
|
|
|
|
[batch_size, seq_len, vocab_size]
|
|
|
|
labels (:class:`torch.Tensor`): The labels of the vocabulary, shape is
|
|
|
|
[batch_size, seq_len]
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
:class:`torch.Tensor`: The cross entropy loss
|
|
|
|
"""
|
|
|
|
# get the max
|
|
|
|
logits_max = torch.max(vocab_logits, dim=-1)[0]
|
2023-06-16 07:00:26 +00:00
|
|
|
dist.all_reduce(logits_max, op=dist.ReduceOp.MAX, group=process_group)
|
2023-06-01 08:21:02 +00:00
|
|
|
|
|
|
|
# minus the max to avoid the result of sum of exp is too large and the log is nan
|
|
|
|
vocab_logits = vocab_logits - logits_max.unsqueeze(dim=-1)
|
|
|
|
|
|
|
|
# mask the target in the local device
|
|
|
|
partition_vocab_size = vocab_logits.size()[-1]
|
2023-06-16 07:00:26 +00:00
|
|
|
rank = dist.get_rank(group=process_group)
|
|
|
|
world_size = dist.get_world_size(group=process_group)
|
2023-06-01 08:21:02 +00:00
|
|
|
global_vocab_size = partition_vocab_size * world_size
|
|
|
|
|
|
|
|
# [down, up) => false, other device and -100 => true
|
|
|
|
delta = (global_vocab_size + world_size - 1) // world_size
|
2023-07-04 09:53:39 +00:00
|
|
|
down_threshold = rank * delta
|
|
|
|
up_threshold = down_threshold + delta
|
|
|
|
mask = (target < down_threshold) | (target >= up_threshold)
|
|
|
|
masked_target = target.clone() - down_threshold
|
2023-06-01 08:21:02 +00:00
|
|
|
masked_target[mask] = 0
|
|
|
|
|
2023-07-04 09:53:39 +00:00
|
|
|
# reshape the logits and target
|
2023-06-01 08:21:02 +00:00
|
|
|
# reshape the vocab_logits to [bath_size * seq_len, vocab_size]
|
|
|
|
# reshape the labels to [bath_size * seq_len]
|
|
|
|
logits_2d = vocab_logits.view(-1, partition_vocab_size)
|
|
|
|
masked_target_1d = masked_target.view(-1)
|
|
|
|
|
|
|
|
# extract the x[class] and set the x[other device] to zero
|
2023-09-19 06:20:26 +00:00
|
|
|
pred_logits_1d = logits_2d[
|
|
|
|
torch.arange(start=0, end=logits_2d.shape[0], device=logits_2d.device), masked_target_1d
|
|
|
|
]
|
2023-06-01 08:21:02 +00:00
|
|
|
pred_logits_1d = pred_logits_1d.clone().contiguous()
|
|
|
|
pred_logits = pred_logits_1d.view_as(target)
|
|
|
|
pred_logits[mask] = 0.0
|
|
|
|
|
|
|
|
# allreduce the get all x(i,y)
|
2023-06-16 07:00:26 +00:00
|
|
|
dist.all_reduce(pred_logits, op=dist.ReduceOp.SUM, group=process_group)
|
2023-06-01 08:21:02 +00:00
|
|
|
exp_logits = vocab_logits
|
|
|
|
torch.exp(vocab_logits, out=exp_logits)
|
|
|
|
sum_exp_logits = torch.sum(exp_logits, dim=-1)
|
2023-06-16 07:00:26 +00:00
|
|
|
dist.all_reduce(sum_exp_logits, op=dist.ReduceOp.SUM, group=process_group)
|
2023-06-01 08:21:02 +00:00
|
|
|
|
|
|
|
# calculate the loss
|
|
|
|
# loss = log(sum(exp(x[i]))) - x[class]
|
2023-06-09 06:36:54 +00:00
|
|
|
loss = torch.where(target == ignore_index, 0.0, torch.log(sum_exp_logits) - pred_logits)
|
|
|
|
loss = torch.sum(loss).div_(torch.sum(loss != 0.0))
|
2023-06-01 08:21:02 +00:00
|
|
|
|
2023-07-04 09:53:39 +00:00
|
|
|
# calculate the softmax
|
2023-06-01 08:21:02 +00:00
|
|
|
exp_logits.div_(sum_exp_logits.unsqueeze(dim=-1))
|
|
|
|
ctx.save_for_backward(exp_logits, mask, masked_target_1d)
|
|
|
|
|
|
|
|
return loss
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
def backward(ctx, grad_output):
|
|
|
|
# retrieve the saved tensors
|
|
|
|
exp_logits, mask, masked_target_1d = ctx.saved_tensors
|
|
|
|
|
|
|
|
# use exp logits as the input grad
|
|
|
|
grad_logits = exp_logits
|
|
|
|
partion_vocab_size = grad_logits.shape[-1]
|
|
|
|
grad_logits_2d = grad_logits.view(-1, partion_vocab_size)
|
|
|
|
|
|
|
|
update = 1.0 - mask.view(-1).float()
|
|
|
|
grad_logits_2d[torch.arange(0, grad_logits_2d.shape[0]), masked_target_1d] -= update
|
|
|
|
|
|
|
|
grad_logits.mul_(grad_output.unsqueeze(dim=-1))
|
|
|
|
return grad_logits, None, None
|
|
|
|
|
|
|
|
|
2023-09-19 06:20:26 +00:00
|
|
|
def cross_entropy_1d(
|
|
|
|
vocab_logits: torch.Tensor, labels: torch.Tensor, ignore_index: int = -100, process_group: ProcessGroup = None
|
|
|
|
) -> torch.Tensor:
|
2023-06-16 07:00:26 +00:00
|
|
|
return DistCrossEntropy.apply(vocab_logits, labels, ignore_index, process_group)
|