mirror of https://github.com/hpcaitech/ColossalAI
Added moe parallel example (#140)
parent
f68eddfb3d
commit
1ff5be36c2
@ -0,0 +1,30 @@
|
|||||||
|
# Overview
|
||||||
|
|
||||||
|
MoE is a new technique to enlarge neural networks while keeping the same throughput in our training.
|
||||||
|
It is designed to improve the performance of our models without any additional time penalty. But now using
|
||||||
|
our temporary moe parallelism will cause a moderate computation overhead and additoinal communication time.
|
||||||
|
The communication time depends on the topology of network in running environment. At present, moe parallelism
|
||||||
|
may not meet what you want. Optimized version of moe parallelism will come soon.
|
||||||
|
|
||||||
|
This is a simple example about how to run widenet-tiny on cifar10. More information about widenet can be
|
||||||
|
found [here](https://arxiv.org/abs/2107.11817).
|
||||||
|
|
||||||
|
# How to run
|
||||||
|
|
||||||
|
On a single server, you can directly use torchrun to start pre-training on multiple GPUs in parallel.
|
||||||
|
If you use the script here to train, just use follow instruction in your terminal. `n_proc` is the
|
||||||
|
number of processes which commonly equals to the number GPUs.
|
||||||
|
|
||||||
|
```shell
|
||||||
|
torchrun --nnodes=1 --nproc_per_node=4 train.py \
|
||||||
|
--config ./config.py
|
||||||
|
```
|
||||||
|
|
||||||
|
If you want to use multi servers, please check our document about environment initialization.
|
||||||
|
|
||||||
|
Make sure to initialize moe running environment by `moe_set_seed` before building the model.
|
||||||
|
|
||||||
|
# Result
|
||||||
|
|
||||||
|
The result of training widenet-tiny on cifar10 from scratch is 89.93%. Since moe makes the model larger
|
||||||
|
than other vit-tiny models, mixup and rand augmentation is needed.
|
@ -0,0 +1,15 @@
|
|||||||
|
BATCH_SIZE = 512
|
||||||
|
LEARNING_RATE = 2e-3
|
||||||
|
WEIGHT_DECAY = 3e-2
|
||||||
|
|
||||||
|
NUM_EPOCHS = 200
|
||||||
|
WARMUP_EPOCHS = 40
|
||||||
|
|
||||||
|
WORLD_SIZE = 4
|
||||||
|
MOE_MODEL_PARALLEL_SIZE = 4
|
||||||
|
|
||||||
|
parallel = dict(
|
||||||
|
moe=dict(size=MOE_MODEL_PARALLEL_SIZE)
|
||||||
|
)
|
||||||
|
|
||||||
|
LOG_PATH = f"./cifar10_moe"
|
@ -0,0 +1,119 @@
|
|||||||
|
import os
|
||||||
|
|
||||||
|
import colossalai
|
||||||
|
import torch
|
||||||
|
import torchvision
|
||||||
|
from torchvision import transforms
|
||||||
|
from colossalai.core import global_context as gpc
|
||||||
|
from colossalai.logging import get_dist_logger
|
||||||
|
from colossalai.nn import Accuracy
|
||||||
|
from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR
|
||||||
|
from colossalai.trainer import Trainer
|
||||||
|
from colossalai.trainer.hooks import (AccuracyHook, LogMemoryByEpochHook,
|
||||||
|
LogMetricByEpochHook,
|
||||||
|
LogMetricByStepHook,
|
||||||
|
LogTimingByEpochHook, LossHook,
|
||||||
|
LRSchedulerHook, ThroughputHook)
|
||||||
|
from colossalai.utils import MultiTimer, get_dataloader
|
||||||
|
from colossalai.nn.loss import MoeCrossEntropyLoss
|
||||||
|
from model_zoo.moe.models import Widenet
|
||||||
|
from colossalai.context.random import moe_set_seed
|
||||||
|
|
||||||
|
DATASET_PATH = str(os.environ['DATA']) # The directory of your dataset
|
||||||
|
|
||||||
|
|
||||||
|
def build_cifar(batch_size):
|
||||||
|
transform_train = transforms.Compose([
|
||||||
|
transforms.RandomCrop(32, padding=4),
|
||||||
|
transforms.AutoAugment(policy=transforms.AutoAugmentPolicy.CIFAR10),
|
||||||
|
transforms.ToTensor(),
|
||||||
|
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
|
||||||
|
])
|
||||||
|
transform_test = transforms.Compose([
|
||||||
|
transforms.Resize(32),
|
||||||
|
transforms.ToTensor(),
|
||||||
|
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
|
||||||
|
])
|
||||||
|
|
||||||
|
train_dataset = torchvision.datasets.CIFAR10(root=DATASET_PATH,
|
||||||
|
train=True,
|
||||||
|
download=True,
|
||||||
|
transform=transform_train)
|
||||||
|
test_dataset = torchvision.datasets.CIFAR10(root=DATASET_PATH, train=False, transform=transform_test)
|
||||||
|
train_dataloader = get_dataloader(dataset=train_dataset,
|
||||||
|
shuffle=True,
|
||||||
|
batch_size=batch_size,
|
||||||
|
num_workers=4,
|
||||||
|
pin_memory=True)
|
||||||
|
test_dataloader = get_dataloader(dataset=test_dataset, batch_size=batch_size, num_workers=4, pin_memory=True)
|
||||||
|
return train_dataloader, test_dataloader
|
||||||
|
|
||||||
|
|
||||||
|
def train_cifar():
|
||||||
|
args = colossalai.get_default_parser().parse_args()
|
||||||
|
colossalai.launch_from_torch(config=args.config)
|
||||||
|
|
||||||
|
logger = get_dist_logger()
|
||||||
|
if hasattr(gpc.config, 'LOG_PATH'):
|
||||||
|
if gpc.get_global_rank() == 0:
|
||||||
|
log_path = gpc.config.LOG_PATH
|
||||||
|
if not os.path.exists(log_path):
|
||||||
|
os.mkdir(log_path)
|
||||||
|
logger.log_to_file(log_path)
|
||||||
|
|
||||||
|
moe_set_seed(42)
|
||||||
|
model = Widenet(
|
||||||
|
num_experts=4,
|
||||||
|
capacity_factor=1.2,
|
||||||
|
img_size=32,
|
||||||
|
patch_size=4,
|
||||||
|
num_classes=10,
|
||||||
|
depth=6,
|
||||||
|
d_model=512,
|
||||||
|
num_heads=2,
|
||||||
|
d_kv=128,
|
||||||
|
d_ff=2048
|
||||||
|
)
|
||||||
|
|
||||||
|
train_dataloader, test_dataloader = build_cifar(gpc.config.BATCH_SIZE // gpc.data_parallel_size)
|
||||||
|
criterion = MoeCrossEntropyLoss(aux_weight=0.01, label_smoothing=0.1)
|
||||||
|
optimizer = torch.optim.AdamW(model.parameters(), lr=gpc.config.LEARNING_RATE,
|
||||||
|
weight_decay=gpc.config.WEIGHT_DECAY)
|
||||||
|
|
||||||
|
lr_scheduler = CosineAnnealingWarmupLR(optimizer=optimizer,
|
||||||
|
total_steps=gpc.config.NUM_EPOCHS,
|
||||||
|
warmup_steps=gpc.config.WARMUP_EPOCHS)
|
||||||
|
|
||||||
|
engine, train_dataloader, test_dataloader, lr_scheduler = colossalai.initialize(model=model,
|
||||||
|
optimizer=optimizer,
|
||||||
|
criterion=criterion,
|
||||||
|
train_dataloader=train_dataloader,
|
||||||
|
test_dataloader=test_dataloader,
|
||||||
|
lr_scheduler=lr_scheduler)
|
||||||
|
|
||||||
|
logger.info("Engine is built", ranks=[0])
|
||||||
|
|
||||||
|
timer = MultiTimer()
|
||||||
|
trainer = Trainer(engine=engine, logger=logger, timer=timer)
|
||||||
|
logger.info("Trainer is built", ranks=[0])
|
||||||
|
|
||||||
|
hooks = [
|
||||||
|
LogMetricByEpochHook(logger=logger),
|
||||||
|
LogMetricByStepHook(),
|
||||||
|
AccuracyHook(accuracy_func=Accuracy()),
|
||||||
|
LossHook(),
|
||||||
|
ThroughputHook(),
|
||||||
|
LRSchedulerHook(lr_scheduler=lr_scheduler, by_epoch=True)
|
||||||
|
]
|
||||||
|
|
||||||
|
logger.info("Train start", ranks=[0])
|
||||||
|
trainer.fit(train_dataloader=train_dataloader,
|
||||||
|
test_dataloader=test_dataloader,
|
||||||
|
epochs=gpc.config.NUM_EPOCHS,
|
||||||
|
hooks=hooks,
|
||||||
|
display_progress=True,
|
||||||
|
test_interval=1)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
train_cifar()
|
Loading…
Reference in new issue