ColossalAI/examples/images/vit/vit.py

68 lines
2.4 KiB
Python
Raw Normal View History

import torch
import torch.nn as nn
from utils.dummy_data_generator import DummyDataGenerator
from colossalai.utils.cuda import get_current_device
from transformers import ViTConfig, ViTForImageClassification
class DummyDataLoader(DummyDataGenerator):
batch_size = 4
channel = 3
category = 8
image_size = 224
def generate(self):
image_dict = {}
image_dict['pixel_values'] = torch.rand(DummyDataLoader.batch_size,
DummyDataLoader.channel,
DummyDataLoader.image_size,
DummyDataLoader.image_size,
device=get_current_device()) * 2 - 1
image_dict['label'] = torch.randint(DummyDataLoader.category, (DummyDataLoader.batch_size,),
dtype=torch.int64,
device=get_current_device())
return image_dict
class ViTCVModel(nn.Module):
def __init__(self,
hidden_size=768,
num_hidden_layers=12,
num_attention_heads=12,
image_size=224,
patch_size=16,
num_channels=3,
num_labels=8,
checkpoint=False):
super().__init__()
self.checkpoint = checkpoint
self.model = ViTForImageClassification(
ViTConfig(hidden_size=hidden_size,
num_hidden_layers=num_hidden_layers,
num_attention_heads=num_attention_heads,
image_size=image_size,
patch_size=patch_size,
num_channels=num_channels,
num_labels=num_labels))
if checkpoint:
self.model.gradient_checkpointing_enable()
def forward(self, pixel_values):
return self.model(pixel_values=pixel_values)
def vit_base_s(checkpoint=True):
return ViTCVModel(checkpoint=checkpoint)
def vit_base_micro(checkpoint=True):
return ViTCVModel(hidden_size=32, num_hidden_layers=2, num_attention_heads=4, checkpoint=checkpoint)
def get_training_components():
trainloader = DummyDataLoader()
testloader = DummyDataLoader()
return vit_base_micro, trainloader, testloader, torch.optim.Adam, torch.nn.functional.cross_entropy