Browse Source

[examples] update autoparallel tutorial demo (#2449)

* [examples] update autoparallel tutorial demo

* add test_ci.sh

* polish

* add conda yaml
pull/2464/head
YuliangLiu0306 2 years ago committed by GitHub
parent
commit
c20529fe78
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
  1. 112
      examples/tutorial/auto_parallel/auto_parallel_with_resnet.py
  2. 32
      examples/tutorial/auto_parallel/environment.yaml
  3. 13
      examples/tutorial/auto_parallel/setup.py
  4. 11
      examples/tutorial/auto_parallel/test_ci.sh

112
examples/tutorial/auto_parallel/auto_parallel_with_resnet.py

@ -4,23 +4,14 @@ from pathlib import Path
import torch import torch
from titans.utils import barrier_context from titans.utils import barrier_context
from torch.fx import GraphModule
from torchvision import transforms from torchvision import transforms
from torchvision.datasets import CIFAR10 from torchvision.datasets import CIFAR10
from torchvision.models import resnet50 from torchvision.models import resnet50
from tqdm import tqdm from tqdm import tqdm
import colossalai import colossalai
from colossalai.auto_parallel.passes.runtime_apply_pass import runtime_apply_pass from colossalai.auto_parallel.tensor_shard.initialize import autoparallelize
from colossalai.auto_parallel.passes.runtime_preparation_pass import runtime_preparation_pass
from colossalai.auto_parallel.tensor_shard.solver.cost_graph import CostGraph
from colossalai.auto_parallel.tensor_shard.solver.graph_analysis import GraphAnalyser
from colossalai.auto_parallel.tensor_shard.solver.options import DataloaderOption, SolverOptions
from colossalai.auto_parallel.tensor_shard.solver.solver import Solver
from colossalai.auto_parallel.tensor_shard.solver.strategies_constructor import StrategiesConstructor
from colossalai.core import global_context as gpc from colossalai.core import global_context as gpc
from colossalai.device.device_mesh import DeviceMesh
from colossalai.fx.tracer.tracer import ColoTracer
from colossalai.logging import get_dist_logger from colossalai.logging import get_dist_logger
from colossalai.nn.lr_scheduler import CosineAnnealingLR from colossalai.nn.lr_scheduler import CosineAnnealingLR
from colossalai.utils import get_dataloader from colossalai.utils import get_dataloader
@ -28,12 +19,6 @@ from colossalai.utils import get_dataloader
DATA_ROOT = Path(os.environ.get('DATA', '../data')).absolute() DATA_ROOT = Path(os.environ.get('DATA', '../data')).absolute()
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument('-s', '--synthetic', action="store_true", help="use synthetic dataset instead of CIFAR10")
return parser.parse_args()
def synthesize_data(): def synthesize_data():
img = torch.rand(gpc.config.BATCH_SIZE, 3, 32, 32) img = torch.rand(gpc.config.BATCH_SIZE, 3, 32, 32)
label = torch.randint(low=0, high=10, size=(gpc.config.BATCH_SIZE,)) label = torch.randint(low=0, high=10, size=(gpc.config.BATCH_SIZE,))
@ -41,82 +26,15 @@ def synthesize_data():
def main(): def main():
args = parse_args()
colossalai.launch_from_torch(config='./config.py') colossalai.launch_from_torch(config='./config.py')
logger = get_dist_logger() logger = get_dist_logger()
if not args.synthetic:
with barrier_context():
# build dataloaders
train_dataset = CIFAR10(root=DATA_ROOT,
download=True,
transform=transforms.Compose([
transforms.RandomCrop(size=32, padding=4),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize(mean=[0.4914, 0.4822, 0.4465],
std=[0.2023, 0.1994, 0.2010]),
]))
test_dataset = CIFAR10(root=DATA_ROOT,
train=False,
transform=transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(mean=[0.4914, 0.4822, 0.4465], std=[0.2023, 0.1994, 0.2010]),
]))
train_dataloader = get_dataloader(
dataset=train_dataset,
add_sampler=True,
shuffle=True,
batch_size=gpc.config.BATCH_SIZE,
pin_memory=True,
)
test_dataloader = get_dataloader(
dataset=test_dataset,
add_sampler=True,
batch_size=gpc.config.BATCH_SIZE,
pin_memory=True,
)
else:
train_dataloader, test_dataloader = None, None
# initialize device mesh
physical_mesh_id = torch.arange(0, 4)
mesh_shape = (2, 2)
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True)
# trace the model with meta data # trace the model with meta data
tracer = ColoTracer()
model = resnet50(num_classes=10).cuda() model = resnet50(num_classes=10).cuda()
input_sample = {'x': torch.rand([gpc.config.BATCH_SIZE * torch.distributed.get_world_size(), 3, 32, 32]).to('meta')} input_sample = {'x': torch.rand([gpc.config.BATCH_SIZE * torch.distributed.get_world_size(), 3, 32, 32]).to('meta')}
graph = tracer.trace(root=model, meta_args=input_sample)
gm = GraphModule(model, graph, model.__class__.__name__)
gm.recompile()
# prepare info for solver
solver_options = SolverOptions(dataloader_option=DataloaderOption.DISTRIBUTED)
strategies_constructor = StrategiesConstructor(graph, device_mesh, solver_options)
strategies_constructor.build_strategies_and_cost()
cost_graph = CostGraph(strategies_constructor.leaf_strategies)
cost_graph.simplify_graph()
graph_analyser = GraphAnalyser(gm)
# solve the solution
solver = Solver(gm.graph, strategies_constructor, cost_graph, graph_analyser)
ret = solver.call_solver_serialized_args()
solution = list(ret[0])
if gpc.get_global_rank() == 0:
for index, node in enumerate(graph.nodes):
print(node.name, node.strategies_vector[solution[index]].name)
# process the graph for distributed training ability
gm, sharding_spec_dict, origin_spec_dict, comm_actions_dict = runtime_preparation_pass(gm, solution, device_mesh)
gm = runtime_apply_pass(gm)
gm.recompile()
model = autoparallelize(model, input_sample)
# build criterion # build criterion
criterion = torch.nn.CrossEntropyLoss() criterion = torch.nn.CrossEntropyLoss()
@ -127,65 +45,47 @@ def main():
lr_scheduler = CosineAnnealingLR(optimizer, total_steps=gpc.config.NUM_EPOCHS) lr_scheduler = CosineAnnealingLR(optimizer, total_steps=gpc.config.NUM_EPOCHS)
for epoch in range(gpc.config.NUM_EPOCHS): for epoch in range(gpc.config.NUM_EPOCHS):
gm.train() model.train()
if args.synthetic:
# if we use synthetic data # if we use synthetic data
# we assume it only has 30 steps per epoch # we assume it only has 30 steps per epoch
num_steps = range(30) num_steps = range(30)
else:
# we use the actual number of steps for training
num_steps = range(len(train_dataloader))
data_iter = iter(train_dataloader)
progress = tqdm(num_steps) progress = tqdm(num_steps)
for _ in progress: for _ in progress:
if args.synthetic:
# generate fake data # generate fake data
img, label = synthesize_data() img, label = synthesize_data()
else:
# get the real data
img, label = next(data_iter)
img = img.cuda() img = img.cuda()
label = label.cuda() label = label.cuda()
optimizer.zero_grad() optimizer.zero_grad()
output = gm(img, sharding_spec_dict, origin_spec_dict, comm_actions_dict) output = model(img)
train_loss = criterion(output, label) train_loss = criterion(output, label)
train_loss.backward(train_loss) train_loss.backward(train_loss)
optimizer.step() optimizer.step()
lr_scheduler.step() lr_scheduler.step()
# run evaluation # run evaluation
gm.eval() model.eval()
correct = 0 correct = 0
total = 0 total = 0
if args.synthetic:
# if we use synthetic data # if we use synthetic data
# we assume it only has 10 steps for evaluation # we assume it only has 10 steps for evaluation
num_steps = range(30) num_steps = range(30)
else:
# we use the actual number of steps for training
num_steps = range(len(test_dataloader))
data_iter = iter(test_dataloader)
progress = tqdm(num_steps) progress = tqdm(num_steps)
for _ in progress: for _ in progress:
if args.synthetic:
# generate fake data # generate fake data
img, label = synthesize_data() img, label = synthesize_data()
else:
# get the real data
img, label = next(data_iter)
img = img.cuda() img = img.cuda()
label = label.cuda() label = label.cuda()
with torch.no_grad(): with torch.no_grad():
output = gm(img, sharding_spec_dict, origin_spec_dict, comm_actions_dict) output = model(img)
test_loss = criterion(output, label) test_loss = criterion(output, label)
pred = torch.argmax(output, dim=-1) pred = torch.argmax(output, dim=-1)
correct += torch.sum(pred == label) correct += torch.sum(pred == label)

