update vit example for new API (#98) (#99)

pull/114/head
ver217 2022-01-04 20:35:33 +08:00 committed by GitHub
parent d09a79bad5
commit f03bcb359b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 19 additions and 10 deletions

View File

@ -104,9 +104,9 @@ class DaliDataloader(DALIClassificationIterator):
img = lam * img + (1 - lam) * img[idx, :] img = lam * img + (1 - lam) * img[idx, :]
label_a, label_b = label, label[idx] label_a, label_b = label, label[idx]
lam = torch.tensor([lam], device=img.device, dtype=img.dtype) lam = torch.tensor([lam], device=img.device, dtype=img.dtype)
label = (label_a, label_b, lam) label = {'targets_a': label_a, 'targets_b': label_b, 'lam': lam}
else: else:
label = (label, label, torch.ones( label = {'targets_a': label, 'targets_b': label,
1, device=img.device, dtype=img.dtype)) 'lam': torch.ones(1, device=img.device, dtype=img.dtype)}
return (img,), label return img, label
return (img,), (label,) return img, label

View File

@ -1,5 +1,7 @@
import torch.nn as nn import torch.nn as nn
from colossalai.registry import LOSSES from colossalai.registry import LOSSES
import torch
@LOSSES.register_module @LOSSES.register_module
class MixupLoss(nn.Module): class MixupLoss(nn.Module):
@ -7,6 +9,13 @@ class MixupLoss(nn.Module):
super().__init__() super().__init__()
self.loss_fn = loss_fn_cls() self.loss_fn = loss_fn_cls()
def forward(self, inputs, *args): def forward(self, inputs, targets_a, targets_b, lam):
targets_a, targets_b, lam = args
return lam * self.loss_fn(inputs, targets_a) + (1 - lam) * self.loss_fn(inputs, targets_b) return lam * self.loss_fn(inputs, targets_a) + (1 - lam) * self.loss_fn(inputs, targets_b)
class MixupAccuracy(nn.Module):
def forward(self, logits, targets):
targets = targets['targets_a']
preds = torch.argmax(logits, dim=-1)
correct = torch.sum(targets == preds)
return correct

View File

@ -11,7 +11,7 @@ from colossalai.logging import get_dist_logger
from colossalai.trainer import Trainer, hooks from colossalai.trainer import Trainer, hooks
from colossalai.nn.lr_scheduler import LinearWarmupLR from colossalai.nn.lr_scheduler import LinearWarmupLR
from dataloader.imagenet_dali_dataloader import DaliDataloader from dataloader.imagenet_dali_dataloader import DaliDataloader
from mixup import MixupLoss from mixup import MixupLoss, MixupAccuracy
from timm.models import vit_base_patch16_224 from timm.models import vit_base_patch16_224
from myhooks import TotalBatchsizeHook from myhooks import TotalBatchsizeHook
@ -62,7 +62,7 @@ def main():
port=args.port, port=args.port,
backend=args.backend backend=args.backend
) )
# launch from torch # launch from torch
# colossalai.launch_from_torch(config=args.config) # colossalai.launch_from_torch(config=args.config)
# get logger # get logger
@ -96,7 +96,7 @@ def main():
# build hooks # build hooks
hook_list = [ hook_list = [
hooks.LossHook(), hooks.LossHook(),
hooks.AccuracyHook(accuracy_func=Accuracy()), hooks.AccuracyHook(accuracy_func=MixupAccuracy()),
hooks.LogMetricByEpochHook(logger), hooks.LogMetricByEpochHook(logger),
hooks.LRSchedulerHook(lr_scheduler, by_epoch=True), hooks.LRSchedulerHook(lr_scheduler, by_epoch=True),
TotalBatchsizeHook(), TotalBatchsizeHook(),