mirror of https://github.com/hpcaitech/ColossalAI
[tutorial] added synthetic dataset for auto parallel demo (#1918)
parent
acd9abc5ca
commit
1b0dd05940
|
@ -1,3 +1,4 @@
|
|||
import argparse
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
|
@ -29,43 +30,60 @@ BATCH_SIZE = 1024
|
|||
NUM_EPOCHS = 10
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('-s', '--synthetic', action="store_true", help="use synthetic dataset instead of CIFAR10")
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def synthesize_data():
|
||||
img = torch.rand(BATCH_SIZE, 3, 32, 32)
|
||||
label = torch.randint(low=0, high=10, size=(BATCH_SIZE,))
|
||||
return img, label
|
||||
|
||||
|
||||
def main():
|
||||
args = parse_args()
|
||||
colossalai.launch_from_torch(config={})
|
||||
|
||||
logger = get_dist_logger()
|
||||
|
||||
with barrier_context():
|
||||
# build dataloaders
|
||||
train_dataset = CIFAR10(root=DATA_ROOT,
|
||||
download=True,
|
||||
transform=transforms.Compose([
|
||||
transforms.RandomCrop(size=32, padding=4),
|
||||
transforms.RandomHorizontalFlip(),
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize(mean=[0.4914, 0.4822, 0.4465], std=[0.2023, 0.1994, 0.2010]),
|
||||
]))
|
||||
if not args.synthetic:
|
||||
with barrier_context():
|
||||
# build dataloaders
|
||||
train_dataset = CIFAR10(root=DATA_ROOT,
|
||||
download=True,
|
||||
transform=transforms.Compose([
|
||||
transforms.RandomCrop(size=32, padding=4),
|
||||
transforms.RandomHorizontalFlip(),
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize(mean=[0.4914, 0.4822, 0.4465],
|
||||
std=[0.2023, 0.1994, 0.2010]),
|
||||
]))
|
||||
|
||||
test_dataset = CIFAR10(root=DATA_ROOT,
|
||||
train=False,
|
||||
transform=transforms.Compose([
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize(mean=[0.4914, 0.4822, 0.4465], std=[0.2023, 0.1994, 0.2010]),
|
||||
]))
|
||||
test_dataset = CIFAR10(root=DATA_ROOT,
|
||||
train=False,
|
||||
transform=transforms.Compose([
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize(mean=[0.4914, 0.4822, 0.4465], std=[0.2023, 0.1994, 0.2010]),
|
||||
]))
|
||||
|
||||
train_dataloader = get_dataloader(
|
||||
dataset=train_dataset,
|
||||
add_sampler=False,
|
||||
shuffle=True,
|
||||
batch_size=BATCH_SIZE,
|
||||
pin_memory=True,
|
||||
)
|
||||
train_dataloader = get_dataloader(
|
||||
dataset=train_dataset,
|
||||
add_sampler=False,
|
||||
shuffle=True,
|
||||
batch_size=BATCH_SIZE,
|
||||
pin_memory=True,
|
||||
)
|
||||
|
||||
test_dataloader = get_dataloader(
|
||||
dataset=test_dataset,
|
||||
add_sampler=False,
|
||||
batch_size=BATCH_SIZE,
|
||||
pin_memory=True,
|
||||
)
|
||||
test_dataloader = get_dataloader(
|
||||
dataset=test_dataset,
|
||||
add_sampler=False,
|
||||
batch_size=BATCH_SIZE,
|
||||
pin_memory=True,
|
||||
)
|
||||
else:
|
||||
train_dataloader, test_dataloader = None, None
|
||||
|
||||
# initialize device mesh
|
||||
physical_mesh_id = torch.arange(0, 4)
|
||||
|
@ -112,11 +130,26 @@ def main():
|
|||
|
||||
for epoch in range(NUM_EPOCHS):
|
||||
gm.train()
|
||||
if gpc.get_global_rank() == 0:
|
||||
train_dl = tqdm(train_dataloader)
|
||||
|
||||
if args.synthetic:
|
||||
# if we use synthetic data
|
||||
# we assume it only has 30 steps per epoch
|
||||
num_steps = range(30)
|
||||
|
||||
else:
|
||||
train_dl = train_dataloader
|
||||
for img, label in train_dl:
|
||||
# we use the actual number of steps for training
|
||||
num_steps = range(len(train_dataloader))
|
||||
data_iter = iter(train_dataloader)
|
||||
progress = tqdm(num_steps)
|
||||
|
||||
for _ in progress:
|
||||
if args.synthetic:
|
||||
# generate fake data
|
||||
img, label = synthesize_data()
|
||||
else:
|
||||
# get the real data
|
||||
img, label = next(data_iter)
|
||||
|
||||
img = img.cuda()
|
||||
label = label.cuda()
|
||||
optimizer.zero_grad()
|
||||
|
@ -126,10 +159,30 @@ def main():
|
|||
optimizer.step()
|
||||
lr_scheduler.step()
|
||||
|
||||
# run evaluation
|
||||
gm.eval()
|
||||
correct = 0
|
||||
total = 0
|
||||
for img, label in test_dataloader:
|
||||
|
||||
if args.synthetic:
|
||||
# if we use synthetic data
|
||||
# we assume it only has 10 steps for evaluation
|
||||
num_steps = range(30)
|
||||
|
||||
else:
|
||||
# we use the actual number of steps for training
|
||||
num_steps = range(len(test_dataloader))
|
||||
data_iter = iter(test_dataloader)
|
||||
progress = tqdm(num_steps)
|
||||
|
||||
for _ in progress:
|
||||
if args.synthetic:
|
||||
# generate fake data
|
||||
img, label = synthesize_data()
|
||||
else:
|
||||
# get the real data
|
||||
img, label = next(data_iter)
|
||||
|
||||
img = img.cuda()
|
||||
label = label.cuda()
|
||||
|
||||
|
|
Loading…
Reference in New Issue