32
examples/tutorial/auto_parallel/environment.yaml

@ -0,0 +1,32 @@
name: auto
channels:
- pytorch
- conda-forge
- defaults
dependencies:
- _libgcc_mutex=0.1=conda_forge
- _openmp_mutex=4.5=2_kmp_llvm
- blas=1.0=mkl
- brotlipy=0.7.0=py38h27cfd23_1003
- bzip2=1.0.8=h7b6447c_0
- ca-certificates=2022.12.7=ha878542_0
- certifi=2022.12.7=pyhd8ed1ab_0
- cffi=1.15.1=py38h74dc2b5_0
- charset-normalizer=2.0.4=pyhd3eb1b0_0
- coin-or-cbc=2.10.8=h3786ebc_0
- coin-or-cgl=0.60.6=h6f57e76_2
- coin-or-clp=1.17.7=hc56784d_2
- coin-or-osi=0.108.7=h2720bb7_2
- coin-or-utils=2.11.6=h202d8b1_2
- python=3.8.13
- pip=22.2.2
- cudatoolkit=11.3
- pytorch=1.12.1
- torchvision=0.13.1
- numpy=1.23.1
- pip:
- titans
- torch==1.12.1
- pulp==2.7.0
- datasets
- colossalai

13
examples/tutorial/auto_parallel/setup.py

@ -0,0 +1,13 @@
from setuptools import find_packages, setup
setup(
name='auto_parallel',
version='0.0.1',
description='',
packages=find_packages(),
install_requires=[
'torch',
'numpy',
'tqdm',
],
)

11
examples/tutorial/auto_parallel/test_ci.sh

@ -0,0 +1,11 @@
#!/bin/bash
set -euxo pipefail
conda init bash
conda env create -f environment.yaml
conda activate auto
cd ../../..
pip uninstall colossalai
pip install -v .
cd ./examples/tutorial/auto_parallel
colossalai run --nproc_per_node 4 auto_parallel_with_resnet.py -s
Loading…
Cancel
Save