mirror of https://github.com/hpcaitech/ColossalAI
[test] fixed hybrid parallel test case on 8 GPUs (#1106)
parent
85b58093d2
commit
53297330c0
|
@ -20,11 +20,8 @@ from colossalai.utils import is_using_pp, get_dataloader
|
||||||
from colossalai.pipeline.pipelinable import PipelinableContext
|
from colossalai.pipeline.pipelinable import PipelinableContext
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
from torchvision.datasets import CIFAR10
|
from torchvision.datasets import CIFAR10
|
||||||
from torchvision.transforms import transforms
|
from torchvision import transforms
|
||||||
try:
|
from titans.model.vit import vit_tiny_patch4_32
|
||||||
from titans.model.vit import vit_tiny_patch4_32
|
|
||||||
except:
|
|
||||||
pass
|
|
||||||
|
|
||||||
BATCH_SIZE = 4
|
BATCH_SIZE = 4
|
||||||
NUM_EPOCHS = 60
|
NUM_EPOCHS = 60
|
||||||
|
@ -47,13 +44,13 @@ def run_trainer(rank, world_size, port):
|
||||||
with pipelinable:
|
with pipelinable:
|
||||||
model = vit_tiny_patch4_32()
|
model = vit_tiny_patch4_32()
|
||||||
pipelinable.to_layer_list()
|
pipelinable.to_layer_list()
|
||||||
pipelinable.load_policy("uniform")
|
pipelinable.policy = "uniform"
|
||||||
model = pipelinable.partition(1, gpc.pipeline_parallel_size, gpc.get_local_rank(ParallelMode.PIPELINE))
|
model = pipelinable.partition(1, gpc.pipeline_parallel_size, gpc.get_local_rank(ParallelMode.PIPELINE))
|
||||||
|
|
||||||
# craete dataloaders
|
# craete dataloaders
|
||||||
root = Path(os.environ['DATA'])
|
root = Path(os.environ['DATA'])
|
||||||
transform_train = transforms.Compose([
|
transform_train = transforms.Compose([
|
||||||
transforms.RandomCrop(224, padding=4, pad_if_needed=True),
|
transforms.RandomCrop(32, padding=4, pad_if_needed=True),
|
||||||
transforms.AutoAugment(policy=transforms.AutoAugmentPolicy.CIFAR10),
|
transforms.AutoAugment(policy=transforms.AutoAugmentPolicy.CIFAR10),
|
||||||
transforms.ToTensor(),
|
transforms.ToTensor(),
|
||||||
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
|
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
|
||||||
|
@ -71,11 +68,10 @@ def run_trainer(rank, world_size, port):
|
||||||
lr_scheduler = CosineAnnealingWarmupLR(optimizer=optimizer, total_steps=NUM_EPOCHS, warmup_steps=WARMUP_EPOCHS)
|
lr_scheduler = CosineAnnealingWarmupLR(optimizer=optimizer, total_steps=NUM_EPOCHS, warmup_steps=WARMUP_EPOCHS)
|
||||||
|
|
||||||
# intiailize
|
# intiailize
|
||||||
engine, train_dataloader, test_dataloader, _ = colossalai.initialize(model=model,
|
engine, train_dataloader, *_ = colossalai.initialize(model=model,
|
||||||
optimizer=optimizer,
|
optimizer=optimizer,
|
||||||
criterion=criterion,
|
criterion=criterion,
|
||||||
train_dataloader=train_dataloader,
|
train_dataloader=train_dataloader)
|
||||||
test_dataloader=test_dataloader)
|
|
||||||
|
|
||||||
logger = get_dist_logger()
|
logger = get_dist_logger()
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue