diff --git a/examples/images/vit/configs/vit_1d_tp2_ci.py b/examples/images/vit/configs/vit_1d_tp2_ci.py new file mode 100644 index 000000000..e491e4ada --- /dev/null +++ b/examples/images/vit/configs/vit_1d_tp2_ci.py @@ -0,0 +1,32 @@ +from colossalai.amp import AMP_TYPE + +# hyperparameters +# BATCH_SIZE is as per GPU +# global batch size = BATCH_SIZE x data parallel size +BATCH_SIZE = 8 +LEARNING_RATE = 3e-3 +WEIGHT_DECAY = 0.3 +NUM_EPOCHS = 3 +WARMUP_EPOCHS = 1 + +# model config +IMG_SIZE = 224 +PATCH_SIZE = 16 +HIDDEN_SIZE = 32 +DEPTH = 2 +NUM_HEADS = 4 +MLP_RATIO = 4 +NUM_CLASSES = 10 +CHECKPOINT = False +SEQ_LENGTH = (IMG_SIZE // PATCH_SIZE)**2 + 1 # add 1 for cls token + +USE_DDP = True +TP_WORLD_SIZE = 2 +TP_TYPE = 'row' +parallel = dict(tensor=dict(mode="1d", size=TP_WORLD_SIZE),) + +fp16 = dict(mode=AMP_TYPE.NAIVE) +clip_grad_norm = 1.0 +gradient_accumulation = 2 + +LOG_PATH = "./log_ci" diff --git a/examples/images/vit/requirements.txt b/examples/images/vit/requirements.txt index 137a69e80..1f69794eb 100644 --- a/examples/images/vit/requirements.txt +++ b/examples/images/vit/requirements.txt @@ -1,2 +1,8 @@ colossalai >= 0.1.12 torch >= 1.8.1 +numpy>=1.24.1 +timm>=0.6.12 +titans>=0.0.7 +tqdm>=4.61.2 +transformers>=4.25.1 +nvidia-dali-cuda110>=1.8.0 --extra-index-url https://developer.download.nvidia.com/compute/redist diff --git a/examples/images/vit/test_ci.sh b/examples/images/vit/test_ci.sh new file mode 100644 index 000000000..41d25ee23 --- /dev/null +++ b/examples/images/vit/test_ci.sh @@ -0,0 +1,9 @@ +export OMP_NUM_THREADS=4 + +pip install -r requirements.txt + +# train +colossalai run \ +--nproc_per_node 4 train.py \ +--config configs/vit_1d_tp2_ci.py \ +--dummy_data diff --git a/examples/images/vit/train.py b/examples/images/vit/train.py index de39801c7..0b4489244 100644 --- a/examples/images/vit/train.py +++ b/examples/images/vit/train.py @@ -7,6 +7,7 @@ import torch.nn.functional as F from timm.models.vision_transformer import _create_vision_transformer from titans.dataloader.imagenet import build_dali_imagenet from tqdm import tqdm +from vit import DummyDataLoader import colossalai from colossalai.core import global_context as gpc @@ -56,8 +57,8 @@ def init_spec_func(model, tp_type): def train_imagenet(): parser = colossalai.get_default_parser() - parser.add_argument('--from_torch', default=True, action='store_true') - parser.add_argument('--resume_from', default=False) + parser.add_argument('--resume_from', default=False, action='store_true') + parser.add_argument('--dummy_data', default=False, action='store_true') args = parser.parse_args() colossalai.launch_from_torch(config=args.config) @@ -74,10 +75,22 @@ def train_imagenet(): logger.log_to_file(log_path) logger.info('Build data loader', ranks=[0]) - root = os.environ['DATA'] - train_dataloader, test_dataloader = build_dali_imagenet(root, - train_batch_size=gpc.config.BATCH_SIZE, - test_batch_size=gpc.config.BATCH_SIZE) + if not args.dummy_data: + root = os.environ['DATA'] + train_dataloader, test_dataloader = build_dali_imagenet(root, + train_batch_size=gpc.config.BATCH_SIZE, + test_batch_size=gpc.config.BATCH_SIZE) + else: + train_dataloader = DummyDataLoader(length=10, + batch_size=gpc.config.BATCH_SIZE, + category=gpc.config.NUM_CLASSES, + image_size=gpc.config.IMG_SIZE, + return_dict=False) + test_dataloader = DummyDataLoader(length=5, + batch_size=gpc.config.BATCH_SIZE, + category=gpc.config.NUM_CLASSES, + image_size=gpc.config.IMG_SIZE, + return_dict=False) logger.info('Build model', ranks=[0]) diff --git a/examples/images/vit/vit.py b/examples/images/vit/vit.py index 14c870b39..f22e8ea90 100644 --- a/examples/images/vit/vit.py +++ b/examples/images/vit/vit.py @@ -32,21 +32,24 @@ class DummyDataGenerator(ABC): class DummyDataLoader(DummyDataGenerator): - batch_size = 4 - channel = 3 - category = 8 - image_size = 224 + + def __init__(self, length=10, batch_size=4, channel=3, category=8, image_size=224, return_dict=True): + super().__init__(length) + self.batch_size = batch_size + self.channel = channel + self.category = category + self.image_size = image_size + self.return_dict = return_dict def generate(self): image_dict = {} - image_dict['pixel_values'] = torch.rand(DummyDataLoader.batch_size, - DummyDataLoader.channel, - DummyDataLoader.image_size, - DummyDataLoader.image_size, - device=get_current_device()) * 2 - 1 - image_dict['label'] = torch.randint(DummyDataLoader.category, (DummyDataLoader.batch_size,), + image_dict['pixel_values'] = torch.rand( + self.batch_size, self.channel, self.image_size, self.image_size, device=get_current_device()) * 2 - 1 + image_dict['label'] = torch.randint(self.category, (self.batch_size,), dtype=torch.int64, device=get_current_device()) + if not self.return_dict: + return image_dict['pixel_values'], image_dict['label'] return image_dict