mirror of https://github.com/hpcaitech/ColossalAI
[example] updated large-batch optimizer tutorial (#2448)
* [example] updated large-batch optimizer tutorial * polish code * polish codepull/3058/head
parent
2bfeb24308
commit
ac18a445fa
|
@ -1,31 +1,35 @@
|
||||||
# Comparison of Large Batch Training Optimization
|
# Comparison of Large Batch Training Optimization
|
||||||
|
|
||||||
|
## Table of contents
|
||||||
|
|
||||||
|
- [Overview](#-overview)
|
||||||
|
- [Quick Start](#-quick-start)
|
||||||
|
|
||||||
|
## 📚 Overview
|
||||||
|
|
||||||
|
This example lets you to quickly try out the large batch training optimization provided by Colossal-AI. We use synthetic dataset to go through the process, thus, you don't need to prepare any dataset. You can try out the `Lamb` and `Lars` optimizers from Colossal-AI with the following code.
|
||||||
|
|
||||||
|
```python
|
||||||
|
from colossalai.nn.optimizer import Lamb, Lars
|
||||||
|
```
|
||||||
|
|
||||||
## 🚀 Quick Start
|
## 🚀 Quick Start
|
||||||
Run with synthetic data
|
|
||||||
```bash
|
|
||||||
colossalai run --nproc_per_node 4 train.py --config config.py -s
|
|
||||||
```
|
|
||||||
|
|
||||||
|
1. Install PyTorch
|
||||||
|
|
||||||
## Prepare Dataset
|
2. Install the dependencies.
|
||||||
|
|
||||||
We use CIFAR10 dataset in this example. You should invoke the `donwload_cifar10.py` in the tutorial root directory or directly run the `auto_parallel_with_resnet.py`.
|
|
||||||
The dataset will be downloaded to `colossalai/examples/tutorials/data` by default.
|
|
||||||
If you wish to use customized directory for the dataset. You can set the environment variable `DATA` via the following command.
|
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
export DATA=/path/to/data
|
pip install -r requirements.txt
|
||||||
```
|
```
|
||||||
|
|
||||||
You can also use synthetic data for this tutorial if you don't wish to download the `CIFAR10` dataset by adding the `-s` or `--synthetic` flag to the command.
|
3. Run the training scripts with synthetic data.
|
||||||
|
|
||||||
|
|
||||||
## Run on 2*2 device mesh
|
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
# run with cifar10
|
# run on 4 GPUs
|
||||||
colossalai run --nproc_per_node 4 train.py --config config.py
|
# run with lars
|
||||||
|
colossalai run --nproc_per_node 4 train.py --config config.py --optimizer lars
|
||||||
|
|
||||||
# run with synthetic dataset
|
# run with lamb
|
||||||
colossalai run --nproc_per_node 4 train.py --config config.py -s
|
colossalai run --nproc_per_node 4 train.py --config config.py --optimizer lamb
|
||||||
```
|
```
|
||||||
|
|
|
@ -6,31 +6,11 @@ from colossalai.amp import AMP_TYPE
|
||||||
BATCH_SIZE = 512
|
BATCH_SIZE = 512
|
||||||
LEARNING_RATE = 3e-3
|
LEARNING_RATE = 3e-3
|
||||||
WEIGHT_DECAY = 0.3
|
WEIGHT_DECAY = 0.3
|
||||||
NUM_EPOCHS = 10
|
NUM_EPOCHS = 2
|
||||||
WARMUP_EPOCHS = 3
|
WARMUP_EPOCHS = 1
|
||||||
|
|
||||||
# model config
|
# model config
|
||||||
IMG_SIZE = 224
|
NUM_CLASSES = 10
|
||||||
PATCH_SIZE = 16
|
|
||||||
HIDDEN_SIZE = 512
|
|
||||||
DEPTH = 4
|
|
||||||
NUM_HEADS = 4
|
|
||||||
MLP_RATIO = 2
|
|
||||||
NUM_CLASSES = 1000
|
|
||||||
CHECKPOINT = False
|
|
||||||
SEQ_LENGTH = (IMG_SIZE // PATCH_SIZE)**2 + 1 # add 1 for cls token
|
|
||||||
|
|
||||||
# parallel setting
|
|
||||||
TENSOR_PARALLEL_SIZE = 2
|
|
||||||
TENSOR_PARALLEL_MODE = '1d'
|
|
||||||
|
|
||||||
parallel = dict(
|
|
||||||
pipeline=2,
|
|
||||||
tensor=dict(mode=TENSOR_PARALLEL_MODE, size=TENSOR_PARALLEL_SIZE),
|
|
||||||
)
|
|
||||||
|
|
||||||
fp16 = dict(mode=AMP_TYPE.NAIVE)
|
fp16 = dict(mode=AMP_TYPE.NAIVE)
|
||||||
clip_grad_norm = 1.0
|
clip_grad_norm = 1.0
|
||||||
|
|
||||||
# pipeline config
|
|
||||||
NUM_MICRO_BATCHES = parallel['pipeline']
|
|
||||||
|
|
|
@ -1,2 +1,3 @@
|
||||||
colossalai >= 0.1.12
|
colossalai
|
||||||
torch >= 1.8.1
|
torch
|
||||||
|
titans
|
||||||
|
|
|
@ -0,0 +1,8 @@
|
||||||
|
#!/bin/bash
|
||||||
|
set -euxo pipefail
|
||||||
|
|
||||||
|
pip install -r requirements.txt
|
||||||
|
|
||||||
|
# run test
|
||||||
|
colossalai run --nproc_per_node 4 --master_port 29500 train.py --config config.py --optimizer lars
|
||||||
|
colossalai run --nproc_per_node 4 --master_port 29501 train.py --config config.py --optimizer lamb
|
|
@ -1,19 +1,13 @@
|
||||||
import os
|
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from titans.dataloader.cifar10 import build_cifar
|
import torch.nn as nn
|
||||||
from titans.model.vit.vit import _create_vit_model
|
from torchvision.models import resnet18
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
import colossalai
|
import colossalai
|
||||||
from colossalai.context import ParallelMode
|
|
||||||
from colossalai.core import global_context as gpc
|
from colossalai.core import global_context as gpc
|
||||||
from colossalai.logging import get_dist_logger
|
from colossalai.logging import get_dist_logger
|
||||||
from colossalai.nn import CrossEntropyLoss
|
|
||||||
from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR
|
from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR
|
||||||
from colossalai.nn.optimizer import Lamb, Lars
|
from colossalai.nn.optimizer import Lamb, Lars
|
||||||
from colossalai.pipeline.pipelinable import PipelinableContext
|
|
||||||
from colossalai.utils import get_dataloader, is_using_pp
|
|
||||||
|
|
||||||
|
|
||||||
class DummyDataloader():
|
class DummyDataloader():
|
||||||
|
@ -45,7 +39,10 @@ class DummyDataloader():
|
||||||
def main():
|
def main():
|
||||||
# initialize distributed setting
|
# initialize distributed setting
|
||||||
parser = colossalai.get_default_parser()
|
parser = colossalai.get_default_parser()
|
||||||
parser.add_argument('-s', '--synthetic', action="store_true", help="whether use synthetic data")
|
parser.add_argument('--optimizer',
|
||||||
|
choices=['lars', 'lamb'],
|
||||||
|
help="Choose your large-batch optimizer",
|
||||||
|
required=True)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
# launch from torch
|
# launch from torch
|
||||||
|
@ -55,59 +52,22 @@ def main():
|
||||||
logger = get_dist_logger()
|
logger = get_dist_logger()
|
||||||
logger.info("initialized distributed environment", ranks=[0])
|
logger.info("initialized distributed environment", ranks=[0])
|
||||||
|
|
||||||
if hasattr(gpc.config, 'LOG_PATH'):
|
# create synthetic dataloaders
|
||||||
if gpc.get_global_rank() == 0:
|
train_dataloader = DummyDataloader(length=10, batch_size=gpc.config.BATCH_SIZE)
|
||||||
log_path = gpc.config.LOG_PATH
|
test_dataloader = DummyDataloader(length=5, batch_size=gpc.config.BATCH_SIZE)
|
||||||
if not os.path.exists(log_path):
|
|
||||||
os.mkdir(log_path)
|
|
||||||
logger.log_to_file(log_path)
|
|
||||||
|
|
||||||
use_pipeline = is_using_pp()
|
# build model
|
||||||
|
model = resnet18(num_classes=gpc.config.NUM_CLASSES)
|
||||||
# create model
|
|
||||||
model_kwargs = dict(img_size=gpc.config.IMG_SIZE,
|
|
||||||
patch_size=gpc.config.PATCH_SIZE,
|
|
||||||
hidden_size=gpc.config.HIDDEN_SIZE,
|
|
||||||
depth=gpc.config.DEPTH,
|
|
||||||
num_heads=gpc.config.NUM_HEADS,
|
|
||||||
mlp_ratio=gpc.config.MLP_RATIO,
|
|
||||||
num_classes=10,
|
|
||||||
init_method='jax',
|
|
||||||
checkpoint=gpc.config.CHECKPOINT)
|
|
||||||
|
|
||||||
if use_pipeline:
|
|
||||||
pipelinable = PipelinableContext()
|
|
||||||
with pipelinable:
|
|
||||||
model = _create_vit_model(**model_kwargs)
|
|
||||||
pipelinable.to_layer_list()
|
|
||||||
pipelinable.policy = "uniform"
|
|
||||||
model = pipelinable.partition(1, gpc.pipeline_parallel_size, gpc.get_local_rank(ParallelMode.PIPELINE))
|
|
||||||
else:
|
|
||||||
model = _create_vit_model(**model_kwargs)
|
|
||||||
|
|
||||||
# count number of parameters
|
|
||||||
total_numel = 0
|
|
||||||
for p in model.parameters():
|
|
||||||
total_numel += p.numel()
|
|
||||||
if not gpc.is_initialized(ParallelMode.PIPELINE):
|
|
||||||
pipeline_stage = 0
|
|
||||||
else:
|
|
||||||
pipeline_stage = gpc.get_local_rank(ParallelMode.PIPELINE)
|
|
||||||
logger.info(f"number of parameters: {total_numel} on pipeline stage {pipeline_stage}")
|
|
||||||
|
|
||||||
# create dataloaders
|
|
||||||
root = os.environ.get('DATA', '../data/')
|
|
||||||
if args.synthetic:
|
|
||||||
train_dataloader = DummyDataloader(length=30, batch_size=gpc.config.BATCH_SIZE)
|
|
||||||
test_dataloader = DummyDataloader(length=10, batch_size=gpc.config.BATCH_SIZE)
|
|
||||||
else:
|
|
||||||
train_dataloader, test_dataloader = build_cifar(gpc.config.BATCH_SIZE, root, pad_if_needed=True)
|
|
||||||
|
|
||||||
# create loss function
|
# create loss function
|
||||||
criterion = CrossEntropyLoss(label_smoothing=0.1)
|
criterion = nn.CrossEntropyLoss()
|
||||||
|
|
||||||
# create optimizer
|
# create optimizer
|
||||||
optimizer = Lars(model.parameters(), lr=gpc.config.LEARNING_RATE, weight_decay=gpc.config.WEIGHT_DECAY)
|
if args.optimizer == "lars":
|
||||||
|
optim_cls = Lars
|
||||||
|
elif args.optimizer == "lamb":
|
||||||
|
optim_cls = Lamb
|
||||||
|
optimizer = optim_cls(model.parameters(), lr=gpc.config.LEARNING_RATE, weight_decay=gpc.config.WEIGHT_DECAY)
|
||||||
|
|
||||||
# create lr scheduler
|
# create lr scheduler
|
||||||
lr_scheduler = CosineAnnealingWarmupLR(optimizer=optimizer,
|
lr_scheduler = CosineAnnealingWarmupLR(optimizer=optimizer,
|
||||||
|
|
Loading…
Reference in New Issue