diff --git a/colossalai/nn/loss/loss_2p5d.py b/colossalai/nn/loss/loss_2p5d.py index ed58c13f8..f8e3324fc 100644 --- a/colossalai/nn/loss/loss_2p5d.py +++ b/colossalai/nn/loss/loss_2p5d.py @@ -30,6 +30,7 @@ class CrossEntropyLoss2p5D(_Loss): More details about ``args``, ``kwargs`` and ``torch.nn.functional.cross_entropy`` could be found in `Cross_entropy `_. """ + def __init__(self, reduction=True, *args, **kwargs): super().__init__() assert_tesseract_initialization() @@ -127,6 +128,7 @@ class VocabParallelCrossEntropyLoss2p5D(_Loss): Args: reduction (bool, optional): whether to average the loss, defaults to True. """ + def __init__(self, reduction=True): super().__init__() self.reduction_mean = reduction