mirror of https://github.com/hpcaitech/ColossalAI
You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
31 lines
1.1 KiB
31 lines
1.1 KiB
3 years ago
|
from torchvision.models import resnet18
|
||
|
from .registry import non_distributed_component_funcs
|
||
|
from pathlib import Path
|
||
|
import os
|
||
|
import torch
|
||
|
from torchvision.transforms import transforms
|
||
|
from torchvision.datasets import CIFAR10
|
||
|
from colossalai.utils import get_dataloader
|
||
|
|
||
|
|
||
|
def get_cifar10_dataloader(train):
|
||
|
# build dataloaders
|
||
|
dataset = CIFAR10(root=Path(os.environ['DATA']),
|
||
|
download=True,
|
||
|
train=train,
|
||
|
transform=transforms.Compose(
|
||
|
[transforms.ToTensor(),
|
||
|
transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))]))
|
||
|
dataloader = get_dataloader(dataset=dataset, shuffle=True, batch_size=16, drop_last=True)
|
||
|
return dataloader
|
||
|
|
||
|
|
||
|
@non_distributed_component_funcs.register(name='resnet18')
|
||
|
def get_resnet_training_components():
|
||
|
model = resnet18(num_classes=10)
|
||
|
trainloader = get_cifar10_dataloader(train=True)
|
||
|
testloader = get_cifar10_dataloader(train=False)
|
||
|
optim = torch.optim.Adam(model.parameters(), lr=0.001)
|
||
|
criterion = torch.nn.CrossEntropyLoss()
|
||
|
return model, trainloader, testloader, optim, criterion
|