[example] updated the hybrid parallel tutorial (#2444)

* [example] updated the hybrid parallel tutorial

* polish code
pull/2451/head
Frank Lee 2023-01-11 15:17:17 +08:00 committed by GitHub
parent 5521af7877
commit 39163417a1
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 82 additions and 65 deletions

View File

@ -1,13 +1,16 @@
import click
import sys
import os import os
import torch import sys
from colossalai.context import Config
from .multinode_runner import MultiNodeRunner
from .hostinfo import HostInfo, HostInfoList
from typing import List from typing import List
import click
import torch
from packaging import version from packaging import version
from colossalai.context import Config
from .hostinfo import HostInfo, HostInfoList
from .multinode_runner import MultiNodeRunner
# Constants that define our syntax # Constants that define our syntax
NODE_SEP = ',' NODE_SEP = ','
@ -276,6 +279,33 @@ def launch_multi_processes(args: Config) -> None:
extra_launch_args=args.extra_launch_args) extra_launch_args=args.extra_launch_args)
runner.send(hostinfo=hostinfo, cmd=cmd) runner.send(hostinfo=hostinfo, cmd=cmd)
runner.recv_from_all() # start training
msg_from_node = runner.recv_from_all()
has_error = False
# print node status
click.echo("\n====== Training on All Nodes =====")
for hostname, msg in msg_from_node.items():
click.echo(f"{hostname}: {msg}")
# check if a process failed
if msg == "failure":
has_error = True
# stop all nodes
runner.stop_all() runner.stop_all()
runner.recv_from_all()
# receive the stop status
msg_from_node = runner.recv_from_all()
# printe node status
click.echo("\n====== Stopping All Nodes =====")
for hostname, msg in msg_from_node.items():
click.echo(f"{hostname}: {msg}")
# give the process an exit code
# so that it behaves like a normal process
if has_error:
sys.exit(1)
else:
sys.exit(0)

View File

