Making large AI models cheaper, faster and more accessible
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
 
 

37 lines
1.1 KiB

import os
from pathlib import Path
import torch
from torchvision.datasets import CIFAR10
from torchvision.models import resnet18
from torchvision.transforms import transforms
from colossalai.legacy.utils import get_dataloader
from .registry import non_distributed_component_funcs
def get_cifar10_dataloader(train):
# build dataloaders
dataset = CIFAR10(
root=Path(os.environ["DATA"]),
download=True,
train=train,
transform=transforms.Compose(
[transforms.ToTensor(), transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))]
),
)
dataloader = get_dataloader(dataset=dataset, shuffle=True, batch_size=16, drop_last=True)
return dataloader
@non_distributed_component_funcs.register(name="resnet18")
def get_resnet_training_components():
def model_builder(checkpoint=False):
return resnet18(num_classes=10)
trainloader = get_cifar10_dataloader(train=True)
testloader = get_cifar10_dataloader(train=False)
criterion = torch.nn.CrossEntropyLoss()
return model_builder, trainloader, testloader, torch.optim.Adam, criterion