from abc import ABC, abstractmethod import torch import torch.nn as nn from transformers import ViTConfig, ViTForImageClassification from colossalai.utils.cuda import get_current_device class DummyDataGenerator(ABC): def __init__(self, length=10): self.length = length @abstractmethod def generate(self): pass def __iter__(self): self.step = 0 return self def __next__(self): if self.step < self.length: self.step += 1 return self.generate() else: raise StopIteration def __len__(self): return self.length class DummyDataLoader(DummyDataGenerator): def __init__(self, length=10, batch_size=4, channel=3, category=8, image_size=224, return_dict=True): super().__init__(length) self.batch_size = batch_size self.channel = channel self.category = category self.image_size = image_size self.return_dict = return_dict def generate(self): image_dict = {} image_dict['pixel_values'] = torch.rand( self.batch_size, self.channel, self.image_size, self.image_size, device=get_current_device()) * 2 - 1 image_dict['label'] = torch.randint(self.category, (self.batch_size,), dtype=torch.int64, device=get_current_device()) if not self.return_dict: return image_dict['pixel_values'], image_dict['label'] return image_dict class ViTCVModel(nn.Module): def __init__(self, hidden_size=768, num_hidden_layers=12, num_attention_heads=12, image_size=224, patch_size=16, num_channels=3, num_labels=8, checkpoint=False): super().__init__() self.checkpoint = checkpoint self.model = ViTForImageClassification( ViTConfig(hidden_size=hidden_size, num_hidden_layers=num_hidden_layers, num_attention_heads=num_attention_heads, image_size=image_size, patch_size=patch_size, num_channels=num_channels, num_labels=num_labels)) if checkpoint: self.model.gradient_checkpointing_enable() def forward(self, pixel_values): return self.model(pixel_values=pixel_values) def vit_base_s(checkpoint=True): return ViTCVModel(checkpoint=checkpoint) def vit_base_micro(checkpoint=True): return ViTCVModel(hidden_size=32, num_hidden_layers=2, num_attention_heads=4, checkpoint=checkpoint) def get_training_components(): trainloader = DummyDataLoader() testloader = DummyDataLoader() return vit_base_micro, trainloader, testloader, torch.optim.Adam, torch.nn.functional.cross_entropy