From 1b0dd059408108c3fadc46a5a9193d5061a23e5d Mon Sep 17 00:00:00 2001 From: Frank Lee Date: Sat, 12 Nov 2022 17:14:32 +0800 Subject: [PATCH] [tutorial] added synthetic dataset for auto parallel demo (#1918) --- .../auto_parallel_with_resnet.py | 121 +++++++++++++----- 1 file changed, 87 insertions(+), 34 deletions(-) diff --git a/examples/tutorial/auto_parallel/auto_parallel_with_resnet.py b/examples/tutorial/auto_parallel/auto_parallel_with_resnet.py index 534d2d0af..474c56a61 100644 --- a/examples/tutorial/auto_parallel/auto_parallel_with_resnet.py +++ b/examples/tutorial/auto_parallel/auto_parallel_with_resnet.py @@ -1,3 +1,4 @@ +import argparse import os from pathlib import Path @@ -29,43 +30,60 @@ BATCH_SIZE = 1024 NUM_EPOCHS = 10 +def parse_args(): + parser = argparse.ArgumentParser() + parser.add_argument('-s', '--synthetic', action="store_true", help="use synthetic dataset instead of CIFAR10") + return parser.parse_args() + + +def synthesize_data(): + img = torch.rand(BATCH_SIZE, 3, 32, 32) + label = torch.randint(low=0, high=10, size=(BATCH_SIZE,)) + return img, label + + def main(): + args = parse_args() colossalai.launch_from_torch(config={}) logger = get_dist_logger() - with barrier_context(): - # build dataloaders - train_dataset = CIFAR10(root=DATA_ROOT, - download=True, - transform=transforms.Compose([ - transforms.RandomCrop(size=32, padding=4), - transforms.RandomHorizontalFlip(), - transforms.ToTensor(), - transforms.Normalize(mean=[0.4914, 0.4822, 0.4465], std=[0.2023, 0.1994, 0.2010]), - ])) + if not args.synthetic: + with barrier_context(): + # build dataloaders + train_dataset = CIFAR10(root=DATA_ROOT, + download=True, + transform=transforms.Compose([ + transforms.RandomCrop(size=32, padding=4), + transforms.RandomHorizontalFlip(), + transforms.ToTensor(), + transforms.Normalize(mean=[0.4914, 0.4822, 0.4465], + std=[0.2023, 0.1994, 0.2010]), + ])) - test_dataset = CIFAR10(root=DATA_ROOT, - train=False, - transform=transforms.Compose([ - transforms.ToTensor(), - transforms.Normalize(mean=[0.4914, 0.4822, 0.4465], std=[0.2023, 0.1994, 0.2010]), - ])) + test_dataset = CIFAR10(root=DATA_ROOT, + train=False, + transform=transforms.Compose([ + transforms.ToTensor(), + transforms.Normalize(mean=[0.4914, 0.4822, 0.4465], std=[0.2023, 0.1994, 0.2010]), + ])) - train_dataloader = get_dataloader( - dataset=train_dataset, - add_sampler=False, - shuffle=True, - batch_size=BATCH_SIZE, - pin_memory=True, - ) + train_dataloader = get_dataloader( + dataset=train_dataset, + add_sampler=False, + shuffle=True, + batch_size=BATCH_SIZE, + pin_memory=True, + ) - test_dataloader = get_dataloader( - dataset=test_dataset, - add_sampler=False, - batch_size=BATCH_SIZE, - pin_memory=True, - ) + test_dataloader = get_dataloader( + dataset=test_dataset, + add_sampler=False, + batch_size=BATCH_SIZE, + pin_memory=True, + ) + else: + train_dataloader, test_dataloader = None, None # initialize device mesh physical_mesh_id = torch.arange(0, 4) @@ -112,11 +130,26 @@ def main(): for epoch in range(NUM_EPOCHS): gm.train() - if gpc.get_global_rank() == 0: - train_dl = tqdm(train_dataloader) + + if args.synthetic: + # if we use synthetic data + # we assume it only has 30 steps per epoch + num_steps = range(30) + else: - train_dl = train_dataloader - for img, label in train_dl: + # we use the actual number of steps for training + num_steps = range(len(train_dataloader)) + data_iter = iter(train_dataloader) + progress = tqdm(num_steps) + + for _ in progress: + if args.synthetic: + # generate fake data + img, label = synthesize_data() + else: + # get the real data + img, label = next(data_iter) + img = img.cuda() label = label.cuda() optimizer.zero_grad() @@ -126,10 +159,30 @@ def main(): optimizer.step() lr_scheduler.step() + # run evaluation gm.eval() correct = 0 total = 0 - for img, label in test_dataloader: + + if args.synthetic: + # if we use synthetic data + # we assume it only has 10 steps for evaluation + num_steps = range(30) + + else: + # we use the actual number of steps for training + num_steps = range(len(test_dataloader)) + data_iter = iter(test_dataloader) + progress = tqdm(num_steps) + + for _ in progress: + if args.synthetic: + # generate fake data + img, label = synthesize_data() + else: + # get the real data + img, label = next(data_iter) + img = img.cuda() label = label.cuda()