From f477a14f4aeb49f7a30ee0f46775040391a96e1c Mon Sep 17 00:00:00 2001 From: YuliangLiu0306 <72588413+YuliangLiu0306@users.noreply.github.com> Date: Tue, 31 Jan 2023 17:42:45 +0800 Subject: [PATCH] [hotfix] fix autoparallel demo (#2533) --- .../auto_parallel/auto_parallel_with_resnet.py | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/examples/tutorial/auto_parallel/auto_parallel_with_resnet.py b/examples/tutorial/auto_parallel/auto_parallel_with_resnet.py index 15429f19c..a6a9ad0a3 100644 --- a/examples/tutorial/auto_parallel/auto_parallel_with_resnet.py +++ b/examples/tutorial/auto_parallel/auto_parallel_with_resnet.py @@ -3,8 +3,9 @@ from torchvision.models import resnet50 from tqdm import tqdm import colossalai -from colossalai.auto_parallel.tensor_shard.initialize import autoparallelize +from colossalai.auto_parallel.tensor_shard.initialize import initialize_model from colossalai.core import global_context as gpc +from colossalai.device.device_mesh import DeviceMesh from colossalai.logging import get_dist_logger from colossalai.nn.lr_scheduler import CosineAnnealingLR @@ -22,9 +23,14 @@ def main(): # trace the model with meta data model = resnet50(num_classes=10).cuda() + input_sample = {'x': torch.rand([gpc.config.BATCH_SIZE * torch.distributed.get_world_size(), 3, 32, 32]).to('meta')} + device_mesh = DeviceMesh(physical_mesh_id=torch.tensor([0, 1, 2, 3]), mesh_shape=[2, 2], init_process_group=True) + model, solution = initialize_model(model, input_sample, device_mesh=device_mesh, return_solution=True) - model = autoparallelize(model, input_sample) + if gpc.get_global_rank() == 0: + for node_strategy in solution: + print(node_strategy) # build criterion criterion = torch.nn.CrossEntropyLoss() @@ -52,6 +58,7 @@ def main(): output = model(img) train_loss = criterion(output, label) train_loss.backward(train_loss) + torch.cuda.synchronize() optimizer.step() lr_scheduler.step()