[NFC] polish colossalai/nn/loss/loss_2p5d.py code style (#1553)

pull/1550/head
shenggan 2022-09-08 15:06:04 +08:00 committed by Frank Lee
parent bd2d789832
commit 8edb777cc2
1 changed files with 2 additions and 0 deletions

View File

@ -30,6 +30,7 @@ class CrossEntropyLoss2p5D(_Loss):
More details about ``args``, ``kwargs`` and ``torch.nn.functional.cross_entropy`` could be found in
`Cross_entropy <https://pytorch.org/docs/stable/generated/torch.nn.functional.cross_entropy.html#torch.nn.functional.cross_entropy>`_.
"""
def __init__(self, reduction=True, *args, **kwargs):
super().__init__()
assert_tesseract_initialization()
@ -127,6 +128,7 @@ class VocabParallelCrossEntropyLoss2p5D(_Loss):
Args:
reduction (bool, optional): whether to average the loss, defaults to True.
"""
def __init__(self, reduction=True):
super().__init__()
self.reduction_mean = reduction