mirror of https://github.com/hpcaitech/ColossalAI
13 lines
394 B
Python
13 lines
394 B
Python
import torch.nn as nn
|
|
from colossalai.registry import LOSSES
|
|
|
|
@LOSSES.register_module
|
|
class MixupLoss(nn.Module):
|
|
def __init__(self, loss_fn_cls):
|
|
super().__init__()
|
|
self.loss_fn = loss_fn_cls()
|
|
|
|
def forward(self, inputs, *args):
|
|
targets_a, targets_b, lam = args
|
|
return lam * self.loss_fn(inputs, targets_a) + (1 - lam) * self.loss_fn(inputs, targets_b)
|