mirror of https://github.com/hpcaitech/ColossalAI
68 lines
2.4 KiB
Python
68 lines
2.4 KiB
Python
import torch
|
|
import torch.nn as nn
|
|
from utils.dummy_data_generator import DummyDataGenerator
|
|
|
|
from colossalai.utils.cuda import get_current_device
|
|
from transformers import ViTConfig, ViTForImageClassification
|
|
|
|
|
|
class DummyDataLoader(DummyDataGenerator):
|
|
batch_size = 4
|
|
channel = 3
|
|
category = 8
|
|
image_size = 224
|
|
|
|
def generate(self):
|
|
image_dict = {}
|
|
image_dict['pixel_values'] = torch.rand(DummyDataLoader.batch_size,
|
|
DummyDataLoader.channel,
|
|
DummyDataLoader.image_size,
|
|
DummyDataLoader.image_size,
|
|
device=get_current_device()) * 2 - 1
|
|
image_dict['label'] = torch.randint(DummyDataLoader.category, (DummyDataLoader.batch_size,),
|
|
dtype=torch.int64,
|
|
device=get_current_device())
|
|
return image_dict
|
|
|
|
|
|
class ViTCVModel(nn.Module):
|
|
|
|
def __init__(self,
|
|
hidden_size=768,
|
|
num_hidden_layers=12,
|
|
num_attention_heads=12,
|
|
image_size=224,
|
|
patch_size=16,
|
|
num_channels=3,
|
|
num_labels=8,
|
|
checkpoint=False):
|
|
super().__init__()
|
|
self.checkpoint = checkpoint
|
|
self.model = ViTForImageClassification(
|
|
ViTConfig(hidden_size=hidden_size,
|
|
num_hidden_layers=num_hidden_layers,
|
|
num_attention_heads=num_attention_heads,
|
|
image_size=image_size,
|
|
patch_size=patch_size,
|
|
num_channels=num_channels,
|
|
num_labels=num_labels))
|
|
if checkpoint:
|
|
self.model.gradient_checkpointing_enable()
|
|
|
|
def forward(self, pixel_values):
|
|
return self.model(pixel_values=pixel_values)
|
|
|
|
|
|
def vit_base_s(checkpoint=True):
|
|
return ViTCVModel(checkpoint=checkpoint)
|
|
|
|
|
|
def vit_base_micro(checkpoint=True):
|
|
return ViTCVModel(hidden_size=32, num_hidden_layers=2, num_attention_heads=4, checkpoint=checkpoint)
|
|
|
|
|
|
def get_training_components():
|
|
trainloader = DummyDataLoader()
|
|
testloader = DummyDataLoader()
|
|
return vit_base_micro, trainloader, testloader, torch.optim.Adam, torch.nn.functional.cross_entropy
|