mirror of https://github.com/hpcaitech/ColossalAI
36 lines
962 B
Python
36 lines
962 B
Python
![]() |
import torch
|
||
|
import torch.nn as nn
|
||
|
import torch.nn.functional as F
|
||
|
from .Backbone import backbone
|
||
|
|
||
|
class projection_MLP(nn.Module):
|
||
|
def __init__(self, in_dim, out_dim=256):
|
||
|
super().__init__()
|
||
|
hidden_dim = in_dim
|
||
|
self.layer1 = nn.Sequential(
|
||
|
nn.Linear(in_dim, hidden_dim),
|
||
|
nn.ReLU(inplace=True)
|
||
|
)
|
||
|
self.layer2 = nn.Linear(hidden_dim, out_dim)
|
||
|
def forward(self, x):
|
||
|
x = self.layer1(x)
|
||
|
x = self.layer2(x)
|
||
|
return x
|
||
|
|
||
|
class SimCLR(nn.Module):
|
||
|
|
||
|
def __init__(self, model='resnet18', **kwargs):
|
||
|
super().__init__()
|
||
|
|
||
|
self.backbone = backbone(model, **kwargs)
|
||
|
self.projector = projection_MLP(self.backbone.output_dim)
|
||
|
self.encoder = nn.Sequential(
|
||
|
self.backbone,
|
||
|
self.projector
|
||
|
)
|
||
|
|
||
|
def forward(self, x1, x2):
|
||
|
|
||
|
z1 = self.encoder(x1)
|
||
|
z2 = self.encoder(x2)
|
||
|
return z1, z2
|