Added moe parallel example (#140)

pull/152/head
HELSON 2022-01-17 15:34:04 +08:00 committed by GitHub
parent f68eddfb3d
commit 1ff5be36c2
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 168 additions and 3 deletions

View File

@ -0,0 +1,30 @@
# Overview
MoE is a new technique to enlarge neural networks while keeping the same throughput in our training.
It is designed to improve the performance of our models without any additional time penalty. But now using
our temporary moe parallelism will cause a moderate computation overhead and additoinal communication time.
The communication time depends on the topology of network in running environment. At present, moe parallelism
may not meet what you want. Optimized version of moe parallelism will come soon.
This is a simple example about how to run widenet-tiny on cifar10. More information about widenet can be
found [here](https://arxiv.org/abs/2107.11817).
# How to run
On a single server, you can directly use torchrun to start pre-training on multiple GPUs in parallel.
If you use the script here to train, just use follow instruction in your terminal. `n_proc` is the
number of processes which commonly equals to the number GPUs.
```shell
torchrun --nnodes=1 --nproc_per_node=4 train.py \
--config ./config.py
```
If you want to use multi servers, please check our document about environment initialization.
Make sure to initialize moe running environment by `moe_set_seed` before building the model.
# Result
The result of training widenet-tiny on cifar10 from scratch is 89.93%. Since moe makes the model larger
than other vit-tiny models, mixup and rand augmentation is needed.

View File

@ -0,0 +1,15 @@
BATCH_SIZE = 512
LEARNING_RATE = 2e-3
WEIGHT_DECAY = 3e-2
NUM_EPOCHS = 200
WARMUP_EPOCHS = 40
WORLD_SIZE = 4
MOE_MODEL_PARALLEL_SIZE = 4
parallel = dict(
moe=dict(size=MOE_MODEL_PARALLEL_SIZE)
)
LOG_PATH = f"./cifar10_moe"

View File

@ -0,0 +1,119 @@
import os
import colossalai
import torch
import torchvision
from torchvision import transforms
from colossalai.core import global_context as gpc
from colossalai.logging import get_dist_logger
from colossalai.nn import Accuracy
from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR
from colossalai.trainer import Trainer
from colossalai.trainer.hooks import (AccuracyHook, LogMemoryByEpochHook,
LogMetricByEpochHook,
LogMetricByStepHook,
LogTimingByEpochHook, LossHook,
LRSchedulerHook, ThroughputHook)
from colossalai.utils import MultiTimer, get_dataloader
from colossalai.nn.loss import MoeCrossEntropyLoss
from model_zoo.moe.models import Widenet
from colossalai.context.random import moe_set_seed
DATASET_PATH = str(os.environ['DATA']) # The directory of your dataset
def build_cifar(batch_size):
transform_train = transforms.Compose([
transforms.RandomCrop(32, padding=4),
transforms.AutoAugment(policy=transforms.AutoAugmentPolicy.CIFAR10),
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])
transform_test = transforms.Compose([
transforms.Resize(32),
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])
train_dataset = torchvision.datasets.CIFAR10(root=DATASET_PATH,
train=True,
download=True,
transform=transform_train)
test_dataset = torchvision.datasets.CIFAR10(root=DATASET_PATH, train=False, transform=transform_test)
train_dataloader = get_dataloader(dataset=train_dataset,
shuffle=True,
batch_size=batch_size,
num_workers=4,
pin_memory=True)
test_dataloader = get_dataloader(dataset=test_dataset, batch_size=batch_size, num_workers=4, pin_memory=True)
return train_dataloader, test_dataloader
def train_cifar():
args = colossalai.get_default_parser().parse_args()
colossalai.launch_from_torch(config=args.config)
logger = get_dist_logger()
if hasattr(gpc.config, 'LOG_PATH'):
if gpc.get_global_rank() == 0:
log_path = gpc.config.LOG_PATH
if not os.path.exists(log_path):
os.mkdir(log_path)
logger.log_to_file(log_path)
moe_set_seed(42)
model = Widenet(
num_experts=4,
capacity_factor=1.2,
img_size=32,
patch_size=4,
num_classes=10,
depth=6,
d_model=512,
num_heads=2,
d_kv=128,
d_ff=2048
)
train_dataloader, test_dataloader = build_cifar(gpc.config.BATCH_SIZE // gpc.data_parallel_size)
criterion = MoeCrossEntropyLoss(aux_weight=0.01, label_smoothing=0.1)
optimizer = torch.optim.AdamW(model.parameters(), lr=gpc.config.LEARNING_RATE,
weight_decay=gpc.config.WEIGHT_DECAY)
lr_scheduler = CosineAnnealingWarmupLR(optimizer=optimizer,
total_steps=gpc.config.NUM_EPOCHS,
warmup_steps=gpc.config.WARMUP_EPOCHS)
engine, train_dataloader, test_dataloader, lr_scheduler = colossalai.initialize(model=model,
optimizer=optimizer,
criterion=criterion,
train_dataloader=train_dataloader,
test_dataloader=test_dataloader,
lr_scheduler=lr_scheduler)
logger.info("Engine is built", ranks=[0])
timer = MultiTimer()
trainer = Trainer(engine=engine, logger=logger, timer=timer)
logger.info("Trainer is built", ranks=[0])
hooks = [
LogMetricByEpochHook(logger=logger),
LogMetricByStepHook(),
AccuracyHook(accuracy_func=Accuracy()),
LossHook(),
ThroughputHook(),
LRSchedulerHook(lr_scheduler=lr_scheduler, by_epoch=True)
]
logger.info("Train start", ranks=[0])
trainer.fit(train_dataloader=train_dataloader,
test_dataloader=test_dataloader,
epochs=gpc.config.NUM_EPOCHS,
hooks=hooks,
display_progress=True,
test_interval=1)
if __name__ == '__main__':
train_cifar()

View File

@ -14,6 +14,7 @@ from colossalai.utils import get_current_device
class VanillaSelfAttention(nn.Module):
"""Standard ViT self attention.
"""
def __init__(self,
d_model: int,
n_heads: int,
@ -57,6 +58,7 @@ class VanillaSelfAttention(nn.Module):
class VanillaFFN(nn.Module):
"""FFN composed with two linear layers, also called MLP.
"""
def __init__(self,
d_model: int,
d_ff: int,
@ -72,8 +74,7 @@ class VanillaFFN(nn.Module):
drop1 = nn.Dropout(drop_rate) if dropout1 is None else dropout1
drop2 = nn.Dropout(drop_rate) if dropout2 is None else dropout2
self.ffn = nn.Sequential(
dense1, act, drop1,dense2, drop2)
self.ffn = nn.Sequential(dense1, act, drop1, dense2, drop2)
def forward(self, x):
return self.ffn(x)
@ -91,7 +92,7 @@ class Widenet(nn.Module):
d_model: int = 768,
num_heads: int = 12,
d_kv: int = 64,
d_ff: int = 3072,
d_ff: int = 4096,
attention_drop: float = 0.,
drop_rate: float = 0.1,
drop_path: float = 0.):