import torch

def calc_acc(logits, targets):
    preds = torch.argmax(logits, dim=-1)
    correct = torch.sum(targets == preds)
    return correct