|
|
|
@ -16,12 +16,10 @@ from colossalai.core import global_context as gpc
|
|
|
|
|
from colossalai.logging import get_dist_logger
|
|
|
|
|
from colossalai.nn import CrossEntropyLoss
|
|
|
|
|
from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR
|
|
|
|
|
from colossalai.utils import is_using_pp, get_dataloader
|
|
|
|
|
from colossalai.utils import get_dataloader
|
|
|
|
|
from colossalai.pipeline.pipelinable import PipelinableContext
|
|
|
|
|
from tqdm import tqdm
|
|
|
|
|
from torchvision.datasets import CIFAR10
|
|
|
|
|
from torchvision import transforms
|
|
|
|
|
from titans.model.vit import vit_tiny_patch4_32
|
|
|
|
|
|
|
|
|
|
BATCH_SIZE = 4
|
|
|
|
|
NUM_EPOCHS = 60
|
|
|
|
@ -41,6 +39,12 @@ def run_trainer(rank, world_size, port):
|
|
|
|
|
logger = get_dist_logger()
|
|
|
|
|
|
|
|
|
|
pipelinable = PipelinableContext()
|
|
|
|
|
try:
|
|
|
|
|
from titans.model.vit import vit_tiny_patch4_32
|
|
|
|
|
except ImportError:
|
|
|
|
|
logger.warning('skip the test_cifar_with_data_pipeline_tensor test because titan is not installed')
|
|
|
|
|
logger.warning('please install titan from https://github.com/hpcaitech/Titans')
|
|
|
|
|
return
|
|
|
|
|
with pipelinable:
|
|
|
|
|
model = vit_tiny_patch4_32()
|
|
|
|
|
pipelinable.to_layer_list()
|
|
|
|
|