mirror of https://github.com/hpcaitech/ColossalAI
545 lines
22 KiB
Python
545 lines
22 KiB
Python
|
import torch
|
||
|
import pytorch_lightning as pl
|
||
|
import torch.nn.functional as F
|
||
|
from contextlib import contextmanager
|
||
|
|
||
|
from taming.modules.vqvae.quantize import VectorQuantizer2 as VectorQuantizer
|
||
|
|
||
|
from ldm.modules.diffusionmodules.model import Encoder, Decoder
|
||
|
from ldm.modules.distributions.distributions import DiagonalGaussianDistribution
|
||
|
|
||
|
from ldm.util import instantiate_from_config
|
||
|
|
||
|
|
||
|
class VQModel(pl.LightningModule):
|
||
|
def __init__(self,
|
||
|
ddconfig,
|
||
|
lossconfig,
|
||
|
n_embed,
|
||
|
embed_dim,
|
||
|
ckpt_path=None,
|
||
|
ignore_keys=[],
|
||
|
image_key="image",
|
||
|
colorize_nlabels=None,
|
||
|
monitor=None,
|
||
|
batch_resize_range=None,
|
||
|
scheduler_config=None,
|
||
|
lr_g_factor=1.0,
|
||
|
remap=None,
|
||
|
sane_index_shape=False, # tell vector quantizer to return indices as bhw
|
||
|
use_ema=False
|
||
|
):
|
||
|
super().__init__()
|
||
|
self.embed_dim = embed_dim
|
||
|
self.n_embed = n_embed
|
||
|
self.image_key = image_key
|
||
|
self.encoder = Encoder(**ddconfig)
|
||
|
self.decoder = Decoder(**ddconfig)
|
||
|
self.loss = instantiate_from_config(lossconfig)
|
||
|
self.quantize = VectorQuantizer(n_embed, embed_dim, beta=0.25,
|
||
|
remap=remap,
|
||
|
sane_index_shape=sane_index_shape)
|
||
|
self.quant_conv = torch.nn.Conv2d(ddconfig["z_channels"], embed_dim, 1)
|
||
|
self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1)
|
||
|
if colorize_nlabels is not None:
|
||
|
assert type(colorize_nlabels)==int
|
||
|
self.register_buffer("colorize", torch.randn(3, colorize_nlabels, 1, 1))
|
||
|
if monitor is not None:
|
||
|
self.monitor = monitor
|
||
|
self.batch_resize_range = batch_resize_range
|
||
|
if self.batch_resize_range is not None:
|
||
|
print(f"{self.__class__.__name__}: Using per-batch resizing in range {batch_resize_range}.")
|
||
|
|
||
|
self.use_ema = use_ema
|
||
|
if self.use_ema:
|
||
|
self.model_ema = LitEma(self)
|
||
|
print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.")
|
||
|
|
||
|
if ckpt_path is not None:
|
||
|
self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
|
||
|
self.scheduler_config = scheduler_config
|
||
|
self.lr_g_factor = lr_g_factor
|
||
|
|
||
|
@contextmanager
|
||
|
def ema_scope(self, context=None):
|
||
|
if self.use_ema:
|
||
|
self.model_ema.store(self.parameters())
|
||
|
self.model_ema.copy_to(self)
|
||
|
if context is not None:
|
||
|
print(f"{context}: Switched to EMA weights")
|
||
|
try:
|
||
|
yield None
|
||
|
finally:
|
||
|
if self.use_ema:
|
||
|
self.model_ema.restore(self.parameters())
|
||
|
if context is not None:
|
||
|
print(f"{context}: Restored training weights")
|
||
|
|
||
|
def init_from_ckpt(self, path, ignore_keys=list()):
|
||
|
sd = torch.load(path, map_location="cpu")["state_dict"]
|
||
|
keys = list(sd.keys())
|
||
|
for k in keys:
|
||
|
for ik in ignore_keys:
|
||
|
if k.startswith(ik):
|
||
|
print("Deleting key {} from state_dict.".format(k))
|
||
|
del sd[k]
|
||
|
missing, unexpected = self.load_state_dict(sd, strict=False)
|
||
|
print(f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys")
|
||
|
if len(missing) > 0:
|
||
|
print(f"Missing Keys: {missing}")
|
||
|
print(f"Unexpected Keys: {unexpected}")
|
||
|
|
||
|
def on_train_batch_end(self, *args, **kwargs):
|
||
|
if self.use_ema:
|
||
|
self.model_ema(self)
|
||
|
|
||
|
def encode(self, x):
|
||
|
h = self.encoder(x)
|
||
|
h = self.quant_conv(h)
|
||
|
quant, emb_loss, info = self.quantize(h)
|
||
|
return quant, emb_loss, info
|
||
|
|
||
|
def encode_to_prequant(self, x):
|
||
|
h = self.encoder(x)
|
||
|
h = self.quant_conv(h)
|
||
|
return h
|
||
|
|
||
|
def decode(self, quant):
|
||
|
quant = self.post_quant_conv(quant)
|
||
|
dec = self.decoder(quant)
|
||
|
return dec
|
||
|
|
||
|
def decode_code(self, code_b):
|
||
|
quant_b = self.quantize.embed_code(code_b)
|
||
|
dec = self.decode(quant_b)
|
||
|
return dec
|
||
|
|
||
|
def forward(self, input, return_pred_indices=False):
|
||
|
quant, diff, (_,_,ind) = self.encode(input)
|
||
|
dec = self.decode(quant)
|
||
|
if return_pred_indices:
|
||
|
return dec, diff, ind
|
||
|
return dec, diff
|
||
|
|
||
|
def get_input(self, batch, k):
|
||
|
x = batch[k]
|
||
|
if len(x.shape) == 3:
|
||
|
x = x[..., None]
|
||
|
x = x.permute(0, 3, 1, 2).to(memory_format=torch.contiguous_format).float()
|
||
|
if self.batch_resize_range is not None:
|
||
|
lower_size = self.batch_resize_range[0]
|
||
|
upper_size = self.batch_resize_range[1]
|
||
|
if self.global_step <= 4:
|
||
|
# do the first few batches with max size to avoid later oom
|
||
|
new_resize = upper_size
|
||
|
else:
|
||
|
new_resize = np.random.choice(np.arange(lower_size, upper_size+16, 16))
|
||
|
if new_resize != x.shape[2]:
|
||
|
x = F.interpolate(x, size=new_resize, mode="bicubic")
|
||
|
x = x.detach()
|
||
|
return x
|
||
|
|
||
|
def training_step(self, batch, batch_idx, optimizer_idx):
|
||
|
# https://github.com/pytorch/pytorch/issues/37142
|
||
|
# try not to fool the heuristics
|
||
|
x = self.get_input(batch, self.image_key)
|
||
|
xrec, qloss, ind = self(x, return_pred_indices=True)
|
||
|
|
||
|
if optimizer_idx == 0:
|
||
|
# autoencode
|
||
|
aeloss, log_dict_ae = self.loss(qloss, x, xrec, optimizer_idx, self.global_step,
|
||
|
last_layer=self.get_last_layer(), split="train",
|
||
|
predicted_indices=ind)
|
||
|
|
||
|
self.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=True)
|
||
|
return aeloss
|
||
|
|
||
|
if optimizer_idx == 1:
|
||
|
# discriminator
|
||
|
discloss, log_dict_disc = self.loss(qloss, x, xrec, optimizer_idx, self.global_step,
|
||
|
last_layer=self.get_last_layer(), split="train")
|
||
|
self.log_dict(log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=True)
|
||
|
return discloss
|
||
|
|
||
|
def validation_step(self, batch, batch_idx):
|
||
|
log_dict = self._validation_step(batch, batch_idx)
|
||
|
with self.ema_scope():
|
||
|
log_dict_ema = self._validation_step(batch, batch_idx, suffix="_ema")
|
||
|
return log_dict
|
||
|
|
||
|
def _validation_step(self, batch, batch_idx, suffix=""):
|
||
|
x = self.get_input(batch, self.image_key)
|
||
|
xrec, qloss, ind = self(x, return_pred_indices=True)
|
||
|
aeloss, log_dict_ae = self.loss(qloss, x, xrec, 0,
|
||
|
self.global_step,
|
||
|
last_layer=self.get_last_layer(),
|
||
|
split="val"+suffix,
|
||
|
predicted_indices=ind
|
||
|
)
|
||
|
|
||
|
discloss, log_dict_disc = self.loss(qloss, x, xrec, 1,
|
||
|
self.global_step,
|
||
|
last_layer=self.get_last_layer(),
|
||
|
split="val"+suffix,
|
||
|
predicted_indices=ind
|
||
|
)
|
||
|
rec_loss = log_dict_ae[f"val{suffix}/rec_loss"]
|
||
|
self.log(f"val{suffix}/rec_loss", rec_loss,
|
||
|
prog_bar=True, logger=True, on_step=False, on_epoch=True, sync_dist=True)
|
||
|
self.log(f"val{suffix}/aeloss", aeloss,
|
||
|
prog_bar=True, logger=True, on_step=False, on_epoch=True, sync_dist=True)
|
||
|
if version.parse(pl.__version__) >= version.parse('1.4.0'):
|
||
|
del log_dict_ae[f"val{suffix}/rec_loss"]
|
||
|
self.log_dict(log_dict_ae)
|
||
|
self.log_dict(log_dict_disc)
|
||
|
return self.log_dict
|
||
|
|
||
|
def configure_optimizers(self):
|
||
|
lr_d = self.learning_rate
|
||
|
lr_g = self.lr_g_factor*self.learning_rate
|
||
|
print("lr_d", lr_d)
|
||
|
print("lr_g", lr_g)
|
||
|
opt_ae = torch.optim.Adam(list(self.encoder.parameters())+
|
||
|
list(self.decoder.parameters())+
|
||
|
list(self.quantize.parameters())+
|
||
|
list(self.quant_conv.parameters())+
|
||
|
list(self.post_quant_conv.parameters()),
|
||
|
lr=lr_g, betas=(0.5, 0.9))
|
||
|
opt_disc = torch.optim.Adam(self.loss.discriminator.parameters(),
|
||
|
lr=lr_d, betas=(0.5, 0.9))
|
||
|
|
||
|
if self.scheduler_config is not None:
|
||
|
scheduler = instantiate_from_config(self.scheduler_config)
|
||
|
|
||
|
print("Setting up LambdaLR scheduler...")
|
||
|
scheduler = [
|
||
|
{
|
||
|
'scheduler': LambdaLR(opt_ae, lr_lambda=scheduler.schedule),
|
||
|
'interval': 'step',
|
||
|
'frequency': 1
|
||
|
},
|
||
|
{
|
||
|
'scheduler': LambdaLR(opt_disc, lr_lambda=scheduler.schedule),
|
||
|
'interval': 'step',
|
||
|
'frequency': 1
|
||
|
},
|
||
|
]
|
||
|
return [opt_ae, opt_disc], scheduler
|
||
|
return [opt_ae, opt_disc], []
|
||
|
|
||
|
def get_last_layer(self):
|
||
|
return self.decoder.conv_out.weight
|
||
|
|
||
|
def log_images(self, batch, only_inputs=False, plot_ema=False, **kwargs):
|
||
|
log = dict()
|
||
|
x = self.get_input(batch, self.image_key)
|
||
|
x = x.to(self.device)
|
||
|
if only_inputs:
|
||
|
log["inputs"] = x
|
||
|
return log
|
||
|
xrec, _ = self(x)
|
||
|
if x.shape[1] > 3:
|
||
|
# colorize with random projection
|
||
|
assert xrec.shape[1] > 3
|
||
|
x = self.to_rgb(x)
|
||
|
xrec = self.to_rgb(xrec)
|
||
|
log["inputs"] = x
|
||
|
log["reconstructions"] = xrec
|
||
|
if plot_ema:
|
||
|
with self.ema_scope():
|
||
|
xrec_ema, _ = self(x)
|
||
|
if x.shape[1] > 3: xrec_ema = self.to_rgb(xrec_ema)
|
||
|
log["reconstructions_ema"] = xrec_ema
|
||
|
return log
|
||
|
|
||
|
def to_rgb(self, x):
|
||
|
assert self.image_key == "segmentation"
|
||
|
if not hasattr(self, "colorize"):
|
||
|
self.register_buffer("colorize", torch.randn(3, x.shape[1], 1, 1).to(x))
|
||
|
x = F.conv2d(x, weight=self.colorize)
|
||
|
x = 2.*(x-x.min())/(x.max()-x.min()) - 1.
|
||
|
return x
|
||
|
|
||
|
|
||
|
class VQModelInterface(VQModel):
|
||
|
def __init__(self, embed_dim, *args, **kwargs):
|
||
|
super().__init__(embed_dim=embed_dim, *args, **kwargs)
|
||
|
self.embed_dim = embed_dim
|
||
|
|
||
|
def encode(self, x):
|
||
|
h = self.encoder(x)
|
||
|
h = self.quant_conv(h)
|
||
|
return h
|
||
|
|
||
|
def decode(self, h, force_not_quantize=False):
|
||
|
# also go through quantization layer
|
||
|
if not force_not_quantize:
|
||
|
quant, emb_loss, info = self.quantize(h)
|
||
|
else:
|
||
|
quant = h
|
||
|
quant = self.post_quant_conv(quant)
|
||
|
dec = self.decoder(quant)
|
||
|
return dec
|
||
|
|
||
|
|
||
|
class AutoencoderKL(pl.LightningModule):
|
||
|
def __init__(self,
|
||
|
ddconfig,
|
||
|
lossconfig,
|
||
|
embed_dim,
|
||
|
ckpt_path=None,
|
||
|
ignore_keys=[],
|
||
|
image_key="image",
|
||
|
colorize_nlabels=None,
|
||
|
monitor=None,
|
||
|
from_pretrained: str=None
|
||
|
):
|
||
|
super().__init__()
|
||
|
self.image_key = image_key
|
||
|
self.encoder = Encoder(**ddconfig)
|
||
|
self.decoder = Decoder(**ddconfig)
|
||
|
self.loss = instantiate_from_config(lossconfig)
|
||
|
assert ddconfig["double_z"]
|
||
|
self.quant_conv = torch.nn.Conv2d(2*ddconfig["z_channels"], 2*embed_dim, 1)
|
||
|
self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1)
|
||
|
self.embed_dim = embed_dim
|
||
|
if colorize_nlabels is not None:
|
||
|
assert type(colorize_nlabels)==int
|
||
|
self.register_buffer("colorize", torch.randn(3, colorize_nlabels, 1, 1))
|
||
|
if monitor is not None:
|
||
|
self.monitor = monitor
|
||
|
if ckpt_path is not None:
|
||
|
self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
|
||
|
from diffusers.modeling_utils import load_state_dict
|
||
|
if from_pretrained is not None:
|
||
|
state_dict = load_state_dict(from_pretrained)
|
||
|
self._load_pretrained_model(state_dict)
|
||
|
|
||
|
def _state_key_mapping(self, state_dict: dict):
|
||
|
import re
|
||
|
res_dict = {}
|
||
|
key_list = state_dict.keys()
|
||
|
key_str = " ".join(key_list)
|
||
|
up_block_pattern = re.compile('upsamplers')
|
||
|
p1 = re.compile('mid.block_[0-9]')
|
||
|
p2 = re.compile('decoder.up.[0-9]')
|
||
|
up_blocks_count = int(len(re.findall(up_block_pattern, key_str)) / 2 + 1)
|
||
|
for key_, val_ in state_dict.items():
|
||
|
key_ = key_.replace("up_blocks", "up").replace("down_blocks", "down").replace('resnets', 'block')\
|
||
|
.replace('mid_block', 'mid').replace("mid.block.", "mid.block_")\
|
||
|
.replace('mid.attentions.0.key', 'mid.attn_1.k')\
|
||
|
.replace('mid.attentions.0.query', 'mid.attn_1.q') \
|
||
|
.replace('mid.attentions.0.value', 'mid.attn_1.v') \
|
||
|
.replace('mid.attentions.0.group_norm', 'mid.attn_1.norm') \
|
||
|
.replace('mid.attentions.0.proj_attn', 'mid.attn_1.proj_out')\
|
||
|
.replace('upsamplers.0', 'upsample')\
|
||
|
.replace('downsamplers.0', 'downsample')\
|
||
|
.replace('conv_shortcut', 'nin_shortcut')\
|
||
|
.replace('conv_norm_out', 'norm_out')
|
||
|
|
||
|
mid_list = re.findall(p1, key_)
|
||
|
if len(mid_list) != 0:
|
||
|
mid_str = mid_list[0]
|
||
|
mid_id = int(mid_str[-1]) + 1
|
||
|
key_ = key_.replace(mid_str, mid_str[:-1] + str(mid_id))
|
||
|
|
||
|
up_list = re.findall(p2, key_)
|
||
|
if len(up_list) != 0:
|
||
|
up_str = up_list[0]
|
||
|
up_id = up_blocks_count - 1 -int(up_str[-1])
|
||
|
key_ = key_.replace(up_str, up_str[:-1] + str(up_id))
|
||
|
res_dict[key_] = val_
|
||
|
return res_dict
|
||
|
|
||
|
def _load_pretrained_model(self, state_dict, ignore_mismatched_sizes=False):
|
||
|
state_dict = self._state_key_mapping(state_dict)
|
||
|
model_state_dict = self.state_dict()
|
||
|
loaded_keys = [k for k in state_dict.keys()]
|
||
|
expected_keys = list(model_state_dict.keys())
|
||
|
original_loaded_keys = loaded_keys
|
||
|
missing_keys = list(set(expected_keys) - set(loaded_keys))
|
||
|
unexpected_keys = list(set(loaded_keys) - set(expected_keys))
|
||
|
|
||
|
def _find_mismatched_keys(
|
||
|
state_dict,
|
||
|
model_state_dict,
|
||
|
loaded_keys,
|
||
|
ignore_mismatched_sizes,
|
||
|
):
|
||
|
mismatched_keys = []
|
||
|
if ignore_mismatched_sizes:
|
||
|
for checkpoint_key in loaded_keys:
|
||
|
model_key = checkpoint_key
|
||
|
|
||
|
if (
|
||
|
model_key in model_state_dict
|
||
|
and state_dict[checkpoint_key].shape != model_state_dict[model_key].shape
|
||
|
):
|
||
|
mismatched_keys.append(
|
||
|
(checkpoint_key, state_dict[checkpoint_key].shape, model_state_dict[model_key].shape)
|
||
|
)
|
||
|
del state_dict[checkpoint_key]
|
||
|
return mismatched_keys
|
||
|
if state_dict is not None:
|
||
|
# Whole checkpoint
|
||
|
mismatched_keys = _find_mismatched_keys(
|
||
|
state_dict,
|
||
|
model_state_dict,
|
||
|
original_loaded_keys,
|
||
|
ignore_mismatched_sizes,
|
||
|
)
|
||
|
error_msgs = self._load_state_dict_into_model(state_dict)
|
||
|
return missing_keys, unexpected_keys, mismatched_keys, error_msgs
|
||
|
|
||
|
def _load_state_dict_into_model(self, state_dict):
|
||
|
# Convert old format to new format if needed from a PyTorch state_dict
|
||
|
# copy state_dict so _load_from_state_dict can modify it
|
||
|
state_dict = state_dict.copy()
|
||
|
error_msgs = []
|
||
|
|
||
|
# PyTorch's `_load_from_state_dict` does not copy parameters in a module's descendants
|
||
|
# so we need to apply the function recursively.
|
||
|
def load(module: torch.nn.Module, prefix=""):
|
||
|
args = (state_dict, prefix, {}, True, [], [], error_msgs)
|
||
|
module._load_from_state_dict(*args)
|
||
|
|
||
|
for name, child in module._modules.items():
|
||
|
if child is not None:
|
||
|
load(child, prefix + name + ".")
|
||
|
|
||
|
load(self)
|
||
|
|
||
|
return error_msgs
|
||
|
|
||
|
def init_from_ckpt(self, path, ignore_keys=list()):
|
||
|
sd = torch.load(path, map_location="cpu")["state_dict"]
|
||
|
keys = list(sd.keys())
|
||
|
for k in keys:
|
||
|
for ik in ignore_keys:
|
||
|
if k.startswith(ik):
|
||
|
print("Deleting key {} from state_dict.".format(k))
|
||
|
del sd[k]
|
||
|
self.load_state_dict(sd, strict=False)
|
||
|
print(f"Restored from {path}")
|
||
|
|
||
|
def encode(self, x):
|
||
|
h = self.encoder(x)
|
||
|
moments = self.quant_conv(h)
|
||
|
posterior = DiagonalGaussianDistribution(moments)
|
||
|
return posterior
|
||
|
|
||
|
def decode(self, z):
|
||
|
z = self.post_quant_conv(z)
|
||
|
dec = self.decoder(z)
|
||
|
return dec
|
||
|
|
||
|
def forward(self, input, sample_posterior=True):
|
||
|
posterior = self.encode(input)
|
||
|
if sample_posterior:
|
||
|
z = posterior.sample()
|
||
|
else:
|
||
|
z = posterior.mode()
|
||
|
dec = self.decode(z)
|
||
|
return dec, posterior
|
||
|
|
||
|
def get_input(self, batch, k):
|
||
|
x = batch[k]
|
||
|
if len(x.shape) == 3:
|
||
|
x = x[..., None]
|
||
|
x = x.permute(0, 3, 1, 2).to(memory_format=torch.contiguous_format).float()
|
||
|
return x
|
||
|
|
||
|
def training_step(self, batch, batch_idx, optimizer_idx):
|
||
|
inputs = self.get_input(batch, self.image_key)
|
||
|
reconstructions, posterior = self(inputs)
|
||
|
|
||
|
if optimizer_idx == 0:
|
||
|
# train encoder+decoder+logvar
|
||
|
aeloss, log_dict_ae = self.loss(inputs, reconstructions, posterior, optimizer_idx, self.global_step,
|
||
|
last_layer=self.get_last_layer(), split="train")
|
||
|
self.log("aeloss", aeloss, prog_bar=True, logger=True, on_step=True, on_epoch=True)
|
||
|
self.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=False)
|
||
|
return aeloss
|
||
|
|
||
|
if optimizer_idx == 1:
|
||
|
# train the discriminator
|
||
|
discloss, log_dict_disc = self.loss(inputs, reconstructions, posterior, optimizer_idx, self.global_step,
|
||
|
last_layer=self.get_last_layer(), split="train")
|
||
|
|
||
|
self.log("discloss", discloss, prog_bar=True, logger=True, on_step=True, on_epoch=True)
|
||
|
self.log_dict(log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=False)
|
||
|
return discloss
|
||
|
|
||
|
def validation_step(self, batch, batch_idx):
|
||
|
inputs = self.get_input(batch, self.image_key)
|
||
|
reconstructions, posterior = self(inputs)
|
||
|
aeloss, log_dict_ae = self.loss(inputs, reconstructions, posterior, 0, self.global_step,
|
||
|
last_layer=self.get_last_layer(), split="val")
|
||
|
|
||
|
discloss, log_dict_disc = self.loss(inputs, reconstructions, posterior, 1, self.global_step,
|
||
|
last_layer=self.get_last_layer(), split="val")
|
||
|
|
||
|
self.log("val/rec_loss", log_dict_ae["val/rec_loss"])
|
||
|
self.log_dict(log_dict_ae)
|
||
|
self.log_dict(log_dict_disc)
|
||
|
return self.log_dict
|
||
|
|
||
|
def configure_optimizers(self):
|
||
|
lr = self.learning_rate
|
||
|
opt_ae = torch.optim.Adam(list(self.encoder.parameters())+
|
||
|
list(self.decoder.parameters())+
|
||
|
list(self.quant_conv.parameters())+
|
||
|
list(self.post_quant_conv.parameters()),
|
||
|
lr=lr, betas=(0.5, 0.9))
|
||
|
opt_disc = torch.optim.Adam(self.loss.discriminator.parameters(),
|
||
|
lr=lr, betas=(0.5, 0.9))
|
||
|
return [opt_ae, opt_disc], []
|
||
|
|
||
|
def get_last_layer(self):
|
||
|
return self.decoder.conv_out.weight
|
||
|
|
||
|
@torch.no_grad()
|
||
|
def log_images(self, batch, only_inputs=False, **kwargs):
|
||
|
log = dict()
|
||
|
x = self.get_input(batch, self.image_key)
|
||
|
x = x.to(self.device)
|
||
|
if not only_inputs:
|
||
|
xrec, posterior = self(x)
|
||
|
if x.shape[1] > 3:
|
||
|
# colorize with random projection
|
||
|
assert xrec.shape[1] > 3
|
||
|
x = self.to_rgb(x)
|
||
|
xrec = self.to_rgb(xrec)
|
||
|
log["samples"] = self.decode(torch.randn_like(posterior.sample()))
|
||
|
log["reconstructions"] = xrec
|
||
|
log["inputs"] = x
|
||
|
return log
|
||
|
|
||
|
def to_rgb(self, x):
|
||
|
assert self.image_key == "segmentation"
|
||
|
if not hasattr(self, "colorize"):
|
||
|
self.register_buffer("colorize", torch.randn(3, x.shape[1], 1, 1).to(x))
|
||
|
x = F.conv2d(x, weight=self.colorize)
|
||
|
x = 2.*(x-x.min())/(x.max()-x.min()) - 1.
|
||
|
return x
|
||
|
|
||
|
|
||
|
class IdentityFirstStage(torch.nn.Module):
|
||
|
def __init__(self, *args, vq_interface=False, **kwargs):
|
||
|
self.vq_interface = vq_interface # TODO: Should be true by default but check to not break older stuff
|
||
|
super().__init__()
|
||
|
|
||
|
def encode(self, x, *args, **kwargs):
|
||
|
return x
|
||
|
|
||
|
def decode(self, x, *args, **kwargs):
|
||
|
return x
|
||
|
|
||
|
def quantize(self, x, *args, **kwargs):
|
||
|
if self.vq_interface:
|
||
|
return x, None, [None, None, None]
|
||
|
return x
|
||
|
|
||
|
def forward(self, x, *args, **kwargs):
|
||
|
return x
|