@ -1,45 +1,40 @@
# Multi-dimensional Parallelism with Colossal-AI # Multi-dimensional Parallelism with Colossal-AI
## Table of contents
## 🚀Quick Start - [Overview](#-overview)
1. Install our model zoo. - [Quick Start](#-quick-start)
```bash
pip install titans ## 📚 Overview
```
2. Run with synthetic data which is of similar shape to CIFAR10 with the `-s` flag. This example lets you to quickly try out the hybrid parallelism provided by Colossal-AI.
```bash You can change the parameters below to try out different settings in the `config.py`.
colossalai run --nproc_per_node 4 train.py --config config.py -s
```python
# parallel setting
TENSOR_PARALLEL_SIZE = 2
TENSOR_PARALLEL_MODE = '1d'
parallel = dict(
pipeline=2,
tensor=dict(mode=TENSOR_PARALLEL_MODE, size=TENSOR_PARALLEL_SIZE),
)
``` ```
3. Modify the config file to play with different types of tensor parallelism, for example, change tensor parallel size to be 4 and mode to be 2d and run on 8 GPUs. ## 🚀 Quick Start
1. Install PyTorch
## Install Titans Model Zoo 2. Install the dependencies.
```bash ```bash
pip install titans pip install -r requirements.txt
``` ```
3. Run the training scripts with synthetic data.
## Prepare Dataset
We use CIFAR10 dataset in this example. You should invoke the `donwload_cifar10.py` in the tutorial root directory or directly run the `auto_parallel_with_resnet.py`.
The dataset will be downloaded to `colossalai/examples/tutorials/data` by default.
If you wish to use customized directory for the dataset. You can set the environment variable `DATA` via the following command.
```bash ```bash
export DATA=/path/to/data
```
## Run on 2*2 device mesh
Current configuration setting on `config.py` is TP=2, PP=2.
```bash
# train with cifar10
colossalai run --nproc_per_node 4 train.py --config config.py colossalai run --nproc_per_node 4 train.py --config config.py
# train with synthetic data
colossalai run --nproc_per_node 4 train.py --config config.py -s
``` ```
4. Modify the config file to play with different types of tensor parallelism, for example, change tensor parallel size to be 4 and mode to be 2d and run on 8 GPUs.

View File

@ -3,7 +3,7 @@ from colossalai.amp import AMP_TYPE
# hyperparameters # hyperparameters
# BATCH_SIZE is as per GPU # BATCH_SIZE is as per GPU
# global batch size = BATCH_SIZE x data parallel size # global batch size = BATCH_SIZE x data parallel size
BATCH_SIZE = 256 BATCH_SIZE = 4
LEARNING_RATE = 3e-3 LEARNING_RATE = 3e-3
WEIGHT_DECAY = 0.3 WEIGHT_DECAY = 0.3
NUM_EPOCHS = 2 NUM_EPOCHS = 2
@ -12,11 +12,11 @@ WARMUP_EPOCHS = 1
# model config # model config
IMG_SIZE = 224 IMG_SIZE = 224
PATCH_SIZE = 16 PATCH_SIZE = 16
HIDDEN_SIZE = 512 HIDDEN_SIZE = 128
DEPTH = 4 DEPTH = 4
NUM_HEADS = 4 NUM_HEADS = 4
MLP_RATIO = 2 MLP_RATIO = 2
NUM_CLASSES = 1000 NUM_CLASSES = 10
CHECKPOINT = False CHECKPOINT = False
SEQ_LENGTH = (IMG_SIZE // PATCH_SIZE)**2 + 1 # add 1 for cls token SEQ_LENGTH = (IMG_SIZE // PATCH_SIZE)**2 + 1 # add 1 for cls token

View File

@ -1,3 +1,3 @@
colossalai >= 0.1.12 torch
torch >= 1.8.1 colossalai
titans titans

View File

@ -2,4 +2,4 @@
set -euxo pipefail set -euxo pipefail
pip install -r requirements.txt pip install -r requirements.txt
torchrun --standalone --nproc_per_node 4 train.py --config config.py -s colossalai run --nproc_per_node 4 train.py --config config.py

View File

@ -1,7 +1,6 @@
import os import os
import torch import torch
from titans.dataloader.cifar10 import build_cifar
from titans.model.vit.vit import _create_vit_model from titans.model.vit.vit import _create_vit_model
from tqdm import tqdm from tqdm import tqdm
@ -12,7 +11,7 @@ from colossalai.logging import get_dist_logger
from colossalai.nn import CrossEntropyLoss from colossalai.nn import CrossEntropyLoss
from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR
from colossalai.pipeline.pipelinable import PipelinableContext from colossalai.pipeline.pipelinable import PipelinableContext
from colossalai.utils import get_dataloader, is_using_pp from colossalai.utils import is_using_pp
class DummyDataloader(): class DummyDataloader():
@ -42,12 +41,9 @@ class DummyDataloader():
def main(): def main():
# initialize distributed setting
parser = colossalai.get_default_parser()
parser.add_argument('-s', '--synthetic', action="store_true", help="whether use synthetic data")
args = parser.parse_args()
# launch from torch # launch from torch
parser = colossalai.get_default_parser()
args = parser.parse_args()
colossalai.launch_from_torch(config=args.config) colossalai.launch_from_torch(config=args.config)
# get logger # get logger
@ -94,15 +90,10 @@ def main():
pipeline_stage = gpc.get_local_rank(ParallelMode.PIPELINE) pipeline_stage = gpc.get_local_rank(ParallelMode.PIPELINE)
logger.info(f"number of parameters: {total_numel} on pipeline stage {pipeline_stage}") logger.info(f"number of parameters: {total_numel} on pipeline stage {pipeline_stage}")
# create dataloaders # use synthetic dataset
root = os.environ.get('DATA', '../data') # we train for 10 steps and eval for 5 steps per epoch
if args.synthetic: train_dataloader = DummyDataloader(length=10, batch_size=gpc.config.BATCH_SIZE)
# if we use synthetic dataset test_dataloader = DummyDataloader(length=5, batch_size=gpc.config.BATCH_SIZE)
# we train for 10 steps and eval for 5 steps per epoch
train_dataloader = DummyDataloader(length=10, batch_size=gpc.config.BATCH_SIZE)
test_dataloader = DummyDataloader(length=5, batch_size=gpc.config.BATCH_SIZE)
else:
train_dataloader, test_dataloader = build_cifar(gpc.config.BATCH_SIZE, root, pad_if_needed=True)
# create loss function # create loss function
criterion = CrossEntropyLoss(label_smoothing=0.1) criterion = CrossEntropyLoss(label_smoothing=0.1)
@ -139,6 +130,7 @@ def main():
engine.execute_schedule(data_iter, return_output_label=False) engine.execute_schedule(data_iter, return_output_label=False)
engine.step() engine.step()
lr_scheduler.step() lr_scheduler.step()
gpc.destroy()
if __name__ == '__main__': if __name__ == '__main__':