[test] fixed hybrid parallel test case on 8 GPUs (#1106)

pull/1115/head
Frank Lee 2022-06-14 10:30:54 +08:00 committed by GitHub
parent 85b58093d2
commit 53297330c0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 8 additions and 12 deletions

View File

@ -20,11 +20,8 @@ from colossalai.utils import is_using_pp, get_dataloader
from colossalai.pipeline.pipelinable import PipelinableContext from colossalai.pipeline.pipelinable import PipelinableContext
from tqdm import tqdm from tqdm import tqdm
from torchvision.datasets import CIFAR10 from torchvision.datasets import CIFAR10
from torchvision.transforms import transforms from torchvision import transforms
try: from titans.model.vit import vit_tiny_patch4_32
from titans.model.vit import vit_tiny_patch4_32
except:
pass
BATCH_SIZE = 4 BATCH_SIZE = 4
NUM_EPOCHS = 60 NUM_EPOCHS = 60
@ -47,13 +44,13 @@ def run_trainer(rank, world_size, port):
with pipelinable: with pipelinable:
model = vit_tiny_patch4_32() model = vit_tiny_patch4_32()
pipelinable.to_layer_list() pipelinable.to_layer_list()
pipelinable.load_policy("uniform") pipelinable.policy = "uniform"
model = pipelinable.partition(1, gpc.pipeline_parallel_size, gpc.get_local_rank(ParallelMode.PIPELINE)) model = pipelinable.partition(1, gpc.pipeline_parallel_size, gpc.get_local_rank(ParallelMode.PIPELINE))
# craete dataloaders # craete dataloaders
root = Path(os.environ['DATA']) root = Path(os.environ['DATA'])
transform_train = transforms.Compose([ transform_train = transforms.Compose([
transforms.RandomCrop(224, padding=4, pad_if_needed=True), transforms.RandomCrop(32, padding=4, pad_if_needed=True),
transforms.AutoAugment(policy=transforms.AutoAugmentPolicy.CIFAR10), transforms.AutoAugment(policy=transforms.AutoAugmentPolicy.CIFAR10),
transforms.ToTensor(), transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
@ -71,11 +68,10 @@ def run_trainer(rank, world_size, port):
lr_scheduler = CosineAnnealingWarmupLR(optimizer=optimizer, total_steps=NUM_EPOCHS, warmup_steps=WARMUP_EPOCHS) lr_scheduler = CosineAnnealingWarmupLR(optimizer=optimizer, total_steps=NUM_EPOCHS, warmup_steps=WARMUP_EPOCHS)
# intiailize # intiailize
engine, train_dataloader, test_dataloader, _ = colossalai.initialize(model=model, engine, train_dataloader, *_ = colossalai.initialize(model=model,
optimizer=optimizer, optimizer=optimizer,
criterion=criterion, criterion=criterion,
train_dataloader=train_dataloader, train_dataloader=train_dataloader)
test_dataloader=test_dataloader)
logger = get_dist_logger() logger = get_dist_logger()