add example of self-supervised SimCLR training - V2 (#50)

* add example of self-supervised SimCLR training

* simclr v2, replace nvidia dali dataloader

* updated

* sync to latest code writing style

* sync to latest code writing style and modify README

* detail README & standardize dataset path
pull/55/head
Xin Zhang 3 years ago committed by GitHub
parent 8f02a88db2
commit 648f806315
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -0,0 +1,45 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from colossalai.registry import LOSSES
from torch.nn.modules.linear import Linear
@LOSSES.register_module
class NT_Xentloss(nn.Module):
def __init__(self, temperature=0.5):
super().__init__()
self.temperature = temperature
def forward(self, z1, z2, label):
z1 = F.normalize(z1, dim=1)
z2 = F.normalize(z2, dim=1)
N, Z = z1.shape
device = z1.device
representations = torch.cat([z1, z2], dim=0)
similarity_matrix = F.cosine_similarity(representations.unsqueeze(1), representations.unsqueeze(0), dim=-1)
l_pos = torch.diag(similarity_matrix, N)
r_pos = torch.diag(similarity_matrix, -N)
positives = torch.cat([l_pos, r_pos]).view(2 * N, 1)
diag = torch.eye(2*N, dtype=torch.bool, device=device)
diag[N:,:N] = diag[:N,N:] = diag[:N,:N]
negatives = similarity_matrix[~diag].view(2*N, -1)
logits = torch.cat([positives, negatives], dim=1)
logits /= self.temperature
labels = torch.zeros(2*N, device=device, dtype=torch.int64)
loss = F.cross_entropy(logits, labels, reduction='sum')
return loss / (2 * N)
if __name__=='__main__':
criterion = NT_Xentloss()
net = Linear(256,512)
output = [net(torch.randn(512,256)), net(torch.randn(512,256))]
label = [torch.randn(512)]
loss = criterion(*output, *label)
print(loss)

@ -0,0 +1,30 @@
# Overview
Here is an example of applying [PreAct-ResNet18](https://arxiv.org/abs/1603.05027) to train [SimCLR](https://arxiv.org/abs/2002.05709) on CIFAR10.
SimCLR is a kind of self-supervised representation learning algorithm which learns generic representations of images on an unlabeled dataset. The generic representations are learned by simultaneously maximizing agreement between differently transformed views of the same image and minimizing agreement between transformed views of different images, following a method called contrastive learning. Updating the parameters of a neural network using this contrastive objective causes representations of corresponding views to “attract” each other, while representations of non-corresponding views “repel” each other. A more detailed description of SimCLR is available [here](https://ai.googleblog.com/2020/04/advancing-self-supervised-and-semi.html).
The training process consists of two phases: (1) self-supervised representation learning: the model which acts as a feature extractor is trained exactly as described above; and (2) linear evaluation: to evaluate how well representations are learned, generally a linear classifier is added on top of the trained feature extractor in phase 1. The linear classifier is trained with a labeled dataset in a conventional supervised manner, while parameters of the feature extractor keep fixed. This process is called linear evaluation.
# How to run
The training commands are specified in:
```shell
bash train.sh
```
Before running, you can specify the experiment name (folders with the same name will be created in `ckpt` to save checkpoints and in `tb_logs` to save the tensorboard file) and other training hyperparameters in `config.py`. By default CIFAR10 dataset will be downloaded automatically and saved in `./dataset`. Note that `LOG_NAME` in `le_config.py` should be the same as that in `config.py`.
Besides linear evaluation, you can also visualize the distribution of learned representations. A script is provided which first extracts representations and then visualizes them with [t-SNE](https://scikit-learn.org/stable/modules/generated/sklearn.manifold.TSNE.html). t-SNE is a good tool to visualize high-dimensional data. It converts similarities between data points to joint probabilities and tries to minimize the Kullback-Leibler divergence between the joint probabilities of the low-dimensional embedding and the high-dimensional data. You can directly run the script by (remember modifying `log_name` and `epoch` to specify the model in which experiment folder and of which training epoch to load):
```python
python visualization.py
```
# Results
The loss curve of SimCLR self-supervised training is as follows:
![SimCLR Loss Curve](./results/ssl_loss.png)
The loss curve of linear evaluation is as follows:
![Linear Evaluation Loss Curve](./results/linear_eval_loss.png)
The accuracy curve of linear evaluation is as follows:
![Linear Evaluation Accuracy](./results/linear_eval_acc.png)
The t-SNE of the training set of CIFAR10 is as follows:
![train tSNE](./results/train_tsne.png)
The t-SNE of the test set of CIFAR10 is as follows:
![test tSNE](./results/test_tsne.png)

@ -0,0 +1,32 @@
from torchvision.transforms import transforms
class SimCLRTransform():
def __init__(self):
self.transform = transforms.Compose([
transforms.RandomResizedCrop(size=32, scale=(0.2, 1.0)),
transforms.RandomHorizontalFlip(),
transforms.RandomApply([transforms.ColorJitter(0.8, 0.8, 0.8, 0.2)], p=0.8),
transforms.RandomGrayscale(p=0.2),
transforms.RandomApply([transforms.GaussianBlur(kernel_size=32//20*2+1, sigma=(0.1, 2.0))], p=0.5),
transforms.ToTensor(),
transforms.Normalize(mean=[0.4914, 0.4822, 0.4465], std=[0.2023, 0.1994, 0.2010])
])
def __call__(self, x):
x1 = self.transform(x)
x2 = self.transform(x)
return x1, x2
class LeTransform():
def __init__(self):
self.transform = transforms.Compose([
transforms.RandomResizedCrop(size=32, scale=(0.2, 1.0)),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize(mean=[0.4914, 0.4822, 0.4465], std=[0.2023, 0.1994, 0.2010])
])
def __call__(self, x):
x = self.transform(x)
return x

@ -0,0 +1,23 @@
from colossalai.amp import AMP_TYPE
LOG_NAME = 'cifar-simclr'
BATCH_SIZE = 512
NUM_EPOCHS = 801
LEARNING_RATE = 0.03*BATCH_SIZE/256
WEIGHT_DECAY = 0.0005
MOMENTUM = 0.9
fp16 = dict(
mode=AMP_TYPE.TORCH,
)
dataset = dict(
root='./dataset',
)
gradient_accumulation=2
gradient_clipping=1.0

@ -0,0 +1,23 @@
from colossalai.amp import AMP_TYPE
LOG_NAME = 'cifar-simclr'
EPOCH = 800
BATCH_SIZE = 512
NUM_EPOCHS = 51
LEARNING_RATE = 0.03*BATCH_SIZE/256
WEIGHT_DECAY = 0.0005
MOMENTUM = 0.9
fp16 = dict(
mode=AMP_TYPE.TORCH,
)
dataset = dict(
root='./dataset',
)
gradient_accumulation=1
gradient_clipping=1.0

@ -0,0 +1,178 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision.models import resnet34, resnet50, resnet101, resnet152
def backbone(model, **kwargs):
assert model in ['resnet18', 'resnet34', 'resnet50', 'resnet101', 'resnet152'], "current version only support resnet18 ~ resnet152"
if model == 'resnet18':
net = ResNet(PreActBlock, [2,2,2,2], **kwargs)
else:
net = eval(f"{model}(**kwargs)")
net.output_dim = net.fc.in_features
net.fc = torch.nn.Identity()
return net
def conv3x3(in_planes, out_planes, stride=1):
return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False)
class BasicBlock(nn.Module):
expansion = 1
def __init__(self, in_planes, planes, stride=1):
super(BasicBlock, self).__init__()
self.conv1 = conv3x3(in_planes, planes, stride)
self.bn1 = nn.BatchNorm2d(planes)
self.conv2 = conv3x3(planes, planes)
self.bn2 = nn.BatchNorm2d(planes)
self.shortcut = nn.Sequential()
if stride != 1 or in_planes != self.expansion*planes:
self.shortcut = nn.Sequential(
nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False),
nn.BatchNorm2d(self.expansion*planes)
)
def forward(self, x):
out = F.relu(self.bn1(self.conv1(x)))
out = self.bn2(self.conv2(out))
out += self.shortcut(x)
out = F.relu(out)
return out
class PreActBlock(nn.Module):
'''Pre-activation version of the BasicBlock.'''
expansion = 1
def __init__(self, in_planes, planes, stride=1):
super(PreActBlock, self).__init__()
self.bn1 = nn.BatchNorm2d(in_planes)
self.conv1 = conv3x3(in_planes, planes, stride)
self.bn2 = nn.BatchNorm2d(planes)
self.conv2 = conv3x3(planes, planes)
self.shortcut = nn.Sequential()
if stride != 1 or in_planes != self.expansion*planes:
self.shortcut = nn.Sequential(
nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False)
)
def forward(self, x):
out = F.relu(self.bn1(x))
shortcut = self.shortcut(out)
out = self.conv1(out)
out = self.conv2(F.relu(self.bn2(out)))
out += shortcut
return out
class Bottleneck(nn.Module):
expansion = 4
def __init__(self, in_planes, planes, stride=1):
super(Bottleneck, self).__init__()
self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False)
self.bn1 = nn.BatchNorm2d(planes)
self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
self.bn2 = nn.BatchNorm2d(planes)
self.conv3 = nn.Conv2d(planes, self.expansion*planes, kernel_size=1, bias=False)
self.bn3 = nn.BatchNorm2d(self.expansion*planes)
self.shortcut = nn.Sequential()
if stride != 1 or in_planes != self.expansion*planes:
self.shortcut = nn.Sequential(
nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False),
nn.BatchNorm2d(self.expansion*planes)
)
def forward(self, x):
out = F.relu(self.bn1(self.conv1(x)))
out = F.relu(self.bn2(self.conv2(out)))
out = self.bn3(self.conv3(out))
out += self.shortcut(x)
out = F.relu(out)
return out
class PreActBottleneck(nn.Module):
'''Pre-activation version of the original Bottleneck module.'''
expansion = 4
def __init__(self, in_planes, planes, stride=1):
super(PreActBottleneck, self).__init__()
self.bn1 = nn.BatchNorm2d(in_planes)
self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False)
self.bn2 = nn.BatchNorm2d(planes)
self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
self.bn3 = nn.BatchNorm2d(planes)
self.conv3 = nn.Conv2d(planes, self.expansion*planes, kernel_size=1, bias=False)
self.shortcut = nn.Sequential()
if stride != 1 or in_planes != self.expansion*planes:
self.shortcut = nn.Sequential(
nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False)
)
def forward(self, x):
out = F.relu(self.bn1(x))
shortcut = self.shortcut(out)
out = self.conv1(out)
out = self.conv2(F.relu(self.bn2(out)))
out = self.conv3(F.relu(self.bn3(out)))
out += shortcut
return out
class ResNet(nn.Module):
def __init__(self, block, num_blocks, num_classes=10):
super(ResNet, self).__init__()
self.in_planes = 64
self.conv1 = conv3x3(3,64)
self.bn1 = nn.BatchNorm2d(64)
self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1)
self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2)
self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2)
self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2)
self.fc = nn.Linear(512*block.expansion, num_classes)
def _make_layer(self, block, planes, num_blocks, stride):
strides = [stride] + [1]*(num_blocks-1)
layers = []
for stride in strides:
layers.append(block(self.in_planes, planes, stride))
self.in_planes = planes * block.expansion
return nn.Sequential(*layers)
def forward(self, x, lin=0, lout=5):
out = x
if lin < 1 and lout > -1:
out = self.conv1(out)
out = self.bn1(out)
out = F.relu(out)
if lin < 2 and lout > 0:
out = self.layer1(out)
if lin < 3 and lout > 1:
out = self.layer2(out)
if lin < 4 and lout > 2:
out = self.layer3(out)
if lin < 5 and lout > 3:
out = self.layer4(out)
if lout > 4:
out = F.avg_pool2d(out, 4)
out = out.view(out.size(0), -1)
out = self.fc(out)
return out
def debug():
net = backbone('resnet18', pretrained=True)
x = torch.randn(4,3,32,32)
y = net(x)
print(y.size())
if __name__ == '__main__':
debug()

