mirror of https://github.com/hpcaitech/ColossalAI
polish unitest test with titans (#1152)
parent
f1f51990b9
commit
ff644ee5e4
|
@ -16,12 +16,10 @@ from colossalai.core import global_context as gpc
|
|||
from colossalai.logging import get_dist_logger
|
||||
from colossalai.nn import CrossEntropyLoss
|
||||
from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR
|
||||
from colossalai.utils import is_using_pp, get_dataloader
|
||||
from colossalai.utils import get_dataloader
|
||||
from colossalai.pipeline.pipelinable import PipelinableContext
|
||||
from tqdm import tqdm
|
||||
from torchvision.datasets import CIFAR10
|
||||
from torchvision import transforms
|
||||
from titans.model.vit import vit_tiny_patch4_32
|
||||
|
||||
BATCH_SIZE = 4
|
||||
NUM_EPOCHS = 60
|
||||
|
@ -41,6 +39,12 @@ def run_trainer(rank, world_size, port):
|
|||
logger = get_dist_logger()
|
||||
|
||||
pipelinable = PipelinableContext()
|
||||
try:
|
||||
from titans.model.vit import vit_tiny_patch4_32
|
||||
except ImportError:
|
||||
logger.warning('skip the test_cifar_with_data_pipeline_tensor test because titan is not installed')
|
||||
logger.warning('please install titan from https://github.com/hpcaitech/Titans')
|
||||
return
|
||||
with pipelinable:
|
||||
model = vit_tiny_patch4_32()
|
||||
pipelinable.to_layer_list()
|
||||
|
|
Loading…
Reference in New Issue