From c20529fe78f52e36df209bd2ab4143609eec7535 Mon Sep 17 00:00:00 2001 From: YuliangLiu0306 <72588413+YuliangLiu0306@users.noreply.github.com> Date: Thu, 12 Jan 2023 14:30:58 +0800 Subject: [PATCH] [examples] update autoparallel tutorial demo (#2449) * [examples] update autoparallel tutorial demo * add test_ci.sh * polish * add conda yaml --- .../auto_parallel_with_resnet.py | 132 +++--------------- .../tutorial/auto_parallel/environment.yaml | 32 +++++ examples/tutorial/auto_parallel/setup.py | 13 ++ examples/tutorial/auto_parallel/test_ci.sh | 11 ++ 4 files changed, 72 insertions(+), 116 deletions(-) create mode 100644 examples/tutorial/auto_parallel/environment.yaml create mode 100644 examples/tutorial/auto_parallel/setup.py create mode 100644 examples/tutorial/auto_parallel/test_ci.sh diff --git a/examples/tutorial/auto_parallel/auto_parallel_with_resnet.py b/examples/tutorial/auto_parallel/auto_parallel_with_resnet.py index e4aff13e4..1f0d72044 100644 --- a/examples/tutorial/auto_parallel/auto_parallel_with_resnet.py +++ b/examples/tutorial/auto_parallel/auto_parallel_with_resnet.py @@ -4,23 +4,14 @@ from pathlib import Path import torch from titans.utils import barrier_context -from torch.fx import GraphModule from torchvision import transforms from torchvision.datasets import CIFAR10 from torchvision.models import resnet50 from tqdm import tqdm import colossalai -from colossalai.auto_parallel.passes.runtime_apply_pass import runtime_apply_pass -from colossalai.auto_parallel.passes.runtime_preparation_pass import runtime_preparation_pass -from colossalai.auto_parallel.tensor_shard.solver.cost_graph import CostGraph -from colossalai.auto_parallel.tensor_shard.solver.graph_analysis import GraphAnalyser -from colossalai.auto_parallel.tensor_shard.solver.options import DataloaderOption, SolverOptions -from colossalai.auto_parallel.tensor_shard.solver.solver import Solver -from colossalai.auto_parallel.tensor_shard.solver.strategies_constructor import StrategiesConstructor +from colossalai.auto_parallel.tensor_shard.initialize import autoparallelize from colossalai.core import global_context as gpc -from colossalai.device.device_mesh import DeviceMesh -from colossalai.fx.tracer.tracer import ColoTracer from colossalai.logging import get_dist_logger from colossalai.nn.lr_scheduler import CosineAnnealingLR from colossalai.utils import get_dataloader @@ -28,12 +19,6 @@ from colossalai.utils import get_dataloader DATA_ROOT = Path(os.environ.get('DATA', '../data')).absolute() -def parse_args(): - parser = argparse.ArgumentParser() - parser.add_argument('-s', '--synthetic', action="store_true", help="use synthetic dataset instead of CIFAR10") - return parser.parse_args() - - def synthesize_data(): img = torch.rand(gpc.config.BATCH_SIZE, 3, 32, 32) label = torch.randint(low=0, high=10, size=(gpc.config.BATCH_SIZE,)) @@ -41,82 +26,15 @@ def synthesize_data(): def main(): - args = parse_args() colossalai.launch_from_torch(config='./config.py') logger = get_dist_logger() - if not args.synthetic: - with barrier_context(): - # build dataloaders - train_dataset = CIFAR10(root=DATA_ROOT, - download=True, - transform=transforms.Compose([ - transforms.RandomCrop(size=32, padding=4), - transforms.RandomHorizontalFlip(), - transforms.ToTensor(), - transforms.Normalize(mean=[0.4914, 0.4822, 0.4465], - std=[0.2023, 0.1994, 0.2010]), - ])) - - test_dataset = CIFAR10(root=DATA_ROOT, - train=False, - transform=transforms.Compose([ - transforms.ToTensor(), - transforms.Normalize(mean=[0.4914, 0.4822, 0.4465], std=[0.2023, 0.1994, 0.2010]), - ])) - - train_dataloader = get_dataloader( - dataset=train_dataset, - add_sampler=True, - shuffle=True, - batch_size=gpc.config.BATCH_SIZE, - pin_memory=True, - ) - - test_dataloader = get_dataloader( - dataset=test_dataset, - add_sampler=True, - batch_size=gpc.config.BATCH_SIZE, - pin_memory=True, - ) - else: - train_dataloader, test_dataloader = None, None - - # initialize device mesh - physical_mesh_id = torch.arange(0, 4) - mesh_shape = (2, 2) - device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True) - # trace the model with meta data - tracer = ColoTracer() model = resnet50(num_classes=10).cuda() input_sample = {'x': torch.rand([gpc.config.BATCH_SIZE * torch.distributed.get_world_size(), 3, 32, 32]).to('meta')} - graph = tracer.trace(root=model, meta_args=input_sample) - gm = GraphModule(model, graph, model.__class__.__name__) - gm.recompile() - - # prepare info for solver - solver_options = SolverOptions(dataloader_option=DataloaderOption.DISTRIBUTED) - strategies_constructor = StrategiesConstructor(graph, device_mesh, solver_options) - strategies_constructor.build_strategies_and_cost() - cost_graph = CostGraph(strategies_constructor.leaf_strategies) - cost_graph.simplify_graph() - graph_analyser = GraphAnalyser(gm) - - # solve the solution - solver = Solver(gm.graph, strategies_constructor, cost_graph, graph_analyser) - ret = solver.call_solver_serialized_args() - solution = list(ret[0]) - if gpc.get_global_rank() == 0: - for index, node in enumerate(graph.nodes): - print(node.name, node.strategies_vector[solution[index]].name) - - # process the graph for distributed training ability - gm, sharding_spec_dict, origin_spec_dict, comm_actions_dict = runtime_preparation_pass(gm, solution, device_mesh) - gm = runtime_apply_pass(gm) - gm.recompile() + model = autoparallelize(model, input_sample) # build criterion criterion = torch.nn.CrossEntropyLoss() @@ -127,65 +45,47 @@ def main(): lr_scheduler = CosineAnnealingLR(optimizer, total_steps=gpc.config.NUM_EPOCHS) for epoch in range(gpc.config.NUM_EPOCHS): - gm.train() + model.train() - if args.synthetic: - # if we use synthetic data - # we assume it only has 30 steps per epoch - num_steps = range(30) + # if we use synthetic data + # we assume it only has 30 steps per epoch + num_steps = range(30) - else: - # we use the actual number of steps for training - num_steps = range(len(train_dataloader)) - data_iter = iter(train_dataloader) progress = tqdm(num_steps) for _ in progress: - if args.synthetic: - # generate fake data - img, label = synthesize_data() - else: - # get the real data - img, label = next(data_iter) + # generate fake data + img, label = synthesize_data() img = img.cuda() label = label.cuda() optimizer.zero_grad() - output = gm(img, sharding_spec_dict, origin_spec_dict, comm_actions_dict) + output = model(img) train_loss = criterion(output, label) train_loss.backward(train_loss) optimizer.step() lr_scheduler.step() # run evaluation - gm.eval() + model.eval() correct = 0 total = 0 - if args.synthetic: - # if we use synthetic data - # we assume it only has 10 steps for evaluation - num_steps = range(30) + # if we use synthetic data + # we assume it only has 10 steps for evaluation + num_steps = range(30) - else: - # we use the actual number of steps for training - num_steps = range(len(test_dataloader)) - data_iter = iter(test_dataloader) progress = tqdm(num_steps) for _ in progress: - if args.synthetic: - # generate fake data - img, label = synthesize_data() - else: - # get the real data - img, label = next(data_iter) + # generate fake data + img, label = synthesize_data() img = img.cuda() label = label.cuda() with torch.no_grad(): - output = gm(img, sharding_spec_dict, origin_spec_dict, comm_actions_dict) + output = model(img) test_loss = criterion(output, label) pred = torch.argmax(output, dim=-1) correct += torch.sum(pred == label) diff --git a/examples/tutorial/auto_parallel/environment.yaml b/examples/tutorial/auto_parallel/environment.yaml new file mode 100644 index 000000000..5b811631a --- /dev/null +++ b/examples/tutorial/auto_parallel/environment.yaml @@ -0,0 +1,32 @@ +name: auto +channels: + - pytorch + - conda-forge + - defaults +dependencies: + - _libgcc_mutex=0.1=conda_forge + - _openmp_mutex=4.5=2_kmp_llvm + - blas=1.0=mkl + - brotlipy=0.7.0=py38h27cfd23_1003 + - bzip2=1.0.8=h7b6447c_0 + - ca-certificates=2022.12.7=ha878542_0 + - certifi=2022.12.7=pyhd8ed1ab_0 + - cffi=1.15.1=py38h74dc2b5_0 + - charset-normalizer=2.0.4=pyhd3eb1b0_0 + - coin-or-cbc=2.10.8=h3786ebc_0 + - coin-or-cgl=0.60.6=h6f57e76_2 + - coin-or-clp=1.17.7=hc56784d_2 + - coin-or-osi=0.108.7=h2720bb7_2 + - coin-or-utils=2.11.6=h202d8b1_2 + - python=3.8.13 + - pip=22.2.2 + - cudatoolkit=11.3 + - pytorch=1.12.1 + - torchvision=0.13.1 + - numpy=1.23.1 + - pip: + - titans + - torch==1.12.1 + - pulp==2.7.0 + - datasets + - colossalai diff --git a/examples/tutorial/auto_parallel/setup.py b/examples/tutorial/auto_parallel/setup.py new file mode 100644 index 000000000..6e6cff32e --- /dev/null +++ b/examples/tutorial/auto_parallel/setup.py @@ -0,0 +1,13 @@ +from setuptools import find_packages, setup + +setup( + name='auto_parallel', + version='0.0.1', + description='', + packages=find_packages(), + install_requires=[ + 'torch', + 'numpy', + 'tqdm', + ], +) diff --git a/examples/tutorial/auto_parallel/test_ci.sh b/examples/tutorial/auto_parallel/test_ci.sh new file mode 100644 index 000000000..74332548f --- /dev/null +++ b/examples/tutorial/auto_parallel/test_ci.sh @@ -0,0 +1,11 @@ +#!/bin/bash +set -euxo pipefail + +conda init bash +conda env create -f environment.yaml +conda activate auto +cd ../../.. +pip uninstall colossalai +pip install -v . +cd ./examples/tutorial/auto_parallel +colossalai run --nproc_per_node 4 auto_parallel_with_resnet.py -s