ColossalAI/examples/simclr_cifar10_data_parallel/models/simclr.py

36 lines
962 B
Python
Raw Normal View History

import torch
import torch.nn as nn
import torch.nn.functional as F
from .Backbone import backbone
class projection_MLP(nn.Module):
def __init__(self, in_dim, out_dim=256):
super().__init__()
hidden_dim = in_dim
self.layer1 = nn.Sequential(
nn.Linear(in_dim, hidden_dim),
nn.ReLU(inplace=True)
)
self.layer2 = nn.Linear(hidden_dim, out_dim)
def forward(self, x):
x = self.layer1(x)
x = self.layer2(x)
return x
class SimCLR(nn.Module):
def __init__(self, model='resnet18', **kwargs):
super().__init__()
self.backbone = backbone(model, **kwargs)
self.projector = projection_MLP(self.backbone.output_dim)
self.encoder = nn.Sequential(
self.backbone,
self.projector
)
def forward(self, x1, x2):
z1 = self.encoder(x1)
z2 = self.encoder(x2)
return z1, z2