mirror of https://github.com/hpcaitech/ColossalAI
[tutorial] added synthetic data for hybrid parallel (#1921)
* [tutorial] added synthetic data for hybrid parallel * polish codepull/1922/head
parent
3c42fdbedc
commit
ff16773ded
|
@ -1,16 +1,17 @@
|
||||||
# Handson 1: Multi-dimensional Parallelism with Colossal-AI
|
# Handson 1: Multi-dimensional Parallelism with Colossal-AI
|
||||||
|
|
||||||
|
|
||||||
## Install Colossal-AI and other dependencies
|
## Install Titans Model Zoo
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
sh install.sh
|
pip install titans
|
||||||
```
|
```
|
||||||
|
|
||||||
|
|
||||||
## Prepare Dataset
|
## Prepare Dataset
|
||||||
|
|
||||||
We use CIFAR10 dataset in this example. The dataset will be downloaded to `../data` by default.
|
We use CIFAR10 dataset in this example. You should invoke the `donwload_cifar10.py` in the tutorial root directory or directly run the `auto_parallel_with_resnet.py`.
|
||||||
|
The dataset will be downloaded to `colossalai/examples/tutorials/data` by default.
|
||||||
If you wish to use customized directory for the dataset. You can set the environment variable `DATA` via the following command.
|
If you wish to use customized directory for the dataset. You can set the environment variable `DATA` via the following command.
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
|
@ -23,5 +24,9 @@ export DATA=/path/to/data
|
||||||
Current configuration setting on `config.py` is TP=2, PP=2.
|
Current configuration setting on `config.py` is TP=2, PP=2.
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
|
# train with cifar10
|
||||||
|
colossalai run --nproc_per_node 4 train.py --config config.py
|
||||||
|
|
||||||
|
# train with synthetic data
|
||||||
colossalai run --nproc_per_node 4 train.py --config config.py
|
colossalai run --nproc_per_node 4 train.py --config config.py
|
||||||
```
|
```
|
|
@ -1,4 +0,0 @@
|
||||||
pip install torch==1.11.0+cu113 torchvision==0.12.0+cu113 torchaudio==0.11.0 --extra-index-url https://download.pytorch.org/whl/cu113
|
|
||||||
pip install colossalai==0.1.10+torch1.12cu11.3 -f https://release.colossalai.org
|
|
||||||
pip install titans
|
|
||||||
colossalai check -i
|
|
|
@ -1,22 +1,50 @@
|
||||||
import os
|
import os
|
||||||
import colossalai
|
|
||||||
import torch
|
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from titans.dataloader.cifar10 import build_cifar
|
||||||
|
from titans.model.vit.vit import _create_vit_model
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
import colossalai
|
||||||
from colossalai.context import ParallelMode
|
from colossalai.context import ParallelMode
|
||||||
from colossalai.core import global_context as gpc
|
from colossalai.core import global_context as gpc
|
||||||
from colossalai.logging import get_dist_logger
|
from colossalai.logging import get_dist_logger
|
||||||
from colossalai.nn import CrossEntropyLoss
|
from colossalai.nn import CrossEntropyLoss
|
||||||
from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR
|
from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR
|
||||||
from colossalai.utils import is_using_pp, get_dataloader
|
|
||||||
from colossalai.pipeline.pipelinable import PipelinableContext
|
from colossalai.pipeline.pipelinable import PipelinableContext
|
||||||
from titans.model.vit.vit import _create_vit_model
|
from colossalai.utils import get_dataloader, is_using_pp
|
||||||
from titans.dataloader.cifar10 import build_cifar
|
|
||||||
|
|
||||||
|
class DummyDataloader():
|
||||||
|
|
||||||
|
def __init__(self, length, batch_size):
|
||||||
|
self.length = length
|
||||||
|
self.batch_size = batch_size
|
||||||
|
|
||||||
|
def generate(self):
|
||||||
|
data = torch.rand(self.batch_size, 3, 224, 224)
|
||||||
|
label = torch.randint(low=0, high=10, size=(self.batch_size,))
|
||||||
|
return data, label
|
||||||
|
|
||||||
|
def __iter__(self):
|
||||||
|
self.step = 0
|
||||||
|
return self
|
||||||
|
|
||||||
|
def __next__(self):
|
||||||
|
if self.step < self.length:
|
||||||
|
self.step += 1
|
||||||
|
return self.generate()
|
||||||
|
else:
|
||||||
|
raise StopIteration
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return self.length
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
# initialize distributed setting
|
# initialize distributed setting
|
||||||
parser = colossalai.get_default_parser()
|
parser = colossalai.get_default_parser()
|
||||||
|
parser.add_argument('-s', '--synthetic', action="store_true", help="whether use synthetic data")
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
# launch from torch
|
# launch from torch
|
||||||
|
@ -52,8 +80,7 @@ def main():
|
||||||
model = _create_vit_model(**model_kwargs)
|
model = _create_vit_model(**model_kwargs)
|
||||||
pipelinable.to_layer_list()
|
pipelinable.to_layer_list()
|
||||||
pipelinable.policy = "uniform"
|
pipelinable.policy = "uniform"
|
||||||
model = pipelinable.partition(
|
model = pipelinable.partition(1, gpc.pipeline_parallel_size, gpc.get_local_rank(ParallelMode.PIPELINE))
|
||||||
1, gpc.pipeline_parallel_size, gpc.get_local_rank(ParallelMode.PIPELINE))
|
|
||||||
else:
|
else:
|
||||||
model = _create_vit_model(**model_kwargs)
|
model = _create_vit_model(**model_kwargs)
|
||||||
|
|
||||||
|
@ -65,20 +92,23 @@ def main():
|
||||||
pipeline_stage = 0
|
pipeline_stage = 0
|
||||||
else:
|
else:
|
||||||
pipeline_stage = gpc.get_local_rank(ParallelMode.PIPELINE)
|
pipeline_stage = gpc.get_local_rank(ParallelMode.PIPELINE)
|
||||||
logger.info(
|
logger.info(f"number of parameters: {total_numel} on pipeline stage {pipeline_stage}")
|
||||||
f"number of parameters: {total_numel} on pipeline stage {pipeline_stage}")
|
|
||||||
|
|
||||||
# create dataloaders
|
# create dataloaders
|
||||||
root = os.environ.get('DATA', '../data/cifar10')
|
root = os.environ.get('DATA', '../data')
|
||||||
train_dataloader, test_dataloader = build_cifar(
|
if args.synthetic:
|
||||||
gpc.config.BATCH_SIZE, root, pad_if_needed=True)
|
# if we use synthetic dataset
|
||||||
|
# we train for 30 steps and eval for 10 steps per epoch
|
||||||
|
train_dataloader = DummyDataloader(length=30, batch_size=gpc.config.BATCH_SIZE)
|
||||||
|
test_dataloader = DummyDataloader(length=10, batch_size=gpc.config.BATCH_SIZE)
|
||||||
|
else:
|
||||||
|
train_dataloader, test_dataloader = build_cifar(gpc.config.BATCH_SIZE, root, pad_if_needed=True)
|
||||||
|
|
||||||
# create loss function
|
# create loss function
|
||||||
criterion = CrossEntropyLoss(label_smoothing=0.1)
|
criterion = CrossEntropyLoss(label_smoothing=0.1)
|
||||||
|
|
||||||
# create optimizer
|
# create optimizer
|
||||||
optimizer = torch.optim.AdamW(model.parameters(
|
optimizer = torch.optim.AdamW(model.parameters(), lr=gpc.config.LEARNING_RATE, weight_decay=gpc.config.WEIGHT_DECAY)
|
||||||
), lr=gpc.config.LEARNING_RATE, weight_decay=gpc.config.WEIGHT_DECAY)
|
|
||||||
|
|
||||||
# create lr scheduler
|
# create lr scheduler
|
||||||
lr_scheduler = CosineAnnealingWarmupLR(optimizer=optimizer,
|
lr_scheduler = CosineAnnealingWarmupLR(optimizer=optimizer,
|
||||||
|
@ -94,11 +124,10 @@ def main():
|
||||||
|
|
||||||
logger.info("Engine is built", ranks=[0])
|
logger.info("Engine is built", ranks=[0])
|
||||||
|
|
||||||
data_iter = iter(train_dataloader)
|
|
||||||
|
|
||||||
for epoch in range(gpc.config.NUM_EPOCHS):
|
for epoch in range(gpc.config.NUM_EPOCHS):
|
||||||
# training
|
# training
|
||||||
engine.train()
|
engine.train()
|
||||||
|
data_iter = iter(train_dataloader)
|
||||||
|
|
||||||
if gpc.get_global_rank() == 0:
|
if gpc.get_global_rank() == 0:
|
||||||
description = 'Epoch {} / {}'.format(epoch, gpc.config.NUM_EPOCHS)
|
description = 'Epoch {} / {}'.format(epoch, gpc.config.NUM_EPOCHS)
|
||||||
|
|
Loading…
Reference in New Issue