from contextlib import contextmanager import lightning.pytorch as pl import torch from ldm.modules.diffusionmodules.model import Decoder, Encoder from ldm.modules.distributions.distributions import DiagonalGaussianDistribution from ldm.modules.ema import LitEma from torch.nn import Identity from torch.nn import functional as F class AutoencoderKL(pl.LightningModule): def __init__( self, ddconfig, lossconfig, embed_dim, ckpt_path=None, ignore_keys=[], image_key="image", colorize_nlabels=None, monitor=None, ema_decay=None, learn_logvar=False, ): super().__init__() self.learn_logvar = learn_logvar self.image_key = image_key self.encoder = Encoder(**ddconfig) self.decoder = Decoder(**ddconfig) self.loss = Identity() assert ddconfig["double_z"] self.quant_conv = torch.nn.Conv2d(2 * ddconfig["z_channels"], 2 * embed_dim, 1) self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1) self.embed_dim = embed_dim if colorize_nlabels is not None: assert type(colorize_nlabels) == int self.register_buffer("colorize", torch.randn(3, colorize_nlabels, 1, 1)) if monitor is not None: self.monitor = monitor self.use_ema = ema_decay is not None if self.use_ema: self.ema_decay = ema_decay assert 0.0 < ema_decay < 1.0 self.model_ema = LitEma(self, decay=ema_decay) print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.") if ckpt_path is not None: self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys) def init_from_ckpt(self, path, ignore_keys=list()): sd = torch.load(path, map_location="cpu")["state_dict"] keys = list(sd.keys()) for k in keys: for ik in ignore_keys: if k.startswith(ik): print("Deleting key {} from state_dict.".format(k)) del sd[k] self.load_state_dict(sd, strict=False) print(f"Restored from {path}") @contextmanager def ema_scope(self, context=None): if self.use_ema: self.model_ema.store(self.parameters()) self.model_ema.copy_to(self) if context is not None: print(f"{context}: Switched to EMA weights") try: yield None finally: if self.use_ema: self.model_ema.restore(self.parameters()) if context is not None: print(f"{context}: Restored training weights") def on_train_batch_end(self, *args, **kwargs): if self.use_ema: self.model_ema(self) def encode(self, x): h = self.encoder(x) moments = self.quant_conv(h) posterior = DiagonalGaussianDistribution(moments) return posterior def decode(self, z): z = self.post_quant_conv(z) dec = self.decoder(z) return dec def forward(self, input, sample_posterior=True): posterior = self.encode(input) if sample_posterior: z = posterior.sample() else: z = posterior.mode() dec = self.decode(z) return dec, posterior def get_input(self, batch, k): x = batch[k] if len(x.shape) == 3: x = x[..., None] x = x.permute(0, 3, 1, 2).to(memory_format=torch.contiguous_format).float() return x def training_step(self, batch, batch_idx, optimizer_idx): inputs = self.get_input(batch, self.image_key) reconstructions, posterior = self(inputs) if optimizer_idx == 0: # train encoder+decoder+logvar aeloss, log_dict_ae = self.loss( inputs, reconstructions, posterior, optimizer_idx, self.global_step, last_layer=self.get_last_layer(), split="train", ) self.log("aeloss", aeloss, prog_bar=True, logger=True, on_step=True, on_epoch=True) self.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=False) return aeloss if optimizer_idx == 1: # train the discriminator discloss, log_dict_disc = self.loss( inputs, reconstructions, posterior, optimizer_idx, self.global_step, last_layer=self.get_last_layer(), split="train", ) self.log("discloss", discloss, prog_bar=True, logger=True, on_step=True, on_epoch=True) self.log_dict(log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=False) return discloss def validation_step(self, batch, batch_idx): log_dict = self._validation_step(batch, batch_idx) with self.ema_scope(): log_dict_ema = self._validation_step(batch, batch_idx, postfix="_ema") return log_dict def _validation_step(self, batch, batch_idx, postfix=""): inputs = self.get_input(batch, self.image_key) reconstructions, posterior = self(inputs) aeloss, log_dict_ae = self.loss( inputs, reconstructions, posterior, 0, self.global_step, last_layer=self.get_last_layer(), split="val" + postfix, ) discloss, log_dict_disc = self.loss( inputs, reconstructions, posterior, 1, self.global_step, last_layer=self.get_last_layer(), split="val" + postfix, ) self.log(f"val{postfix}/rec_loss", log_dict_ae[f"val{postfix}/rec_loss"]) self.log_dict(log_dict_ae) self.log_dict(log_dict_disc) return self.log_dict def configure_optimizers(self): lr = self.learning_rate ae_params_list = ( list(self.encoder.parameters()) + list(self.decoder.parameters()) + list(self.quant_conv.parameters()) + list(self.post_quant_conv.parameters()) ) if self.learn_logvar: print(f"{self.__class__.__name__}: Learning logvar") ae_params_list.append(self.loss.logvar) opt_ae = torch.optim.Adam(ae_params_list, lr=lr, betas=(0.5, 0.9)) opt_disc = torch.optim.Adam(self.loss.discriminator.parameters(), lr=lr, betas=(0.5, 0.9)) return [opt_ae, opt_disc], [] def get_last_layer(self): return self.decoder.conv_out.weight @torch.no_grad() def log_images(self, batch, only_inputs=False, log_ema=False, **kwargs): log = dict() x = self.get_input(batch, self.image_key) x = x.to(self.device) if not only_inputs: xrec, posterior = self(x) if x.shape[1] > 3: # colorize with random projection assert xrec.shape[1] > 3 x = self.to_rgb(x) xrec = self.to_rgb(xrec) log["samples"] = self.decode(torch.randn_like(posterior.sample())) log["reconstructions"] = xrec if log_ema or self.use_ema: with self.ema_scope(): xrec_ema, posterior_ema = self(x) if x.shape[1] > 3: # colorize with random projection assert xrec_ema.shape[1] > 3 xrec_ema = self.to_rgb(xrec_ema) log["samples_ema"] = self.decode(torch.randn_like(posterior_ema.sample())) log["reconstructions_ema"] = xrec_ema log["inputs"] = x return log def to_rgb(self, x): assert self.image_key == "segmentation" if not hasattr(self, "colorize"): self.register_buffer("colorize", torch.randn(3, x.shape[1], 1, 1).to(x)) x = F.conv2d(x, weight=self.colorize) x = 2.0 * (x - x.min()) / (x.max() - x.min()) - 1.0 return x class IdentityFirstStage(torch.nn.Module): def __init__(self, *args, vq_interface=False, **kwargs): self.vq_interface = vq_interface super().__init__() def encode(self, x, *args, **kwargs): return x def decode(self, x, *args, **kwargs): return x def quantize(self, x, *args, **kwargs): if self.vq_interface: return x, None, [None, None, None] return x def forward(self, x, *args, **kwargs): return x