mirror of https://github.com/hpcaitech/ColossalAI
parent
d09a79bad5
commit
f03bcb359b
|
@ -104,9 +104,9 @@ class DaliDataloader(DALIClassificationIterator):
|
|||
img = lam * img + (1 - lam) * img[idx, :]
|
||||
label_a, label_b = label, label[idx]
|
||||
lam = torch.tensor([lam], device=img.device, dtype=img.dtype)
|
||||
label = (label_a, label_b, lam)
|
||||
label = {'targets_a': label_a, 'targets_b': label_b, 'lam': lam}
|
||||
else:
|
||||
label = (label, label, torch.ones(
|
||||
1, device=img.device, dtype=img.dtype))
|
||||
return (img,), label
|
||||
return (img,), (label,)
|
||||
label = {'targets_a': label, 'targets_b': label,
|
||||
'lam': torch.ones(1, device=img.device, dtype=img.dtype)}
|
||||
return img, label
|
||||
return img, label
|
||||
|
|
|
@ -1,5 +1,7 @@
|
|||
import torch.nn as nn
|
||||
from colossalai.registry import LOSSES
|
||||
import torch
|
||||
|
||||
|
||||
@LOSSES.register_module
|
||||
class MixupLoss(nn.Module):
|
||||
|
@ -7,6 +9,13 @@ class MixupLoss(nn.Module):
|
|||
super().__init__()
|
||||
self.loss_fn = loss_fn_cls()
|
||||
|
||||
def forward(self, inputs, *args):
|
||||
targets_a, targets_b, lam = args
|
||||
def forward(self, inputs, targets_a, targets_b, lam):
|
||||
return lam * self.loss_fn(inputs, targets_a) + (1 - lam) * self.loss_fn(inputs, targets_b)
|
||||
|
||||
|
||||
class MixupAccuracy(nn.Module):
|
||||
def forward(self, logits, targets):
|
||||
targets = targets['targets_a']
|
||||
preds = torch.argmax(logits, dim=-1)
|
||||
correct = torch.sum(targets == preds)
|
||||
return correct
|
||||
|
|
|
@ -11,7 +11,7 @@ from colossalai.logging import get_dist_logger
|
|||
from colossalai.trainer import Trainer, hooks
|
||||
from colossalai.nn.lr_scheduler import LinearWarmupLR
|
||||
from dataloader.imagenet_dali_dataloader import DaliDataloader
|
||||
from mixup import MixupLoss
|
||||
from mixup import MixupLoss, MixupAccuracy
|
||||
from timm.models import vit_base_patch16_224
|
||||
from myhooks import TotalBatchsizeHook
|
||||
|
||||
|
@ -62,7 +62,7 @@ def main():
|
|||
port=args.port,
|
||||
backend=args.backend
|
||||
)
|
||||
# launch from torch
|
||||
# launch from torch
|
||||
# colossalai.launch_from_torch(config=args.config)
|
||||
|
||||
# get logger
|
||||
|
@ -96,7 +96,7 @@ def main():
|
|||
# build hooks
|
||||
hook_list = [
|
||||
hooks.LossHook(),
|
||||
hooks.AccuracyHook(accuracy_func=Accuracy()),
|
||||
hooks.AccuracyHook(accuracy_func=MixupAccuracy()),
|
||||
hooks.LogMetricByEpochHook(logger),
|
||||
hooks.LRSchedulerHook(lr_scheduler, by_epoch=True),
|
||||
TotalBatchsizeHook(),
|
||||
|
|
Loading…
Reference in New Issue