2023-09-11 08:24:28 +00:00
|
|
|
from torch import nn
|
|
|
|
from torch.nn.modules.loss import *
|
|
|
|
from torch.nn.modules.loss import _Loss
|
|
|
|
|
2023-09-18 08:31:06 +00:00
|
|
|
from colossalai.legacy.global_variables import tensor_parallel_env as env
|
2023-09-11 08:24:28 +00:00
|
|
|
from colossalai.legacy.nn.layer.utils import get_tensor_parallel_mode
|
|
|
|
|
|
|
|
from .loss_1d import VocabParallelCrossEntropyLoss1D
|
|
|
|
from .loss_2d import CrossEntropyLoss2D, VocabParallelCrossEntropyLoss2D
|
|
|
|
from .loss_2p5d import CrossEntropyLoss2p5D, VocabParallelCrossEntropyLoss2p5D
|
|
|
|
from .loss_3d import CrossEntropyLoss3D, VocabParallelCrossEntropyLoss3D
|
|
|
|
|
|
|
|
_parallel_cross_entropy = {
|
2023-09-19 06:20:26 +00:00
|
|
|
"2d": CrossEntropyLoss2D,
|
|
|
|
"2.5d": CrossEntropyLoss2p5D,
|
|
|
|
"3d": CrossEntropyLoss3D,
|
2023-09-11 08:24:28 +00:00
|
|
|
}
|
|
|
|
|
|
|
|
_vocab_parallel_cross_entropy = {
|
2023-09-19 06:20:26 +00:00
|
|
|
"1d": VocabParallelCrossEntropyLoss1D,
|
|
|
|
"2d": VocabParallelCrossEntropyLoss2D,
|
|
|
|
"2.5d": VocabParallelCrossEntropyLoss2p5D,
|
|
|
|
"3d": VocabParallelCrossEntropyLoss3D,
|
2023-09-11 08:24:28 +00:00
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
class CrossEntropyLoss(_Loss):
|
|
|
|
def __init__(self, reduction: bool = True, *args, **kwargs):
|
|
|
|
super().__init__()
|
|
|
|
tensor_parallel = get_tensor_parallel_mode()
|
|
|
|
if tensor_parallel is not None and env.vocab_parallel:
|
|
|
|
self.loss = _vocab_parallel_cross_entropy[tensor_parallel](reduction=reduction, *args, **kwargs)
|
2023-09-19 06:20:26 +00:00
|
|
|
elif tensor_parallel is None or tensor_parallel == "1d":
|
|
|
|
reduction = "mean" if reduction else "none"
|
2023-09-11 08:24:28 +00:00
|
|
|
self.loss = nn.CrossEntropyLoss(reduction=reduction, *args, **kwargs)
|
|
|
|
else:
|
|
|
|
self.loss = _parallel_cross_entropy[tensor_parallel](reduction=reduction, *args, **kwargs)
|
|
|
|
|
|
|
|
def forward(self, *args):
|
|
|
|
return self.loss(*args)
|