diff --git a/examples/simclr_cifar10_data_parallel/NT_Xentloss.py b/examples/simclr_cifar10_data_parallel/NT_Xentloss.py new file mode 100644 index 000000000..5437b2abf --- /dev/null +++ b/examples/simclr_cifar10_data_parallel/NT_Xentloss.py @@ -0,0 +1,45 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +from colossalai.registry import LOSSES +from torch.nn.modules.linear import Linear + +@LOSSES.register_module +class NT_Xentloss(nn.Module): + def __init__(self, temperature=0.5): + super().__init__() + self.temperature = temperature + + def forward(self, z1, z2, label): + z1 = F.normalize(z1, dim=1) + z2 = F.normalize(z2, dim=1) + N, Z = z1.shape + device = z1.device + representations = torch.cat([z1, z2], dim=0) + similarity_matrix = F.cosine_similarity(representations.unsqueeze(1), representations.unsqueeze(0), dim=-1) + l_pos = torch.diag(similarity_matrix, N) + r_pos = torch.diag(similarity_matrix, -N) + positives = torch.cat([l_pos, r_pos]).view(2 * N, 1) + diag = torch.eye(2*N, dtype=torch.bool, device=device) + diag[N:,:N] = diag[:N,N:] = diag[:N,:N] + + negatives = similarity_matrix[~diag].view(2*N, -1) + + logits = torch.cat([positives, negatives], dim=1) + logits /= self.temperature + + labels = torch.zeros(2*N, device=device, dtype=torch.int64) + + loss = F.cross_entropy(logits, labels, reduction='sum') + return loss / (2 * N) + + +if __name__=='__main__': + criterion = NT_Xentloss() + net = Linear(256,512) + output = [net(torch.randn(512,256)), net(torch.randn(512,256))] + label = [torch.randn(512)] + loss = criterion(*output, *label) + print(loss) + + \ No newline at end of file diff --git a/examples/simclr_cifar10_data_parallel/README.md b/examples/simclr_cifar10_data_parallel/README.md new file mode 100644 index 000000000..7b21858bf --- /dev/null +++ b/examples/simclr_cifar10_data_parallel/README.md @@ -0,0 +1,30 @@ +# Overview + +Here is an example of applying [PreAct-ResNet18](https://arxiv.org/abs/1603.05027) to train [SimCLR](https://arxiv.org/abs/2002.05709) on CIFAR10. +SimCLR is a kind of self-supervised representation learning algorithm which learns generic representations of images on an unlabeled dataset. The generic representations are learned by simultaneously maximizing agreement between differently transformed views of the same image and minimizing agreement between transformed views of different images, following a method called contrastive learning. Updating the parameters of a neural network using this contrastive objective causes representations of corresponding views to “attract” each other, while representations of non-corresponding views “repel” each other. A more detailed description of SimCLR is available [here](https://ai.googleblog.com/2020/04/advancing-self-supervised-and-semi.html). + +The training process consists of two phases: (1) self-supervised representation learning: the model which acts as a feature extractor is trained exactly as described above; and (2) linear evaluation: to evaluate how well representations are learned, generally a linear classifier is added on top of the trained feature extractor in phase 1. The linear classifier is trained with a labeled dataset in a conventional supervised manner, while parameters of the feature extractor keep fixed. This process is called linear evaluation. + +# How to run +The training commands are specified in: +```shell +bash train.sh +``` +Before running, you can specify the experiment name (folders with the same name will be created in `ckpt` to save checkpoints and in `tb_logs` to save the tensorboard file) and other training hyperparameters in `config.py`. By default CIFAR10 dataset will be downloaded automatically and saved in `./dataset`. Note that `LOG_NAME` in `le_config.py` should be the same as that in `config.py`. + +Besides linear evaluation, you can also visualize the distribution of learned representations. A script is provided which first extracts representations and then visualizes them with [t-SNE](https://scikit-learn.org/stable/modules/generated/sklearn.manifold.TSNE.html). t-SNE is a good tool to visualize high-dimensional data. It converts similarities between data points to joint probabilities and tries to minimize the Kullback-Leibler divergence between the joint probabilities of the low-dimensional embedding and the high-dimensional data. You can directly run the script by (remember modifying `log_name` and `epoch` to specify the model in which experiment folder and of which training epoch to load): +```python +python visualization.py +``` + +# Results +The loss curve of SimCLR self-supervised training is as follows: +![SimCLR Loss Curve](./results/ssl_loss.png) +The loss curve of linear evaluation is as follows: +![Linear Evaluation Loss Curve](./results/linear_eval_loss.png) +The accuracy curve of linear evaluation is as follows: +![Linear Evaluation Accuracy](./results/linear_eval_acc.png) +The t-SNE of the training set of CIFAR10 is as follows: +![train tSNE](./results/train_tsne.png) +The t-SNE of the test set of CIFAR10 is as follows: +![test tSNE](./results/test_tsne.png) \ No newline at end of file diff --git a/examples/simclr_cifar10_data_parallel/augmentation.py b/examples/simclr_cifar10_data_parallel/augmentation.py new file mode 100644 index 000000000..9636f992c --- /dev/null +++ b/examples/simclr_cifar10_data_parallel/augmentation.py @@ -0,0 +1,32 @@ +from torchvision.transforms import transforms + +class SimCLRTransform(): + def __init__(self): + self.transform = transforms.Compose([ + transforms.RandomResizedCrop(size=32, scale=(0.2, 1.0)), + transforms.RandomHorizontalFlip(), + transforms.RandomApply([transforms.ColorJitter(0.8, 0.8, 0.8, 0.2)], p=0.8), + transforms.RandomGrayscale(p=0.2), + transforms.RandomApply([transforms.GaussianBlur(kernel_size=32//20*2+1, sigma=(0.1, 2.0))], p=0.5), + transforms.ToTensor(), + transforms.Normalize(mean=[0.4914, 0.4822, 0.4465], std=[0.2023, 0.1994, 0.2010]) + ]) + + def __call__(self, x): + x1 = self.transform(x) + x2 = self.transform(x) + return x1, x2 + + +class LeTransform(): + def __init__(self): + self.transform = transforms.Compose([ + transforms.RandomResizedCrop(size=32, scale=(0.2, 1.0)), + transforms.RandomHorizontalFlip(), + transforms.ToTensor(), + transforms.Normalize(mean=[0.4914, 0.4822, 0.4465], std=[0.2023, 0.1994, 0.2010]) + ]) + + def __call__(self, x): + x = self.transform(x) + return x \ No newline at end of file diff --git a/examples/simclr_cifar10_data_parallel/config.py b/examples/simclr_cifar10_data_parallel/config.py new file mode 100755 index 000000000..a4f220859 --- /dev/null +++ b/examples/simclr_cifar10_data_parallel/config.py @@ -0,0 +1,23 @@ +from colossalai.amp import AMP_TYPE + + +LOG_NAME = 'cifar-simclr' + +BATCH_SIZE = 512 +NUM_EPOCHS = 801 +LEARNING_RATE = 0.03*BATCH_SIZE/256 +WEIGHT_DECAY = 0.0005 +MOMENTUM = 0.9 + + +fp16 = dict( + mode=AMP_TYPE.TORCH, +) + +dataset = dict( + root='./dataset', +) + +gradient_accumulation=2 +gradient_clipping=1.0 + diff --git a/examples/simclr_cifar10_data_parallel/le_config.py b/examples/simclr_cifar10_data_parallel/le_config.py new file mode 100755 index 000000000..fc3a0ed92 --- /dev/null +++ b/examples/simclr_cifar10_data_parallel/le_config.py @@ -0,0 +1,23 @@ +from colossalai.amp import AMP_TYPE + + +LOG_NAME = 'cifar-simclr' +EPOCH = 800 + +BATCH_SIZE = 512 +NUM_EPOCHS = 51 +LEARNING_RATE = 0.03*BATCH_SIZE/256 +WEIGHT_DECAY = 0.0005 +MOMENTUM = 0.9 + + +fp16 = dict( + mode=AMP_TYPE.TORCH, +) + +dataset = dict( + root='./dataset', +) + +gradient_accumulation=1 +gradient_clipping=1.0 diff --git a/examples/simclr_cifar10_data_parallel/models/Backbone.py b/examples/simclr_cifar10_data_parallel/models/Backbone.py new file mode 100644 index 000000000..2160d5cde --- /dev/null +++ b/examples/simclr_cifar10_data_parallel/models/Backbone.py @@ -0,0 +1,178 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +from torchvision.models import resnet34, resnet50, resnet101, resnet152 + + +def backbone(model, **kwargs): + assert model in ['resnet18', 'resnet34', 'resnet50', 'resnet101', 'resnet152'], "current version only support resnet18 ~ resnet152" + if model == 'resnet18': + net = ResNet(PreActBlock, [2,2,2,2], **kwargs) + else: + net = eval(f"{model}(**kwargs)") + net.output_dim = net.fc.in_features + net.fc = torch.nn.Identity() + return net + + +def conv3x3(in_planes, out_planes, stride=1): + return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False) + + +class BasicBlock(nn.Module): + expansion = 1 + + def __init__(self, in_planes, planes, stride=1): + super(BasicBlock, self).__init__() + self.conv1 = conv3x3(in_planes, planes, stride) + self.bn1 = nn.BatchNorm2d(planes) + self.conv2 = conv3x3(planes, planes) + self.bn2 = nn.BatchNorm2d(planes) + + self.shortcut = nn.Sequential() + if stride != 1 or in_planes != self.expansion*planes: + self.shortcut = nn.Sequential( + nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False), + nn.BatchNorm2d(self.expansion*planes) + ) + + def forward(self, x): + out = F.relu(self.bn1(self.conv1(x))) + out = self.bn2(self.conv2(out)) + out += self.shortcut(x) + out = F.relu(out) + return out + + +class PreActBlock(nn.Module): + '''Pre-activation version of the BasicBlock.''' + expansion = 1 + + def __init__(self, in_planes, planes, stride=1): + super(PreActBlock, self).__init__() + self.bn1 = nn.BatchNorm2d(in_planes) + self.conv1 = conv3x3(in_planes, planes, stride) + self.bn2 = nn.BatchNorm2d(planes) + self.conv2 = conv3x3(planes, planes) + + self.shortcut = nn.Sequential() + if stride != 1 or in_planes != self.expansion*planes: + self.shortcut = nn.Sequential( + nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False) + ) + + def forward(self, x): + out = F.relu(self.bn1(x)) + shortcut = self.shortcut(out) + out = self.conv1(out) + out = self.conv2(F.relu(self.bn2(out))) + out += shortcut + return out + + +class Bottleneck(nn.Module): + expansion = 4 + + def __init__(self, in_planes, planes, stride=1): + super(Bottleneck, self).__init__() + self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False) + self.bn1 = nn.BatchNorm2d(planes) + self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) + self.bn2 = nn.BatchNorm2d(planes) + self.conv3 = nn.Conv2d(planes, self.expansion*planes, kernel_size=1, bias=False) + self.bn3 = nn.BatchNorm2d(self.expansion*planes) + + self.shortcut = nn.Sequential() + if stride != 1 or in_planes != self.expansion*planes: + self.shortcut = nn.Sequential( + nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False), + nn.BatchNorm2d(self.expansion*planes) + ) + + def forward(self, x): + out = F.relu(self.bn1(self.conv1(x))) + out = F.relu(self.bn2(self.conv2(out))) + out = self.bn3(self.conv3(out)) + out += self.shortcut(x) + out = F.relu(out) + return out + + +class PreActBottleneck(nn.Module): + '''Pre-activation version of the original Bottleneck module.''' + expansion = 4 + + def __init__(self, in_planes, planes, stride=1): + super(PreActBottleneck, self).__init__() + self.bn1 = nn.BatchNorm2d(in_planes) + self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False) + self.bn2 = nn.BatchNorm2d(planes) + self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) + self.bn3 = nn.BatchNorm2d(planes) + self.conv3 = nn.Conv2d(planes, self.expansion*planes, kernel_size=1, bias=False) + + self.shortcut = nn.Sequential() + if stride != 1 or in_planes != self.expansion*planes: + self.shortcut = nn.Sequential( + nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False) + ) + + def forward(self, x): + out = F.relu(self.bn1(x)) + shortcut = self.shortcut(out) + out = self.conv1(out) + out = self.conv2(F.relu(self.bn2(out))) + out = self.conv3(F.relu(self.bn3(out))) + out += shortcut + return out + + +class ResNet(nn.Module): + def __init__(self, block, num_blocks, num_classes=10): + super(ResNet, self).__init__() + self.in_planes = 64 + + self.conv1 = conv3x3(3,64) + self.bn1 = nn.BatchNorm2d(64) + self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1) + self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2) + self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2) + self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2) + self.fc = nn.Linear(512*block.expansion, num_classes) + + def _make_layer(self, block, planes, num_blocks, stride): + strides = [stride] + [1]*(num_blocks-1) + layers = [] + for stride in strides: + layers.append(block(self.in_planes, planes, stride)) + self.in_planes = planes * block.expansion + return nn.Sequential(*layers) + + def forward(self, x, lin=0, lout=5): + out = x + if lin < 1 and lout > -1: + out = self.conv1(out) + out = self.bn1(out) + out = F.relu(out) + if lin < 2 and lout > 0: + out = self.layer1(out) + if lin < 3 and lout > 1: + out = self.layer2(out) + if lin < 4 and lout > 2: + out = self.layer3(out) + if lin < 5 and lout > 3: + out = self.layer4(out) + if lout > 4: + out = F.avg_pool2d(out, 4) + out = out.view(out.size(0), -1) + out = self.fc(out) + return out + +def debug(): + net = backbone('resnet18', pretrained=True) + x = torch.randn(4,3,32,32) + y = net(x) + print(y.size()) + +if __name__ == '__main__': + debug() \ No newline at end of file diff --git a/examples/simclr_cifar10_data_parallel/models/linear_eval.py b/examples/simclr_cifar10_data_parallel/models/linear_eval.py new file mode 100644 index 000000000..588476cca --- /dev/null +++ b/examples/simclr_cifar10_data_parallel/models/linear_eval.py @@ -0,0 +1,19 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +from .Backbone import backbone + +class Linear_eval(nn.Module): + + def __init__(self, model='resnet18', class_num=10, **kwargs): + super().__init__() + + self.backbone = backbone(model, **kwargs) + self.backbone.requires_grad_(False) + self.fc = nn.Linear(self.backbone.output_dim, class_num) + + def forward(self, x): + + out = self.backbone(x) + out = self.fc(out) + return out diff --git a/examples/simclr_cifar10_data_parallel/models/simclr.py b/examples/simclr_cifar10_data_parallel/models/simclr.py new file mode 100644 index 000000000..27c4d47a5 --- /dev/null +++ b/examples/simclr_cifar10_data_parallel/models/simclr.py @@ -0,0 +1,36 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +from .Backbone import backbone + +class projection_MLP(nn.Module): + def __init__(self, in_dim, out_dim=256): + super().__init__() + hidden_dim = in_dim + self.layer1 = nn.Sequential( + nn.Linear(in_dim, hidden_dim), + nn.ReLU(inplace=True) + ) + self.layer2 = nn.Linear(hidden_dim, out_dim) + def forward(self, x): + x = self.layer1(x) + x = self.layer2(x) + return x + +class SimCLR(nn.Module): + + def __init__(self, model='resnet18', **kwargs): + super().__init__() + + self.backbone = backbone(model, **kwargs) + self.projector = projection_MLP(self.backbone.output_dim) + self.encoder = nn.Sequential( + self.backbone, + self.projector + ) + + def forward(self, x1, x2): + + z1 = self.encoder(x1) + z2 = self.encoder(x2) + return z1, z2 \ No newline at end of file diff --git a/examples/simclr_cifar10_data_parallel/myhooks.py b/examples/simclr_cifar10_data_parallel/myhooks.py new file mode 100644 index 000000000..002e466d3 --- /dev/null +++ b/examples/simclr_cifar10_data_parallel/myhooks.py @@ -0,0 +1,15 @@ +from colossalai.trainer.hooks import BaseHook +from colossalai.core import global_context as gpc +from colossalai.context import ParallelMode +from colossalai.logging import get_dist_logger + + +class TotalBatchsizeHook(BaseHook): + def __init__(self, priority: int = 2) -> None: + super().__init__(priority) + self.logger = get_dist_logger() + + def before_train(self, trainer): + total_batch_size = gpc.config.BATCH_SIZE * \ + gpc.config.gradient_accumulation * gpc.get_world_size(ParallelMode.DATA) + self.logger.info(f'Total batch size = {total_batch_size}', ranks=[0]) \ No newline at end of file diff --git a/examples/simclr_cifar10_data_parallel/results/embedding.npz b/examples/simclr_cifar10_data_parallel/results/embedding.npz new file mode 100644 index 000000000..160a8e191 Binary files /dev/null and b/examples/simclr_cifar10_data_parallel/results/embedding.npz differ diff --git a/examples/simclr_cifar10_data_parallel/results/linear_eval_acc.png b/examples/simclr_cifar10_data_parallel/results/linear_eval_acc.png new file mode 100644 index 000000000..864dbb9dc Binary files /dev/null and b/examples/simclr_cifar10_data_parallel/results/linear_eval_acc.png differ diff --git a/examples/simclr_cifar10_data_parallel/results/linear_eval_loss.png b/examples/simclr_cifar10_data_parallel/results/linear_eval_loss.png new file mode 100644 index 000000000..21885f3e6 Binary files /dev/null and b/examples/simclr_cifar10_data_parallel/results/linear_eval_loss.png differ diff --git a/examples/simclr_cifar10_data_parallel/results/ssl_loss.png b/examples/simclr_cifar10_data_parallel/results/ssl_loss.png new file mode 100644 index 000000000..e5ce0a479 Binary files /dev/null and b/examples/simclr_cifar10_data_parallel/results/ssl_loss.png differ diff --git a/examples/simclr_cifar10_data_parallel/results/test_tsne.png b/examples/simclr_cifar10_data_parallel/results/test_tsne.png new file mode 100644 index 000000000..75af0860b Binary files /dev/null and b/examples/simclr_cifar10_data_parallel/results/test_tsne.png differ diff --git a/examples/simclr_cifar10_data_parallel/results/train_tsne.png b/examples/simclr_cifar10_data_parallel/results/train_tsne.png new file mode 100644 index 000000000..b229e7186 Binary files /dev/null and b/examples/simclr_cifar10_data_parallel/results/train_tsne.png differ diff --git a/examples/simclr_cifar10_data_parallel/train.sh b/examples/simclr_cifar10_data_parallel/train.sh new file mode 100644 index 000000000..bb0a0dfd8 --- /dev/null +++ b/examples/simclr_cifar10_data_parallel/train.sh @@ -0,0 +1,7 @@ +#!/usr/bin/env sh + +## phase 1: self-supervised training +python -m torch.distributed.launch --nproc_per_node 1 train_simclr.py + +## phase 2: linear evaluation +python -m torch.distributed.launch --nproc_per_node 1 train_linear.py \ No newline at end of file diff --git a/examples/simclr_cifar10_data_parallel/train_linear.py b/examples/simclr_cifar10_data_parallel/train_linear.py new file mode 100644 index 000000000..92eb0cc6d --- /dev/null +++ b/examples/simclr_cifar10_data_parallel/train_linear.py @@ -0,0 +1,106 @@ +import torch +import colossalai +from colossalai.core import global_context as gpc +from colossalai.logging import get_dist_logger +from colossalai.trainer import Trainer, hooks +from colossalai.utils import get_dataloader, MultiTimer +from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR + +from torchvision.datasets import CIFAR10 +from myhooks import TotalBatchsizeHook +from models.linear_eval import Linear_eval +from augmentation import LeTransform + +def build_dataset_train(): + augment = LeTransform() + train_dataset = CIFAR10(root=gpc.config.dataset.root, + transform=augment, + train=True) + + return get_dataloader( + dataset=train_dataset, + shuffle=True, + num_workers = 1, + batch_size=gpc.config.BATCH_SIZE, + pin_memory=True, + ) + +def build_dataset_test(): + augment = LeTransform() + val_dataset = CIFAR10(root=gpc.config.dataset.root, + transform=augment, + train=False) + + return get_dataloader( + dataset=val_dataset, + add_sampler=False, + num_workers = 1, + batch_size=gpc.config.BATCH_SIZE, + pin_memory=True, + ) + +def main(): + colossalai.launch_from_torch(config='./le_config.py', + host='localhost', + port=29500) + + # get logger + logger = get_dist_logger() + + ## build model + model = Linear_eval(model='resnet18', class_num=10) + + # build dataloader + train_dataloader = build_dataset_train() + test_dataloader = build_dataset_test() + + # build loss + criterion = torch.nn.CrossEntropyLoss() + + # build optimizer + optimizer = colossalai.nn.FusedSGD(model.parameters(), lr=gpc.config.LEARNING_RATE, weight_decay=gpc.config.WEIGHT_DECAY, momentum=gpc.config.MOMENTUM) + + # lr_scheduelr + lr_scheduler = CosineAnnealingWarmupLR(optimizer, warmup_steps=5, total_steps=gpc.config.NUM_EPOCHS) + + engine, train_dataloader, test_dataloader, _ = colossalai.initialize( + model, optimizer, criterion, train_dataloader, test_dataloader + ) + logger.info("initialized colossalai components", ranks=[0]) + + ## Load trained self-supervised SimCLR model + engine.model.load_state_dict(torch.load(f'./ckpt/{gpc.config.LOG_NAME}/epoch{gpc.config.EPOCH}-tp0-pp0.pt')['model'], strict=False) + logger.info("pretrained model loaded", ranks=[0]) + + # build a timer to measure time + timer = MultiTimer() + + # build trainer + trainer = Trainer(engine=engine, logger=logger, timer=timer) + + # build hooks + hook_list = [ + hooks.LossHook(), + hooks.AccuracyHook(), + hooks.LogMetricByEpochHook(logger), + hooks.LRSchedulerHook(lr_scheduler, by_epoch=True), + TotalBatchsizeHook(), + + # comment if you do not need to use the hooks below + hooks.SaveCheckpointHook(interval=5, checkpoint_dir=f'./ckpt/{gpc.config.LOG_NAME}-eval'), + hooks.TensorboardHook(log_dir=f'./tb_logs/{gpc.config.LOG_NAME}-eval', ranks=[0]), + ] + + # start training + trainer.fit( + train_dataloader=train_dataloader, + test_dataloader=test_dataloader, + epochs=gpc.config.NUM_EPOCHS, + hooks=hook_list, + display_progress=True, + test_interval=1 + ) + + +if __name__ == '__main__': + main() \ No newline at end of file diff --git a/examples/simclr_cifar10_data_parallel/train_simclr.py b/examples/simclr_cifar10_data_parallel/train_simclr.py new file mode 100644 index 000000000..1ab504c7e --- /dev/null +++ b/examples/simclr_cifar10_data_parallel/train_simclr.py @@ -0,0 +1,102 @@ +import colossalai +from colossalai.core import global_context as gpc +from colossalai.logging import get_dist_logger +from colossalai.trainer import Trainer, hooks +from colossalai.utils import get_dataloader, MultiTimer +from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR + +from torchvision.datasets import CIFAR10 +from NT_Xentloss import NT_Xentloss +from myhooks import TotalBatchsizeHook +from models.simclr import SimCLR +from augmentation import SimCLRTransform + +def build_dataset_train(): + augment = SimCLRTransform() + train_dataset = CIFAR10(root=gpc.config.dataset.root, + transform=augment, + train=True, + download=True) + + return get_dataloader( + dataset=train_dataset, + shuffle=True, + num_workers = 1, + batch_size=gpc.config.BATCH_SIZE, + pin_memory=True, + ) + +def build_dataset_test(): + augment = SimCLRTransform() + val_dataset = CIFAR10(root=gpc.config.dataset.root, + transform=augment, + train=False) + + return get_dataloader( + dataset=val_dataset, + add_sampler=False, + num_workers = 1, + batch_size=gpc.config.BATCH_SIZE, + pin_memory=True, + ) + +def main(): + colossalai.launch_from_torch(config='./config.py', + host='localhost', + port=29500) + + # get logger + logger = get_dist_logger() + + ## build model + model = SimCLR(model='resnet18') + + # build dataloader + train_dataloader = build_dataset_train() + test_dataloader = build_dataset_test() + + # build loss + criterion = NT_Xentloss() + + # build optimizer + optimizer = colossalai.nn.FusedSGD(model.parameters(), lr=gpc.config.LEARNING_RATE, weight_decay=gpc.config.WEIGHT_DECAY, momentum=gpc.config.MOMENTUM) + + # lr_scheduelr + lr_scheduler = CosineAnnealingWarmupLR(optimizer, warmup_steps=10, total_steps=gpc.config.NUM_EPOCHS) + + engine, train_dataloader, test_dataloader, _ = colossalai.initialize( + model, optimizer, criterion, train_dataloader, test_dataloader + ) + logger.info("initialized colossalai components", ranks=[0]) + + # build a timer to measure time + timer = MultiTimer() + + # build trainer + trainer = Trainer(engine=engine, logger=logger, timer=timer) + + # build hooks + hook_list = [ + hooks.LossHook(), + hooks.LogMetricByEpochHook(logger), + hooks.LRSchedulerHook(lr_scheduler, by_epoch=True), + TotalBatchsizeHook(), + + # comment if you do not need to use the hooks below + hooks.SaveCheckpointHook(interval=50, checkpoint_dir=f'./ckpt/{gpc.config.LOG_NAME}'), + hooks.TensorboardHook(log_dir=f'./tb_logs/{gpc.config.LOG_NAME}', ranks=[0]), + ] + + # start training + trainer.fit( + train_dataloader=train_dataloader, + test_dataloader=test_dataloader, + epochs=gpc.config.NUM_EPOCHS, + hooks=hook_list, + display_progress=True, + test_interval=1 + ) + + +if __name__ == '__main__': + main() diff --git a/examples/simclr_cifar10_data_parallel/visualization.py b/examples/simclr_cifar10_data_parallel/visualization.py new file mode 100644 index 000000000..f449b1f7c --- /dev/null +++ b/examples/simclr_cifar10_data_parallel/visualization.py @@ -0,0 +1,72 @@ +import torch +import numpy as np +from sklearn.manifold import TSNE +import matplotlib.pyplot as plt +from models.simclr import SimCLR +from torchvision.datasets import CIFAR10 +from torch.utils.data import DataLoader +from torchvision import transforms + +log_name = 'cifar-simclr' +epoch = 800 + +fea_flag = True +tsne_flag = True +plot_flag = True + +if fea_flag: + path = f'ckpt/{log_name}/epoch{epoch}-tp0-pp0.pt' + net = SimCLR('resnet18').cuda() + print(net.load_state_dict(torch.load(path)['model'])) + + transform_eval = transforms.Compose([ + transforms.ToTensor(), + transforms.Normalize([0.4914, 0.4822, 0.4465], [0.2023, 0.1994, 0.2010]) + ]) + + train_dataset = CIFAR10(root='./dataset', train=True, transform=transform_eval) + train_dataloader = DataLoader(train_dataset, batch_size=256, shuffle=False, num_workers=4) + + test_dataset = CIFAR10(root='./dataset', train=False, transform=transform_eval) + test_dataloader = DataLoader(test_dataset, batch_size=256, shuffle=False, num_workers=4) + +def feature_extractor(model, loader): + model.eval() + all_fea = [] + all_targets = [] + for img, target in loader: + img = img.cuda() + fea = model.backbone(img) + all_fea.append(fea.detach().cpu()) + all_targets.append(target) + all_fea = torch.cat(all_fea) + all_targets = torch.cat(all_targets) + return all_fea.numpy(), all_targets.numpy() + +if tsne_flag: + train_fea, train_targets = feature_extractor(net, train_dataloader) + train_embedded = TSNE(n_components=2).fit_transform(train_fea) + test_fea, test_targets = feature_extractor(net, test_dataloader) + test_embedded = TSNE(n_components=2).fit_transform(test_fea) + np.savez('results/embedding.npz', train_embedded=train_embedded, train_targets=train_targets, test_embedded=test_embedded, test_targets=test_targets) + +if plot_flag: + npz = np.load('embedding.npz') + train_embedded = npz['train_embedded'] + train_targets = npz['train_targets'] + test_embedded = npz['test_embedded'] + test_targets = npz['test_targets'] + + plt.figure(figsize=(16,16)) + for i in range(len(np.unique(train_targets))): + plt.scatter(train_embedded[train_targets==i,0], train_embedded[train_targets==i,1], label=i) + plt.title('train') + plt.legend() + plt.savefig('results/train_tsne.png') + + plt.figure(figsize=(16,16)) + for i in range(len(np.unique(test_targets))): + plt.scatter(test_embedded[test_targets==i,0], test_embedded[test_targets==i,1], label=i) + plt.title('test') + plt.legend() + plt.savefig('results/test_tsne.png')