update vit example for new API (#98) (#99)

pull/114/head
ver217 2022-01-04 20:35:33 +08:00 committed by GitHub
parent d09a79bad5
commit f03bcb359b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 19 additions and 10 deletions

View File

@ -104,9 +104,9 @@ class DaliDataloader(DALIClassificationIterator):
img = lam * img + (1 - lam) * img[idx, :]
label_a, label_b = label, label[idx]
lam = torch.tensor([lam], device=img.device, dtype=img.dtype)
label = (label_a, label_b, lam)
label = {'targets_a': label_a, 'targets_b': label_b, 'lam': lam}
else:
label = (label, label, torch.ones(
1, device=img.device, dtype=img.dtype))
return (img,), label
return (img,), (label,)
label = {'targets_a': label, 'targets_b': label,
'lam': torch.ones(1, device=img.device, dtype=img.dtype)}
return img, label
return img, label

View File

@ -1,5 +1,7 @@
import torch.nn as nn
from colossalai.registry import LOSSES
import torch
@LOSSES.register_module
class MixupLoss(nn.Module):
@ -7,6 +9,13 @@ class MixupLoss(nn.Module):
super().__init__()
self.loss_fn = loss_fn_cls()
def forward(self, inputs, *args):
targets_a, targets_b, lam = args
def forward(self, inputs, targets_a, targets_b, lam):
return lam * self.loss_fn(inputs, targets_a) + (1 - lam) * self.loss_fn(inputs, targets_b)
class MixupAccuracy(nn.Module):
def forward(self, logits, targets):
targets = targets['targets_a']
preds = torch.argmax(logits, dim=-1)
correct = torch.sum(targets == preds)
return correct

View File

@ -11,7 +11,7 @@ from colossalai.logging import get_dist_logger
from colossalai.trainer import Trainer, hooks
from colossalai.nn.lr_scheduler import LinearWarmupLR
from dataloader.imagenet_dali_dataloader import DaliDataloader
from mixup import MixupLoss
from mixup import MixupLoss, MixupAccuracy
from timm.models import vit_base_patch16_224
from myhooks import TotalBatchsizeHook
@ -62,7 +62,7 @@ def main():
port=args.port,
backend=args.backend
)
# launch from torch
# launch from torch
# colossalai.launch_from_torch(config=args.config)
# get logger
@ -96,7 +96,7 @@ def main():
# build hooks
hook_list = [
hooks.LossHook(),
hooks.AccuracyHook(accuracy_func=Accuracy()),
hooks.AccuracyHook(accuracy_func=MixupAccuracy()),
hooks.LogMetricByEpochHook(logger),
hooks.LRSchedulerHook(lr_scheduler, by_epoch=True),
TotalBatchsizeHook(),