@ -0,0 +1,19 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from .Backbone import backbone
class Linear_eval(nn.Module):
def __init__(self, model='resnet18', class_num=10, **kwargs):
super().__init__()
self.backbone = backbone(model, **kwargs)
self.backbone.requires_grad_(False)
self.fc = nn.Linear(self.backbone.output_dim, class_num)
def forward(self, x):
out = self.backbone(x)
out = self.fc(out)
return out

@ -0,0 +1,36 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from .Backbone import backbone
class projection_MLP(nn.Module):
def __init__(self, in_dim, out_dim=256):
super().__init__()
hidden_dim = in_dim
self.layer1 = nn.Sequential(
nn.Linear(in_dim, hidden_dim),
nn.ReLU(inplace=True)
)
self.layer2 = nn.Linear(hidden_dim, out_dim)
def forward(self, x):
x = self.layer1(x)
x = self.layer2(x)
return x
class SimCLR(nn.Module):
def __init__(self, model='resnet18', **kwargs):
super().__init__()
self.backbone = backbone(model, **kwargs)
self.projector = projection_MLP(self.backbone.output_dim)
self.encoder = nn.Sequential(
self.backbone,
self.projector
)
def forward(self, x1, x2):
z1 = self.encoder(x1)
z2 = self.encoder(x2)
return z1, z2

