import torch from timm.models.beit import Beit from colossalai.utils.cuda import get_current_device from .registry import non_distributed_component_funcs from .utils.dummy_data_generator import DummyDataGenerator class DummyDataLoader(DummyDataGenerator): img_size = 64 num_channel = 3 num_class = 10 batch_size = 4 def generate(self): data = torch.randn( ( DummyDataLoader.batch_size, DummyDataLoader.num_channel, DummyDataLoader.img_size, DummyDataLoader.img_size, ), device=get_current_device(), ) label = torch.randint( low=0, high=DummyDataLoader.num_class, size=(DummyDataLoader.batch_size,), device=get_current_device() ) return data, label @non_distributed_component_funcs.register(name="beit") def get_training_components(): def model_builder(checkpoint=False): model = Beit( img_size=DummyDataLoader.img_size, num_classes=DummyDataLoader.num_class, embed_dim=32, depth=2, num_heads=4 ) return model trainloader = DummyDataLoader() testloader = DummyDataLoader() criterion = torch.nn.CrossEntropyLoss() return model_builder, trainloader, testloader, torch.optim.Adam, criterion