ColossalAI/examples/simclr_cifar10_data_parallel/augmentation.py

32 lines
1.2 KiB
Python

from torchvision.transforms import transforms
class SimCLRTransform():
def __init__(self):
self.transform = transforms.Compose([
transforms.RandomResizedCrop(size=32, scale=(0.2, 1.0)),
transforms.RandomHorizontalFlip(),
transforms.RandomApply([transforms.ColorJitter(0.8, 0.8, 0.8, 0.2)], p=0.8),
transforms.RandomGrayscale(p=0.2),
transforms.RandomApply([transforms.GaussianBlur(kernel_size=32//20*2+1, sigma=(0.1, 2.0))], p=0.5),
transforms.ToTensor(),
transforms.Normalize(mean=[0.4914, 0.4822, 0.4465], std=[0.2023, 0.1994, 0.2010])
])
def __call__(self, x):
x1 = self.transform(x)
x2 = self.transform(x)
return x1, x2
class LeTransform():
def __init__(self):
self.transform = transforms.Compose([
transforms.RandomResizedCrop(size=32, scale=(0.2, 1.0)),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize(mean=[0.4914, 0.4822, 0.4465], std=[0.2023, 0.1994, 0.2010])
])
def __call__(self, x):
x = self.transform(x)
return x