mirror of https://github.com/hpcaitech/ColossalAI
32 lines
1.2 KiB
Python
32 lines
1.2 KiB
Python
from torchvision.transforms import transforms
|
|
|
|
class SimCLRTransform():
|
|
def __init__(self):
|
|
self.transform = transforms.Compose([
|
|
transforms.RandomResizedCrop(size=32, scale=(0.2, 1.0)),
|
|
transforms.RandomHorizontalFlip(),
|
|
transforms.RandomApply([transforms.ColorJitter(0.8, 0.8, 0.8, 0.2)], p=0.8),
|
|
transforms.RandomGrayscale(p=0.2),
|
|
transforms.RandomApply([transforms.GaussianBlur(kernel_size=32//20*2+1, sigma=(0.1, 2.0))], p=0.5),
|
|
transforms.ToTensor(),
|
|
transforms.Normalize(mean=[0.4914, 0.4822, 0.4465], std=[0.2023, 0.1994, 0.2010])
|
|
])
|
|
|
|
def __call__(self, x):
|
|
x1 = self.transform(x)
|
|
x2 = self.transform(x)
|
|
return x1, x2
|
|
|
|
|
|
class LeTransform():
|
|
def __init__(self):
|
|
self.transform = transforms.Compose([
|
|
transforms.RandomResizedCrop(size=32, scale=(0.2, 1.0)),
|
|
transforms.RandomHorizontalFlip(),
|
|
transforms.ToTensor(),
|
|
transforms.Normalize(mean=[0.4914, 0.4822, 0.4465], std=[0.2023, 0.1994, 0.2010])
|
|
])
|
|
|
|
def __call__(self, x):
|
|
x = self.transform(x)
|
|
return x |