2022-03-08 02:19:18 +00:00
|
|
|
import os
|
2023-09-18 08:31:06 +00:00
|
|
|
from pathlib import Path
|
|
|
|
|
2022-03-08 02:19:18 +00:00
|
|
|
import torch
|
|
|
|
from torchvision.datasets import CIFAR10
|
2023-09-18 08:31:06 +00:00
|
|
|
from torchvision.models import resnet18
|
|
|
|
from torchvision.transforms import transforms
|
|
|
|
|
|
|
|
from colossalai.legacy.utils import get_dataloader
|
|
|
|
|
|
|
|
from .registry import non_distributed_component_funcs
|
2022-03-08 02:19:18 +00:00
|
|
|
|
|
|
|
|
|
|
|
def get_cifar10_dataloader(train):
|
|
|
|
# build dataloaders
|
2023-09-19 06:20:26 +00:00
|
|
|
dataset = CIFAR10(
|
|
|
|
root=Path(os.environ["DATA"]),
|
|
|
|
download=True,
|
|
|
|
train=train,
|
|
|
|
transform=transforms.Compose(
|
|
|
|
[transforms.ToTensor(), transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))]
|
|
|
|
),
|
|
|
|
)
|
2022-03-08 02:19:18 +00:00
|
|
|
dataloader = get_dataloader(dataset=dataset, shuffle=True, batch_size=16, drop_last=True)
|
|
|
|
return dataloader
|
|
|
|
|
|
|
|
|
2023-09-19 06:20:26 +00:00
|
|
|
@non_distributed_component_funcs.register(name="resnet18")
|
2022-03-08 02:19:18 +00:00
|
|
|
def get_resnet_training_components():
|
2022-03-08 06:45:01 +00:00
|
|
|
def model_builder(checkpoint=False):
|
|
|
|
return resnet18(num_classes=10)
|
|
|
|
|
2022-03-08 02:19:18 +00:00
|
|
|
trainloader = get_cifar10_dataloader(train=True)
|
|
|
|
testloader = get_cifar10_dataloader(train=False)
|
2022-03-08 06:45:01 +00:00
|
|
|
|
2022-03-08 02:19:18 +00:00
|
|
|
criterion = torch.nn.CrossEntropyLoss()
|
2022-03-14 12:48:41 +00:00
|
|
|
return model_builder, trainloader, testloader, torch.optim.Adam, criterion
|