2021-12-27 07:04:32 +00:00
|
|
|
from torch import nn
|
|
|
|
from torch.nn.modules.loss import *
|
|
|
|
from torch.nn.modules.loss import _Loss
|
2021-10-28 16:21:23 +00:00
|
|
|
|
2021-12-29 15:32:10 +00:00
|
|
|
from colossalai.nn.layer.utils import get_tensor_parallel_mode
|
2021-12-27 07:04:32 +00:00
|
|
|
from .loss_2d import CrossEntropyLoss2D
|
|
|
|
from .loss_2p5d import CrossEntropyLoss2p5D
|
|
|
|
from .loss_3d import CrossEntropyLoss3D
|
|
|
|
|
|
|
|
_parallel_cross_entropy = {
|
|
|
|
'2d': CrossEntropyLoss2D,
|
|
|
|
'2.5d': CrossEntropyLoss2p5D,
|
|
|
|
'3d': CrossEntropyLoss3D
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
class CrossEntropyLoss(_Loss):
|
2021-12-29 15:32:10 +00:00
|
|
|
def __init__(self, reduction: bool = True, *args, **kwargs):
|
2021-12-27 07:04:32 +00:00
|
|
|
super().__init__()
|
2021-12-29 15:32:10 +00:00
|
|
|
tensor_parallel = get_tensor_parallel_mode()
|
|
|
|
if tensor_parallel in ['None', '1d']:
|
2021-12-27 07:04:32 +00:00
|
|
|
reduction = 'mean' if reduction else 'none'
|
|
|
|
self.loss = nn.CrossEntropyLoss(reduction=reduction, *args, **kwargs)
|
|
|
|
else:
|
|
|
|
self.loss = _parallel_cross_entropy[tensor_parallel](reduction=reduction, *args, **kwargs)
|
|
|
|
|
|
|
|
def forward(self, *args):
|
|
|
|
return self.loss(*args)
|