from torchvision.transforms import transforms class SimCLRTransform(): def __init__(self): self.transform = transforms.Compose([ transforms.RandomResizedCrop(size=32, scale=(0.2, 1.0)), transforms.RandomHorizontalFlip(), transforms.RandomApply([transforms.ColorJitter(0.8, 0.8, 0.8, 0.2)], p=0.8), transforms.RandomGrayscale(p=0.2), transforms.RandomApply([transforms.GaussianBlur(kernel_size=32//20*2+1, sigma=(0.1, 2.0))], p=0.5), transforms.ToTensor(), transforms.Normalize(mean=[0.4914, 0.4822, 0.4465], std=[0.2023, 0.1994, 0.2010]) ]) def __call__(self, x): x1 = self.transform(x) x2 = self.transform(x) return x1, x2 class LeTransform(): def __init__(self): self.transform = transforms.Compose([ transforms.RandomResizedCrop(size=32, scale=(0.2, 1.0)), transforms.RandomHorizontalFlip(), transforms.ToTensor(), transforms.Normalize(mean=[0.4914, 0.4822, 0.4465], std=[0.2023, 0.1994, 0.2010]) ]) def __call__(self, x): x = self.transform(x) return x