ColossalAI/examples/simclr_cifar10_data_parallel/visualization.py

73 lines
2.5 KiB
Python

import torch
import numpy as np
from sklearn.manifold import TSNE
import matplotlib.pyplot as plt
from models.simclr import SimCLR
from torchvision.datasets import CIFAR10
from torch.utils.data import DataLoader
from torchvision import transforms
log_name = 'cifar-simclr'
epoch = 800
fea_flag = True
tsne_flag = True
plot_flag = True
if fea_flag:
path = f'ckpt/{log_name}/epoch{epoch}-tp0-pp0.pt'
net = SimCLR('resnet18').cuda()
print(net.load_state_dict(torch.load(path)['model']))
transform_eval = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize([0.4914, 0.4822, 0.4465], [0.2023, 0.1994, 0.2010])
])
train_dataset = CIFAR10(root='./dataset', train=True, transform=transform_eval)
train_dataloader = DataLoader(train_dataset, batch_size=256, shuffle=False, num_workers=4)
test_dataset = CIFAR10(root='./dataset', train=False, transform=transform_eval)
test_dataloader = DataLoader(test_dataset, batch_size=256, shuffle=False, num_workers=4)
def feature_extractor(model, loader):
model.eval()
all_fea = []
all_targets = []
for img, target in loader:
img = img.cuda()
fea = model.backbone(img)
all_fea.append(fea.detach().cpu())
all_targets.append(target)
all_fea = torch.cat(all_fea)
all_targets = torch.cat(all_targets)
return all_fea.numpy(), all_targets.numpy()
if tsne_flag:
train_fea, train_targets = feature_extractor(net, train_dataloader)
train_embedded = TSNE(n_components=2).fit_transform(train_fea)
test_fea, test_targets = feature_extractor(net, test_dataloader)
test_embedded = TSNE(n_components=2).fit_transform(test_fea)
np.savez('results/embedding.npz', train_embedded=train_embedded, train_targets=train_targets, test_embedded=test_embedded, test_targets=test_targets)
if plot_flag:
npz = np.load('embedding.npz')
train_embedded = npz['train_embedded']
train_targets = npz['train_targets']
test_embedded = npz['test_embedded']
test_targets = npz['test_targets']
plt.figure(figsize=(16,16))
for i in range(len(np.unique(train_targets))):
plt.scatter(train_embedded[train_targets==i,0], train_embedded[train_targets==i,1], label=i)
plt.title('train')
plt.legend()
plt.savefig('results/train_tsne.png')
plt.figure(figsize=(16,16))
for i in range(len(np.unique(test_targets))):
plt.scatter(test_embedded[test_targets==i,0], test_embedded[test_targets==i,1], label=i)
plt.title('test')
plt.legend()
plt.savefig('results/test_tsne.png')