@ -0,0 +1,15 @@
from colossalai.trainer.hooks import BaseHook
from colossalai.core import global_context as gpc
from colossalai.context import ParallelMode
from colossalai.logging import get_dist_logger
class TotalBatchsizeHook(BaseHook):
def __init__(self, priority: int = 2) -> None:
super().__init__(priority)
self.logger = get_dist_logger()
def before_train(self, trainer):
total_batch_size = gpc.config.BATCH_SIZE * \
gpc.config.gradient_accumulation * gpc.get_world_size(ParallelMode.DATA)
self.logger.info(f'Total batch size = {total_batch_size}', ranks=[0])

Binary file not shown.

After

Width:  |  Height:  |  Size: 16 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 26 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 29 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 405 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 531 KiB

@ -0,0 +1,7 @@
#!/usr/bin/env sh
## phase 1: self-supervised training
python -m torch.distributed.launch --nproc_per_node 1 train_simclr.py
## phase 2: linear evaluation
python -m torch.distributed.launch --nproc_per_node 1 train_linear.py

@ -0,0 +1,106 @@
import torch
import colossalai
from colossalai.core import global_context as gpc
from colossalai.logging import get_dist_logger
from colossalai.trainer import Trainer, hooks
from colossalai.utils import get_dataloader, MultiTimer
from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR
from torchvision.datasets import CIFAR10
from myhooks import TotalBatchsizeHook
from models.linear_eval import Linear_eval
from augmentation import LeTransform
def build_dataset_train():
augment = LeTransform()
train_dataset = CIFAR10(root=gpc.config.dataset.root,
transform=augment,
train=True)
return get_dataloader(
dataset=train_dataset,
shuffle=True,
num_workers = 1,
batch_size=gpc.config.BATCH_SIZE,
pin_memory=True,
)
def build_dataset_test():
augment = LeTransform()
val_dataset = CIFAR10(root=gpc.config.dataset.root,
transform=augment,
train=False)
return get_dataloader(
dataset=val_dataset,
add_sampler=False,
num_workers = 1,
batch_size=gpc.config.BATCH_SIZE,
pin_memory=True,
)
def main():
colossalai.launch_from_torch(config='./le_config.py',
host='localhost',
port=29500)
# get logger
logger = get_dist_logger()
## build model
model = Linear_eval(model='resnet18', class_num=10)
# build dataloader
train_dataloader = build_dataset_train()
test_dataloader = build_dataset_test()
# build loss
criterion = torch.nn.CrossEntropyLoss()
# build optimizer
optimizer = colossalai.nn.FusedSGD(model.parameters(), lr=gpc.config.LEARNING_RATE, weight_decay=gpc.config.WEIGHT_DECAY, momentum=gpc.config.MOMENTUM)
# lr_scheduelr
lr_scheduler = CosineAnnealingWarmupLR(optimizer, warmup_steps=5, total_steps=gpc.config.NUM_EPOCHS)
engine, train_dataloader, test_dataloader, _ = colossalai.initialize(
model, optimizer, criterion, train_dataloader, test_dataloader
)
logger.info("initialized colossalai components", ranks=[0])
## Load trained self-supervised SimCLR model
engine.model.load_state_dict(torch.load(f'./ckpt/{gpc.config.LOG_NAME}/epoch{gpc.config.EPOCH}-tp0-pp0.pt')['model'], strict=False)
logger.info("pretrained model loaded", ranks=[0])
# build a timer to measure time
timer = MultiTimer()
# build trainer
trainer = Trainer(engine=engine, logger=logger, timer=timer)
# build hooks
hook_list = [
hooks.LossHook(),
hooks.AccuracyHook(),
hooks.LogMetricByEpochHook(logger),
hooks.LRSchedulerHook(lr_scheduler, by_epoch=True),
TotalBatchsizeHook(),
# comment if you do not need to use the hooks below
hooks.SaveCheckpointHook(interval=5, checkpoint_dir=f'./ckpt/{gpc.config.LOG_NAME}-eval'),
hooks.TensorboardHook(log_dir=f'./tb_logs/{gpc.config.LOG_NAME}-eval', ranks=[0]),
]
# start training
trainer.fit(
train_dataloader=train_dataloader,
test_dataloader=test_dataloader,
epochs=gpc.config.NUM_EPOCHS,
hooks=hook_list,
display_progress=True,
test_interval=1
)
if __name__ == '__main__':
main()

