polish unitest test with titans (#1152)

pull/1136/head
Jiarui Fang 2 years ago committed by GitHub
parent f1f51990b9
commit ff644ee5e4
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -16,12 +16,10 @@ from colossalai.core import global_context as gpc
from colossalai.logging import get_dist_logger from colossalai.logging import get_dist_logger
from colossalai.nn import CrossEntropyLoss from colossalai.nn import CrossEntropyLoss
from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR
from colossalai.utils import is_using_pp, get_dataloader from colossalai.utils import get_dataloader
from colossalai.pipeline.pipelinable import PipelinableContext from colossalai.pipeline.pipelinable import PipelinableContext
from tqdm import tqdm
from torchvision.datasets import CIFAR10 from torchvision.datasets import CIFAR10
from torchvision import transforms from torchvision import transforms
from titans.model.vit import vit_tiny_patch4_32
BATCH_SIZE = 4 BATCH_SIZE = 4
NUM_EPOCHS = 60 NUM_EPOCHS = 60
@ -41,6 +39,12 @@ def run_trainer(rank, world_size, port):
logger = get_dist_logger() logger = get_dist_logger()
pipelinable = PipelinableContext() pipelinable = PipelinableContext()
try:
from titans.model.vit import vit_tiny_patch4_32
except ImportError:
logger.warning('skip the test_cifar_with_data_pipeline_tensor test because titan is not installed')
logger.warning('please install titan from https://github.com/hpcaitech/Titans')
return
with pipelinable: with pipelinable:
model = vit_tiny_patch4_32() model = vit_tiny_patch4_32()
pipelinable.to_layer_list() pipelinable.to_layer_list()

Loading…
Cancel
Save