ColossalAI/examples/vit_b16_imagenet_data_parallel/mixup.py

13 lines
394 B
Python

import torch.nn as nn
from colossalai.registry import LOSSES
@LOSSES.register_module
class MixupLoss(nn.Module):
def __init__(self, loss_fn_cls):
super().__init__()
self.loss_fn = loss_fn_cls()
def forward(self, inputs, *args):
targets_a, targets_b, lam = args
return lam * self.loss_fn(inputs, targets_a) + (1 - lam) * self.loss_fn(inputs, targets_b)