ColossalAI/colossalai/nn/loss/loss_3d.py

140 lines
5.6 KiB
Python

import torch
import torch.distributed as dist
from colossalai.constants import INPUT_GROUP_3D, WEIGHT_GROUP_3D, OUTPUT_GROUP_3D
from colossalai.core import global_context as gpc
from colossalai.nn.layer.parallel_3d import reduce_by_batch_3d, split_tensor_3d
from colossalai.nn.layer.parallel_3d._utils import get_parallel_mode_from_env
from colossalai.registry import LOSSES
from colossalai.utils import get_current_device
from torch.cuda.amp import custom_bwd, custom_fwd
from torch.nn.functional import cross_entropy
from torch.nn.modules.loss import _Loss
@LOSSES.register_module
class CrossEntropyLoss3D(_Loss):
"""
Cross entropy loss for 3D parallelism
:param reduction: whether to average the loss, defaults to True
:param args: Args for loss function
:param kwargs: Kwargs for loss function
:type reduction: bool, optional
"""
def __init__(self, reduction=True, *args, **kwargs):
super().__init__()
self.input_parallel_mode = get_parallel_mode_from_env(INPUT_GROUP_3D)
self.weight_parallel_mode = get_parallel_mode_from_env(WEIGHT_GROUP_3D)
self.reduction_mean = reduction
self.loss_args = args
self.loss_kwargs = kwargs
def forward(self, logits, targets):
"""Calculate loss between logits and targets
:param logits: Output logits of model
:param targets: True targets from data
"""
targets = split_tensor_3d(targets, 0, self.weight_parallel_mode)
targets = split_tensor_3d(targets, 0, self.input_parallel_mode)
loss = cross_entropy(logits, targets, reduction='none', *self.loss_args, **self.loss_kwargs)
if self.reduction_mean:
loss = loss.mean()
loss = reduce_by_batch_3d(loss, self.input_parallel_mode, self.weight_parallel_mode, True)
return loss
class _VocabParallelCrossEntropy3D(torch.autograd.Function):
# Adapted from megatron.mpu.cross_entropy
# loss[i] = -logits[i][targets] + log(sum(exp(logits[i])))
@staticmethod
@custom_fwd(cast_inputs=torch.float32)
def forward(ctx, logits, targets, output_parallel_mode):
# logits: [b/q^2, c/q]
# labels: [b/q^2]
# loss: [b/q^2]
logits_max = torch.max(logits, dim=-1)[0]
dist.all_reduce(logits_max, op=torch.distributed.ReduceOp.MAX, group=gpc.get_group(output_parallel_mode))
# Subtract the maximum value.
logits = logits - logits_max.unsqueeze(dim=-1)
vocab_size_per_partition = logits.size()[-1]
rank = gpc.get_local_rank(output_parallel_mode)
vocab_start = rank * vocab_size_per_partition
vocab_end = (rank + 1) * vocab_size_per_partition - 1
# loss[i] = 0 if targets[i] < vocab_start or targets[i] > vocab_end
target_mask = (targets < vocab_start) | (targets > vocab_end)
masked_target = targets.clone() - vocab_start
masked_target[target_mask] = 0
arange_1d = torch.arange(start=0, end=logits.size()[0], device=get_current_device())
predicted_logits = logits[arange_1d, masked_target]
predicted_logits = predicted_logits.clone().contiguous().view_as(targets)
predicted_logits[target_mask] = 0.
dist.all_reduce(predicted_logits, group=gpc.get_group(output_parallel_mode))
# Loss = log(sum(exp(logits))) - predicted-logit.
exp_logits = torch.exp(logits)
sum_exp_logits = exp_logits.sum(dim=-1)
dist.all_reduce(sum_exp_logits, group=gpc.get_group(output_parallel_mode))
loss = torch.log(sum_exp_logits) - predicted_logits
exp_logits.div_(sum_exp_logits.unsqueeze(dim=-1))
ctx.save_for_backward(exp_logits, target_mask, masked_target)
return loss
@staticmethod
@custom_bwd
def backward(ctx, output_grad):
# Retreive tensors from the forward path.
softmax, target_mask, masked_target = ctx.saved_tensors
# All the inputs have softmax as thier gradient.
input_grad = softmax
# For simplicity, work with the 2D gradient.
partition_vocab_size = softmax.size()[-1]
grad_2d = input_grad.view(-1, partition_vocab_size)
# Add the gradient from matching classes.
arange_1d = torch.arange(start=0, end=grad_2d.size()[0], device=get_current_device())
grad_2d[arange_1d, masked_target] -= (1.0 - target_mask.view(-1).float())
input_grad.mul_(output_grad.unsqueeze(dim=-1))
return input_grad, None, None, None
@LOSSES.register_module
class VocabParallelCrossEntropyLoss3D(_Loss):
"""
Vocab parallel cross entropy loss for 2D parallelism
:param reduction: whether to average the loss, defaults to True
:type reduction: bool, optional
"""
def __init__(self, reduction=True):
super().__init__()
self.input_parallel_mode = get_parallel_mode_from_env(INPUT_GROUP_3D)
self.weight_parallel_mode = get_parallel_mode_from_env(WEIGHT_GROUP_3D)
self.output_parallel_mode = get_parallel_mode_from_env(OUTPUT_GROUP_3D)
self.reduction_mean = reduction
def forward(self, logits, targets):
"""Calculate loss between logits and targets
:param logits: Output logits of model
:param targets: True targets from data
"""
targets = split_tensor_3d(targets, 0, self.weight_parallel_mode)
targets = split_tensor_3d(targets, 0, self.input_parallel_mode)
loss = _VocabParallelCrossEntropy3D.apply(logits, targets, self.output_parallel_mode)
if self.reduction_mean:
loss = loss.mean()
loss = reduce_by_batch_3d(loss, self.input_parallel_mode, self.weight_parallel_mode, True)
return loss