mirror of https://github.com/hpcaitech/ColossalAI
[NFC] polish colossalai/nn/loss/loss_2p5d.py code style (#1553)
parent
bd2d789832
commit
8edb777cc2
|
@ -30,6 +30,7 @@ class CrossEntropyLoss2p5D(_Loss):
|
|||
More details about ``args``, ``kwargs`` and ``torch.nn.functional.cross_entropy`` could be found in
|
||||
`Cross_entropy <https://pytorch.org/docs/stable/generated/torch.nn.functional.cross_entropy.html#torch.nn.functional.cross_entropy>`_.
|
||||
"""
|
||||
|
||||
def __init__(self, reduction=True, *args, **kwargs):
|
||||
super().__init__()
|
||||
assert_tesseract_initialization()
|
||||
|
@ -127,6 +128,7 @@ class VocabParallelCrossEntropyLoss2p5D(_Loss):
|
|||
Args:
|
||||
reduction (bool, optional): whether to average the loss, defaults to True.
|
||||
"""
|
||||
|
||||
def __init__(self, reduction=True):
|
||||
super().__init__()
|
||||
self.reduction_mean = reduction
|
||||
|
|
Loading…
Reference in New Issue