@ -0,0 +1,102 @@
import colossalai
from colossalai.core import global_context as gpc
from colossalai.logging import get_dist_logger
from colossalai.trainer import Trainer, hooks
from colossalai.utils import get_dataloader, MultiTimer
from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR
from torchvision.datasets import CIFAR10
from NT_Xentloss import NT_Xentloss
from myhooks import TotalBatchsizeHook
from models.simclr import SimCLR
from augmentation import SimCLRTransform
def build_dataset_train():
augment = SimCLRTransform()
train_dataset = CIFAR10(root=gpc.config.dataset.root,
transform=augment,
train=True,
download=True)
return get_dataloader(
dataset=train_dataset,
shuffle=True,
num_workers = 1,
batch_size=gpc.config.BATCH_SIZE,
pin_memory=True,
)
def build_dataset_test():
augment = SimCLRTransform()
val_dataset = CIFAR10(root=gpc.config.dataset.root,
transform=augment,
train=False)
return get_dataloader(
dataset=val_dataset,
add_sampler=False,
num_workers = 1,
batch_size=gpc.config.BATCH_SIZE,
pin_memory=True,
)
def main():
colossalai.launch_from_torch(config='./config.py',
host='localhost',
port=29500)
# get logger
logger = get_dist_logger()
## build model
model = SimCLR(model='resnet18')
# build dataloader
train_dataloader = build_dataset_train()
test_dataloader = build_dataset_test()
# build loss
criterion = NT_Xentloss()
# build optimizer
optimizer = colossalai.nn.FusedSGD(model.parameters(), lr=gpc.config.LEARNING_RATE, weight_decay=gpc.config.WEIGHT_DECAY, momentum=gpc.config.MOMENTUM)
# lr_scheduelr
lr_scheduler = CosineAnnealingWarmupLR(optimizer, warmup_steps=10, total_steps=gpc.config.NUM_EPOCHS)
engine, train_dataloader, test_dataloader, _ = colossalai.initialize(
model, optimizer, criterion, train_dataloader, test_dataloader
)
logger.info("initialized colossalai components", ranks=[0])
# build a timer to measure time
timer = MultiTimer()
# build trainer
trainer = Trainer(engine=engine, logger=logger, timer=timer)
# build hooks
hook_list = [
hooks.LossHook(),
hooks.LogMetricByEpochHook(logger),
hooks.LRSchedulerHook(lr_scheduler, by_epoch=True),
TotalBatchsizeHook(),
# comment if you do not need to use the hooks below
hooks.SaveCheckpointHook(interval=50, checkpoint_dir=f'./ckpt/{gpc.config.LOG_NAME}'),
hooks.TensorboardHook(log_dir=f'./tb_logs/{gpc.config.LOG_NAME}', ranks=[0]),
]
# start training
trainer.fit(
train_dataloader=train_dataloader,
test_dataloader=test_dataloader,
epochs=gpc.config.NUM_EPOCHS,
hooks=hook_list,
display_progress=True,
test_interval=1
)
if __name__ == '__main__':
main()

