[test] fixed hybrid parallel test case on 8 GPUs (#1106)

pull/1115/head
Frank Lee 2022-06-14 10:30:54 +08:00 committed by GitHub
parent 85b58093d2
commit 53297330c0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 8 additions and 12 deletions

View File

@ -20,11 +20,8 @@ from colossalai.utils import is_using_pp, get_dataloader
from colossalai.pipeline.pipelinable import PipelinableContext
from tqdm import tqdm
from torchvision.datasets import CIFAR10
from torchvision.transforms import transforms
try:
from titans.model.vit import vit_tiny_patch4_32
except:
pass
from torchvision import transforms
from titans.model.vit import vit_tiny_patch4_32
BATCH_SIZE = 4
NUM_EPOCHS = 60
@ -47,13 +44,13 @@ def run_trainer(rank, world_size, port):
with pipelinable:
model = vit_tiny_patch4_32()
pipelinable.to_layer_list()
pipelinable.load_policy("uniform")
pipelinable.policy = "uniform"
model = pipelinable.partition(1, gpc.pipeline_parallel_size, gpc.get_local_rank(ParallelMode.PIPELINE))
# craete dataloaders
root = Path(os.environ['DATA'])
transform_train = transforms.Compose([
transforms.RandomCrop(224, padding=4, pad_if_needed=True),
transforms.RandomCrop(32, padding=4, pad_if_needed=True),
transforms.AutoAugment(policy=transforms.AutoAugmentPolicy.CIFAR10),
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
@ -71,11 +68,10 @@ def run_trainer(rank, world_size, port):
lr_scheduler = CosineAnnealingWarmupLR(optimizer=optimizer, total_steps=NUM_EPOCHS, warmup_steps=WARMUP_EPOCHS)
# intiailize
engine, train_dataloader, test_dataloader, _ = colossalai.initialize(model=model,
engine, train_dataloader, *_ = colossalai.initialize(model=model,
optimizer=optimizer,
criterion=criterion,
train_dataloader=train_dataloader,
test_dataloader=test_dataloader)
train_dataloader=train_dataloader)
logger = get_dist_logger()