2022-02-14 03:15:02 +00:00
|
|
|
from colossalai.global_variables import tensor_parallel_env as env
|
|
|
|
from colossalai.nn.layer.utils import get_tensor_parallel_mode
|
2021-12-27 07:04:32 +00:00
|
|
|
from torch import nn
|
|
|
|
from torch.nn.modules.loss import *
|
|
|
|
from torch.nn.modules.loss import _Loss
|
2021-10-28 16:21:23 +00:00
|
|
|
|
2022-02-14 03:15:02 +00:00
|
|
|
from .loss_1d import VocabParallelCrossEntropyLoss1D
|
|
|
|
from .loss_2d import CrossEntropyLoss2D, VocabParallelCrossEntropyLoss2D
|
|
|
|
from .loss_2p5d import CrossEntropyLoss2p5D, VocabParallelCrossEntropyLoss2p5D
|
|
|
|
from .loss_3d import CrossEntropyLoss3D, VocabParallelCrossEntropyLoss3D
|
2022-01-07 07:08:36 +00:00
|
|
|
from .loss_moe import MoeCrossEntropyLoss, MoeLoss
|
2021-12-27 07:04:32 +00:00
|
|
|
|
|
|
|
_parallel_cross_entropy = {
|
|
|
|
'2d': CrossEntropyLoss2D,
|
|
|
|
'2.5d': CrossEntropyLoss2p5D,
|
2022-02-14 03:15:02 +00:00
|
|
|
'3d': CrossEntropyLoss3D,
|
|
|
|
}
|
|
|
|
|
|
|
|
_vocab_parallel_cross_entropy = {
|
|
|
|
'1d': VocabParallelCrossEntropyLoss1D,
|
|
|
|
'2d': VocabParallelCrossEntropyLoss2D,
|
|
|
|
'2.5d': VocabParallelCrossEntropyLoss2p5D,
|
|
|
|
'3d': VocabParallelCrossEntropyLoss3D,
|
2021-12-27 07:04:32 +00:00
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
class CrossEntropyLoss(_Loss):
|
2022-02-14 03:15:02 +00:00
|
|
|
|
2021-12-29 15:32:10 +00:00
|
|
|
def __init__(self, reduction: bool = True, *args, **kwargs):
|
2021-12-27 07:04:32 +00:00
|
|
|
super().__init__()
|
2021-12-29 15:32:10 +00:00
|
|
|
tensor_parallel = get_tensor_parallel_mode()
|
2022-02-14 03:15:02 +00:00
|
|
|
if tensor_parallel is not None and env.vocab_parallel:
|
|
|
|
self.loss = _vocab_parallel_cross_entropy[tensor_parallel](reduction=reduction, *args, **kwargs)
|
|
|
|
elif tensor_parallel is None or tensor_parallel == '1d':
|
2021-12-27 07:04:32 +00:00
|
|
|
reduction = 'mean' if reduction else 'none'
|
|
|
|
self.loss = nn.CrossEntropyLoss(reduction=reduction, *args, **kwargs)
|
|
|
|
else:
|
|
|
|
self.loss = _parallel_cross_entropy[tensor_parallel](reduction=reduction, *args, **kwargs)
|
|
|
|
|
|
|
|
def forward(self, *args):
|
|
|
|
return self.loss(*args)
|