@ -0,0 +1,72 @@
import torch
import numpy as np
from sklearn.manifold import TSNE
import matplotlib.pyplot as plt
from models.simclr import SimCLR
from torchvision.datasets import CIFAR10
from torch.utils.data import DataLoader
from torchvision import transforms
log_name = 'cifar-simclr'
epoch = 800
fea_flag = True
tsne_flag = True
plot_flag = True
if fea_flag:
path = f'ckpt/{log_name}/epoch{epoch}-tp0-pp0.pt'
net = SimCLR('resnet18').cuda()
print(net.load_state_dict(torch.load(path)['model']))
transform_eval = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize([0.4914, 0.4822, 0.4465], [0.2023, 0.1994, 0.2010])
])
train_dataset = CIFAR10(root='./dataset', train=True, transform=transform_eval)
train_dataloader = DataLoader(train_dataset, batch_size=256, shuffle=False, num_workers=4)
test_dataset = CIFAR10(root='./dataset', train=False, transform=transform_eval)
test_dataloader = DataLoader(test_dataset, batch_size=256, shuffle=False, num_workers=4)
def feature_extractor(model, loader):
model.eval()
all_fea = []
all_targets = []
for img, target in loader:
img = img.cuda()
fea = model.backbone(img)
all_fea.append(fea.detach().cpu())
all_targets.append(target)
all_fea = torch.cat(all_fea)
all_targets = torch.cat(all_targets)
return all_fea.numpy(), all_targets.numpy()
if tsne_flag:
train_fea, train_targets = feature_extractor(net, train_dataloader)
train_embedded = TSNE(n_components=2).fit_transform(train_fea)
test_fea, test_targets = feature_extractor(net, test_dataloader)
test_embedded = TSNE(n_components=2).fit_transform(test_fea)
np.savez('results/embedding.npz', train_embedded=train_embedded, train_targets=train_targets, test_embedded=test_embedded, test_targets=test_targets)
if plot_flag:
npz = np.load('embedding.npz')
train_embedded = npz['train_embedded']
train_targets = npz['train_targets']
test_embedded = npz['test_embedded']
test_targets = npz['test_targets']
plt.figure(figsize=(16,16))
for i in range(len(np.unique(train_targets))):
plt.scatter(train_embedded[train_targets==i,0], train_embedded[train_targets==i,1], label=i)
plt.title('train')
plt.legend()
plt.savefig('results/train_tsne.png')
plt.figure(figsize=(16,16))
for i in range(len(np.unique(test_targets))):
plt.scatter(test_embedded[test_targets==i,0], test_embedded[test_targets==i,1], label=i)
plt.title('test')
plt.legend()
plt.savefig('results/test_tsne.png')
Loading…
Cancel
Save