Set examples as submodule (#162)
* remove examples folder * added examples as submodule * update .gitmodulespull/165/head
|
@ -2,3 +2,7 @@
|
|||
path = benchmark
|
||||
url = https://github.com/hpcaitech/ColossalAI-Benchmark.git
|
||||
branch = main
|
||||
[submodule "examples"]
|
||||
path = examples
|
||||
url = https://github.com/FrankLeeeee/ColossalAI-Examples.git
|
||||
branch = main
|
||||
|
|
|
@ -0,0 +1 @@
|
|||
Subproject commit 217ac4600172ddbc020596587a0fe1af5e1287e8
|
|
@ -1,50 +0,0 @@
|
|||
# Train ResNet34 on CIFAR10
|
||||
|
||||
## Prepare Dataset
|
||||
|
||||
In the script, we used CIFAR10 dataset provided by the `torchvision` library. The code snippet is shown below:
|
||||
|
||||
```python
|
||||
train_dataset = CIFAR10(
|
||||
root=Path(os.environ['DATA']),
|
||||
download=True,
|
||||
transform=transforms.Compose(
|
||||
[
|
||||
transforms.RandomCrop(size=32, padding=4),
|
||||
transforms.RandomHorizontalFlip(),
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize(mean=[0.4914, 0.4822, 0.4465], std=[
|
||||
0.2023, 0.1994, 0.2010]),
|
||||
]
|
||||
)
|
||||
)
|
||||
```
|
||||
|
||||
Firstly, you need to specify where you want to store your CIFAR10 dataset by setting the environment variable `DATA`.
|
||||
|
||||
```bash
|
||||
export DATA=/path/to/data
|
||||
|
||||
# example
|
||||
# this will store the data in the current directory
|
||||
export DATA=$PWD/data
|
||||
```
|
||||
|
||||
The `torchvison` module will download the data automatically for you into the specified directory.
|
||||
|
||||
|
||||
## Run training
|
||||
|
||||
We provide two examples of training resnet 34 on the CIFAR10 dataset. One example is with engine and the other is
|
||||
with the trainer. You can invoke the training script by the following command. This batch size and learning rate
|
||||
are for a single GPU. Thus, in the following command, `nproc_per_node` is 1, which means there is only one process
|
||||
invoked. If you change `nproc_per_node`, you will have to change the learning rate accordingly as the global batch
|
||||
size has changed.
|
||||
|
||||
```bash
|
||||
# with engine
|
||||
python -m torch.distributed.launch --nproc_per_node 1 run_resnet_cifar10_with_engine.py
|
||||
|
||||
# with trainer
|
||||
python -m torch.distributed.launch --nproc_per_node 1 run_resnet_cifar10_with_trainer.py
|
||||
```
|
|
@ -1,10 +0,0 @@
|
|||
from colossalai.amp import AMP_TYPE
|
||||
|
||||
BATCH_SIZE = 128
|
||||
NUM_EPOCHS = 200
|
||||
|
||||
CONFIG = dict(
|
||||
fp16=dict(
|
||||
mode=AMP_TYPE.TORCH
|
||||
)
|
||||
)
|
|
@ -1,116 +0,0 @@
|
|||
from pathlib import Path
|
||||
from colossalai.logging import get_dist_logger
|
||||
import colossalai
|
||||
import torch
|
||||
import os
|
||||
from colossalai.core import global_context as gpc
|
||||
from colossalai.utils import get_dataloader
|
||||
from torchvision import transforms
|
||||
from colossalai.nn.lr_scheduler import CosineAnnealingLR
|
||||
from torchvision.datasets import CIFAR10
|
||||
from torchvision.models import resnet34
|
||||
from tqdm import tqdm
|
||||
|
||||
|
||||
def main():
|
||||
colossalai.launch_from_torch(config='./config.py')
|
||||
|
||||
logger = get_dist_logger()
|
||||
|
||||
# build resnet
|
||||
model = resnet34(num_classes=10)
|
||||
|
||||
# build dataloaders
|
||||
train_dataset = CIFAR10(
|
||||
root=Path(os.environ['DATA']),
|
||||
download=True,
|
||||
transform=transforms.Compose(
|
||||
[
|
||||
transforms.RandomCrop(size=32, padding=4),
|
||||
transforms.RandomHorizontalFlip(),
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize(mean=[0.4914, 0.4822, 0.4465], std=[
|
||||
0.2023, 0.1994, 0.2010]),
|
||||
]
|
||||
)
|
||||
)
|
||||
|
||||
test_dataset = CIFAR10(
|
||||
root=Path(os.environ['DATA']),
|
||||
train=False,
|
||||
transform=transforms.Compose(
|
||||
[
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize(mean=[0.4914, 0.4822, 0.4465], std=[
|
||||
0.2023, 0.1994, 0.2010]),
|
||||
]
|
||||
)
|
||||
)
|
||||
|
||||
train_dataloader = get_dataloader(dataset=train_dataset,
|
||||
shuffle=True,
|
||||
batch_size=gpc.config.BATCH_SIZE,
|
||||
num_workers=1,
|
||||
pin_memory=True,
|
||||
)
|
||||
|
||||
test_dataloader = get_dataloader(dataset=test_dataset,
|
||||
add_sampler=False,
|
||||
batch_size=gpc.config.BATCH_SIZE,
|
||||
num_workers=1,
|
||||
pin_memory=True,
|
||||
)
|
||||
|
||||
# build criterion
|
||||
criterion = torch.nn.CrossEntropyLoss()
|
||||
|
||||
# optimizer
|
||||
optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9, weight_decay=5e-4)
|
||||
|
||||
# lr_scheduler
|
||||
lr_scheduler = CosineAnnealingLR(optimizer, total_steps=gpc.config.NUM_EPOCHS)
|
||||
|
||||
engine, train_dataloader, test_dataloader, _ = colossalai.initialize(model,
|
||||
optimizer,
|
||||
criterion,
|
||||
train_dataloader,
|
||||
test_dataloader,
|
||||
)
|
||||
|
||||
for epoch in range(gpc.config.NUM_EPOCHS):
|
||||
engine.train()
|
||||
if gpc.get_global_rank() == 0:
|
||||
train_dl = tqdm(train_dataloader)
|
||||
else:
|
||||
train_dl = train_dataloader
|
||||
for img, label in train_dl:
|
||||
img = img.cuda()
|
||||
label = label.cuda()
|
||||
|
||||
engine.zero_grad()
|
||||
output = engine(img)
|
||||
train_loss = engine.criterion(output, label)
|
||||
engine.backward(train_loss)
|
||||
engine.step()
|
||||
lr_scheduler.step()
|
||||
|
||||
engine.eval()
|
||||
correct = 0
|
||||
total = 0
|
||||
for img, label in test_dataloader:
|
||||
img = img.cuda()
|
||||
label = label.cuda()
|
||||
|
||||
with torch.no_grad():
|
||||
output = engine(img)
|
||||
test_loss = engine.criterion(output, label)
|
||||
pred = torch.argmax(output, dim=-1)
|
||||
correct += torch.sum(pred == label)
|
||||
total += img.size(0)
|
||||
|
||||
logger.info(
|
||||
f"Epoch {epoch} - train loss: {train_loss:.5}, test loss: {test_loss:.5}, acc: {correct / total:.5}, lr: {lr_scheduler.get_last_lr()[0]:.5g}", ranks=[0])
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
|
@ -1,118 +0,0 @@
|
|||
import os
|
||||
from pathlib import Path
|
||||
|
||||
import colossalai
|
||||
import torch
|
||||
from colossalai.core import global_context as gpc
|
||||
from colossalai.logging import get_dist_logger
|
||||
from colossalai.nn import CosineAnnealingLR
|
||||
from colossalai.nn.metric import Accuracy
|
||||
from colossalai.trainer import Trainer, hooks
|
||||
from colossalai.utils import MultiTimer, get_dataloader
|
||||
from torchvision import transforms
|
||||
from torchvision.datasets import CIFAR10
|
||||
from torchvision.models import resnet34
|
||||
from tqdm import tqdm
|
||||
|
||||
|
||||
def main():
|
||||
colossalai.launch_from_torch(config='./config.py')
|
||||
|
||||
logger = get_dist_logger()
|
||||
|
||||
# build resnet
|
||||
model = resnet34(num_classes=10)
|
||||
|
||||
# build dataloaders
|
||||
train_dataset = CIFAR10(
|
||||
root=Path(os.environ['DATA']),
|
||||
download=True,
|
||||
transform=transforms.Compose(
|
||||
[
|
||||
transforms.RandomCrop(size=32, padding=4),
|
||||
transforms.RandomHorizontalFlip(),
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize(mean=[0.4914, 0.4822, 0.4465], std=[
|
||||
0.2023, 0.1994, 0.2010]),
|
||||
]
|
||||
)
|
||||
)
|
||||
|
||||
test_dataset = CIFAR10(
|
||||
root=Path(os.environ['DATA']),
|
||||
train=False,
|
||||
transform=transforms.Compose(
|
||||
[
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize(mean=[0.4914, 0.4822, 0.4465], std=[
|
||||
0.2023, 0.1994, 0.2010]),
|
||||
]
|
||||
)
|
||||
)
|
||||
|
||||
train_dataloader = get_dataloader(dataset=train_dataset,
|
||||
shuffle=True,
|
||||
batch_size=gpc.config.BATCH_SIZE,
|
||||
num_workers=1,
|
||||
pin_memory=True,
|
||||
)
|
||||
|
||||
test_dataloader = get_dataloader(dataset=test_dataset,
|
||||
add_sampler=False,
|
||||
batch_size=gpc.config.BATCH_SIZE,
|
||||
num_workers=1,
|
||||
pin_memory=True,
|
||||
)
|
||||
|
||||
# build criterion
|
||||
criterion = torch.nn.CrossEntropyLoss()
|
||||
|
||||
# optimizer
|
||||
optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9, weight_decay=5e-4)
|
||||
|
||||
# lr_scheduler
|
||||
lr_scheduler = CosineAnnealingLR(optimizer, total_steps=gpc.config.NUM_EPOCHS)
|
||||
|
||||
engine, train_dataloader, test_dataloader, _ = colossalai.initialize(model,
|
||||
optimizer,
|
||||
criterion,
|
||||
train_dataloader,
|
||||
test_dataloader,
|
||||
)
|
||||
# build a timer to measure time
|
||||
timer = MultiTimer()
|
||||
|
||||
# create a trainer object
|
||||
trainer = Trainer(
|
||||
engine=engine,
|
||||
timer=timer,
|
||||
logger=logger
|
||||
)
|
||||
|
||||
# define the hooks to attach to the trainer
|
||||
hook_list = [
|
||||
hooks.LossHook(),
|
||||
hooks.LRSchedulerHook(lr_scheduler=lr_scheduler, by_epoch=True),
|
||||
hooks.AccuracyHook(accuracy_func=Accuracy()),
|
||||
hooks.LogMetricByEpochHook(logger),
|
||||
hooks.LogMemoryByEpochHook(logger),
|
||||
hooks.LogTimingByEpochHook(timer, logger),
|
||||
|
||||
# you can uncomment these lines if you wish to use them
|
||||
# hooks.TensorboardHook(log_dir='./tb_logs', ranks=[0]),
|
||||
# hooks.SaveCheckpointHook(checkpoint_dir='./ckpt')
|
||||
]
|
||||
|
||||
# start training
|
||||
trainer.fit(
|
||||
train_dataloader=train_dataloader,
|
||||
epochs=gpc.config.NUM_EPOCHS,
|
||||
test_dataloader=test_dataloader,
|
||||
test_interval=1,
|
||||
hooks=hook_list,
|
||||
display_progress=True
|
||||
)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
|
@ -1,45 +0,0 @@
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from colossalai.registry import LOSSES
|
||||
from torch.nn.modules.linear import Linear
|
||||
|
||||
@LOSSES.register_module
|
||||
class NT_Xentloss(nn.Module):
|
||||
def __init__(self, temperature=0.5):
|
||||
super().__init__()
|
||||
self.temperature = temperature
|
||||
|
||||
def forward(self, z1, z2, label):
|
||||
z1 = F.normalize(z1, dim=1)
|
||||
z2 = F.normalize(z2, dim=1)
|
||||
N, Z = z1.shape
|
||||
device = z1.device
|
||||
representations = torch.cat([z1, z2], dim=0)
|
||||
similarity_matrix = F.cosine_similarity(representations.unsqueeze(1), representations.unsqueeze(0), dim=-1)
|
||||
l_pos = torch.diag(similarity_matrix, N)
|
||||
r_pos = torch.diag(similarity_matrix, -N)
|
||||
positives = torch.cat([l_pos, r_pos]).view(2 * N, 1)
|
||||
diag = torch.eye(2*N, dtype=torch.bool, device=device)
|
||||
diag[N:,:N] = diag[:N,N:] = diag[:N,:N]
|
||||
|
||||
negatives = similarity_matrix[~diag].view(2*N, -1)
|
||||
|
||||
logits = torch.cat([positives, negatives], dim=1)
|
||||
logits /= self.temperature
|
||||
|
||||
labels = torch.zeros(2*N, device=device, dtype=torch.int64)
|
||||
|
||||
loss = F.cross_entropy(logits, labels, reduction='sum')
|
||||
return loss / (2 * N)
|
||||
|
||||
|
||||
if __name__=='__main__':
|
||||
criterion = NT_Xentloss()
|
||||
net = Linear(256,512)
|
||||
output = [net(torch.randn(512,256)), net(torch.randn(512,256))]
|
||||
label = [torch.randn(512)]
|
||||
loss = criterion(*output, *label)
|
||||
print(loss)
|
||||
|
||||
|
|
@ -1,30 +0,0 @@
|
|||
# Overview
|
||||
|
||||
Here is an example of applying [PreAct-ResNet18](https://arxiv.org/abs/1603.05027) to train [SimCLR](https://arxiv.org/abs/2002.05709) on CIFAR10.
|
||||
SimCLR is a kind of self-supervised representation learning algorithm which learns generic representations of images on an unlabeled dataset. The generic representations are learned by simultaneously maximizing agreement between differently transformed views of the same image and minimizing agreement between transformed views of different images, following a method called contrastive learning. Updating the parameters of a neural network using this contrastive objective causes representations of corresponding views to “attract” each other, while representations of non-corresponding views “repel” each other. A more detailed description of SimCLR is available [here](https://ai.googleblog.com/2020/04/advancing-self-supervised-and-semi.html).
|
||||
|
||||
The training process consists of two phases: (1) self-supervised representation learning: the model which acts as a feature extractor is trained exactly as described above; and (2) linear evaluation: to evaluate how well representations are learned, generally a linear classifier is added on top of the trained feature extractor in phase 1. The linear classifier is trained with a labeled dataset in a conventional supervised manner, while parameters of the feature extractor keep fixed. This process is called linear evaluation.
|
||||
|
||||
# How to run
|
||||
The training commands are specified in:
|
||||
```shell
|
||||
bash train.sh
|
||||
```
|
||||
Before running, you can specify the experiment name (folders with the same name will be created in `ckpt` to save checkpoints and in `tb_logs` to save the tensorboard file) and other training hyperparameters in `config.py`. By default CIFAR10 dataset will be downloaded automatically and saved in `./dataset`. Note that `LOG_NAME` in `le_config.py` should be the same as that in `config.py`.
|
||||
|
||||
Besides linear evaluation, you can also visualize the distribution of learned representations. A script is provided which first extracts representations and then visualizes them with [t-SNE](https://scikit-learn.org/stable/modules/generated/sklearn.manifold.TSNE.html). t-SNE is a good tool to visualize high-dimensional data. It converts similarities between data points to joint probabilities and tries to minimize the Kullback-Leibler divergence between the joint probabilities of the low-dimensional embedding and the high-dimensional data. You can directly run the script by (remember modifying `log_name` and `epoch` to specify the model in which experiment folder and of which training epoch to load):
|
||||
```python
|
||||
python visualization.py
|
||||
```
|
||||
|
||||
# Results
|
||||
The loss curve of SimCLR self-supervised training is as follows:
|
||||
data:image/s3,"s3://crabby-images/cd6ae/cd6ae89566b551353b9749a6bb508138e97f34e7" alt="SimCLR Loss Curve"
|
||||
The loss curve of linear evaluation is as follows:
|
||||
data:image/s3,"s3://crabby-images/194a7/194a7111a2e207dbfff1c3f78c67f35752ab4e8c" alt="Linear Evaluation Loss Curve"
|
||||
The accuracy curve of linear evaluation is as follows:
|
||||
data:image/s3,"s3://crabby-images/38251/38251dd84b09f8d486ccb2b96662a00ddcd6bee4" alt="Linear Evaluation Accuracy"
|
||||
The t-SNE of the training set of CIFAR10 is as follows:
|
||||
data:image/s3,"s3://crabby-images/a87e0/a87e0bbf4a14f9006a2e8ee4ae820dae9f7edbc1" alt="train tSNE"
|
||||
The t-SNE of the test set of CIFAR10 is as follows:
|
||||
data:image/s3,"s3://crabby-images/0c521/0c521b4f8e2979fd0523db288ead112cccc3ca38" alt="test tSNE"
|
|
@ -1,32 +0,0 @@
|
|||
from torchvision.transforms import transforms
|
||||
|
||||
class SimCLRTransform():
|
||||
def __init__(self):
|
||||
self.transform = transforms.Compose([
|
||||
transforms.RandomResizedCrop(size=32, scale=(0.2, 1.0)),
|
||||
transforms.RandomHorizontalFlip(),
|
||||
transforms.RandomApply([transforms.ColorJitter(0.8, 0.8, 0.8, 0.2)], p=0.8),
|
||||
transforms.RandomGrayscale(p=0.2),
|
||||
transforms.RandomApply([transforms.GaussianBlur(kernel_size=32//20*2+1, sigma=(0.1, 2.0))], p=0.5),
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize(mean=[0.4914, 0.4822, 0.4465], std=[0.2023, 0.1994, 0.2010])
|
||||
])
|
||||
|
||||
def __call__(self, x):
|
||||
x1 = self.transform(x)
|
||||
x2 = self.transform(x)
|
||||
return x1, x2
|
||||
|
||||
|
||||
class LeTransform():
|
||||
def __init__(self):
|
||||
self.transform = transforms.Compose([
|
||||
transforms.RandomResizedCrop(size=32, scale=(0.2, 1.0)),
|
||||
transforms.RandomHorizontalFlip(),
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize(mean=[0.4914, 0.4822, 0.4465], std=[0.2023, 0.1994, 0.2010])
|
||||
])
|
||||
|
||||
def __call__(self, x):
|
||||
x = self.transform(x)
|
||||
return x
|
|
@ -1,23 +0,0 @@
|
|||
from colossalai.amp import AMP_TYPE
|
||||
|
||||
|
||||
LOG_NAME = 'cifar-simclr'
|
||||
|
||||
BATCH_SIZE = 512
|
||||
NUM_EPOCHS = 801
|
||||
LEARNING_RATE = 0.03*BATCH_SIZE/256
|
||||
WEIGHT_DECAY = 0.0005
|
||||
MOMENTUM = 0.9
|
||||
|
||||
|
||||
fp16 = dict(
|
||||
mode=AMP_TYPE.TORCH,
|
||||
)
|
||||
|
||||
dataset = dict(
|
||||
root='./dataset',
|
||||
)
|
||||
|
||||
gradient_accumulation=2
|
||||
clip_grad_norm=1.0
|
||||
|
|
@ -1,23 +0,0 @@
|
|||
from colossalai.amp import AMP_TYPE
|
||||
|
||||
|
||||
LOG_NAME = 'cifar-simclr'
|
||||
EPOCH = 800
|
||||
|
||||
BATCH_SIZE = 512
|
||||
NUM_EPOCHS = 51
|
||||
LEARNING_RATE = 0.03*BATCH_SIZE/256
|
||||
WEIGHT_DECAY = 0.0005
|
||||
MOMENTUM = 0.9
|
||||
|
||||
|
||||
fp16 = dict(
|
||||
mode=AMP_TYPE.TORCH,
|
||||
)
|
||||
|
||||
dataset = dict(
|
||||
root='./dataset',
|
||||
)
|
||||
|
||||
gradient_accumulation=1
|
||||
clip_grad_norm=1.0
|
|
@ -1,178 +0,0 @@
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from torchvision.models import resnet34, resnet50, resnet101, resnet152
|
||||
|
||||
|
||||
def backbone(model, **kwargs):
|
||||
assert model in ['resnet18', 'resnet34', 'resnet50', 'resnet101', 'resnet152'], "current version only support resnet18 ~ resnet152"
|
||||
if model == 'resnet18':
|
||||
net = ResNet(PreActBlock, [2,2,2,2], **kwargs)
|
||||
else:
|
||||
net = eval(f"{model}(**kwargs)")
|
||||
net.output_dim = net.fc.in_features
|
||||
net.fc = torch.nn.Identity()
|
||||
return net
|
||||
|
||||
|
||||
def conv3x3(in_planes, out_planes, stride=1):
|
||||
return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False)
|
||||
|
||||
|
||||
class BasicBlock(nn.Module):
|
||||
expansion = 1
|
||||
|
||||
def __init__(self, in_planes, planes, stride=1):
|
||||
super(BasicBlock, self).__init__()
|
||||
self.conv1 = conv3x3(in_planes, planes, stride)
|
||||
self.bn1 = nn.BatchNorm2d(planes)
|
||||
self.conv2 = conv3x3(planes, planes)
|
||||
self.bn2 = nn.BatchNorm2d(planes)
|
||||
|
||||
self.shortcut = nn.Sequential()
|
||||
if stride != 1 or in_planes != self.expansion*planes:
|
||||
self.shortcut = nn.Sequential(
|
||||
nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False),
|
||||
nn.BatchNorm2d(self.expansion*planes)
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
out = F.relu(self.bn1(self.conv1(x)))
|
||||
out = self.bn2(self.conv2(out))
|
||||
out += self.shortcut(x)
|
||||
out = F.relu(out)
|
||||
return out
|
||||
|
||||
|
||||
class PreActBlock(nn.Module):
|
||||
'''Pre-activation version of the BasicBlock.'''
|
||||
expansion = 1
|
||||
|
||||
def __init__(self, in_planes, planes, stride=1):
|
||||
super(PreActBlock, self).__init__()
|
||||
self.bn1 = nn.BatchNorm2d(in_planes)
|
||||
self.conv1 = conv3x3(in_planes, planes, stride)
|
||||
self.bn2 = nn.BatchNorm2d(planes)
|
||||
self.conv2 = conv3x3(planes, planes)
|
||||
|
||||
self.shortcut = nn.Sequential()
|
||||
if stride != 1 or in_planes != self.expansion*planes:
|
||||
self.shortcut = nn.Sequential(
|
||||
nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False)
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
out = F.relu(self.bn1(x))
|
||||
shortcut = self.shortcut(out)
|
||||
out = self.conv1(out)
|
||||
out = self.conv2(F.relu(self.bn2(out)))
|
||||
out += shortcut
|
||||
return out
|
||||
|
||||
|
||||
class Bottleneck(nn.Module):
|
||||
expansion = 4
|
||||
|
||||
def __init__(self, in_planes, planes, stride=1):
|
||||
super(Bottleneck, self).__init__()
|
||||
self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False)
|
||||
self.bn1 = nn.BatchNorm2d(planes)
|
||||
self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
|
||||
self.bn2 = nn.BatchNorm2d(planes)
|
||||
self.conv3 = nn.Conv2d(planes, self.expansion*planes, kernel_size=1, bias=False)
|
||||
self.bn3 = nn.BatchNorm2d(self.expansion*planes)
|
||||
|
||||
self.shortcut = nn.Sequential()
|
||||
if stride != 1 or in_planes != self.expansion*planes:
|
||||
self.shortcut = nn.Sequential(
|
||||
nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False),
|
||||
nn.BatchNorm2d(self.expansion*planes)
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
out = F.relu(self.bn1(self.conv1(x)))
|
||||
out = F.relu(self.bn2(self.conv2(out)))
|
||||
out = self.bn3(self.conv3(out))
|
||||
out += self.shortcut(x)
|
||||
out = F.relu(out)
|
||||
return out
|
||||
|
||||
|
||||
class PreActBottleneck(nn.Module):
|
||||
'''Pre-activation version of the original Bottleneck module.'''
|
||||
expansion = 4
|
||||
|
||||
def __init__(self, in_planes, planes, stride=1):
|
||||
super(PreActBottleneck, self).__init__()
|
||||
self.bn1 = nn.BatchNorm2d(in_planes)
|
||||
self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False)
|
||||
self.bn2 = nn.BatchNorm2d(planes)
|
||||
self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
|
||||
self.bn3 = nn.BatchNorm2d(planes)
|
||||
self.conv3 = nn.Conv2d(planes, self.expansion*planes, kernel_size=1, bias=False)
|
||||
|
||||
self.shortcut = nn.Sequential()
|
||||
if stride != 1 or in_planes != self.expansion*planes:
|
||||
self.shortcut = nn.Sequential(
|
||||
nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False)
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
out = F.relu(self.bn1(x))
|
||||
shortcut = self.shortcut(out)
|
||||
out = self.conv1(out)
|
||||
out = self.conv2(F.relu(self.bn2(out)))
|
||||
out = self.conv3(F.relu(self.bn3(out)))
|
||||
out += shortcut
|
||||
return out
|
||||
|
||||
|
||||
class ResNet(nn.Module):
|
||||
def __init__(self, block, num_blocks, num_classes=10):
|
||||
super(ResNet, self).__init__()
|
||||
self.in_planes = 64
|
||||
|
||||
self.conv1 = conv3x3(3,64)
|
||||
self.bn1 = nn.BatchNorm2d(64)
|
||||
self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1)
|
||||
self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2)
|
||||
self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2)
|
||||
self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2)
|
||||
self.fc = nn.Linear(512*block.expansion, num_classes)
|
||||
|
||||
def _make_layer(self, block, planes, num_blocks, stride):
|
||||
strides = [stride] + [1]*(num_blocks-1)
|
||||
layers = []
|
||||
for stride in strides:
|
||||
layers.append(block(self.in_planes, planes, stride))
|
||||
self.in_planes = planes * block.expansion
|
||||
return nn.Sequential(*layers)
|
||||
|
||||
def forward(self, x, lin=0, lout=5):
|
||||
out = x
|
||||
if lin < 1 and lout > -1:
|
||||
out = self.conv1(out)
|
||||
out = self.bn1(out)
|
||||
out = F.relu(out)
|
||||
if lin < 2 and lout > 0:
|
||||
out = self.layer1(out)
|
||||
if lin < 3 and lout > 1:
|
||||
out = self.layer2(out)
|
||||
if lin < 4 and lout > 2:
|
||||
out = self.layer3(out)
|
||||
if lin < 5 and lout > 3:
|
||||
out = self.layer4(out)
|
||||
if lout > 4:
|
||||
out = F.avg_pool2d(out, 4)
|
||||
out = out.view(out.size(0), -1)
|
||||
out = self.fc(out)
|
||||
return out
|
||||
|
||||
def debug():
|
||||
net = backbone('resnet18', pretrained=True)
|
||||
x = torch.randn(4,3,32,32)
|
||||
y = net(x)
|
||||
print(y.size())
|
||||
|
||||
if __name__ == '__main__':
|
||||
debug()
|
|
@ -1,19 +0,0 @@
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from .Backbone import backbone
|
||||
|
||||
class Linear_eval(nn.Module):
|
||||
|
||||
def __init__(self, model='resnet18', class_num=10, **kwargs):
|
||||
super().__init__()
|
||||
|
||||
self.backbone = backbone(model, **kwargs)
|
||||
self.backbone.requires_grad_(False)
|
||||
self.fc = nn.Linear(self.backbone.output_dim, class_num)
|
||||
|
||||
def forward(self, x):
|
||||
|
||||
out = self.backbone(x)
|
||||
out = self.fc(out)
|
||||
return out
|
|
@ -1,36 +0,0 @@
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from .Backbone import backbone
|
||||
|
||||
class projection_MLP(nn.Module):
|
||||
def __init__(self, in_dim, out_dim=256):
|
||||
super().__init__()
|
||||
hidden_dim = in_dim
|
||||
self.layer1 = nn.Sequential(
|
||||
nn.Linear(in_dim, hidden_dim),
|
||||
nn.ReLU(inplace=True)
|
||||
)
|
||||
self.layer2 = nn.Linear(hidden_dim, out_dim)
|
||||
def forward(self, x):
|
||||
x = self.layer1(x)
|
||||
x = self.layer2(x)
|
||||
return x
|
||||
|
||||
class SimCLR(nn.Module):
|
||||
|
||||
def __init__(self, model='resnet18', **kwargs):
|
||||
super().__init__()
|
||||
|
||||
self.backbone = backbone(model, **kwargs)
|
||||
self.projector = projection_MLP(self.backbone.output_dim)
|
||||
self.encoder = nn.Sequential(
|
||||
self.backbone,
|
||||
self.projector
|
||||
)
|
||||
|
||||
def forward(self, x1, x2):
|
||||
|
||||
z1 = self.encoder(x1)
|
||||
z2 = self.encoder(x2)
|
||||
return z1, z2
|
|
@ -1,15 +0,0 @@
|
|||
from colossalai.trainer.hooks import BaseHook
|
||||
from colossalai.core import global_context as gpc
|
||||
from colossalai.context import ParallelMode
|
||||
from colossalai.logging import get_dist_logger
|
||||
|
||||
|
||||
class TotalBatchsizeHook(BaseHook):
|
||||
def __init__(self, priority: int = 2) -> None:
|
||||
super().__init__(priority)
|
||||
self.logger = get_dist_logger()
|
||||
|
||||
def before_train(self, trainer):
|
||||
total_batch_size = gpc.config.BATCH_SIZE * \
|
||||
gpc.config.gradient_accumulation * gpc.get_world_size(ParallelMode.DATA)
|
||||
self.logger.info(f'Total batch size = {total_batch_size}', ranks=[0])
|
Before Width: | Height: | Size: 16 KiB |
Before Width: | Height: | Size: 26 KiB |
Before Width: | Height: | Size: 29 KiB |
Before Width: | Height: | Size: 405 KiB |
Before Width: | Height: | Size: 531 KiB |
|
@ -1,7 +0,0 @@
|
|||
#!/usr/bin/env sh
|
||||
|
||||
## phase 1: self-supervised training
|
||||
python -m torch.distributed.launch --nproc_per_node 1 train_simclr.py
|
||||
|
||||
## phase 2: linear evaluation
|
||||
python -m torch.distributed.launch --nproc_per_node 1 train_linear.py
|
|
@ -1,105 +0,0 @@
|
|||
from colossalai.nn.metric import Accuracy
|
||||
import torch
|
||||
import colossalai
|
||||
from colossalai.core import global_context as gpc
|
||||
from colossalai.logging import get_dist_logger
|
||||
from colossalai.trainer import Trainer, hooks
|
||||
from colossalai.utils import get_dataloader, MultiTimer
|
||||
from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR
|
||||
|
||||
from torchvision.datasets import CIFAR10
|
||||
from myhooks import TotalBatchsizeHook
|
||||
from models.linear_eval import Linear_eval
|
||||
from augmentation import LeTransform
|
||||
|
||||
def build_dataset_train():
|
||||
augment = LeTransform()
|
||||
train_dataset = CIFAR10(root=gpc.config.dataset.root,
|
||||
transform=augment,
|
||||
train=True)
|
||||
|
||||
return get_dataloader(
|
||||
dataset=train_dataset,
|
||||
shuffle=True,
|
||||
num_workers = 1,
|
||||
batch_size=gpc.config.BATCH_SIZE,
|
||||
pin_memory=True,
|
||||
)
|
||||
|
||||
def build_dataset_test():
|
||||
augment = LeTransform()
|
||||
val_dataset = CIFAR10(root=gpc.config.dataset.root,
|
||||
transform=augment,
|
||||
train=False)
|
||||
|
||||
return get_dataloader(
|
||||
dataset=val_dataset,
|
||||
add_sampler=False,
|
||||
num_workers = 1,
|
||||
batch_size=gpc.config.BATCH_SIZE,
|
||||
pin_memory=True,
|
||||
)
|
||||
|
||||
def main():
|
||||
colossalai.launch_from_torch(config='./le_config.py')
|
||||
|
||||
# get logger
|
||||
logger = get_dist_logger()
|
||||
|
||||
## build model
|
||||
model = Linear_eval(model='resnet18', class_num=10)
|
||||
|
||||
# build dataloader
|
||||
train_dataloader = build_dataset_train()
|
||||
test_dataloader = build_dataset_test()
|
||||
|
||||
# build loss
|
||||
criterion = torch.nn.CrossEntropyLoss()
|
||||
|
||||
# build optimizer
|
||||
optimizer = colossalai.nn.FusedSGD(model.parameters(), lr=gpc.config.LEARNING_RATE, weight_decay=gpc.config.WEIGHT_DECAY, momentum=gpc.config.MOMENTUM)
|
||||
|
||||
# lr_scheduelr
|
||||
lr_scheduler = CosineAnnealingWarmupLR(optimizer, warmup_steps=5, total_steps=gpc.config.NUM_EPOCHS)
|
||||
|
||||
engine, train_dataloader, test_dataloader, _ = colossalai.initialize(
|
||||
model, optimizer, criterion, train_dataloader, test_dataloader
|
||||
)
|
||||
logger.info("initialized colossalai components", ranks=[0])
|
||||
|
||||
## Load trained self-supervised SimCLR model
|
||||
engine.model.load_state_dict(torch.load(f'./ckpt/{gpc.config.LOG_NAME}/epoch{gpc.config.EPOCH}-tp0-pp0.pt')['model'], strict=False)
|
||||
logger.info("pretrained model loaded", ranks=[0])
|
||||
|
||||
# build a timer to measure time
|
||||
timer = MultiTimer()
|
||||
|
||||
# build trainer
|
||||
trainer = Trainer(engine=engine, logger=logger, timer=timer)
|
||||
|
||||
# build hooks
|
||||
hook_list = [
|
||||
hooks.LossHook(),
|
||||
hooks.AccuracyHook(accuracy_func=Accuracy()),
|
||||
hooks.LogMetricByEpochHook(logger),
|
||||
hooks.LRSchedulerHook(lr_scheduler, by_epoch=True),
|
||||
TotalBatchsizeHook(),
|
||||
|
||||
# comment if you do not need to use the hooks below
|
||||
hooks.SaveCheckpointHook(interval=5, checkpoint_dir=f'./ckpt/{gpc.config.LOG_NAME}-eval'),
|
||||
hooks.TensorboardHook(log_dir=f'./tb_logs/{gpc.config.LOG_NAME}-eval', ranks=[0]),
|
||||
]
|
||||
|
||||
# start training
|
||||
trainer.fit(
|
||||
train_dataloader=train_dataloader,
|
||||
test_dataloader=test_dataloader,
|
||||
epochs=gpc.config.NUM_EPOCHS,
|
||||
hooks=hook_list,
|
||||
display_progress=True,
|
||||
test_interval=1
|
||||
)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
|
@ -1,100 +0,0 @@
|
|||
import colossalai
|
||||
from colossalai.core import global_context as gpc
|
||||
from colossalai.logging import get_dist_logger
|
||||
from colossalai.trainer import Trainer, hooks
|
||||
from colossalai.utils import get_dataloader, MultiTimer
|
||||
from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR
|
||||
|
||||
from torchvision.datasets import CIFAR10
|
||||
from NT_Xentloss import NT_Xentloss
|
||||
from myhooks import TotalBatchsizeHook
|
||||
from models.simclr import SimCLR
|
||||
from augmentation import SimCLRTransform
|
||||
|
||||
def build_dataset_train():
|
||||
augment = SimCLRTransform()
|
||||
train_dataset = CIFAR10(root=gpc.config.dataset.root,
|
||||
transform=augment,
|
||||
train=True,
|
||||
download=True)
|
||||
|
||||
return get_dataloader(
|
||||
dataset=train_dataset,
|
||||
shuffle=True,
|
||||
num_workers = 1,
|
||||
batch_size=gpc.config.BATCH_SIZE,
|
||||
pin_memory=True,
|
||||
)
|
||||
|
||||
def build_dataset_test():
|
||||
augment = SimCLRTransform()
|
||||
val_dataset = CIFAR10(root=gpc.config.dataset.root,
|
||||
transform=augment,
|
||||
train=False)
|
||||
|
||||
return get_dataloader(
|
||||
dataset=val_dataset,
|
||||
add_sampler=False,
|
||||
num_workers = 1,
|
||||
batch_size=gpc.config.BATCH_SIZE,
|
||||
pin_memory=True,
|
||||
)
|
||||
|
||||
def main():
|
||||
colossalai.launch_from_torch(config='./config.py')
|
||||
|
||||
# get logger
|
||||
logger = get_dist_logger()
|
||||
|
||||
## build model
|
||||
model = SimCLR(model='resnet18')
|
||||
|
||||
# build dataloader
|
||||
train_dataloader = build_dataset_train()
|
||||
test_dataloader = build_dataset_test()
|
||||
|
||||
# build loss
|
||||
criterion = NT_Xentloss()
|
||||
|
||||
# build optimizer
|
||||
optimizer = colossalai.nn.FusedSGD(model.parameters(), lr=gpc.config.LEARNING_RATE, weight_decay=gpc.config.WEIGHT_DECAY, momentum=gpc.config.MOMENTUM)
|
||||
|
||||
# lr_scheduelr
|
||||
lr_scheduler = CosineAnnealingWarmupLR(optimizer, warmup_steps=10, total_steps=gpc.config.NUM_EPOCHS)
|
||||
|
||||
engine, train_dataloader, test_dataloader, _ = colossalai.initialize(
|
||||
model, optimizer, criterion, train_dataloader, test_dataloader
|
||||
)
|
||||
logger.info("initialized colossalai components", ranks=[0])
|
||||
|
||||
# build a timer to measure time
|
||||
timer = MultiTimer()
|
||||
|
||||
# build trainer
|
||||
trainer = Trainer(engine=engine, logger=logger, timer=timer)
|
||||
|
||||
# build hooks
|
||||
hook_list = [
|
||||
hooks.LossHook(),
|
||||
hooks.LogMetricByEpochHook(logger),
|
||||
hooks.LRSchedulerHook(lr_scheduler, by_epoch=True),
|
||||
TotalBatchsizeHook(),
|
||||
|
||||
# comment if you do not need to use the hooks below
|
||||
hooks.SaveCheckpointHook(interval=50, checkpoint_dir=f'./ckpt/{gpc.config.LOG_NAME}'),
|
||||
hooks.TensorboardHook(log_dir=f'./tb_logs/{gpc.config.LOG_NAME}', ranks=[0]),
|
||||
]
|
||||
|
||||
# start training
|
||||
trainer.fit(
|
||||
train_dataloader=train_dataloader,
|
||||
test_dataloader=test_dataloader,
|
||||
epochs=gpc.config.NUM_EPOCHS,
|
||||
hooks=hook_list,
|
||||
display_progress=True,
|
||||
test_interval=1
|
||||
)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
|
@ -1,72 +0,0 @@
|
|||
import torch
|
||||
import numpy as np
|
||||
from sklearn.manifold import TSNE
|
||||
import matplotlib.pyplot as plt
|
||||
from models.simclr import SimCLR
|
||||
from torchvision.datasets import CIFAR10
|
||||
from torch.utils.data import DataLoader
|
||||
from torchvision import transforms
|
||||
|
||||
log_name = 'cifar-simclr'
|
||||
epoch = 800
|
||||
|
||||
fea_flag = True
|
||||
tsne_flag = True
|
||||
plot_flag = True
|
||||
|
||||
if fea_flag:
|
||||
path = f'ckpt/{log_name}/epoch{epoch}-tp0-pp0.pt'
|
||||
net = SimCLR('resnet18').cuda()
|
||||
print(net.load_state_dict(torch.load(path)['model']))
|
||||
|
||||
transform_eval = transforms.Compose([
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize([0.4914, 0.4822, 0.4465], [0.2023, 0.1994, 0.2010])
|
||||
])
|
||||
|
||||
train_dataset = CIFAR10(root='./dataset', train=True, transform=transform_eval)
|
||||
train_dataloader = DataLoader(train_dataset, batch_size=256, shuffle=False, num_workers=4)
|
||||
|
||||
test_dataset = CIFAR10(root='./dataset', train=False, transform=transform_eval)
|
||||
test_dataloader = DataLoader(test_dataset, batch_size=256, shuffle=False, num_workers=4)
|
||||
|
||||
def feature_extractor(model, loader):
|
||||
model.eval()
|
||||
all_fea = []
|
||||
all_targets = []
|
||||
for img, target in loader:
|
||||
img = img.cuda()
|
||||
fea = model.backbone(img)
|
||||
all_fea.append(fea.detach().cpu())
|
||||
all_targets.append(target)
|
||||
all_fea = torch.cat(all_fea)
|
||||
all_targets = torch.cat(all_targets)
|
||||
return all_fea.numpy(), all_targets.numpy()
|
||||
|
||||
if tsne_flag:
|
||||
train_fea, train_targets = feature_extractor(net, train_dataloader)
|
||||
train_embedded = TSNE(n_components=2).fit_transform(train_fea)
|
||||
test_fea, test_targets = feature_extractor(net, test_dataloader)
|
||||
test_embedded = TSNE(n_components=2).fit_transform(test_fea)
|
||||
np.savez('results/embedding.npz', train_embedded=train_embedded, train_targets=train_targets, test_embedded=test_embedded, test_targets=test_targets)
|
||||
|
||||
if plot_flag:
|
||||
npz = np.load('embedding.npz')
|
||||
train_embedded = npz['train_embedded']
|
||||
train_targets = npz['train_targets']
|
||||
test_embedded = npz['test_embedded']
|
||||
test_targets = npz['test_targets']
|
||||
|
||||
plt.figure(figsize=(16,16))
|
||||
for i in range(len(np.unique(train_targets))):
|
||||
plt.scatter(train_embedded[train_targets==i,0], train_embedded[train_targets==i,1], label=i)
|
||||
plt.title('train')
|
||||
plt.legend()
|
||||
plt.savefig('results/train_tsne.png')
|
||||
|
||||
plt.figure(figsize=(16,16))
|
||||
for i in range(len(np.unique(test_targets))):
|
||||
plt.scatter(test_embedded[test_targets==i,0], test_embedded[test_targets==i,1], label=i)
|
||||
plt.title('test')
|
||||
plt.legend()
|
||||
plt.savefig('results/test_tsne.png')
|
|
@ -1,86 +0,0 @@
|
|||
# Overview
|
||||
|
||||
A common way to speed up AI model training is to implement large-batch training with the help of data parallelism, but this requires expensive supercomputer clusters. In this example, we used a small server with only 4 GPUs to reproduce the large-scale pre-training of Vision Transformer (ViT) on ImageNet-1K in 14 hours.
|
||||
|
||||
# How to run
|
||||
|
||||
On a single server, you can directly use torch.distributed to start pre-training on multiple GPUs in parallel. In Colossal-AI, we provided several launch methods to init the distributed backend. You can use `colossalai.launch` and `colossalai.get_default_parser` to pass the parameters via command line. If you happen to use launchers such as SLURM, OpenMPI and PyTorch launch utility, you can use `colossalai.launch_from_<torch/slurm/openmpi>` to read rank and world size from the environment variables directly for convenience. In this example, we use `launch_from_slurm` for demo purpose. You can check out more information about SLURM [here](https://slurm.schedmd.com/documentation.html).
|
||||
|
||||
```shell
|
||||
HOST=<node name> srun bash ./scripts/train_slurm.sh
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
If you are using `colossalai.launch`, do this:
|
||||
In your training script:
|
||||
```python
|
||||
# initialize distributed setting
|
||||
parser = colossalai.get_default_parser()
|
||||
args = parser.parse_args()
|
||||
colossalai.launch(config=args.config,
|
||||
rank=args.rank,
|
||||
world_size=args.world_size,
|
||||
host=args.host,
|
||||
port=args.port,
|
||||
backend=args.backend
|
||||
)
|
||||
```
|
||||
|
||||
In your terminal:
|
||||
```shell
|
||||
<some_launcher> python train.py --config ./config.py --rank <rank> --world_size <world_size> --host <node name> --port 29500
|
||||
```
|
||||
---
|
||||
If you are using `colossalai.launch_from_torch`, do this:
|
||||
In your training script:
|
||||
|
||||
```python
|
||||
# initialize distributed setting
|
||||
parser = colossalai.get_default_parser()
|
||||
args = parser.parse_args()
|
||||
colossalai.launch_from_torch(config=args.config)
|
||||
```
|
||||
|
||||
In your terminal
|
||||
```shell
|
||||
python -m torch.distributed.launch --nproc_per_node <world_size> train.py --config ./config.py --host <node name> --port 29500
|
||||
```
|
||||
|
||||
# Experiments
|
||||
To facilitate more people to reproduce the experiments with large-scale data parallel, we pre-trained ViT-Base/32 in only 14.58 hours on a small server with 4 NVIDIA A100 GPUs using ImageNet-1K dataset with batch size 32K for 300 epochs maintaining accuracy. For more complex pre-training of ViT-Base/16 and ViT-Large/32, it also takes only 78.58 hours and 37.83 hours to complete. Since the server used in this example is not a standard NVIDIA DGX A100 supercomputing unit, perhaps a better acceleration can be obtained on more professional hardware.
|
||||
|
||||
data:image/s3,"s3://crabby-images/b18cb/b18cb0a1396c894a77da763af305cf00ce5e565a" alt="Loss Curve"
|
||||
data:image/s3,"s3://crabby-images/b7d5c/b7d5ce0db3a69a7dbcc4df80260f9e574788f8ce" alt="Accuracy"
|
||||
|
||||
As can be seen from the above figure, the ViT model eventually converges well after training 300 epochs. It is worth noting that, unlike the common small-batch training convergence process, the model performance has a temporary decline in the middle of the large-batch training process. This is due to the difficulty of convergence in large-batch training. As the number of iterations is reduced, a larger learning rate is needed to ensure the final convergence. Since we did not carefully adjust the parameters, perhaps other parameter settings could get better convergence.
|
||||
|
||||
# Details
|
||||
`config.py`
|
||||
|
||||
This is a [configuration file](https://colossalai.org/config.html) that defines hyperparameters and trainign scheme (fp16, gradient accumulation, etc.). The config content can be accessed through `gpc.config` in the program.
|
||||
|
||||
In this example, we trained ViT-Base/16 for 300 epochs on the ImageNet-1K dataset. The batch size is expanded to 32K through data parallelism. Since only 4 A100 GPUs on one small server are used, and the GPU memory is limited, the batch size of 32K cannot be used directly. Therefore, the batch size used on each GPU is only 256, and the 256 batch size is equivalently expanded to 8K through gradient accumulation 32 times. Finally, data parallelism is used between 4 GPUs to achieve an equivalent batch size of 32K.
|
||||
|
||||
Since the batch size of 32K far exceeds the use range of common optimizers and is difficult to train, we use the large-batch optimizer [LAMB](https://arxiv.org/abs/1904.00962) provided by Colossal-AI to achieve a better convergence. The learning rate and weight decay of [LAMB](https://arxiv.org/abs/1904.00962) are set to 1.8e-2 and 0.1, respectively. The learning rate scheduler uses a linear warmup strategy of 150 epochs. We also used FP16 mixed precision to speed up the training process, and introduced gradient clipping to help convergence. For simplicity and speed, we only use [Mixup](https://arxiv.org/abs/1710.09412) instead of `RandAug` in data augmentation.
|
||||
|
||||
By tuning the parallelism, this example can be quickly deployed to a single server with several GPUs or to a large cluster with lots of nodes and GPUs. If there are enough computing resources to allow data parallel to be directly extended to hundreds or even thousands of GPUs, the training process of several days on a single A100 GPU can be shortened to less than half an hour.
|
||||
|
||||
`imagenet_dali_dataloader.py`
|
||||
|
||||
To accelerate the training process, we use [DALI](https://github.com/NVIDIA/DALI) to read data and require the dataset to be in TFRecord format, which avoids directly reading a large number of raw image files and being limited by the efficiency of the file system.
|
||||
|
||||
`train.py`
|
||||
|
||||
We call DALI in this file to read data and start the training process using Colossal-AI.
|
||||
|
||||
`mixup.py`
|
||||
|
||||
Since Mixup is used as data augmentation, we define the loss function of Mixup here.
|
||||
|
||||
`myhooks.py`
|
||||
We define hook functions that record running information to help debugging.
|
||||
|
||||
# How to build TFRecords dataset
|
||||
|
||||
As we use [DALI](https://github.com/NVIDIA/DALI) to read data, we use the TFRecords dataset instead of raw Imagenet dataset. If you don't have TFRecords dataset, follow [imagenet-tools](https://github.com/ver217/imagenet-tools) to build one.
|
|
@ -1,21 +0,0 @@
|
|||
from colossalai.amp import AMP_TYPE
|
||||
|
||||
|
||||
# ViT Base
|
||||
BATCH_SIZE = 256
|
||||
DROP_RATE = 0.1
|
||||
NUM_EPOCHS = 300
|
||||
|
||||
fp16 = dict(
|
||||
mode=AMP_TYPE.TORCH,
|
||||
)
|
||||
|
||||
gradient_accumulation = 16
|
||||
clip_grad_norm = 1.0
|
||||
|
||||
dali = dict(
|
||||
# root='./dataset/ILSVRC2012_1k',
|
||||
root='/project/scratch/p200012/dataset/ILSVRC2012_1k',
|
||||
gpu_aug=True,
|
||||
mixup_alpha=0.2
|
||||
)
|
|
@ -1,110 +0,0 @@
|
|||
from nvidia.dali.pipeline import Pipeline
|
||||
from nvidia.dali.plugin.pytorch import DALIClassificationIterator, LastBatchPolicy
|
||||
import nvidia.dali.fn as fn
|
||||
import nvidia.dali.types as types
|
||||
import nvidia.dali.tfrecord as tfrec
|
||||
import torch
|
||||
import numpy as np
|
||||
from .rand_augment import RandAugment
|
||||
|
||||
|
||||
class DaliDataloader(DALIClassificationIterator):
|
||||
def __init__(self,
|
||||
tfrec_filenames,
|
||||
tfrec_idx_filenames,
|
||||
shard_id=0,
|
||||
num_shards=1,
|
||||
batch_size=128,
|
||||
num_threads=4,
|
||||
resize=256,
|
||||
crop=224,
|
||||
prefetch=2,
|
||||
training=True,
|
||||
gpu_aug=False,
|
||||
cuda=True,
|
||||
mixup_alpha=0.0,
|
||||
randaug_magnitude=10,
|
||||
randaug_num_layers=0):
|
||||
self.mixup_alpha = mixup_alpha
|
||||
self.training = training
|
||||
self.randaug_magnitude = randaug_magnitude
|
||||
self.randaug_num_layers = randaug_num_layers
|
||||
pipe = Pipeline(batch_size=batch_size,
|
||||
num_threads=num_threads,
|
||||
device_id=torch.cuda.current_device() if cuda else None,
|
||||
seed=42)
|
||||
with pipe:
|
||||
inputs = fn.readers.tfrecord(
|
||||
path=tfrec_filenames,
|
||||
index_path=tfrec_idx_filenames,
|
||||
random_shuffle=training,
|
||||
shard_id=shard_id,
|
||||
num_shards=num_shards,
|
||||
initial_fill=10000,
|
||||
read_ahead=True,
|
||||
prefetch_queue_depth=prefetch,
|
||||
name='Reader',
|
||||
features={
|
||||
'image/encoded': tfrec.FixedLenFeature((), tfrec.string, ""),
|
||||
'image/class/label': tfrec.FixedLenFeature([1], tfrec.int64, -1),
|
||||
})
|
||||
images = inputs["image/encoded"]
|
||||
images = fn.decoders.image(images,
|
||||
device='mixed' if gpu_aug else 'cpu',
|
||||
output_type=types.RGB)
|
||||
if training:
|
||||
images = fn.random_resized_crop(images,
|
||||
size=crop,
|
||||
device='gpu' if gpu_aug else 'cpu')
|
||||
if randaug_num_layers == 0:
|
||||
flip_lr = fn.random.coin_flip(probability=0.5)
|
||||
images = fn.flip(images, horizontal=flip_lr)
|
||||
else:
|
||||
images = fn.resize(images,
|
||||
device='gpu' if gpu_aug else 'cpu',
|
||||
resize_x=resize,
|
||||
resize_y=resize,
|
||||
dtype=types.FLOAT,
|
||||
interp_type=types.INTERP_TRIANGULAR)
|
||||
images = fn.crop(images,
|
||||
dtype=types.FLOAT,
|
||||
crop=(crop, crop))
|
||||
label = inputs["image/class/label"] - 1 # 0-999
|
||||
if cuda: # transfer data to gpu
|
||||
pipe.set_outputs(images.gpu(), label.gpu())
|
||||
else:
|
||||
pipe.set_outputs(images, label)
|
||||
|
||||
pipe.build()
|
||||
last_batch_policy = 'DROP' if training else 'PARTIAL'
|
||||
super().__init__(pipe, reader_name="Reader",
|
||||
auto_reset=True,
|
||||
last_batch_policy=last_batch_policy)
|
||||
|
||||
def __iter__(self):
|
||||
# if not reset (after an epoch), reset; if just initialize, ignore
|
||||
if self._counter >= self._size or self._size < 0:
|
||||
self.reset()
|
||||
return self
|
||||
|
||||
def __next__(self):
|
||||
data = super().__next__()
|
||||
img, label = data[0]['data'], data[0]['label']
|
||||
img = img.permute(0, 3, 1, 2)
|
||||
if self.randaug_num_layers > 0 and self.training:
|
||||
img = RandAugment(img, num_layers=self.randaug_num_layers, magnitude=self.randaug_magnitude)
|
||||
img = (img - 127.5) / 127.5
|
||||
label = label.squeeze()
|
||||
if self.mixup_alpha > 0.0:
|
||||
if self.training:
|
||||
lam = np.random.beta(self.mixup_alpha, self.mixup_alpha)
|
||||
idx = torch.randperm(img.size(0)).to(img.device)
|
||||
img = lam * img + (1 - lam) * img[idx, :]
|
||||
label_a, label_b = label, label[idx]
|
||||
lam = torch.tensor([lam], device=img.device, dtype=img.dtype)
|
||||
label = {'targets_a': label_a, 'targets_b': label_b, 'lam': lam}
|
||||
else:
|
||||
label = {'targets_a': label, 'targets_b': label, 'lam': torch.ones(
|
||||
1, device=img.device, dtype=img.dtype)}
|
||||
return img, label
|
||||
return img, label
|
|
@ -1,209 +0,0 @@
|
|||
import torch
|
||||
import numpy as np
|
||||
import torchvision.transforms.functional as TF
|
||||
|
||||
_MAX_LEVEL = 10
|
||||
|
||||
_HPARAMS = {
|
||||
'cutout_const': 40,
|
||||
'translate_const': 40,
|
||||
}
|
||||
|
||||
_FILL = tuple([128, 128, 128])
|
||||
# RGB
|
||||
|
||||
|
||||
def blend(image0, image1, factor):
|
||||
# blend image0 with image1
|
||||
# we only use this function in the 'color' function
|
||||
if factor == 0.0:
|
||||
return image0
|
||||
if factor == 1.0:
|
||||
return image1
|
||||
image0 = image0.type(torch.float32)
|
||||
image1 = image1.type(torch.float32)
|
||||
scaled = (image1 - image0) * factor
|
||||
image = image0 + scaled
|
||||
|
||||
if factor > 0.0 and factor < 1.0:
|
||||
return image.type(torch.uint8)
|
||||
|
||||
image = torch.clamp(image, 0, 255).type(torch.uint8)
|
||||
return image
|
||||
|
||||
|
||||
def autocontrast(image):
|
||||
image = TF.autocontrast(image)
|
||||
return image
|
||||
|
||||
|
||||
def equalize(image):
|
||||
image = TF.equalize(image)
|
||||
return image
|
||||
|
||||
|
||||
def rotate(image, degree, fill=_FILL):
|
||||
image = TF.rotate(image, angle=degree, fill=fill)
|
||||
return image
|
||||
|
||||
|
||||
def posterize(image, bits):
|
||||
image = TF.posterize(image, bits)
|
||||
return image
|
||||
|
||||
|
||||
def sharpness(image, factor):
|
||||
image = TF.adjust_sharpness(image, sharpness_factor=factor)
|
||||
return image
|
||||
|
||||
|
||||
def contrast(image, factor):
|
||||
image = TF.adjust_contrast(image, factor)
|
||||
return image
|
||||
|
||||
|
||||
def brightness(image, factor):
|
||||
image = TF.adjust_brightness(image, factor)
|
||||
return image
|
||||
|
||||
|
||||
def invert(image):
|
||||
return 255-image
|
||||
|
||||
|
||||
def solarize(image, threshold=128):
|
||||
return torch.where(image < threshold, image, 255-image)
|
||||
|
||||
|
||||
def solarize_add(image, addition=0, threshold=128):
|
||||
add_image = image.long() + addition
|
||||
add_image = torch.clamp(add_image, 0, 255).type(torch.uint8)
|
||||
return torch.where(image < threshold, add_image, image)
|
||||
|
||||
|
||||
def color(image, factor):
|
||||
new_image = TF.rgb_to_grayscale(image, num_output_channels=3)
|
||||
return blend(new_image, image, factor=factor)
|
||||
|
||||
|
||||
def shear_x(image, level, fill=_FILL):
|
||||
image = TF.affine(image, 0, [0, 0], 1.0, [level, 0], fill=fill)
|
||||
return image
|
||||
|
||||
|
||||
def shear_y(image, level, fill=_FILL):
|
||||
image = TF.affine(image, 0, [0, 0], 1.0, [0, level], fill=fill)
|
||||
return image
|
||||
|
||||
|
||||
def translate_x(image, level, fill=_FILL):
|
||||
image = TF.affine(image, 0, [level, 0], 1.0, [0, 0], fill=fill)
|
||||
return image
|
||||
|
||||
|
||||
def translate_y(image, level, fill=_FILL):
|
||||
image = TF.affine(image, 0, [0, level], 1.0, [0, 0], fill=fill)
|
||||
return image
|
||||
|
||||
|
||||
def cutout(image, pad_size, fill=_FILL):
|
||||
b, c, h, w = image.shape
|
||||
mask = torch.ones((b, c, h, w), dtype=torch.uint8).cuda()
|
||||
y = np.random.randint(pad_size, h-pad_size)
|
||||
x = np.random.randint(pad_size, w-pad_size)
|
||||
for i in range(c):
|
||||
mask[:, i, (y-pad_size): (y+pad_size), (x-pad_size): (x+pad_size)] = fill[i]
|
||||
image = torch.where(mask == 1, image, mask)
|
||||
return image
|
||||
|
||||
|
||||
def _randomly_negate_tensor(level):
|
||||
# With 50% prob turn the tensor negative.
|
||||
flip = np.random.randint(0, 2)
|
||||
final_level = -level if flip else level
|
||||
return final_level
|
||||
|
||||
|
||||
def _rotate_level_to_arg(level):
|
||||
level = (level/_MAX_LEVEL) * 30.
|
||||
level = _randomly_negate_tensor(level)
|
||||
return level
|
||||
|
||||
|
||||
def _shear_level_to_arg(level):
|
||||
level = (level/_MAX_LEVEL) * 0.3
|
||||
# Flip level to negative with 50% chance.
|
||||
level = _randomly_negate_tensor(level)
|
||||
return level
|
||||
|
||||
|
||||
def _translate_level_to_arg(level, translate_const):
|
||||
level = (level/_MAX_LEVEL) * float(translate_const)
|
||||
# Flip level to negative with 50% chance.
|
||||
level = _randomly_negate_tensor(level)
|
||||
return level
|
||||
|
||||
|
||||
def level(hparams):
|
||||
return {
|
||||
'AutoContrast': lambda level: None,
|
||||
'Equalize': lambda level: None,
|
||||
'Invert': lambda level: None,
|
||||
'Rotate': _rotate_level_to_arg,
|
||||
'Posterize': lambda level: (int((level/_MAX_LEVEL) * 4)),
|
||||
'Solarize': lambda level: (int((level/_MAX_LEVEL) * 200)),
|
||||
'SolarizeAdd': lambda level: (int((level/_MAX_LEVEL) * 110)),
|
||||
'Color': lambda level: ((level/_MAX_LEVEL) * 1.8 + 0.1),
|
||||
'Contrast': lambda level: ((level/_MAX_LEVEL) * 1.8 + 0.1),
|
||||
'Brightness': lambda level: ((level/_MAX_LEVEL) * 1.8 + 0.1),
|
||||
'Sharpness': lambda level: ((level/_MAX_LEVEL) * 1.8 + 0.1),
|
||||
'ShearX': _shear_level_to_arg,
|
||||
'ShearY': _shear_level_to_arg,
|
||||
'Cutout': lambda level: (int((level/_MAX_LEVEL) * hparams['cutout_const'])),
|
||||
'TranslateX': lambda level: _translate_level_to_arg(level, hparams['translate_const']),
|
||||
'TranslateY': lambda level: _translate_level_to_arg(level, hparams['translate_const']),
|
||||
}
|
||||
|
||||
|
||||
AUGMENTS = {
|
||||
'AutoContrast': autocontrast,
|
||||
'Equalize': equalize,
|
||||
'Invert': invert,
|
||||
'Rotate': rotate,
|
||||
'Posterize': posterize,
|
||||
'Solarize': solarize,
|
||||
'SolarizeAdd': solarize_add,
|
||||
'Color': color,
|
||||
'Contrast': contrast,
|
||||
'Brightness': brightness,
|
||||
'Sharpness': sharpness,
|
||||
'ShearX': shear_x,
|
||||
'ShearY': shear_y,
|
||||
'TranslateX': translate_x,
|
||||
'TranslateY': translate_y,
|
||||
'Cutout': cutout,
|
||||
}
|
||||
|
||||
|
||||
def RandAugment(image, num_layers=2, magnitude=_MAX_LEVEL, augments=AUGMENTS):
|
||||
"""Random Augment for images, followed google randaug and the paper(https://arxiv.org/abs/2106.10270)
|
||||
:param image: the input image, in tensor format with shape of C, H, W
|
||||
:type image: uint8 Tensor
|
||||
:num_layers: how many layers will the randaug do, default=2
|
||||
:type num_layers: int
|
||||
:param magnitude: the magnitude of random augment, default=10
|
||||
:type magnitude: int
|
||||
"""
|
||||
if np.random.random() < 0.5:
|
||||
return image
|
||||
Choice_Augment = np.random.choice(a=list(augments.keys()),
|
||||
size=num_layers,
|
||||
replace=False)
|
||||
magnitude = float(magnitude)
|
||||
for i in range(num_layers):
|
||||
arg = level(_HPARAMS)[Choice_Augment[i]](magnitude)
|
||||
if arg is None:
|
||||
image = augments[Choice_Augment[i]](image)
|
||||
else:
|
||||
image = augments[Choice_Augment[i]](image, arg)
|
||||
return image
|
|
@ -1,21 +0,0 @@
|
|||
import torch.nn as nn
|
||||
from colossalai.registry import LOSSES
|
||||
import torch
|
||||
|
||||
|
||||
@LOSSES.register_module
|
||||
class MixupLoss(nn.Module):
|
||||
def __init__(self, loss_fn_cls):
|
||||
super().__init__()
|
||||
self.loss_fn = loss_fn_cls()
|
||||
|
||||
def forward(self, inputs, targets_a, targets_b, lam):
|
||||
return lam * self.loss_fn(inputs, targets_a) + (1 - lam) * self.loss_fn(inputs, targets_b)
|
||||
|
||||
|
||||
class MixupAccuracy(nn.Module):
|
||||
def forward(self, logits, targets):
|
||||
targets = targets['targets_a']
|
||||
preds = torch.argmax(logits, dim=-1)
|
||||
correct = torch.sum(targets == preds)
|
||||
return correct
|
|
@ -1,15 +0,0 @@
|
|||
from colossalai.trainer.hooks import BaseHook
|
||||
from colossalai.core import global_context as gpc
|
||||
from colossalai.context import ParallelMode
|
||||
from colossalai.logging import get_dist_logger
|
||||
|
||||
|
||||
class TotalBatchsizeHook(BaseHook):
|
||||
def __init__(self, priority: int = 2) -> None:
|
||||
super().__init__(priority)
|
||||
self.logger = get_dist_logger()
|
||||
|
||||
def before_train(self, trainer):
|
||||
total_batch_size = gpc.config.BATCH_SIZE * \
|
||||
gpc.config.gradient_accumulation * gpc.get_world_size(ParallelMode.DATA)
|
||||
self.logger.info(f'Total batch size = {total_batch_size}', ranks=[0])
|
Before Width: | Height: | Size: 19 KiB |
Before Width: | Height: | Size: 22 KiB |
|
@ -1,3 +0,0 @@
|
|||
#!/usr/bin/env bash
|
||||
|
||||
python train.py --host $HOST --config ./config.py --port 29500
|
|
@ -1,121 +0,0 @@
|
|||
import glob
|
||||
from math import log
|
||||
import os
|
||||
import colossalai
|
||||
from colossalai.nn.metric import Accuracy
|
||||
import torch
|
||||
|
||||
from colossalai.context import ParallelMode
|
||||
from colossalai.core import global_context as gpc
|
||||
from colossalai.logging import get_dist_logger
|
||||
from colossalai.trainer import Trainer, hooks
|
||||
from colossalai.nn.lr_scheduler import LinearWarmupLR
|
||||
from dataloader.imagenet_dali_dataloader import DaliDataloader
|
||||
from mixup import MixupLoss, MixupAccuracy
|
||||
from timm.models import vit_base_patch16_224
|
||||
from myhooks import TotalBatchsizeHook
|
||||
|
||||
|
||||
def build_dali_train():
|
||||
root = gpc.config.dali.root
|
||||
train_pat = os.path.join(root, 'train/*')
|
||||
train_idx_pat = os.path.join(root, 'idx_files/train/*')
|
||||
return DaliDataloader(
|
||||
sorted(glob.glob(train_pat)),
|
||||
sorted(glob.glob(train_idx_pat)),
|
||||
batch_size=gpc.config.BATCH_SIZE,
|
||||
shard_id=gpc.get_local_rank(ParallelMode.DATA),
|
||||
num_shards=gpc.get_world_size(ParallelMode.DATA),
|
||||
gpu_aug=gpc.config.dali.gpu_aug,
|
||||
cuda=True,
|
||||
mixup_alpha=gpc.config.dali.mixup_alpha,
|
||||
randaug_num_layers=2
|
||||
)
|
||||
|
||||
|
||||
def build_dali_test():
|
||||
root = gpc.config.dali.root
|
||||
val_pat = os.path.join(root, 'validation/*')
|
||||
val_idx_pat = os.path.join(root, 'idx_files/validation/*')
|
||||
return DaliDataloader(
|
||||
sorted(glob.glob(val_pat)),
|
||||
sorted(glob.glob(val_idx_pat)),
|
||||
batch_size=gpc.config.BATCH_SIZE,
|
||||
shard_id=gpc.get_local_rank(ParallelMode.DATA),
|
||||
num_shards=gpc.get_world_size(ParallelMode.DATA),
|
||||
training=False,
|
||||
# gpu_aug=gpc.config.dali.gpu_aug,
|
||||
gpu_aug=False,
|
||||
cuda=True,
|
||||
mixup_alpha=gpc.config.dali.mixup_alpha
|
||||
)
|
||||
|
||||
|
||||
def main():
|
||||
# initialize distributed setting
|
||||
parser = colossalai.get_default_parser()
|
||||
args = parser.parse_args()
|
||||
|
||||
# launch from slurm batch job
|
||||
colossalai.launch_from_slurm(config=args.config,
|
||||
host=args.host,
|
||||
port=args.port,
|
||||
backend=args.backend
|
||||
)
|
||||
# launch from torch
|
||||
# colossalai.launch_from_torch(config=args.config)
|
||||
|
||||
# get logger
|
||||
logger = get_dist_logger()
|
||||
logger.info("initialized distributed environment", ranks=[0])
|
||||
|
||||
# build model
|
||||
model = vit_base_patch16_224(drop_rate=0.1)
|
||||
|
||||
# build dataloader
|
||||
train_dataloader = build_dali_train()
|
||||
test_dataloader = build_dali_test()
|
||||
|
||||
# build optimizer
|
||||
optimizer = colossalai.nn.Lamb(model.parameters(), lr=1.8e-2, weight_decay=0.1)
|
||||
|
||||
# build loss
|
||||
criterion = MixupLoss(loss_fn_cls=torch.nn.CrossEntropyLoss)
|
||||
|
||||
# lr_scheduelr
|
||||
lr_scheduler = LinearWarmupLR(optimizer, warmup_steps=50, total_steps=gpc.config.NUM_EPOCHS)
|
||||
|
||||
engine, train_dataloader, test_dataloader, _ = colossalai.initialize(
|
||||
model, optimizer, criterion, train_dataloader, test_dataloader
|
||||
)
|
||||
logger.info("initialized colossalai components", ranks=[0])
|
||||
|
||||
# build trainer
|
||||
trainer = Trainer(engine=engine, logger=logger)
|
||||
|
||||
# build hooks
|
||||
hook_list = [
|
||||
hooks.LossHook(),
|
||||
hooks.AccuracyHook(accuracy_func=MixupAccuracy()),
|
||||
hooks.LogMetricByEpochHook(logger),
|
||||
hooks.LRSchedulerHook(lr_scheduler, by_epoch=True),
|
||||
TotalBatchsizeHook(),
|
||||
|
||||
# comment if you do not need to use the hooks below
|
||||
hooks.SaveCheckpointHook(interval=1, checkpoint_dir='./ckpt'),
|
||||
hooks.TensorboardHook(log_dir='./tb_logs', ranks=[0]),
|
||||
]
|
||||
|
||||
# start training
|
||||
trainer.fit(
|
||||
train_dataloader=train_dataloader,
|
||||
test_dataloader=test_dataloader,
|
||||
epochs=gpc.config.NUM_EPOCHS,
|
||||
hooks=hook_list,
|
||||
display_progress=True,
|
||||
test_interval=1
|
||||
)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
|
@ -1,30 +0,0 @@
|
|||
# Overview
|
||||
|
||||
MoE is a new technique to enlarge neural networks while keeping the same throughput in our training.
|
||||
It is designed to improve the performance of our models without any additional time penalty. But now using
|
||||
our temporary moe parallelism will cause a moderate computation overhead and additoinal communication time.
|
||||
The communication time depends on the topology of network in running environment. At present, moe parallelism
|
||||
may not meet what you want. Optimized version of moe parallelism will come soon.
|
||||
|
||||
This is a simple example about how to run widenet-tiny on cifar10. More information about widenet can be
|
||||
found [here](https://arxiv.org/abs/2107.11817).
|
||||
|
||||
# How to run
|
||||
|
||||
On a single server, you can directly use torchrun to start pre-training on multiple GPUs in parallel.
|
||||
If you use the script here to train, just use follow instruction in your terminal. `n_proc` is the
|
||||
number of processes which commonly equals to the number GPUs.
|
||||
|
||||
```shell
|
||||
torchrun --nnodes=1 --nproc_per_node=4 train.py \
|
||||
--config ./config.py
|
||||
```
|
||||
|
||||
If you want to use multi servers, please check our document about environment initialization.
|
||||
|
||||
Make sure to initialize moe running environment by `moe_set_seed` before building the model.
|
||||
|
||||
# Result
|
||||
|
||||
The result of training widenet-tiny on cifar10 from scratch is 89.93%. Since moe makes the model larger
|
||||
than other vit-tiny models, mixup and rand augmentation is needed.
|
|
@ -1,15 +0,0 @@
|
|||
BATCH_SIZE = 512
|
||||
LEARNING_RATE = 2e-3
|
||||
WEIGHT_DECAY = 3e-2
|
||||
|
||||
NUM_EPOCHS = 200
|
||||
WARMUP_EPOCHS = 40
|
||||
|
||||
WORLD_SIZE = 4
|
||||
MOE_MODEL_PARALLEL_SIZE = 4
|
||||
|
||||
parallel = dict(
|
||||
moe=dict(size=MOE_MODEL_PARALLEL_SIZE)
|
||||
)
|
||||
|
||||
LOG_PATH = f"./cifar10_moe"
|
|
@ -1,119 +0,0 @@
|
|||
import os
|
||||
|
||||
import colossalai
|
||||
import torch
|
||||
import torchvision
|
||||
from torchvision import transforms
|
||||
from colossalai.core import global_context as gpc
|
||||
from colossalai.logging import get_dist_logger
|
||||
from colossalai.nn import Accuracy
|
||||
from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR
|
||||
from colossalai.trainer import Trainer
|
||||
from colossalai.trainer.hooks import (AccuracyHook, LogMemoryByEpochHook,
|
||||
LogMetricByEpochHook,
|
||||
LogMetricByStepHook,
|
||||
LogTimingByEpochHook, LossHook,
|
||||
LRSchedulerHook, ThroughputHook)
|
||||
from colossalai.utils import MultiTimer, get_dataloader
|
||||
from colossalai.nn.loss import MoeCrossEntropyLoss
|
||||
from model_zoo.moe.models import Widenet
|
||||
from colossalai.context.random import moe_set_seed
|
||||
|
||||
DATASET_PATH = str(os.environ['DATA']) # The directory of your dataset
|
||||
|
||||
|
||||
def build_cifar(batch_size):
|
||||
transform_train = transforms.Compose([
|
||||
transforms.RandomCrop(32, padding=4),
|
||||
transforms.AutoAugment(policy=transforms.AutoAugmentPolicy.CIFAR10),
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
|
||||
])
|
||||
transform_test = transforms.Compose([
|
||||
transforms.Resize(32),
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
|
||||
])
|
||||
|
||||
train_dataset = torchvision.datasets.CIFAR10(root=DATASET_PATH,
|
||||
train=True,
|
||||
download=True,
|
||||
transform=transform_train)
|
||||
test_dataset = torchvision.datasets.CIFAR10(root=DATASET_PATH, train=False, transform=transform_test)
|
||||
train_dataloader = get_dataloader(dataset=train_dataset,
|
||||
shuffle=True,
|
||||
batch_size=batch_size,
|
||||
num_workers=4,
|
||||
pin_memory=True)
|
||||
test_dataloader = get_dataloader(dataset=test_dataset, batch_size=batch_size, num_workers=4, pin_memory=True)
|
||||
return train_dataloader, test_dataloader
|
||||
|
||||
|
||||
def train_cifar():
|
||||
args = colossalai.get_default_parser().parse_args()
|
||||
colossalai.launch_from_torch(config=args.config)
|
||||
|
||||
logger = get_dist_logger()
|
||||
if hasattr(gpc.config, 'LOG_PATH'):
|
||||
if gpc.get_global_rank() == 0:
|
||||
log_path = gpc.config.LOG_PATH
|
||||
if not os.path.exists(log_path):
|
||||
os.mkdir(log_path)
|
||||
logger.log_to_file(log_path)
|
||||
|
||||
moe_set_seed(42)
|
||||
model = Widenet(
|
||||
num_experts=4,
|
||||
capacity_factor=1.2,
|
||||
img_size=32,
|
||||
patch_size=4,
|
||||
num_classes=10,
|
||||
depth=6,
|
||||
d_model=512,
|
||||
num_heads=2,
|
||||
d_kv=128,
|
||||
d_ff=2048
|
||||
)
|
||||
|
||||
train_dataloader, test_dataloader = build_cifar(gpc.config.BATCH_SIZE // gpc.data_parallel_size)
|
||||
criterion = MoeCrossEntropyLoss(aux_weight=0.01, label_smoothing=0.1)
|
||||
optimizer = torch.optim.AdamW(model.parameters(), lr=gpc.config.LEARNING_RATE,
|
||||
weight_decay=gpc.config.WEIGHT_DECAY)
|
||||
|
||||
lr_scheduler = CosineAnnealingWarmupLR(optimizer=optimizer,
|
||||
total_steps=gpc.config.NUM_EPOCHS,
|
||||
warmup_steps=gpc.config.WARMUP_EPOCHS)
|
||||
|
||||
engine, train_dataloader, test_dataloader, lr_scheduler = colossalai.initialize(model=model,
|
||||
optimizer=optimizer,
|
||||
criterion=criterion,
|
||||
train_dataloader=train_dataloader,
|
||||
test_dataloader=test_dataloader,
|
||||
lr_scheduler=lr_scheduler)
|
||||
|
||||
logger.info("Engine is built", ranks=[0])
|
||||
|
||||
timer = MultiTimer()
|
||||
trainer = Trainer(engine=engine, logger=logger, timer=timer)
|
||||
logger.info("Trainer is built", ranks=[0])
|
||||
|
||||
hooks = [
|
||||
LogMetricByEpochHook(logger=logger),
|
||||
LogMetricByStepHook(),
|
||||
AccuracyHook(accuracy_func=Accuracy()),
|
||||
LossHook(),
|
||||
ThroughputHook(),
|
||||
LRSchedulerHook(lr_scheduler=lr_scheduler, by_epoch=True)
|
||||
]
|
||||
|
||||
logger.info("Train start", ranks=[0])
|
||||
trainer.fit(train_dataloader=train_dataloader,
|
||||
test_dataloader=test_dataloader,
|
||||
epochs=gpc.config.NUM_EPOCHS,
|
||||
hooks=hooks,
|
||||
display_progress=True,
|
||||
test_interval=1)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
train_cifar()
|