import argparse

import torch
import torchvision
import torchvision.transforms as transforms

# ==============================
# Parse Arguments
# ==============================
parser = argparse.ArgumentParser()
parser.add_argument("-e", "--epoch", type=int, default=80, help="resume from the epoch's checkpoint")
parser.add_argument("-c", "--checkpoint", type=str, default="./checkpoint", help="checkpoint directory")
args = parser.parse_args()

# ==============================
# Prepare Test Dataset
# ==============================
# CIFAR-10 dataset
test_dataset = torchvision.datasets.CIFAR10(root="./data/", train=False, transform=transforms.ToTensor())

# Data loader
test_loader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=128, shuffle=False)

# ==============================
# Load Model
# ==============================
model = torchvision.models.resnet18(num_classes=10).cuda()
state_dict = torch.load(f"{args.checkpoint}/model_{args.epoch}.pth")
model.load_state_dict(state_dict)

# ==============================
# Run Evaluation
# ==============================
model.eval()

with torch.no_grad():
    correct = 0
    total = 0
    for images, labels in test_loader:
        images = images.cuda()
        labels = labels.cuda()
        outputs = model(images)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

    print("Accuracy of the model on the test images: {} %".format(100 * correct / total))