mirror of https://github.com/hpcaitech/ColossalAI
[example] updated the hybrid parallel tutorial (#2444)
* [example] updated the hybrid parallel tutorial * polish codepull/2451/head
parent
5521af7877
commit
39163417a1
|
@ -1,13 +1,16 @@
|
||||||
import click
|
|
||||||
import sys
|
|
||||||
import os
|
import os
|
||||||
import torch
|
import sys
|
||||||
from colossalai.context import Config
|
|
||||||
from .multinode_runner import MultiNodeRunner
|
|
||||||
from .hostinfo import HostInfo, HostInfoList
|
|
||||||
from typing import List
|
from typing import List
|
||||||
|
|
||||||
|
import click
|
||||||
|
import torch
|
||||||
from packaging import version
|
from packaging import version
|
||||||
|
|
||||||
|
from colossalai.context import Config
|
||||||
|
|
||||||
|
from .hostinfo import HostInfo, HostInfoList
|
||||||
|
from .multinode_runner import MultiNodeRunner
|
||||||
|
|
||||||
# Constants that define our syntax
|
# Constants that define our syntax
|
||||||
NODE_SEP = ','
|
NODE_SEP = ','
|
||||||
|
|
||||||
|
@ -15,7 +18,7 @@ NODE_SEP = ','
|
||||||
def fetch_hostfile(hostfile_path: str, ssh_port: int) -> HostInfoList:
|
def fetch_hostfile(hostfile_path: str, ssh_port: int) -> HostInfoList:
|
||||||
"""
|
"""
|
||||||
Parse the hostfile to obtain a list of hosts.
|
Parse the hostfile to obtain a list of hosts.
|
||||||
|
|
||||||
A hostfile should look like:
|
A hostfile should look like:
|
||||||
worker-0
|
worker-0
|
||||||
worker-1
|
worker-1
|
||||||
|
@ -63,7 +66,7 @@ def parse_device_filter(device_pool: HostInfoList, include_str=None, exclude_str
|
||||||
device_pool (HostInfoList): a list of HostInfo objects
|
device_pool (HostInfoList): a list of HostInfo objects
|
||||||
include_str (str): --include option passed by user, default None
|
include_str (str): --include option passed by user, default None
|
||||||
exclude_str (str): --exclude option passed by user, default None
|
exclude_str (str): --exclude option passed by user, default None
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
filtered_hosts (HostInfoList): filtered hosts after inclusion/exclusion
|
filtered_hosts (HostInfoList): filtered hosts after inclusion/exclusion
|
||||||
'''
|
'''
|
||||||
|
@ -192,7 +195,7 @@ def launch_multi_processes(args: Config) -> None:
|
||||||
Launch multiple processes on a single node or multiple nodes.
|
Launch multiple processes on a single node or multiple nodes.
|
||||||
|
|
||||||
The overall logic can be summarized as the pseudo code below:
|
The overall logic can be summarized as the pseudo code below:
|
||||||
|
|
||||||
if hostfile given:
|
if hostfile given:
|
||||||
hostinfo = parse_hostfile(hostfile)
|
hostinfo = parse_hostfile(hostfile)
|
||||||
hostinfo = include_or_exclude_hosts(hostinfo)
|
hostinfo = include_or_exclude_hosts(hostinfo)
|
||||||
|
@ -202,7 +205,7 @@ def launch_multi_processes(args: Config) -> None:
|
||||||
launch_on_multi_nodes(hostinfo)
|
launch_on_multi_nodes(hostinfo)
|
||||||
else:
|
else:
|
||||||
launch_on_current_node()
|
launch_on_current_node()
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
args (Config): the arguments taken from command line
|
args (Config): the arguments taken from command line
|
||||||
|
|
||||||
|
@ -276,6 +279,33 @@ def launch_multi_processes(args: Config) -> None:
|
||||||
extra_launch_args=args.extra_launch_args)
|
extra_launch_args=args.extra_launch_args)
|
||||||
runner.send(hostinfo=hostinfo, cmd=cmd)
|
runner.send(hostinfo=hostinfo, cmd=cmd)
|
||||||
|
|
||||||
runner.recv_from_all()
|
# start training
|
||||||
|
msg_from_node = runner.recv_from_all()
|
||||||
|
has_error = False
|
||||||
|
|
||||||
|
# print node status
|
||||||
|
click.echo("\n====== Training on All Nodes =====")
|
||||||
|
for hostname, msg in msg_from_node.items():
|
||||||
|
click.echo(f"{hostname}: {msg}")
|
||||||
|
|
||||||
|
# check if a process failed
|
||||||
|
if msg == "failure":
|
||||||
|
has_error = True
|
||||||
|
|
||||||
|
# stop all nodes
|
||||||
runner.stop_all()
|
runner.stop_all()
|
||||||
runner.recv_from_all()
|
|
||||||
|
# receive the stop status
|
||||||
|
msg_from_node = runner.recv_from_all()
|
||||||
|
|
||||||
|
# printe node status
|
||||||
|
click.echo("\n====== Stopping All Nodes =====")
|
||||||
|
for hostname, msg in msg_from_node.items():
|
||||||
|
click.echo(f"{hostname}: {msg}")
|
||||||
|
|
||||||
|
# give the process an exit code
|
||||||
|
# so that it behaves like a normal process
|
||||||
|
if has_error:
|
||||||
|
sys.exit(1)
|
||||||
|
else:
|
||||||
|
sys.exit(0)
|
||||||
|
|
|
@ -1,45 +1,40 @@
|
||||||
# Multi-dimensional Parallelism with Colossal-AI
|
# Multi-dimensional Parallelism with Colossal-AI
|
||||||
|
|
||||||
|
## Table of contents
|
||||||
|
|
||||||
## 🚀Quick Start
|
- [Overview](#-overview)
|
||||||
1. Install our model zoo.
|
- [Quick Start](#-quick-start)
|
||||||
```bash
|
|
||||||
pip install titans
|
## 📚 Overview
|
||||||
```
|
|
||||||
2. Run with synthetic data which is of similar shape to CIFAR10 with the `-s` flag.
|
This example lets you to quickly try out the hybrid parallelism provided by Colossal-AI.
|
||||||
```bash
|
You can change the parameters below to try out different settings in the `config.py`.
|
||||||
colossalai run --nproc_per_node 4 train.py --config config.py -s
|
|
||||||
|
```python
|
||||||
|
# parallel setting
|
||||||
|
TENSOR_PARALLEL_SIZE = 2
|
||||||
|
TENSOR_PARALLEL_MODE = '1d'
|
||||||
|
|
||||||
|
parallel = dict(
|
||||||
|
pipeline=2,
|
||||||
|
tensor=dict(mode=TENSOR_PARALLEL_MODE, size=TENSOR_PARALLEL_SIZE),
|
||||||
|
)
|
||||||
```
|
```
|
||||||
|
|
||||||
3. Modify the config file to play with different types of tensor parallelism, for example, change tensor parallel size to be 4 and mode to be 2d and run on 8 GPUs.
|
## 🚀 Quick Start
|
||||||
|
|
||||||
|
1. Install PyTorch
|
||||||
|
|
||||||
## Install Titans Model Zoo
|
2. Install the dependencies.
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
pip install titans
|
pip install -r requirements.txt
|
||||||
```
|
```
|
||||||
|
|
||||||
|
3. Run the training scripts with synthetic data.
|
||||||
## Prepare Dataset
|
|
||||||
|
|
||||||
We use CIFAR10 dataset in this example. You should invoke the `donwload_cifar10.py` in the tutorial root directory or directly run the `auto_parallel_with_resnet.py`.
|
|
||||||
The dataset will be downloaded to `colossalai/examples/tutorials/data` by default.
|
|
||||||
If you wish to use customized directory for the dataset. You can set the environment variable `DATA` via the following command.
|
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
export DATA=/path/to/data
|
|
||||||
```
|
|
||||||
|
|
||||||
|
|
||||||
## Run on 2*2 device mesh
|
|
||||||
|
|
||||||
Current configuration setting on `config.py` is TP=2, PP=2.
|
|
||||||
|
|
||||||
```bash
|
|
||||||
# train with cifar10
|
|
||||||
colossalai run --nproc_per_node 4 train.py --config config.py
|
colossalai run --nproc_per_node 4 train.py --config config.py
|
||||||
|
|
||||||
# train with synthetic data
|
|
||||||
colossalai run --nproc_per_node 4 train.py --config config.py -s
|
|
||||||
```
|
```
|
||||||
|
|
||||||
|
4. Modify the config file to play with different types of tensor parallelism, for example, change tensor parallel size to be 4 and mode to be 2d and run on 8 GPUs.
|
||||||
|
|
|
@ -3,7 +3,7 @@ from colossalai.amp import AMP_TYPE
|
||||||
# hyperparameters
|
# hyperparameters
|
||||||
# BATCH_SIZE is as per GPU
|
# BATCH_SIZE is as per GPU
|
||||||
# global batch size = BATCH_SIZE x data parallel size
|
# global batch size = BATCH_SIZE x data parallel size
|
||||||
BATCH_SIZE = 256
|
BATCH_SIZE = 4
|
||||||
LEARNING_RATE = 3e-3
|
LEARNING_RATE = 3e-3
|
||||||
WEIGHT_DECAY = 0.3
|
WEIGHT_DECAY = 0.3
|
||||||
NUM_EPOCHS = 2
|
NUM_EPOCHS = 2
|
||||||
|
@ -12,11 +12,11 @@ WARMUP_EPOCHS = 1
|
||||||
# model config
|
# model config
|
||||||
IMG_SIZE = 224
|
IMG_SIZE = 224
|
||||||
PATCH_SIZE = 16
|
PATCH_SIZE = 16
|
||||||
HIDDEN_SIZE = 512
|
HIDDEN_SIZE = 128
|
||||||
DEPTH = 4
|
DEPTH = 4
|
||||||
NUM_HEADS = 4
|
NUM_HEADS = 4
|
||||||
MLP_RATIO = 2
|
MLP_RATIO = 2
|
||||||
NUM_CLASSES = 1000
|
NUM_CLASSES = 10
|
||||||
CHECKPOINT = False
|
CHECKPOINT = False
|
||||||
SEQ_LENGTH = (IMG_SIZE // PATCH_SIZE)**2 + 1 # add 1 for cls token
|
SEQ_LENGTH = (IMG_SIZE // PATCH_SIZE)**2 + 1 # add 1 for cls token
|
||||||
|
|
||||||
|
|
|
@ -1,3 +1,3 @@
|
||||||
colossalai >= 0.1.12
|
torch
|
||||||
torch >= 1.8.1
|
colossalai
|
||||||
titans
|
titans
|
||||||
|
|
|
@ -2,4 +2,4 @@
|
||||||
set -euxo pipefail
|
set -euxo pipefail
|
||||||
|
|
||||||
pip install -r requirements.txt
|
pip install -r requirements.txt
|
||||||
torchrun --standalone --nproc_per_node 4 train.py --config config.py -s
|
colossalai run --nproc_per_node 4 train.py --config config.py
|
||||||
|
|
|
@ -1,7 +1,6 @@
|
||||||
import os
|
import os
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from titans.dataloader.cifar10 import build_cifar
|
|
||||||
from titans.model.vit.vit import _create_vit_model
|
from titans.model.vit.vit import _create_vit_model
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
@ -12,7 +11,7 @@ from colossalai.logging import get_dist_logger
|
||||||
from colossalai.nn import CrossEntropyLoss
|
from colossalai.nn import CrossEntropyLoss
|
||||||
from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR
|
from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR
|
||||||
from colossalai.pipeline.pipelinable import PipelinableContext
|
from colossalai.pipeline.pipelinable import PipelinableContext
|
||||||
from colossalai.utils import get_dataloader, is_using_pp
|
from colossalai.utils import is_using_pp
|
||||||
|
|
||||||
|
|
||||||
class DummyDataloader():
|
class DummyDataloader():
|
||||||
|
@ -42,12 +41,9 @@ class DummyDataloader():
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
# initialize distributed setting
|
|
||||||
parser = colossalai.get_default_parser()
|
|
||||||
parser.add_argument('-s', '--synthetic', action="store_true", help="whether use synthetic data")
|
|
||||||
args = parser.parse_args()
|
|
||||||
|
|
||||||
# launch from torch
|
# launch from torch
|
||||||
|
parser = colossalai.get_default_parser()
|
||||||
|
args = parser.parse_args()
|
||||||
colossalai.launch_from_torch(config=args.config)
|
colossalai.launch_from_torch(config=args.config)
|
||||||
|
|
||||||
# get logger
|
# get logger
|
||||||
|
@ -94,15 +90,10 @@ def main():
|
||||||
pipeline_stage = gpc.get_local_rank(ParallelMode.PIPELINE)
|
pipeline_stage = gpc.get_local_rank(ParallelMode.PIPELINE)
|
||||||
logger.info(f"number of parameters: {total_numel} on pipeline stage {pipeline_stage}")
|
logger.info(f"number of parameters: {total_numel} on pipeline stage {pipeline_stage}")
|
||||||
|
|
||||||
# create dataloaders
|
# use synthetic dataset
|
||||||
root = os.environ.get('DATA', '../data')
|
# we train for 10 steps and eval for 5 steps per epoch
|
||||||
if args.synthetic:
|
train_dataloader = DummyDataloader(length=10, batch_size=gpc.config.BATCH_SIZE)
|
||||||
# if we use synthetic dataset
|
test_dataloader = DummyDataloader(length=5, batch_size=gpc.config.BATCH_SIZE)
|
||||||
# we train for 10 steps and eval for 5 steps per epoch
|
|
||||||
train_dataloader = DummyDataloader(length=10, batch_size=gpc.config.BATCH_SIZE)
|
|
||||||
test_dataloader = DummyDataloader(length=5, batch_size=gpc.config.BATCH_SIZE)
|
|
||||||
else:
|
|
||||||
train_dataloader, test_dataloader = build_cifar(gpc.config.BATCH_SIZE, root, pad_if_needed=True)
|
|
||||||
|
|
||||||
# create loss function
|
# create loss function
|
||||||
criterion = CrossEntropyLoss(label_smoothing=0.1)
|
criterion = CrossEntropyLoss(label_smoothing=0.1)
|
||||||
|
@ -139,6 +130,7 @@ def main():
|
||||||
engine.execute_schedule(data_iter, return_output_label=False)
|
engine.execute_schedule(data_iter, return_output_label=False)
|
||||||
engine.step()
|
engine.step()
|
||||||
lr_scheduler.step()
|
lr_scheduler.step()
|
||||||
|
gpc.destroy()
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
|
|
Loading…
Reference in New Issue