From ff644ee5e416f64b43cd8a70fd32377c92281270 Mon Sep 17 00:00:00 2001 From: Jiarui Fang Date: Wed, 22 Jun 2022 09:58:02 +0800 Subject: [PATCH] polish unitest test with titans (#1152) --- .../test_cifar_with_data_pipeline_tensor.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/tests/test_data_pipeline_tensor_parallel/test_cifar_with_data_pipeline_tensor.py b/tests/test_data_pipeline_tensor_parallel/test_cifar_with_data_pipeline_tensor.py index 1994108bf..3c2390c92 100644 --- a/tests/test_data_pipeline_tensor_parallel/test_cifar_with_data_pipeline_tensor.py +++ b/tests/test_data_pipeline_tensor_parallel/test_cifar_with_data_pipeline_tensor.py @@ -16,12 +16,10 @@ from colossalai.core import global_context as gpc from colossalai.logging import get_dist_logger from colossalai.nn import CrossEntropyLoss from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR -from colossalai.utils import is_using_pp, get_dataloader +from colossalai.utils import get_dataloader from colossalai.pipeline.pipelinable import PipelinableContext -from tqdm import tqdm from torchvision.datasets import CIFAR10 from torchvision import transforms -from titans.model.vit import vit_tiny_patch4_32 BATCH_SIZE = 4 NUM_EPOCHS = 60 @@ -41,6 +39,12 @@ def run_trainer(rank, world_size, port): logger = get_dist_logger() pipelinable = PipelinableContext() + try: + from titans.model.vit import vit_tiny_patch4_32 + except ImportError: + logger.warning('skip the test_cifar_with_data_pipeline_tensor test because titan is not installed') + logger.warning('please install titan from https://github.com/hpcaitech/Titans') + return with pipelinable: model = vit_tiny_patch4_32() pipelinable.to_layer_list()