[tutorial] added synthetic dataset for auto parallel demo (#1918)

pull/1919/head
Frank Lee 2022-11-12 17:14:32 +08:00 committed by GitHub
parent acd9abc5ca
commit 1b0dd05940
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 87 additions and 34 deletions

View File

@ -1,3 +1,4 @@
import argparse
import os
from pathlib import Path
@ -29,43 +30,60 @@ BATCH_SIZE = 1024
NUM_EPOCHS = 10
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument('-s', '--synthetic', action="store_true", help="use synthetic dataset instead of CIFAR10")
return parser.parse_args()
def synthesize_data():
img = torch.rand(BATCH_SIZE, 3, 32, 32)
label = torch.randint(low=0, high=10, size=(BATCH_SIZE,))
return img, label
def main():
args = parse_args()
colossalai.launch_from_torch(config={})
logger = get_dist_logger()
with barrier_context():
# build dataloaders
train_dataset = CIFAR10(root=DATA_ROOT,
download=True,
transform=transforms.Compose([
transforms.RandomCrop(size=32, padding=4),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize(mean=[0.4914, 0.4822, 0.4465], std=[0.2023, 0.1994, 0.2010]),
]))
if not args.synthetic:
with barrier_context():
# build dataloaders
train_dataset = CIFAR10(root=DATA_ROOT,
download=True,
transform=transforms.Compose([
transforms.RandomCrop(size=32, padding=4),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize(mean=[0.4914, 0.4822, 0.4465],
std=[0.2023, 0.1994, 0.2010]),
]))
test_dataset = CIFAR10(root=DATA_ROOT,
train=False,
transform=transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(mean=[0.4914, 0.4822, 0.4465], std=[0.2023, 0.1994, 0.2010]),
]))
test_dataset = CIFAR10(root=DATA_ROOT,
train=False,
transform=transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(mean=[0.4914, 0.4822, 0.4465], std=[0.2023, 0.1994, 0.2010]),
]))
train_dataloader = get_dataloader(
dataset=train_dataset,
add_sampler=False,
shuffle=True,
batch_size=BATCH_SIZE,
pin_memory=True,
)
train_dataloader = get_dataloader(
dataset=train_dataset,
add_sampler=False,
shuffle=True,
batch_size=BATCH_SIZE,
pin_memory=True,
)
test_dataloader = get_dataloader(
dataset=test_dataset,
add_sampler=False,
batch_size=BATCH_SIZE,
pin_memory=True,
)
test_dataloader = get_dataloader(
dataset=test_dataset,
add_sampler=False,
batch_size=BATCH_SIZE,
pin_memory=True,
)
else:
train_dataloader, test_dataloader = None, None
# initialize device mesh
physical_mesh_id = torch.arange(0, 4)
@ -112,11 +130,26 @@ def main():
for epoch in range(NUM_EPOCHS):
gm.train()
if gpc.get_global_rank() == 0:
train_dl = tqdm(train_dataloader)
if args.synthetic:
# if we use synthetic data
# we assume it only has 30 steps per epoch
num_steps = range(30)
else:
train_dl = train_dataloader
for img, label in train_dl:
# we use the actual number of steps for training
num_steps = range(len(train_dataloader))
data_iter = iter(train_dataloader)
progress = tqdm(num_steps)
for _ in progress:
if args.synthetic:
# generate fake data
img, label = synthesize_data()
else:
# get the real data
img, label = next(data_iter)
img = img.cuda()
label = label.cuda()
optimizer.zero_grad()
@ -126,10 +159,30 @@ def main():
optimizer.step()
lr_scheduler.step()
# run evaluation
gm.eval()
correct = 0
total = 0
for img, label in test_dataloader:
if args.synthetic:
# if we use synthetic data
# we assume it only has 10 steps for evaluation
num_steps = range(30)
else:
# we use the actual number of steps for training
num_steps = range(len(test_dataloader))
data_iter = iter(test_dataloader)
progress = tqdm(num_steps)
for _ in progress:
if args.synthetic:
# generate fake data
img, label = synthesize_data()
else:
# get the real data
img, label = next(data_iter)
img = img.cuda()
label = label.cuda()