import torch from timm.models.beit import Beit from colossalai.utils.cuda import get_current_device from .registry import non_distributed_component_funcs from .utils.dummy_data_generator import DummyDataGenerator class DummyDataLoader(DummyDataGenerator): img_size = 64 num_channel = 3 num_class = 10 batch_size = 4 def generate(self): data = torch.randn((DummyDataLoader.batch_size, DummyDataLoader.num_channel, DummyDataLoader.img_size, DummyDataLoader.img_size), device=get_current_device()) label = torch.randint(low=0, high=DummyDataLoader.num_class, size=(DummyDataLoader.batch_size,), device=get_current_device()) return data, label @non_distributed_component_funcs.register(name='beit') def get_training_components(): def model_builder(checkpoint=False): model = Beit(img_size=DummyDataLoader.img_size, num_classes=DummyDataLoader.num_class, embed_dim=32, depth=2, num_heads=4) return model trainloader = DummyDataLoader() testloader = DummyDataLoader() criterion = torch.nn.CrossEntropyLoss() return model_builder, trainloader, testloader, torch.optim.Adam, criterion