mirror of https://github.com/hpcaitech/ColossalAI
49 lines
1.5 KiB
Python
49 lines
1.5 KiB
Python
|
import argparse
|
||
|
|
||
|
import torch
|
||
|
import torch.nn as nn
|
||
|
import torchvision
|
||
|
import torchvision.transforms as transforms
|
||
|
|
||
|
# ==============================
|
||
|
# Parse Arguments
|
||
|
# ==============================
|
||
|
parser = argparse.ArgumentParser()
|
||
|
parser.add_argument('-e', '--epoch', type=int, default=80, help="resume from the epoch's checkpoint")
|
||
|
parser.add_argument('-c', '--checkpoint', type=str, default='./checkpoint', help="checkpoint directory")
|
||
|
args = parser.parse_args()
|
||
|
|
||
|
# ==============================
|
||
|
# Prepare Test Dataset
|
||
|
# ==============================
|
||
|
# CIFAR-10 dataset
|
||
|
test_dataset = torchvision.datasets.CIFAR10(root='./data/', train=False, transform=transforms.ToTensor())
|
||
|
|
||
|
# Data loader
|
||
|
test_loader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=128, shuffle=False)
|
||
|
|
||
|
# ==============================
|
||
|
# Load Model
|
||
|
# ==============================
|
||
|
model = torchvision.models.resnet18(num_classes=10).cuda()
|
||
|
state_dict = torch.load(f'{args.checkpoint}/model_{args.epoch}.pth')
|
||
|
model.load_state_dict(state_dict)
|
||
|
|
||
|
# ==============================
|
||
|
# Run Evaluation
|
||
|
# ==============================
|
||
|
model.eval()
|
||
|
|
||
|
with torch.no_grad():
|
||
|
correct = 0
|
||
|
total = 0
|
||
|
for images, labels in test_loader:
|
||
|
images = images.cuda()
|
||
|
labels = labels.cuda()
|
||
|
outputs = model(images)
|
||
|
_, predicted = torch.max(outputs.data, 1)
|
||
|
total += labels.size(0)
|
||
|
correct += (predicted == labels).sum().item()
|
||
|
|
||
|
print('Accuracy of the model on the test images: {} %'.format(100 * correct / total))
|