From cad1f505125ab3d74f0390a3e3dc5796e5cd790f Mon Sep 17 00:00:00 2001 From: Fazzie <1240419984@qq.com> Date: Fri, 3 Feb 2023 15:34:54 +0800 Subject: [PATCH] fix ckpt --- examples/images/diffusion/README.md | 35 +- .../Teyvat/train_colossalai_teyvat.yaml | 3 +- .../diffusion/ldm/models/diffusion/ddpm.py | 820 +++++++++++------- .../ldm/modules/diffusionmodules/model.py | 545 ++++++------ examples/images/diffusion/main.py | 78 +- examples/images/diffusion/scripts/txt2img.sh | 6 +- examples/images/diffusion/train_colossalai.sh | 2 +- 7 files changed, 831 insertions(+), 658 deletions(-) diff --git a/examples/images/diffusion/README.md b/examples/images/diffusion/README.md index b68347c00..bec1c7503 100644 --- a/examples/images/diffusion/README.md +++ b/examples/images/diffusion/README.md @@ -53,27 +53,33 @@ You can also update an existing [latent diffusion](https://github.com/CompVis/la ``` conda install pytorch==1.12.1 torchvision==0.13.1 torchaudio==0.12.1 cudatoolkit=11.3 -c pytorch pip install transformers==4.19.2 diffusers invisible-watermark -pip install -e . ``` #### Step 2: install lightning Install Lightning version later than 2022.01.04. We suggest you install lightning from source. +##### From Source ``` git clone https://github.com/Lightning-AI/lightning.git pip install -r requirements.txt python setup.py install ``` +##### From pip + +``` +pip install pytorch-lightning +``` + #### Step 3:Install [Colossal-AI](https://colossalai.org/download/) From Our Official Website ##### From pip -For example, you can install v0.1.12 from our official website. +For example, you can install v0.2.0 from our official website. ``` -pip install colossalai==0.1.12+torch1.12cu11.3 -f https://release.colossalai.org +pip install colossalai==0.2.0+torch1.12cu11.3 -f https://release.colossalai.org ``` ##### From source @@ -133,10 +139,9 @@ It is important for you to configure your volume mapping in order to get the bes 3. **Optional**, if you encounter any problem stating that shared memory is insufficient inside container, please add `-v /dev/shm:/dev/shm` to your `docker run` command. - ## Download the model checkpoint from pretrained -### stable-diffusion-v2-base +### stable-diffusion-v2-base(Recommand) ``` wget https://huggingface.co/stabilityai/stable-diffusion-2-base/resolve/main/512-base-ema.ckpt @@ -144,8 +149,6 @@ wget https://huggingface.co/stabilityai/stable-diffusion-2-base/resolve/main/512 ### stable-diffusion-v1-4 -Our default model config use the weight from [CompVis/stable-diffusion-v1-4](https://huggingface.co/CompVis/stable-diffusion-v1-4?text=A+mecha+robot+in+a+favela+in+expressionist+style) - ``` git lfs install git clone https://huggingface.co/CompVis/stable-diffusion-v1-4 @@ -153,8 +156,6 @@ git clone https://huggingface.co/CompVis/stable-diffusion-v1-4 ### stable-diffusion-v1-5 from runway -If you want to useed the Last [stable-diffusion-v1-5](https://huggingface.co/runwayml/stable-diffusion-v1-5) weight from runwayml - ``` git lfs install git clone https://huggingface.co/runwayml/stable-diffusion-v1-5 @@ -171,11 +172,16 @@ We provide the script `train_colossalai.sh` to run the training task with coloss and can also use `train_ddp.sh` to run the training task with ddp to compare. In `train_colossalai.sh` the main command is: + ``` -python main.py --logdir /tmp/ -t -b configs/train_colossalai.yaml +python main.py --logdir /tmp/ --train --base configs/train_colossalai.yaml --ckpt 512-base-ema.ckpt ``` -- you can change the `--logdir` to decide where to save the log information and the last checkpoint. +- You can change the `--logdir` to decide where to save the log information and the last checkpoint. + - You will find your ckpt in `logdir/checkpoints` or `logdir/diff_tb/version_0/checkpoints` + - You will find your train config yaml in `logdir/configs` +- You can add the `--ckpt` if you want to load the pretrained model, for example `512-base-ema.ckpt` +- You can change the `--base` to specify the path of config yaml ### Training config @@ -186,7 +192,8 @@ You can change the trainging config in the yaml file - precision: the precision type used in training, default 16 (fp16), you must use fp16 if you want to apply colossalai - more information about the configuration of ColossalAIStrategy can be found [here](https://pytorch-lightning.readthedocs.io/en/latest/advanced/model_parallel.html#colossal-ai) -## Finetune Example (Work In Progress) + +## Finetune Example ### Training on Teyvat Datasets We provide the finetuning example on [Teyvat](https://huggingface.co/datasets/Fazzie/Teyvat) dataset, which is create by BLIP generated captions. @@ -201,8 +208,8 @@ you can get yout training last.ckpt and train config.yaml in your `--logdir`, an ``` python scripts/txt2img.py --prompt "a photograph of an astronaut riding a horse" --plms --outdir ./output \ - --config path/to/logdir/checkpoints/last.ckpt \ - --ckpt /path/to/logdir/configs/project.yaml \ + --ckpt path/to/logdir/checkpoints/last.ckpt \ + --config /path/to/logdir/configs/project.yaml \ ``` ```commandline diff --git a/examples/images/diffusion/configs/Teyvat/train_colossalai_teyvat.yaml b/examples/images/diffusion/configs/Teyvat/train_colossalai_teyvat.yaml index 8a8250c5d..ff0f4c5a0 100644 --- a/examples/images/diffusion/configs/Teyvat/train_colossalai_teyvat.yaml +++ b/examples/images/diffusion/configs/Teyvat/train_colossalai_teyvat.yaml @@ -6,6 +6,7 @@ model: linear_start: 0.00085 linear_end: 0.0120 num_timesteps_cond: 1 + ckpt: None # use ckpt path log_every_t: 200 timesteps: 1000 first_stage_key: image @@ -16,7 +17,7 @@ model: conditioning_key: crossattn monitor: val/loss_simple_ema scale_factor: 0.18215 - use_ema: False # we set this to false because this is an inference only config + use_ema: False scheduler_config: # 10000 warmup steps target: ldm.lr_scheduler.LambdaLinearScheduler diff --git a/examples/images/diffusion/ldm/models/diffusion/ddpm.py b/examples/images/diffusion/ldm/models/diffusion/ddpm.py index f7ac0a735..b7315b048 100644 --- a/examples/images/diffusion/ldm/models/diffusion/ddpm.py +++ b/examples/images/diffusion/ldm/models/diffusion/ddpm.py @@ -6,56 +6,41 @@ https://github.com/CompVis/taming-transformers -- merci """ +import numpy as np import torch import torch.nn as nn -import numpy as np + try: import lightning.pytorch as pl - from lightning.pytorch.utilities import rank_zero_only, rank_zero_info + from lightning.pytorch.utilities import rank_zero_info, rank_zero_only except: import pytorch_lightning as pl from pytorch_lightning.utilities import rank_zero_only, rank_zero_info -from torch.optim.lr_scheduler import LambdaLR -from einops import rearrange, repeat + +import itertools from contextlib import contextmanager, nullcontext from functools import partial -import itertools -from tqdm import tqdm -from torchvision.utils import make_grid -from omegaconf import ListConfig - -from ldm.util import log_txt_as_img, exists, default, ismap, isimage, mean_flat, count_params, instantiate_from_config -from ldm.modules.ema import LitEma -from ldm.modules.distributions.distributions import normal_kl, DiagonalGaussianDistribution -from ldm.models.autoencoder import IdentityFirstStage, AutoencoderKL - - -from ldm.modules.diffusionmodules.util import make_beta_schedule, extract_into_tensor, noise_like -from ldm.models.diffusion.ddim import DDIMSampler -from ldm.modules.diffusionmodules.openaimodel import * - -from ldm.modules.diffusionmodules.util import make_beta_schedule, extract_into_tensor, noise_like -from ldm.models.diffusion.ddim import DDIMSampler -from ldm.modules.diffusionmodules.openaimodel import AttentionPool2d -from ldm.modules.encoders.modules import * - -from ldm.modules.ema import LitEma -from ldm.modules.distributions.distributions import normal_kl, DiagonalGaussianDistribution +from einops import rearrange, repeat from ldm.models.autoencoder import * +from ldm.models.autoencoder import AutoencoderKL, IdentityFirstStage from ldm.models.diffusion.ddim import * -from ldm.modules.diffusionmodules.openaimodel import * +from ldm.models.diffusion.ddim import DDIMSampler from ldm.modules.diffusionmodules.model import * +from ldm.modules.diffusionmodules.model import Decoder, Encoder, Model +from ldm.modules.diffusionmodules.openaimodel import * +from ldm.modules.diffusionmodules.openaimodel import AttentionPool2d +from ldm.modules.diffusionmodules.util import extract_into_tensor, make_beta_schedule, noise_like +from ldm.modules.distributions.distributions import DiagonalGaussianDistribution, normal_kl +from ldm.modules.ema import LitEma +from ldm.modules.encoders.modules import * +from ldm.util import count_params, default, exists, instantiate_from_config, isimage, ismap, log_txt_as_img, mean_flat +from omegaconf import ListConfig +from torch.optim.lr_scheduler import LambdaLR +from torchvision.utils import make_grid +from tqdm import tqdm - -from ldm.modules.diffusionmodules.model import Model, Encoder, Decoder - -from ldm.util import instantiate_from_config - - -__conditioning_keys__ = {'concat': 'c_concat', - 'crossattn': 'c_crossattn', - 'adm': 'y'} +__conditioning_keys__ = {'concat': 'c_concat', 'crossattn': 'c_crossattn', 'adm': 'y'} def disabled_train(self, mode=True): @@ -70,40 +55,41 @@ def uniform_on_device(r1, r2, shape, device): class DDPM(pl.LightningModule): # classic DDPM with Gaussian diffusion, in image space - def __init__(self, - unet_config, - timesteps=1000, - beta_schedule="linear", - loss_type="l2", - ckpt_path=None, - ignore_keys=[], - load_only_unet=False, - monitor="val/loss", - use_ema=True, - first_stage_key="image", - image_size=256, - channels=3, - log_every_t=100, - clip_denoised=True, - linear_start=1e-4, - linear_end=2e-2, - cosine_s=8e-3, - given_betas=None, - original_elbo_weight=0., - v_posterior=0., # weight for choosing posterior variance as sigma = (1-v) * beta_tilde + v * beta - l_simple_weight=1., - conditioning_key=None, - parameterization="eps", # all assuming fixed variance schedules - scheduler_config=None, - use_positional_encodings=False, - learn_logvar=False, - logvar_init=0., - use_fp16 = True, - make_it_fit=False, - ucg_training=None, - reset_ema=False, - reset_num_ema_updates=False, - ): + def __init__( + self, + unet_config, + timesteps=1000, + beta_schedule="linear", + loss_type="l2", + ckpt=None, + ignore_keys=[], + load_only_unet=False, + monitor="val/loss", + use_ema=True, + first_stage_key="image", + image_size=256, + channels=3, + log_every_t=100, + clip_denoised=True, + linear_start=1e-4, + linear_end=2e-2, + cosine_s=8e-3, + given_betas=None, + original_elbo_weight=0., + v_posterior=0., # weight for choosing posterior variance as sigma = (1-v) * beta_tilde + v * beta + l_simple_weight=1., + conditioning_key=None, + parameterization="eps", # all assuming fixed variance schedules + scheduler_config=None, + use_positional_encodings=False, + learn_logvar=False, + logvar_init=0., + use_fp16=True, + make_it_fit=False, + ucg_training=None, + reset_ema=False, + reset_num_ema_updates=False, + ): super().__init__() assert parameterization in ["eps", "x0", "v"], 'currently only supporting "eps" and "x0" and "v"' self.parameterization = parameterization @@ -112,18 +98,18 @@ class DDPM(pl.LightningModule): self.clip_denoised = clip_denoised self.log_every_t = log_every_t self.first_stage_key = first_stage_key - self.image_size = image_size # try conv? + self.image_size = image_size self.channels = channels self.use_positional_encodings = use_positional_encodings self.unet_config = unet_config self.conditioning_key = conditioning_key self.model = DiffusionWrapper(unet_config, conditioning_key) - count_params(self.model, verbose=True) + # count_params(self.model, verbose=True) self.use_ema = use_ema if self.use_ema: self.model_ema = LitEma(self.model) - print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.") + rank_zero_info(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.") self.use_scheduler = scheduler_config is not None if self.use_scheduler: @@ -136,21 +122,26 @@ class DDPM(pl.LightningModule): if monitor is not None: self.monitor = monitor self.make_it_fit = make_it_fit - self.ckpt_path = ckpt_path + self.ckpt = ckpt self.ignore_keys = ignore_keys self.load_only_unet = load_only_unet self.reset_ema = reset_ema self.reset_num_ema_updates = reset_num_ema_updates - if reset_ema: assert exists(ckpt_path) - if ckpt_path is not None: - self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys, only_model=load_only_unet) - if reset_ema: - assert self.use_ema - print(f"Resetting ema to pure model weights. This is useful when restoring from an ema-only checkpoint.") - self.model_ema = LitEma(self.model) + if reset_ema: + assert exists(ckpt) + ''' + Uncomment if you Use DDP Strategy + ''' + # if ckpt is not None: + # self.init_from_ckpt(ckpt, ignore_keys=ignore_keys, only_model=load_only_unet) + # if reset_ema: + # assert self.use_ema + # rank_zero_info(f"Resetting ema to pure model weights. This is useful when restoring from an ema-only checkpoint.") + # self.model_ema = LitEma(self.model) + if reset_num_ema_updates: - print(" +++++++++++ WARNING: RESETTING NUM_EMA UPDATES TO ZERO +++++++++++ ") + rank_zero_info(" +++++++++++ WARNING: RESETTING NUM_EMA UPDATES TO ZERO +++++++++++ ") assert self.use_ema self.model_ema.reset_num_updates() @@ -160,9 +151,13 @@ class DDPM(pl.LightningModule): self.linear_start = linear_start self.linear_end = linear_end self.cosine_s = cosine_s - - self.register_schedule(given_betas=given_betas, beta_schedule=beta_schedule, timesteps=timesteps, - linear_start=linear_start, linear_end=linear_end, cosine_s=cosine_s) + + self.register_schedule(given_betas=given_betas, + beta_schedule=beta_schedule, + timesteps=timesteps, + linear_start=linear_start, + linear_end=linear_end, + cosine_s=cosine_s) self.loss_type = loss_type @@ -176,12 +171,20 @@ class DDPM(pl.LightningModule): if self.ucg_training: self.ucg_prng = np.random.RandomState() - def register_schedule(self, given_betas=None, beta_schedule="linear", timesteps=1000, - linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3): + def register_schedule(self, + given_betas=None, + beta_schedule="linear", + timesteps=1000, + linear_start=1e-4, + linear_end=2e-2, + cosine_s=8e-3): if exists(given_betas): betas = given_betas else: - betas = make_beta_schedule(beta_schedule, timesteps, linear_start=linear_start, linear_end=linear_end, + betas = make_beta_schedule(beta_schedule, + timesteps, + linear_start=linear_start, + linear_end=linear_end, cosine_s=cosine_s) alphas = 1. - betas alphas_cumprod = np.cumprod(alphas, axis=0) @@ -208,24 +211,23 @@ class DDPM(pl.LightningModule): # calculations for posterior q(x_{t-1} | x_t, x_0) posterior_variance = (1 - self.v_posterior) * betas * (1. - alphas_cumprod_prev) / ( - 1. - alphas_cumprod) + self.v_posterior * betas + 1. - alphas_cumprod) + self.v_posterior * betas # above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t) self.register_buffer('posterior_variance', to_torch(posterior_variance)) # below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain self.register_buffer('posterior_log_variance_clipped', to_torch(np.log(np.maximum(posterior_variance, 1e-20)))) - self.register_buffer('posterior_mean_coef1', to_torch( - betas * np.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod))) - self.register_buffer('posterior_mean_coef2', to_torch( - (1. - alphas_cumprod_prev) * np.sqrt(alphas) / (1. - alphas_cumprod))) + self.register_buffer('posterior_mean_coef1', + to_torch(betas * np.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod))) + self.register_buffer('posterior_mean_coef2', + to_torch((1. - alphas_cumprod_prev) * np.sqrt(alphas) / (1. - alphas_cumprod))) if self.parameterization == "eps": - lvlb_weights = self.betas ** 2 / ( - 2 * self.posterior_variance * to_torch(alphas) * (1 - self.alphas_cumprod)) + lvlb_weights = self.betas**2 / (2 * self.posterior_variance * to_torch(alphas) * (1 - self.alphas_cumprod)) elif self.parameterization == "x0": lvlb_weights = 0.5 * np.sqrt(torch.Tensor(alphas_cumprod)) / (2. * 1 - torch.Tensor(alphas_cumprod)) elif self.parameterization == "v": - lvlb_weights = torch.ones_like(self.betas ** 2 / ( - 2 * self.posterior_variance * to_torch(alphas) * (1 - self.alphas_cumprod))) + lvlb_weights = torch.ones_like(self.betas**2 / (2 * self.posterior_variance * to_torch(alphas) * + (1 - self.alphas_cumprod))) else: raise NotImplementedError("mu not supported") lvlb_weights[0] = lvlb_weights[1] @@ -238,14 +240,14 @@ class DDPM(pl.LightningModule): self.model_ema.store(self.model.parameters()) self.model_ema.copy_to(self.model) if context is not None: - print(f"{context}: Switched to EMA weights") + rank_zero_info(f"{context}: Switched to EMA weights") try: yield None finally: if self.use_ema: self.model_ema.restore(self.model.parameters()) if context is not None: - print(f"{context}: Restored training weights") + rank_zero_info(f"{context}: Restored training weights") @torch.no_grad() def init_from_ckpt(self, path, ignore_keys=list(), only_model=False): @@ -256,18 +258,13 @@ class DDPM(pl.LightningModule): for k in keys: for ik in ignore_keys: if k.startswith(ik): - print("Deleting key {} from state_dict.".format(k)) + rank_zero_info("Deleting key {} from state_dict.".format(k)) del sd[k] if self.make_it_fit: - n_params = len([name for name, _ in - itertools.chain(self.named_parameters(), - self.named_buffers())]) - for name, param in tqdm( - itertools.chain(self.named_parameters(), - self.named_buffers()), - desc="Fitting old weights to new weights", - total=n_params - ): + n_params = len([name for name, _ in itertools.chain(self.named_parameters(), self.named_buffers())]) + for name, param in tqdm(itertools.chain(self.named_parameters(), self.named_buffers()), + desc="Fitting old weights to new weights", + total=n_params): if not name in sd: continue old_shape = sd[name].shape @@ -304,11 +301,11 @@ class DDPM(pl.LightningModule): missing, unexpected = self.load_state_dict(sd, strict=False) if not only_model else self.model.load_state_dict( sd, strict=False) - print(f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys") + rank_zero_info(f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys") if len(missing) > 0: - print(f"Missing Keys:\n {missing}") + rank_zero_info(f"Missing Keys:\n {missing}") if len(unexpected) > 0: - print(f"\nUnexpected Keys:\n {unexpected}") + rank_zero_info(f"\nUnexpected Keys:\n {unexpected}") def q_mean_variance(self, x_start, t): """ @@ -323,30 +320,22 @@ class DDPM(pl.LightningModule): return mean, variance, log_variance def predict_start_from_noise(self, x_t, t, noise): - return ( - extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - - extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * noise - ) + return (extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - + extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * noise) def predict_start_from_z_and_v(self, x_t, t, v): # self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod))) # self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod))) - return ( - extract_into_tensor(self.sqrt_alphas_cumprod, t, x_t.shape) * x_t - - extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_t.shape) * v - ) + return (extract_into_tensor(self.sqrt_alphas_cumprod, t, x_t.shape) * x_t - + extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_t.shape) * v) def predict_eps_from_z_and_v(self, x_t, t, v): - return ( - extract_into_tensor(self.sqrt_alphas_cumprod, t, x_t.shape) * v + - extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_t.shape) * x_t - ) + return (extract_into_tensor(self.sqrt_alphas_cumprod, t, x_t.shape) * v + + extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_t.shape) * x_t) def q_posterior(self, x_start, x_t, t): - posterior_mean = ( - extract_into_tensor(self.posterior_mean_coef1, t, x_t.shape) * x_start + - extract_into_tensor(self.posterior_mean_coef2, t, x_t.shape) * x_t - ) + posterior_mean = (extract_into_tensor(self.posterior_mean_coef1, t, x_t.shape) * x_start + + extract_into_tensor(self.posterior_mean_coef2, t, x_t.shape) * x_t) posterior_variance = extract_into_tensor(self.posterior_variance, t, x_t.shape) posterior_log_variance_clipped = extract_into_tensor(self.posterior_log_variance_clipped, t, x_t.shape) return posterior_mean, posterior_variance, posterior_log_variance_clipped @@ -379,7 +368,8 @@ class DDPM(pl.LightningModule): img = torch.randn(shape, device=device) intermediates = [img] for i in tqdm(reversed(range(0, self.num_timesteps)), desc='Sampling t', total=self.num_timesteps): - img = self.p_sample(img, torch.full((b,), i, device=device, dtype=torch.long), + img = self.p_sample(img, + torch.full((b,), i, device=device, dtype=torch.long), clip_denoised=self.clip_denoised) if i % self.log_every_t == 0 or i == self.num_timesteps - 1: intermediates.append(img) @@ -400,10 +390,8 @@ class DDPM(pl.LightningModule): extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise) def get_v(self, x, noise, t): - return ( - extract_into_tensor(self.sqrt_alphas_cumprod, t, x.shape) * noise - - extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x.shape) * x - ) + return (extract_into_tensor(self.sqrt_alphas_cumprod, t, x.shape) * noise - + extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x.shape) * x) def get_loss(self, pred, target, mean=True): if self.loss_type == 'l1': @@ -485,11 +473,9 @@ class DDPM(pl.LightningModule): loss, loss_dict = self.shared_step(batch) - self.log_dict(loss_dict, prog_bar=True, - logger=True, on_step=True, on_epoch=True) + self.log_dict(loss_dict, prog_bar=True, logger=True, on_step=True, on_epoch=True) - self.log("global_step", self.global_step, - prog_bar=True, logger=True, on_step=True, on_epoch=False) + self.log("global_step", self.global_step, prog_bar=True, logger=True, on_step=True, on_epoch=False) if self.use_scheduler: lr = self.optimizers().param_groups[0]['lr'] @@ -580,7 +566,8 @@ class LatentDiffusion(DDPM): scale_by_std=False, use_fp16=True, force_null_conditioning=False, - *args, **kwargs): + *args, + **kwargs): self.force_null_conditioning = force_null_conditioning self.num_timesteps_cond = default(num_timesteps_cond, 1) self.scale_by_std = scale_by_std @@ -590,7 +577,7 @@ class LatentDiffusion(DDPM): conditioning_key = 'concat' if concat_mode else 'crossattn' if cond_stage_config == '__is_unconditional__' and not self.force_null_conditioning: conditioning_key = None - + super().__init__(conditioning_key=conditioning_key, *args, **kwargs) self.concat_mode = concat_mode self.cond_stage_trainable = cond_stage_trainable @@ -599,7 +586,7 @@ class LatentDiffusion(DDPM): self.num_downs = len(first_stage_config.params.ddconfig.ch_mult) - 1 except: self.num_downs = 0 - + if not scale_by_std: self.scale_factor = scale_factor else: @@ -611,40 +598,44 @@ class LatentDiffusion(DDPM): self.cond_stage_forward = cond_stage_forward self.clip_denoised = False self.bbox_tokenizer = None - - self.restarted_from_ckpt = False - if self.ckpt_path is not None: - self.init_from_ckpt(self.ckpt_path, self.ignore_keys) - self.restarted_from_ckpt = True - if self.reset_ema: - assert self.use_ema - print( - f"Resetting ema to pure model weights. This is useful when restoring from an ema-only checkpoint.") - self.model_ema = LitEma(self.model) + ''' + Uncomment if you Use DDP Strategy + ''' + # self.restarted_from_ckpt = False + # if self.ckpt is not None: + # self.init_from_ckpt(self.ckpt, self.ignore_keys) + # self.restarted_from_ckpt = True + # if self.reset_ema: + # assert self.use_ema + # rank_zero_info( + # f"Resetting ema to pure model weights. This is useful when restoring from an ema-only checkpoint.") + # self.model_ema = LitEma(self.model) if self.reset_num_ema_updates: - print(" +++++++++++ WARNING: RESETTING NUM_EMA UPDATES TO ZERO +++++++++++ ") + rank_zero_info(" +++++++++++ WARNING: RESETTING NUM_EMA UPDATES TO ZERO +++++++++++ ") assert self.use_ema self.model_ema.reset_num_updates() def configure_sharded_model(self) -> None: rank_zero_info("Configure sharded model for LatentDiffusion") self.model = DiffusionWrapper(self.unet_config, self.conditioning_key) + count_params(self.model, verbose=True) if self.use_ema: self.model_ema = LitEma(self.model) - if self.ckpt_path is not None: - self.init_from_ckpt(self.ckpt_path, ignore_keys=self.ignore_keys, only_model=self.load_only_unet) + if self.ckpt is not None: + self.init_from_ckpt(self.ckpt, ignore_keys=self.ignore_keys, only_model=self.load_only_unet) if self.reset_ema: assert self.use_ema - print(f"Resetting ema to pure model weights. This is useful when restoring from an ema-only checkpoint.") + rank_zero_info( + f"Resetting ema to pure model weights. This is useful when restoring from an ema-only checkpoint.") self.model_ema = LitEma(self.model) - if self.reset_num_ema_updates: - print(" +++++++++++ WARNING: RESETTING NUM_EMA UPDATES TO ZERO +++++++++++ ") - assert self.use_ema - self.model_ema.reset_num_updates() - self.register_schedule(given_betas=self.given_betas, beta_schedule=self.beta_schedule, timesteps=self.timesteps, - linear_start=self.linear_start, linear_end=self.linear_end, cosine_s=self.cosine_s) + self.register_schedule(given_betas=self.given_betas, + beta_schedule=self.beta_schedule, + timesteps=self.timesteps, + linear_start=self.linear_start, + linear_end=self.linear_end, + cosine_s=self.cosine_s) self.logvar = torch.full(fill_value=self.logvar_init, size=(self.num_timesteps,)) if self.learn_logvar: @@ -654,20 +645,16 @@ class LatentDiffusion(DDPM): self.instantiate_first_stage(self.first_stage_config) self.instantiate_cond_stage(self.cond_stage_config) - if self.ckpt_path is not None: - self.init_from_ckpt(self.ckpt_path, self.ignore_keys) + if self.ckpt is not None: + self.init_from_ckpt(self.ckpt, self.ignore_keys) self.restarted_from_ckpt = True if self.reset_ema: assert self.use_ema - print( + rank_zero_info( f"Resetting ema to pure model weights. This is useful when restoring from an ema-only checkpoint.") self.model_ema = LitEma(self.model) - if self.reset_num_ema_updates: - print(" +++++++++++ WARNING: RESETTING NUM_EMA UPDATES TO ZERO +++++++++++ ") - assert self.use_ema - self.model_ema.reset_num_updates() - def make_cond_schedule(self, ): + def make_cond_schedule(self,): self.cond_ids = torch.full(size=(self.num_timesteps,), fill_value=self.num_timesteps - 1, dtype=torch.long) ids = torch.round(torch.linspace(0, self.num_timesteps - 1, self.num_timesteps_cond)).long() self.cond_ids[:self.num_timesteps_cond] = ids @@ -679,19 +666,23 @@ class LatentDiffusion(DDPM): if self.scale_by_std and self.current_epoch == 0 and self.global_step == 0 and batch_idx == 0 and not self.restarted_from_ckpt: assert self.scale_factor == 1., 'rather not use custom rescaling and std-rescaling simultaneously' # set rescale weight to 1./std of encodings - print("### USING STD-RESCALING ###") + rank_zero_info("### USING STD-RESCALING ###") x = super().get_input(batch, self.first_stage_key) x = x.to(self.device) encoder_posterior = self.encode_first_stage(x) z = self.get_first_stage_encoding(encoder_posterior).detach() del self.scale_factor self.register_buffer('scale_factor', 1. / z.flatten().std()) - print(f"setting self.scale_factor to {self.scale_factor}") - print("### USING STD-RESCALING ###") + rank_zero_info(f"setting self.scale_factor to {self.scale_factor}") + rank_zero_info("### USING STD-RESCALING ###") def register_schedule(self, - given_betas=None, beta_schedule="linear", timesteps=1000, - linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3): + given_betas=None, + beta_schedule="linear", + timesteps=1000, + linear_start=1e-4, + linear_end=2e-2, + cosine_s=8e-3): super().register_schedule(given_betas, beta_schedule, timesteps, linear_start, linear_end, cosine_s) self.shorten_cond_schedule = self.num_timesteps_cond > 1 @@ -708,10 +699,10 @@ class LatentDiffusion(DDPM): def instantiate_cond_stage(self, config): if not self.cond_stage_trainable: if config == "__is_first_stage__": - print("Using first stage also as cond stage.") + rank_zero_info("Using first stage also as cond stage.") self.cond_stage_model = self.first_stage_model elif config == "__is_unconditional__": - print(f"Training {self.__class__.__name__} as an unconditional model.") + rank_zero_info(f"Training {self.__class__.__name__} as an unconditional model.") self.cond_stage_model = None # self.be_unconditional = True else: @@ -729,10 +720,10 @@ class LatentDiffusion(DDPM): def _get_denoise_row_from_list(self, samples, desc='', force_no_decoder_quantization=False): denoise_row = [] for zd in tqdm(samples, desc=desc): - denoise_row.append(self.decode_first_stage(zd.to(self.device), - force_not_quantize=force_no_decoder_quantization)) + denoise_row.append( + self.decode_first_stage(zd.to(self.device), force_not_quantize=force_no_decoder_quantization)) n_imgs_per_row = len(denoise_row) - denoise_row = torch.stack(denoise_row) # n_log_step, n_row, C, H, W + denoise_row = torch.stack(denoise_row) # n_log_step, n_row, C, H, W denoise_grid = rearrange(denoise_row, 'n b c h w -> b n c h w') denoise_grid = rearrange(denoise_grid, 'b n c h w -> (b n) c h w') denoise_grid = make_grid(denoise_grid, nrow=n_imgs_per_row) @@ -783,21 +774,23 @@ class LatentDiffusion(DDPM): def get_weighting(self, h, w, Ly, Lx, device): weighting = self.delta_border(h, w) - weighting = torch.clip(weighting, self.split_input_params["clip_min_weight"], - self.split_input_params["clip_max_weight"], ) + weighting = torch.clip( + weighting, + self.split_input_params["clip_min_weight"], + self.split_input_params["clip_max_weight"], + ) weighting = weighting.view(1, h * w, 1).repeat(1, 1, Ly * Lx).to(device) if self.split_input_params["tie_braker"]: L_weighting = self.delta_border(Ly, Lx) - L_weighting = torch.clip(L_weighting, - self.split_input_params["clip_min_tie_weight"], + L_weighting = torch.clip(L_weighting, self.split_input_params["clip_min_tie_weight"], self.split_input_params["clip_max_tie_weight"]) L_weighting = L_weighting.view(1, 1, Ly * Lx).to(device) weighting = weighting * L_weighting return weighting - def get_fold_unfold(self, x, kernel_size, stride, uf=1, df=1): # todo load once not every time, shorten code + def get_fold_unfold(self, x, kernel_size, stride, uf=1, df=1): # todo load once not every time, shorten code """ :param x: img of size (bs, c, h, w) :return: n img crops of size (n, bs, c, kernel_size[0], kernel_size[1]) @@ -815,7 +808,7 @@ class LatentDiffusion(DDPM): fold = torch.nn.Fold(output_size=x.shape[2:], **fold_params) weighting = self.get_weighting(kernel_size[0], kernel_size[1], Ly, Lx, x.device).to(x.dtype) - normalization = fold(weighting).view(1, 1, h, w) # normalizes the overlap + normalization = fold(weighting).view(1, 1, h, w) # normalizes the overlap weighting = weighting.view((1, 1, kernel_size[0], kernel_size[1], Ly * Lx)) elif uf > 1 and df == 1: @@ -823,12 +816,13 @@ class LatentDiffusion(DDPM): unfold = torch.nn.Unfold(**fold_params) fold_params2 = dict(kernel_size=(kernel_size[0] * uf, kernel_size[0] * uf), - dilation=1, padding=0, + dilation=1, + padding=0, stride=(stride[0] * uf, stride[1] * uf)) fold = torch.nn.Fold(output_size=(x.shape[2] * uf, x.shape[3] * uf), **fold_params2) weighting = self.get_weighting(kernel_size[0] * uf, kernel_size[1] * uf, Ly, Lx, x.device).to(x.dtype) - normalization = fold(weighting).view(1, 1, h * uf, w * uf) # normalizes the overlap + normalization = fold(weighting).view(1, 1, h * uf, w * uf) # normalizes the overlap weighting = weighting.view((1, 1, kernel_size[0] * uf, kernel_size[1] * uf, Ly * Lx)) elif df > 1 and uf == 1: @@ -836,12 +830,13 @@ class LatentDiffusion(DDPM): unfold = torch.nn.Unfold(**fold_params) fold_params2 = dict(kernel_size=(kernel_size[0] // df, kernel_size[0] // df), - dilation=1, padding=0, + dilation=1, + padding=0, stride=(stride[0] // df, stride[1] // df)) fold = torch.nn.Fold(output_size=(x.shape[2] // df, x.shape[3] // df), **fold_params2) weighting = self.get_weighting(kernel_size[0] // df, kernel_size[1] // df, Ly, Lx, x.device).to(x.dtype) - normalization = fold(weighting).view(1, 1, h // df, w // df) # normalizes the overlap + normalization = fold(weighting).view(1, 1, h // df, w // df) # normalizes the overlap weighting = weighting.view((1, 1, kernel_size[0] // df, kernel_size[1] // df, Ly * Lx)) else: @@ -850,8 +845,15 @@ class LatentDiffusion(DDPM): return fold, unfold, normalization, weighting @torch.no_grad() - def get_input(self, batch, k, return_first_stage_outputs=False, force_c_encode=False, - cond_key=None, return_original_cond=False, bs=None, return_x=False): + def get_input(self, + batch, + k, + return_first_stage_outputs=False, + force_c_encode=False, + cond_key=None, + return_original_cond=False, + bs=None, + return_x=False): x = super().get_input(batch, k) if bs is not None: x = x[:bs] @@ -900,7 +902,7 @@ class LatentDiffusion(DDPM): out.extend([x]) if return_original_cond: out.append(xc) - + return out @torch.no_grad() @@ -929,7 +931,7 @@ class LatentDiffusion(DDPM): assert c is not None if self.cond_stage_trainable: c = self.get_learned_conditioning(c) - if self.shorten_cond_schedule: # TODO: drop this option + if self.shorten_cond_schedule: # TODO: drop this option tc = self.cond_ids[t].to(self.device) c = self.q_sample(x_start=c, t=tc, noise=torch.randn_like(c.float())) return self.p_losses(x, c, t, *args, **kwargs) @@ -1007,8 +1009,16 @@ class LatentDiffusion(DDPM): return loss, loss_dict - def p_mean_variance(self, x, c, t, clip_denoised: bool, return_codebook_ids=False, quantize_denoised=False, - return_x0=False, score_corrector=None, corrector_kwargs=None): + def p_mean_variance(self, + x, + c, + t, + clip_denoised: bool, + return_codebook_ids=False, + quantize_denoised=False, + return_x0=False, + score_corrector=None, + corrector_kwargs=None): t_in = t model_out = self.apply_model(x, t_in, c, return_ids=return_codebook_ids) @@ -1039,15 +1049,29 @@ class LatentDiffusion(DDPM): return model_mean, posterior_variance, posterior_log_variance @torch.no_grad() - def p_sample(self, x, c, t, clip_denoised=False, repeat_noise=False, - return_codebook_ids=False, quantize_denoised=False, return_x0=False, - temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None): + def p_sample(self, + x, + c, + t, + clip_denoised=False, + repeat_noise=False, + return_codebook_ids=False, + quantize_denoised=False, + return_x0=False, + temperature=1., + noise_dropout=0., + score_corrector=None, + corrector_kwargs=None): b, *_, device = *x.shape, x.device - outputs = self.p_mean_variance(x=x, c=c, t=t, clip_denoised=clip_denoised, + outputs = self.p_mean_variance(x=x, + c=c, + t=t, + clip_denoised=clip_denoised, return_codebook_ids=return_codebook_ids, quantize_denoised=quantize_denoised, return_x0=return_x0, - score_corrector=score_corrector, corrector_kwargs=corrector_kwargs) + score_corrector=score_corrector, + corrector_kwargs=corrector_kwargs) if return_codebook_ids: raise DeprecationWarning("Support dropped.") model_mean, _, model_log_variance, logits = outputs @@ -1070,9 +1094,22 @@ class LatentDiffusion(DDPM): return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise @torch.no_grad() - def progressive_denoising(self, cond, shape, verbose=True, callback=None, quantize_denoised=False, - img_callback=None, mask=None, x0=None, temperature=1., noise_dropout=0., - score_corrector=None, corrector_kwargs=None, batch_size=None, x_T=None, start_T=None, + def progressive_denoising(self, + cond, + shape, + verbose=True, + callback=None, + quantize_denoised=False, + img_callback=None, + mask=None, + x0=None, + temperature=1., + noise_dropout=0., + score_corrector=None, + corrector_kwargs=None, + batch_size=None, + x_T=None, + start_T=None, log_every_t=None): if not log_every_t: log_every_t = self.log_every_t @@ -1089,16 +1126,17 @@ class LatentDiffusion(DDPM): intermediates = [] if cond is not None: if isinstance(cond, dict): - cond = {key: cond[key][:batch_size] if not isinstance(cond[key], list) else - list(map(lambda x: x[:batch_size], cond[key])) for key in cond} + cond = { + key: cond[key][:batch_size] if not isinstance(cond[key], list) else list( + map(lambda x: x[:batch_size], cond[key])) for key in cond + } else: cond = [c[:batch_size] for c in cond] if isinstance(cond, list) else cond[:batch_size] if start_T is not None: timesteps = min(timesteps, start_T) iterator = tqdm(reversed(range(0, timesteps)), desc='Progressive Generation', - total=timesteps) if verbose else reversed( - range(0, timesteps)) + total=timesteps) if verbose else reversed(range(0, timesteps)) if type(temperature) == float: temperature = [temperature] * timesteps @@ -1109,11 +1147,16 @@ class LatentDiffusion(DDPM): tc = self.cond_ids[ts].to(cond.device) cond = self.q_sample(x_start=cond, t=tc, noise=torch.randn_like(cond)) - img, x0_partial = self.p_sample(img, cond, ts, + img, x0_partial = self.p_sample(img, + cond, + ts, clip_denoised=self.clip_denoised, - quantize_denoised=quantize_denoised, return_x0=True, - temperature=temperature[i], noise_dropout=noise_dropout, - score_corrector=score_corrector, corrector_kwargs=corrector_kwargs) + quantize_denoised=quantize_denoised, + return_x0=True, + temperature=temperature[i], + noise_dropout=noise_dropout, + score_corrector=score_corrector, + corrector_kwargs=corrector_kwargs) if mask is not None: assert x0 is not None img_orig = self.q_sample(x0, ts) @@ -1121,14 +1164,26 @@ class LatentDiffusion(DDPM): if i % log_every_t == 0 or i == timesteps - 1: intermediates.append(x0_partial) - if callback: callback(i) - if img_callback: img_callback(img, i) + if callback: + callback(i) + if img_callback: + img_callback(img, i) return img, intermediates @torch.no_grad() - def p_sample_loop(self, cond, shape, return_intermediates=False, - x_T=None, verbose=True, callback=None, timesteps=None, quantize_denoised=False, - mask=None, x0=None, img_callback=None, start_T=None, + def p_sample_loop(self, + cond, + shape, + return_intermediates=False, + x_T=None, + verbose=True, + callback=None, + timesteps=None, + quantize_denoised=False, + mask=None, + x0=None, + img_callback=None, + start_T=None, log_every_t=None): if not log_every_t: @@ -1151,7 +1206,7 @@ class LatentDiffusion(DDPM): if mask is not None: assert x0 is not None - assert x0.shape[2:3] == mask.shape[2:3] # spatial size has to match + assert x0.shape[2:3] == mask.shape[2:3] # spatial size has to match for i in iterator: ts = torch.full((b,), i, device=device, dtype=torch.long) @@ -1160,51 +1215,64 @@ class LatentDiffusion(DDPM): tc = self.cond_ids[ts].to(cond.device) cond = self.q_sample(x_start=cond, t=tc, noise=torch.randn_like(cond)) - img = self.p_sample(img, cond, ts, - clip_denoised=self.clip_denoised, - quantize_denoised=quantize_denoised) + img = self.p_sample(img, cond, ts, clip_denoised=self.clip_denoised, quantize_denoised=quantize_denoised) if mask is not None: img_orig = self.q_sample(x0, ts) img = img_orig * mask + (1. - mask) * img if i % log_every_t == 0 or i == timesteps - 1: intermediates.append(img) - if callback: callback(i) - if img_callback: img_callback(img, i) + if callback: + callback(i) + if img_callback: + img_callback(img, i) if return_intermediates: return img, intermediates return img @torch.no_grad() - def sample(self, cond, batch_size=16, return_intermediates=False, x_T=None, - verbose=True, timesteps=None, quantize_denoised=False, - mask=None, x0=None, shape=None, **kwargs): + def sample(self, + cond, + batch_size=16, + return_intermediates=False, + x_T=None, + verbose=True, + timesteps=None, + quantize_denoised=False, + mask=None, + x0=None, + shape=None, + **kwargs): if shape is None: shape = (batch_size, self.channels, self.image_size, self.image_size) if cond is not None: if isinstance(cond, dict): - cond = {key: cond[key][:batch_size] if not isinstance(cond[key], list) else - list(map(lambda x: x[:batch_size], cond[key])) for key in cond} + cond = { + key: cond[key][:batch_size] if not isinstance(cond[key], list) else list( + map(lambda x: x[:batch_size], cond[key])) for key in cond + } else: cond = [c[:batch_size] for c in cond] if isinstance(cond, list) else cond[:batch_size] return self.p_sample_loop(cond, shape, - return_intermediates=return_intermediates, x_T=x_T, - verbose=verbose, timesteps=timesteps, quantize_denoised=quantize_denoised, - mask=mask, x0=x0) + return_intermediates=return_intermediates, + x_T=x_T, + verbose=verbose, + timesteps=timesteps, + quantize_denoised=quantize_denoised, + mask=mask, + x0=x0) @torch.no_grad() def sample_log(self, cond, batch_size, ddim, ddim_steps, **kwargs): if ddim: ddim_sampler = DDIMSampler(self) shape = (self.channels, self.image_size, self.image_size) - samples, intermediates = ddim_sampler.sample(ddim_steps, batch_size, - shape, cond, verbose=False, **kwargs) + samples, intermediates = ddim_sampler.sample(ddim_steps, batch_size, shape, cond, verbose=False, **kwargs) else: - samples, intermediates = self.sample(cond=cond, batch_size=batch_size, - return_intermediates=True, **kwargs) + samples, intermediates = self.sample(cond=cond, batch_size=batch_size, return_intermediates=True, **kwargs) return samples, intermediates @@ -1226,7 +1294,7 @@ class LatentDiffusion(DDPM): return self.get_learned_conditioning(xc) else: raise NotImplementedError("todo") - if isinstance(c, list): # in case the encoder gives us a list + if isinstance(c, list): # in case the encoder gives us a list for i in range(len(c)): c[i] = repeat(c[i], '1 ... -> b ...', b=batch_size).to(self.device) else: @@ -1234,16 +1302,29 @@ class LatentDiffusion(DDPM): return c @torch.no_grad() - def log_images(self, batch, N=8, n_row=4, sample=True, ddim_steps=50, ddim_eta=0., return_keys=None, - quantize_denoised=True, inpaint=True, plot_denoise_rows=False, plot_progressive_rows=True, - plot_diffusion_rows=True, unconditional_guidance_scale=1., unconditional_guidance_label=None, + def log_images(self, + batch, + N=8, + n_row=4, + sample=True, + ddim_steps=50, + ddim_eta=0., + return_keys=None, + quantize_denoised=True, + inpaint=True, + plot_denoise_rows=False, + plot_progressive_rows=True, + plot_diffusion_rows=True, + unconditional_guidance_scale=1., + unconditional_guidance_label=None, use_ema_scope=True, **kwargs): ema_scope = self.ema_scope if use_ema_scope else nullcontext use_ddim = ddim_steps is not None log = dict() - z, c, x, xrec, xc = self.get_input(batch, self.first_stage_key, + z, c, x, xrec, xc = self.get_input(batch, + self.first_stage_key, return_first_stage_outputs=True, force_c_encode=True, return_original_cond=True, @@ -1283,7 +1364,7 @@ class LatentDiffusion(DDPM): z_noisy = self.q_sample(x_start=z_start, t=t, noise=noise) diffusion_row.append(self.decode_first_stage(z_noisy)) - diffusion_row = torch.stack(diffusion_row) # n_log_step, n_row, C, H, W + diffusion_row = torch.stack(diffusion_row) # n_log_step, n_row, C, H, W diffusion_grid = rearrange(diffusion_row, 'n b c h w -> b n c h w') diffusion_grid = rearrange(diffusion_grid, 'b n c h w -> (b n) c h w') diffusion_grid = make_grid(diffusion_grid, nrow=diffusion_row.shape[0]) @@ -1292,8 +1373,11 @@ class LatentDiffusion(DDPM): if sample: # get denoise row with ema_scope("Sampling"): - samples, z_denoise_row = self.sample_log(cond=c, batch_size=N, ddim=use_ddim, - ddim_steps=ddim_steps, eta=ddim_eta) + samples, z_denoise_row = self.sample_log(cond=c, + batch_size=N, + ddim=use_ddim, + ddim_steps=ddim_steps, + eta=ddim_eta) # samples, z_denoise_row = self.sample(cond=c, batch_size=N, return_intermediates=True) x_samples = self.decode_first_stage(samples) log["samples"] = x_samples @@ -1305,8 +1389,11 @@ class LatentDiffusion(DDPM): self.first_stage_model, IdentityFirstStage): # also display when quantizing x0 while sampling with ema_scope("Plotting Quantized Denoised"): - samples, z_denoise_row = self.sample_log(cond=c, batch_size=N, ddim=use_ddim, - ddim_steps=ddim_steps, eta=ddim_eta, + samples, z_denoise_row = self.sample_log(cond=c, + batch_size=N, + ddim=use_ddim, + ddim_steps=ddim_steps, + eta=ddim_eta, quantize_denoised=True) # samples, z_denoise_row = self.sample(cond=c, batch_size=N, return_intermediates=True, # quantize_denoised=True) @@ -1318,11 +1405,15 @@ class LatentDiffusion(DDPM): if self.model.conditioning_key == "crossattn-adm": uc = {"c_crossattn": [uc], "c_adm": c["c_adm"]} with ema_scope("Sampling with classifier-free guidance"): - samples_cfg, _ = self.sample_log(cond=c, batch_size=N, ddim=use_ddim, - ddim_steps=ddim_steps, eta=ddim_eta, - unconditional_guidance_scale=unconditional_guidance_scale, - unconditional_conditioning=uc, - ) + samples_cfg, _ = self.sample_log( + cond=c, + batch_size=N, + ddim=use_ddim, + ddim_steps=ddim_steps, + eta=ddim_eta, + unconditional_guidance_scale=unconditional_guidance_scale, + unconditional_conditioning=uc, + ) x_samples_cfg = self.decode_first_stage(samples_cfg) log[f"samples_cfg_scale_{unconditional_guidance_scale:.2f}"] = x_samples_cfg @@ -1334,8 +1425,13 @@ class LatentDiffusion(DDPM): mask[:, h // 4:3 * h // 4, w // 4:3 * w // 4] = 0. mask = mask[:, None, ...] with ema_scope("Plotting Inpaint"): - samples, _ = self.sample_log(cond=c, batch_size=N, ddim=use_ddim, eta=ddim_eta, - ddim_steps=ddim_steps, x0=z[:N], mask=mask) + samples, _ = self.sample_log(cond=c, + batch_size=N, + ddim=use_ddim, + eta=ddim_eta, + ddim_steps=ddim_steps, + x0=z[:N], + mask=mask) x_samples = self.decode_first_stage(samples.to(self.device)) log["samples_inpainting"] = x_samples log["mask"] = mask @@ -1343,8 +1439,13 @@ class LatentDiffusion(DDPM): # outpaint mask = 1. - mask with ema_scope("Plotting Outpaint"): - samples, _ = self.sample_log(cond=c, batch_size=N, ddim=use_ddim, eta=ddim_eta, - ddim_steps=ddim_steps, x0=z[:N], mask=mask) + samples, _ = self.sample_log(cond=c, + batch_size=N, + ddim=use_ddim, + eta=ddim_eta, + ddim_steps=ddim_steps, + x0=z[:N], + mask=mask) x_samples = self.decode_first_stage(samples.to(self.device)) log["samples_outpainting"] = x_samples @@ -1367,10 +1468,10 @@ class LatentDiffusion(DDPM): lr = self.learning_rate params = list(self.model.parameters()) if self.cond_stage_trainable: - print(f"{self.__class__.__name__}: Also optimizing conditioner params!") + rank_zero_info(f"{self.__class__.__name__}: Also optimizing conditioner params!") params = params + list(self.cond_stage_model.parameters()) if self.learn_logvar: - print('Diffusion model optimizing logvar') + rank_zero_info('Diffusion model optimizing logvar') params.append(self.logvar) from colossalai.nn.optimizer import HybridAdam @@ -1381,13 +1482,8 @@ class LatentDiffusion(DDPM): assert 'target' in self.scheduler_config scheduler = instantiate_from_config(self.scheduler_config) - print("Setting up LambdaLR scheduler...") - scheduler = [ - { - 'scheduler': LambdaLR(opt, lr_lambda=scheduler.schedule), - 'interval': 'step', - 'frequency': 1 - }] + rank_zero_info("Setting up LambdaLR scheduler...") + scheduler = [{'scheduler': LambdaLR(opt, lr_lambda=scheduler.schedule), 'interval': 'step', 'frequency': 1}] return [opt], scheduler return opt @@ -1402,6 +1498,7 @@ class LatentDiffusion(DDPM): class DiffusionWrapper(pl.LightningModule): + def __init__(self, diff_model_config, conditioning_key): super().__init__() self.sequential_cross_attn = diff_model_config.pop("sequential_crossattn", False) @@ -1444,6 +1541,7 @@ class DiffusionWrapper(pl.LightningModule): class LatentUpscaleDiffusion(LatentDiffusion): + def __init__(self, *args, low_scale_config, low_scale_key="LR", noise_level_key=None, **kwargs): super().__init__(*args, **kwargs) # assumes that neither the cond_stage nor the low_scale_model contain trainable params @@ -1464,8 +1562,12 @@ class LatentUpscaleDiffusion(LatentDiffusion): if not log_mode: z, c = super().get_input(batch, k, force_c_encode=True, bs=bs) else: - z, c, x, xrec, xc = super().get_input(batch, self.first_stage_key, return_first_stage_outputs=True, - force_c_encode=True, return_original_cond=True, bs=bs) + z, c, x, xrec, xc = super().get_input(batch, + self.first_stage_key, + return_first_stage_outputs=True, + force_c_encode=True, + return_original_cond=True, + bs=bs) x_low = batch[self.low_scale_key][:bs] x_low = rearrange(x_low, 'b h w c -> b c h w') if self.use_fp16: @@ -1485,15 +1587,28 @@ class LatentUpscaleDiffusion(LatentDiffusion): return z, all_conds @torch.no_grad() - def log_images(self, batch, N=8, n_row=4, sample=True, ddim_steps=200, ddim_eta=1., return_keys=None, - plot_denoise_rows=False, plot_progressive_rows=True, plot_diffusion_rows=True, - unconditional_guidance_scale=1., unconditional_guidance_label=None, use_ema_scope=True, + def log_images(self, + batch, + N=8, + n_row=4, + sample=True, + ddim_steps=200, + ddim_eta=1., + return_keys=None, + plot_denoise_rows=False, + plot_progressive_rows=True, + plot_diffusion_rows=True, + unconditional_guidance_scale=1., + unconditional_guidance_label=None, + use_ema_scope=True, **kwargs): ema_scope = self.ema_scope if use_ema_scope else nullcontext use_ddim = ddim_steps is not None log = dict() - z, c, x, xrec, xc, x_low, x_low_rec, noise_level = self.get_input(batch, self.first_stage_key, bs=N, + z, c, x, xrec, xc, x_low, x_low_rec, noise_level = self.get_input(batch, + self.first_stage_key, + bs=N, log_mode=True) N = min(x.shape[0], N) n_row = min(x.shape[0], n_row) @@ -1528,7 +1643,7 @@ class LatentUpscaleDiffusion(LatentDiffusion): z_noisy = self.q_sample(x_start=z_start, t=t, noise=noise) diffusion_row.append(self.decode_first_stage(z_noisy)) - diffusion_row = torch.stack(diffusion_row) # n_log_step, n_row, C, H, W + diffusion_row = torch.stack(diffusion_row) # n_log_step, n_row, C, H, W diffusion_grid = rearrange(diffusion_row, 'n b c h w -> b n c h w') diffusion_grid = rearrange(diffusion_grid, 'b n c h w -> (b n) c h w') diffusion_grid = make_grid(diffusion_grid, nrow=diffusion_row.shape[0]) @@ -1537,8 +1652,11 @@ class LatentUpscaleDiffusion(LatentDiffusion): if sample: # get denoise row with ema_scope("Sampling"): - samples, z_denoise_row = self.sample_log(cond=c, batch_size=N, ddim=use_ddim, - ddim_steps=ddim_steps, eta=ddim_eta) + samples, z_denoise_row = self.sample_log(cond=c, + batch_size=N, + ddim=use_ddim, + ddim_steps=ddim_steps, + eta=ddim_eta) # samples, z_denoise_row = self.sample(cond=c, batch_size=N, return_intermediates=True) x_samples = self.decode_first_stage(samples) log["samples"] = x_samples @@ -1555,7 +1673,7 @@ class LatentUpscaleDiffusion(LatentDiffusion): if k == "c_crossattn": assert isinstance(c[k], list) and len(c[k]) == 1 uc[k] = [uc_tmp] - elif k == "c_adm": # todo: only run with text-based guidance? + elif k == "c_adm": # todo: only run with text-based guidance? assert isinstance(c[k], torch.Tensor) #uc[k] = torch.ones_like(c[k]) * self.low_scale_model.max_noise_level uc[k] = c[k] @@ -1565,11 +1683,15 @@ class LatentUpscaleDiffusion(LatentDiffusion): uc[k] = c[k] with ema_scope("Sampling with classifier-free guidance"): - samples_cfg, _ = self.sample_log(cond=c, batch_size=N, ddim=use_ddim, - ddim_steps=ddim_steps, eta=ddim_eta, - unconditional_guidance_scale=unconditional_guidance_scale, - unconditional_conditioning=uc, - ) + samples_cfg, _ = self.sample_log( + cond=c, + batch_size=N, + ddim=use_ddim, + ddim_steps=ddim_steps, + eta=ddim_eta, + unconditional_guidance_scale=unconditional_guidance_scale, + unconditional_conditioning=uc, + ) x_samples_cfg = self.decode_first_stage(samples_cfg) log[f"samples_cfg_scale_{unconditional_guidance_scale:.2f}"] = x_samples_cfg @@ -1590,18 +1712,18 @@ class LatentFinetuneDiffusion(LatentDiffusion): To disable finetuning mode, set finetune_keys to None """ - def __init__(self, - concat_keys: tuple, - finetune_keys=("model.diffusion_model.input_blocks.0.0.weight", - "model_ema.diffusion_modelinput_blocks00weight" - ), - keep_finetune_dims=4, - # if model was trained without concat mode before and we would like to keep these channels - c_concat_log_start=None, # to log reconstruction of c_concat codes - c_concat_log_end=None, - *args, **kwargs - ): - ckpt_path = kwargs.pop("ckpt_path", None) + def __init__( + self, + concat_keys: tuple, + finetune_keys=("model.diffusion_model.input_blocks.0.0.weight", + "model_ema.diffusion_modelinput_blocks00weight"), + keep_finetune_dims=4, + # if model was trained without concat mode before and we would like to keep these channels + c_concat_log_start=None, # to log reconstruction of c_concat codes + c_concat_log_end=None, + *args, + **kwargs): + ckpt = kwargs.pop("ckpt", None) ignore_keys = kwargs.pop("ignore_keys", list()) super().__init__(*args, **kwargs) self.finetune_keys = finetune_keys @@ -1609,9 +1731,10 @@ class LatentFinetuneDiffusion(LatentDiffusion): self.keep_dims = keep_finetune_dims self.c_concat_log_start = c_concat_log_start self.c_concat_log_end = c_concat_log_end - if exists(self.finetune_keys): assert exists(ckpt_path), 'can only finetune from a given checkpoint' - if exists(ckpt_path): - self.init_from_ckpt(ckpt_path, ignore_keys) + if exists(self.finetune_keys): + assert exists(ckpt), 'can only finetune from a given checkpoint' + if exists(ckpt): + self.init_from_ckpt(ckpt, ignore_keys) def init_from_ckpt(self, path, ignore_keys=list(), only_model=False): sd = torch.load(path, map_location="cpu") @@ -1621,7 +1744,7 @@ class LatentFinetuneDiffusion(LatentDiffusion): for k in keys: for ik in ignore_keys: if k.startswith(ik): - print("Deleting key {} from state_dict.".format(k)) + rank_zero_info("Deleting key {} from state_dict.".format(k)) del sd[k] # make it explicit, finetune by including extra input channels @@ -1629,25 +1752,38 @@ class LatentFinetuneDiffusion(LatentDiffusion): new_entry = None for name, param in self.named_parameters(): if name in self.finetune_keys: - print( - f"modifying key '{name}' and keeping its original {self.keep_dims} (channels) dimensions only") - new_entry = torch.zeros_like(param) # zero init + rank_zero_info( + f"modifying key '{name}' and keeping its original {self.keep_dims} (channels) dimensions only" + ) + new_entry = torch.zeros_like(param) # zero init assert exists(new_entry), 'did not find matching parameter to modify' new_entry[:, :self.keep_dims, ...] = sd[k] sd[k] = new_entry missing, unexpected = self.load_state_dict(sd, strict=False) if not only_model else self.model.load_state_dict( sd, strict=False) - print(f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys") + rank_zero_info(f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys") if len(missing) > 0: - print(f"Missing Keys: {missing}") + rank_zero_info(f"Missing Keys: {missing}") if len(unexpected) > 0: - print(f"Unexpected Keys: {unexpected}") + rank_zero_info(f"Unexpected Keys: {unexpected}") @torch.no_grad() - def log_images(self, batch, N=8, n_row=4, sample=True, ddim_steps=200, ddim_eta=1., return_keys=None, - quantize_denoised=True, inpaint=True, plot_denoise_rows=False, plot_progressive_rows=True, - plot_diffusion_rows=True, unconditional_guidance_scale=1., unconditional_guidance_label=None, + def log_images(self, + batch, + N=8, + n_row=4, + sample=True, + ddim_steps=200, + ddim_eta=1., + return_keys=None, + quantize_denoised=True, + inpaint=True, + plot_denoise_rows=False, + plot_progressive_rows=True, + plot_diffusion_rows=True, + unconditional_guidance_scale=1., + unconditional_guidance_label=None, use_ema_scope=True, **kwargs): ema_scope = self.ema_scope if use_ema_scope else nullcontext @@ -1690,7 +1826,7 @@ class LatentFinetuneDiffusion(LatentDiffusion): z_noisy = self.q_sample(x_start=z_start, t=t, noise=noise) diffusion_row.append(self.decode_first_stage(z_noisy)) - diffusion_row = torch.stack(diffusion_row) # n_log_step, n_row, C, H, W + diffusion_row = torch.stack(diffusion_row) # n_log_step, n_row, C, H, W diffusion_grid = rearrange(diffusion_row, 'n b c h w -> b n c h w') diffusion_grid = rearrange(diffusion_grid, 'b n c h w -> (b n) c h w') diffusion_grid = make_grid(diffusion_grid, nrow=diffusion_row.shape[0]) @@ -1699,9 +1835,14 @@ class LatentFinetuneDiffusion(LatentDiffusion): if sample: # get denoise row with ema_scope("Sampling"): - samples, z_denoise_row = self.sample_log(cond={"c_concat": [c_cat], "c_crossattn": [c]}, - batch_size=N, ddim=use_ddim, - ddim_steps=ddim_steps, eta=ddim_eta) + samples, z_denoise_row = self.sample_log(cond={ + "c_concat": [c_cat], + "c_crossattn": [c] + }, + batch_size=N, + ddim=use_ddim, + ddim_steps=ddim_steps, + eta=ddim_eta) # samples, z_denoise_row = self.sample(cond=c, batch_size=N, return_intermediates=True) x_samples = self.decode_first_stage(samples) log["samples"] = x_samples @@ -1714,12 +1855,18 @@ class LatentFinetuneDiffusion(LatentDiffusion): uc_cat = c_cat uc_full = {"c_concat": [uc_cat], "c_crossattn": [uc_cross]} with ema_scope("Sampling with classifier-free guidance"): - samples_cfg, _ = self.sample_log(cond={"c_concat": [c_cat], "c_crossattn": [c]}, - batch_size=N, ddim=use_ddim, - ddim_steps=ddim_steps, eta=ddim_eta, - unconditional_guidance_scale=unconditional_guidance_scale, - unconditional_conditioning=uc_full, - ) + samples_cfg, _ = self.sample_log( + cond={ + "c_concat": [c_cat], + "c_crossattn": [c] + }, + batch_size=N, + ddim=use_ddim, + ddim_steps=ddim_steps, + eta=ddim_eta, + unconditional_guidance_scale=unconditional_guidance_scale, + unconditional_conditioning=uc_full, + ) x_samples_cfg = self.decode_first_stage(samples_cfg) log[f"samples_cfg_scale_{unconditional_guidance_scale:.2f}"] = x_samples_cfg @@ -1733,11 +1880,7 @@ class LatentInpaintDiffusion(LatentFinetuneDiffusion): To disable finetuning mode, set finetune_keys to None """ - def __init__(self, - concat_keys=("mask", "masked_image"), - masked_image_key="masked_image", - *args, **kwargs - ): + def __init__(self, concat_keys=("mask", "masked_image"), masked_image_key="masked_image", *args, **kwargs): super().__init__(concat_keys, *args, **kwargs) self.masked_image_key = masked_image_key assert self.masked_image_key in concat_keys @@ -1746,8 +1889,12 @@ class LatentInpaintDiffusion(LatentFinetuneDiffusion): def get_input(self, batch, k, cond_key=None, bs=None, return_first_stage_outputs=False): # note: restricted to non-trainable encoders currently assert not self.cond_stage_trainable, 'trainable cond stages not yet supported for inpainting' - z, c, x, xrec, xc = super().get_input(batch, self.first_stage_key, return_first_stage_outputs=True, - force_c_encode=True, return_original_cond=True, bs=bs) + z, c, x, xrec, xc = super().get_input(batch, + self.first_stage_key, + return_first_stage_outputs=True, + force_c_encode=True, + return_original_cond=True, + bs=bs) assert exists(self.concat_keys) c_cat = list() @@ -1793,8 +1940,12 @@ class LatentDepth2ImageDiffusion(LatentFinetuneDiffusion): def get_input(self, batch, k, cond_key=None, bs=None, return_first_stage_outputs=False): # note: restricted to non-trainable encoders currently assert not self.cond_stage_trainable, 'trainable cond stages not yet supported for depth2img' - z, c, x, xrec, xc = super().get_input(batch, self.first_stage_key, return_first_stage_outputs=True, - force_c_encode=True, return_original_cond=True, bs=bs) + z, c, x, xrec, xc = super().get_input(batch, + self.first_stage_key, + return_first_stage_outputs=True, + force_c_encode=True, + return_original_cond=True, + bs=bs) assert exists(self.concat_keys) assert len(self.concat_keys) == 1 @@ -1812,7 +1963,8 @@ class LatentDepth2ImageDiffusion(LatentFinetuneDiffusion): align_corners=False, ) - depth_min, depth_max = torch.amin(cc, dim=[1, 2, 3], keepdim=True), torch.amax(cc, dim=[1, 2, 3], + depth_min, depth_max = torch.amin(cc, dim=[1, 2, 3], keepdim=True), torch.amax(cc, + dim=[1, 2, 3], keepdim=True) cc = 2. * (cc - depth_min) / (depth_max - depth_min + 0.001) - 1. c_cat.append(cc) @@ -1836,13 +1988,19 @@ class LatentUpscaleFinetuneDiffusion(LatentFinetuneDiffusion): """ condition on low-res image (and optionally on some spatial noise augmentation) """ - def __init__(self, concat_keys=("lr",), reshuffle_patch_size=None, - low_scale_config=None, low_scale_key=None, *args, **kwargs): + + def __init__(self, + concat_keys=("lr",), + reshuffle_patch_size=None, + low_scale_config=None, + low_scale_key=None, + *args, + **kwargs): super().__init__(concat_keys=concat_keys, *args, **kwargs) self.reshuffle_patch_size = reshuffle_patch_size self.low_scale_model = None if low_scale_config is not None: - print("Initializing a low-scale model") + rank_zero_info("Initializing a low-scale model") assert exists(low_scale_key) self.instantiate_low_stage(low_scale_config) self.low_scale_key = low_scale_key @@ -1858,8 +2016,12 @@ class LatentUpscaleFinetuneDiffusion(LatentFinetuneDiffusion): def get_input(self, batch, k, cond_key=None, bs=None, return_first_stage_outputs=False): # note: restricted to non-trainable encoders currently assert not self.cond_stage_trainable, 'trainable cond stages not yet supported for upscaling-ft' - z, c, x, xrec, xc = super().get_input(batch, self.first_stage_key, return_first_stage_outputs=True, - force_c_encode=True, return_original_cond=True, bs=bs) + z, c, x, xrec, xc = super().get_input(batch, + self.first_stage_key, + return_first_stage_outputs=True, + force_c_encode=True, + return_original_cond=True, + bs=bs) assert exists(self.concat_keys) assert len(self.concat_keys) == 1 @@ -1871,8 +2033,10 @@ class LatentUpscaleFinetuneDiffusion(LatentFinetuneDiffusion): cc = rearrange(cc, 'b h w c -> b c h w') if exists(self.reshuffle_patch_size): assert isinstance(self.reshuffle_patch_size, int) - cc = rearrange(cc, 'b c (p1 h) (p2 w) -> b (p1 p2 c) h w', - p1=self.reshuffle_patch_size, p2=self.reshuffle_patch_size) + cc = rearrange(cc, + 'b c (p1 h) (p2 w) -> b (p1 p2 c) h w', + p1=self.reshuffle_patch_size, + p2=self.reshuffle_patch_size) if bs is not None: cc = cc[:bs] cc = cc.to(self.device) diff --git a/examples/images/diffusion/ldm/modules/diffusionmodules/model.py b/examples/images/diffusion/ldm/modules/diffusionmodules/model.py index 57b9a4b80..fb088db58 100644 --- a/examples/images/diffusion/ldm/modules/diffusionmodules/model.py +++ b/examples/images/diffusion/ldm/modules/diffusionmodules/model.py @@ -1,10 +1,11 @@ # pytorch_diffusion + derived encoder decoder import math +from typing import Any, Optional + +import numpy as np import torch import torch.nn as nn -import numpy as np from einops import rearrange -from typing import Optional, Any try: from lightning.pytorch.utilities import rank_zero_info @@ -38,14 +39,14 @@ def get_timestep_embedding(timesteps, embedding_dim): emb = emb.to(device=timesteps.device) emb = timesteps.float()[:, None] * emb[None, :] emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1) - if embedding_dim % 2 == 1: # zero pad - emb = torch.nn.functional.pad(emb, (0,1,0,0)) + if embedding_dim % 2 == 1: # zero pad + emb = torch.nn.functional.pad(emb, (0, 1, 0, 0)) return emb def nonlinearity(x): # swish - return x*torch.sigmoid(x) + return x * torch.sigmoid(x) def Normalize(in_channels, num_groups=32): @@ -53,15 +54,12 @@ def Normalize(in_channels, num_groups=32): class Upsample(nn.Module): + def __init__(self, in_channels, with_conv): super().__init__() self.with_conv = with_conv if self.with_conv: - self.conv = torch.nn.Conv2d(in_channels, - in_channels, - kernel_size=3, - stride=1, - padding=1) + self.conv = torch.nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1) def forward(self, x): x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest") @@ -71,20 +69,17 @@ class Upsample(nn.Module): class Downsample(nn.Module): + def __init__(self, in_channels, with_conv): super().__init__() self.with_conv = with_conv if self.with_conv: # no asymmetric padding in torch conv, must do it ourselves - self.conv = torch.nn.Conv2d(in_channels, - in_channels, - kernel_size=3, - stride=2, - padding=0) + self.conv = torch.nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=2, padding=0) def forward(self, x): if self.with_conv: - pad = (0,1,0,1) + pad = (0, 1, 0, 1) x = torch.nn.functional.pad(x, pad, mode="constant", value=0) x = self.conv(x) else: @@ -93,8 +88,8 @@ class Downsample(nn.Module): class ResnetBlock(nn.Module): - def __init__(self, *, in_channels, out_channels=None, conv_shortcut=False, - dropout, temb_channels=512): + + def __init__(self, *, in_channels, out_channels=None, conv_shortcut=False, dropout, temb_channels=512): super().__init__() self.in_channels = in_channels out_channels = in_channels if out_channels is None else out_channels @@ -102,34 +97,17 @@ class ResnetBlock(nn.Module): self.use_conv_shortcut = conv_shortcut self.norm1 = Normalize(in_channels) - self.conv1 = torch.nn.Conv2d(in_channels, - out_channels, - kernel_size=3, - stride=1, - padding=1) + self.conv1 = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1) if temb_channels > 0: - self.temb_proj = torch.nn.Linear(temb_channels, - out_channels) + self.temb_proj = torch.nn.Linear(temb_channels, out_channels) self.norm2 = Normalize(out_channels) self.dropout = torch.nn.Dropout(dropout) - self.conv2 = torch.nn.Conv2d(out_channels, - out_channels, - kernel_size=3, - stride=1, - padding=1) + self.conv2 = torch.nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1) if self.in_channels != self.out_channels: if self.use_conv_shortcut: - self.conv_shortcut = torch.nn.Conv2d(in_channels, - out_channels, - kernel_size=3, - stride=1, - padding=1) + self.conv_shortcut = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1) else: - self.nin_shortcut = torch.nn.Conv2d(in_channels, - out_channels, - kernel_size=1, - stride=1, - padding=0) + self.nin_shortcut = torch.nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0) def forward(self, x, temb): h = x @@ -138,7 +116,7 @@ class ResnetBlock(nn.Module): h = self.conv1(h) if temb is not None: - h = h + self.temb_proj(nonlinearity(temb))[:,:,None,None] + h = h + self.temb_proj(nonlinearity(temb))[:, :, None, None] h = self.norm2(h) h = nonlinearity(h) @@ -151,35 +129,20 @@ class ResnetBlock(nn.Module): else: x = self.nin_shortcut(x) - return x+h + return x + h class AttnBlock(nn.Module): + def __init__(self, in_channels): super().__init__() self.in_channels = in_channels self.norm = Normalize(in_channels) - self.q = torch.nn.Conv2d(in_channels, - in_channels, - kernel_size=1, - stride=1, - padding=0) - self.k = torch.nn.Conv2d(in_channels, - in_channels, - kernel_size=1, - stride=1, - padding=0) - self.v = torch.nn.Conv2d(in_channels, - in_channels, - kernel_size=1, - stride=1, - padding=0) - self.proj_out = torch.nn.Conv2d(in_channels, - in_channels, - kernel_size=1, - stride=1, - padding=0) + self.q = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) + self.k = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) + self.v = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) + self.proj_out = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) def forward(self, x): h_ = x @@ -189,23 +152,24 @@ class AttnBlock(nn.Module): v = self.v(h_) # compute attention - b,c,h,w = q.shape - q = q.reshape(b,c,h*w) - q = q.permute(0,2,1) # b,hw,c - k = k.reshape(b,c,h*w) # b,c,hw - w_ = torch.bmm(q,k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j] + b, c, h, w = q.shape + q = q.reshape(b, c, h * w) + q = q.permute(0, 2, 1) # b,hw,c + k = k.reshape(b, c, h * w) # b,c,hw + w_ = torch.bmm(q, k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j] w_ = w_ * (int(c)**(-0.5)) w_ = torch.nn.functional.softmax(w_, dim=2) # attend to values - v = v.reshape(b,c,h*w) - w_ = w_.permute(0,2,1) # b,hw,hw (first hw of k, second of q) - h_ = torch.bmm(v,w_) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j] - h_ = h_.reshape(b,c,h,w) + v = v.reshape(b, c, h * w) + w_ = w_.permute(0, 2, 1) # b,hw,hw (first hw of k, second of q) + h_ = torch.bmm(v, w_) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j] + h_ = h_.reshape(b, c, h, w) h_ = self.proj_out(h_) - return x+h_ + return x + h_ + class MemoryEfficientAttnBlock(nn.Module): """ @@ -213,32 +177,17 @@ class MemoryEfficientAttnBlock(nn.Module): see https://github.com/MatthieuTPHR/diffusers/blob/d80b531ff8060ec1ea982b65a1b8df70f73aa67c/src/diffusers/models/attention.py#L223 Note: this is a single-head self-attention operation """ + # def __init__(self, in_channels): super().__init__() self.in_channels = in_channels self.norm = Normalize(in_channels) - self.q = torch.nn.Conv2d(in_channels, - in_channels, - kernel_size=1, - stride=1, - padding=0) - self.k = torch.nn.Conv2d(in_channels, - in_channels, - kernel_size=1, - stride=1, - padding=0) - self.v = torch.nn.Conv2d(in_channels, - in_channels, - kernel_size=1, - stride=1, - padding=0) - self.proj_out = torch.nn.Conv2d(in_channels, - in_channels, - kernel_size=1, - stride=1, - padding=0) + self.q = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) + self.k = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) + self.v = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) + self.proj_out = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) self.attention_op: Optional[Any] = None def forward(self, x): @@ -253,27 +202,20 @@ class MemoryEfficientAttnBlock(nn.Module): q, k, v = map(lambda x: rearrange(x, 'b c h w -> b (h w) c'), (q, k, v)) q, k, v = map( - lambda t: t.unsqueeze(3) - .reshape(B, t.shape[1], 1, C) - .permute(0, 2, 1, 3) - .reshape(B * 1, t.shape[1], C) - .contiguous(), + lambda t: t.unsqueeze(3).reshape(B, t.shape[1], 1, C).permute(0, 2, 1, 3).reshape(B * 1, t.shape[1], C). + contiguous(), (q, k, v), ) out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None, op=self.attention_op) - out = ( - out.unsqueeze(0) - .reshape(B, 1, out.shape[1], C) - .permute(0, 2, 1, 3) - .reshape(B, out.shape[1], C) - ) + out = (out.unsqueeze(0).reshape(B, 1, out.shape[1], C).permute(0, 2, 1, 3).reshape(B, out.shape[1], C)) out = rearrange(out, 'b (h w) c -> b c h w', b=B, h=H, w=W, c=C) out = self.proj_out(out) - return x+out + return x + out class MemoryEfficientCrossAttentionWrapper(MemoryEfficientCrossAttention): + def forward(self, x, context=None, mask=None): b, c, h, w = x.shape x = rearrange(x, 'b c h w -> b (h w) c') @@ -283,10 +225,10 @@ class MemoryEfficientCrossAttentionWrapper(MemoryEfficientCrossAttention): def make_attn(in_channels, attn_type="vanilla", attn_kwargs=None): - assert attn_type in ["vanilla", "vanilla-xformers", "memory-efficient-cross-attn", "linear", "none"], f'attn_type {attn_type} unknown' + assert attn_type in ["vanilla", "vanilla-xformers", "memory-efficient-cross-attn", "linear", + "none"], f'attn_type {attn_type} unknown' if XFORMERS_IS_AVAILBLE and attn_type == "vanilla": attn_type = "vanilla-xformers" - rank_zero_info(f"making attention of type '{attn_type}' with {in_channels} in_channels") if attn_type == "vanilla": assert attn_kwargs is None return AttnBlock(in_channels) @@ -303,13 +245,26 @@ def make_attn(in_channels, attn_type="vanilla", attn_kwargs=None): class Model(nn.Module): - def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks, - attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels, - resolution, use_timestep=True, use_linear_attn=False, attn_type="vanilla"): + + def __init__(self, + *, + ch, + out_ch, + ch_mult=(1, 2, 4, 8), + num_res_blocks, + attn_resolutions, + dropout=0.0, + resamp_with_conv=True, + in_channels, + resolution, + use_timestep=True, + use_linear_attn=False, + attn_type="vanilla"): super().__init__() - if use_linear_attn: attn_type = "linear" + if use_linear_attn: + attn_type = "linear" self.ch = ch - self.temb_ch = self.ch*4 + self.temb_ch = self.ch * 4 self.num_resolutions = len(ch_mult) self.num_res_blocks = num_res_blocks self.resolution = resolution @@ -320,39 +275,34 @@ class Model(nn.Module): # timestep embedding self.temb = nn.Module() self.temb.dense = nn.ModuleList([ - torch.nn.Linear(self.ch, - self.temb_ch), - torch.nn.Linear(self.temb_ch, - self.temb_ch), + torch.nn.Linear(self.ch, self.temb_ch), + torch.nn.Linear(self.temb_ch, self.temb_ch), ]) # downsampling - self.conv_in = torch.nn.Conv2d(in_channels, - self.ch, - kernel_size=3, - stride=1, - padding=1) + self.conv_in = torch.nn.Conv2d(in_channels, self.ch, kernel_size=3, stride=1, padding=1) curr_res = resolution - in_ch_mult = (1,)+tuple(ch_mult) + in_ch_mult = (1,) + tuple(ch_mult) self.down = nn.ModuleList() for i_level in range(self.num_resolutions): block = nn.ModuleList() attn = nn.ModuleList() - block_in = ch*in_ch_mult[i_level] - block_out = ch*ch_mult[i_level] + block_in = ch * in_ch_mult[i_level] + block_out = ch * ch_mult[i_level] for i_block in range(self.num_res_blocks): - block.append(ResnetBlock(in_channels=block_in, - out_channels=block_out, - temb_channels=self.temb_ch, - dropout=dropout)) + block.append( + ResnetBlock(in_channels=block_in, + out_channels=block_out, + temb_channels=self.temb_ch, + dropout=dropout)) block_in = block_out if curr_res in attn_resolutions: attn.append(make_attn(block_in, attn_type=attn_type)) down = nn.Module() down.block = block down.attn = attn - if i_level != self.num_resolutions-1: + if i_level != self.num_resolutions - 1: down.downsample = Downsample(block_in, resamp_with_conv) curr_res = curr_res // 2 self.down.append(down) @@ -374,15 +324,16 @@ class Model(nn.Module): for i_level in reversed(range(self.num_resolutions)): block = nn.ModuleList() attn = nn.ModuleList() - block_out = ch*ch_mult[i_level] - skip_in = ch*ch_mult[i_level] - for i_block in range(self.num_res_blocks+1): + block_out = ch * ch_mult[i_level] + skip_in = ch * ch_mult[i_level] + for i_block in range(self.num_res_blocks + 1): if i_block == self.num_res_blocks: - skip_in = ch*in_ch_mult[i_level] - block.append(ResnetBlock(in_channels=block_in+skip_in, - out_channels=block_out, - temb_channels=self.temb_ch, - dropout=dropout)) + skip_in = ch * in_ch_mult[i_level] + block.append( + ResnetBlock(in_channels=block_in + skip_in, + out_channels=block_out, + temb_channels=self.temb_ch, + dropout=dropout)) block_in = block_out if curr_res in attn_resolutions: attn.append(make_attn(block_in, attn_type=attn_type)) @@ -392,15 +343,11 @@ class Model(nn.Module): if i_level != 0: up.upsample = Upsample(block_in, resamp_with_conv) curr_res = curr_res * 2 - self.up.insert(0, up) # prepend to get consistent order + self.up.insert(0, up) # prepend to get consistent order # end self.norm_out = Normalize(block_in) - self.conv_out = torch.nn.Conv2d(block_in, - out_ch, - kernel_size=3, - stride=1, - padding=1) + self.conv_out = torch.nn.Conv2d(block_in, out_ch, kernel_size=3, stride=1, padding=1) def forward(self, x, t=None, context=None): #assert x.shape[2] == x.shape[3] == self.resolution @@ -425,7 +372,7 @@ class Model(nn.Module): if len(self.down[i_level].attn) > 0: h = self.down[i_level].attn[i_block](h) hs.append(h) - if i_level != self.num_resolutions-1: + if i_level != self.num_resolutions - 1: hs.append(self.down[i_level].downsample(hs[-1])) # middle @@ -436,9 +383,8 @@ class Model(nn.Module): # upsampling for i_level in reversed(range(self.num_resolutions)): - for i_block in range(self.num_res_blocks+1): - h = self.up[i_level].block[i_block]( - torch.cat([h, hs.pop()], dim=1), temb) + for i_block in range(self.num_res_blocks + 1): + h = self.up[i_level].block[i_block](torch.cat([h, hs.pop()], dim=1), temb) if len(self.up[i_level].attn) > 0: h = self.up[i_level].attn[i_block](h) if i_level != 0: @@ -455,12 +401,26 @@ class Model(nn.Module): class Encoder(nn.Module): - def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks, - attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels, - resolution, z_channels, double_z=True, use_linear_attn=False, attn_type="vanilla", + + def __init__(self, + *, + ch, + out_ch, + ch_mult=(1, 2, 4, 8), + num_res_blocks, + attn_resolutions, + dropout=0.0, + resamp_with_conv=True, + in_channels, + resolution, + z_channels, + double_z=True, + use_linear_attn=False, + attn_type="vanilla", **ignore_kwargs): super().__init__() - if use_linear_attn: attn_type = "linear" + if use_linear_attn: + attn_type = "linear" self.ch = ch self.temb_ch = 0 self.num_resolutions = len(ch_mult) @@ -469,33 +429,30 @@ class Encoder(nn.Module): self.in_channels = in_channels # downsampling - self.conv_in = torch.nn.Conv2d(in_channels, - self.ch, - kernel_size=3, - stride=1, - padding=1) + self.conv_in = torch.nn.Conv2d(in_channels, self.ch, kernel_size=3, stride=1, padding=1) curr_res = resolution - in_ch_mult = (1,)+tuple(ch_mult) + in_ch_mult = (1,) + tuple(ch_mult) self.in_ch_mult = in_ch_mult self.down = nn.ModuleList() for i_level in range(self.num_resolutions): block = nn.ModuleList() attn = nn.ModuleList() - block_in = ch*in_ch_mult[i_level] - block_out = ch*ch_mult[i_level] + block_in = ch * in_ch_mult[i_level] + block_out = ch * ch_mult[i_level] for i_block in range(self.num_res_blocks): - block.append(ResnetBlock(in_channels=block_in, - out_channels=block_out, - temb_channels=self.temb_ch, - dropout=dropout)) + block.append( + ResnetBlock(in_channels=block_in, + out_channels=block_out, + temb_channels=self.temb_ch, + dropout=dropout)) block_in = block_out if curr_res in attn_resolutions: attn.append(make_attn(block_in, attn_type=attn_type)) down = nn.Module() down.block = block down.attn = attn - if i_level != self.num_resolutions-1: + if i_level != self.num_resolutions - 1: down.downsample = Downsample(block_in, resamp_with_conv) curr_res = curr_res // 2 self.down.append(down) @@ -515,7 +472,7 @@ class Encoder(nn.Module): # end self.norm_out = Normalize(block_in) self.conv_out = torch.nn.Conv2d(block_in, - 2*z_channels if double_z else z_channels, + 2 * z_channels if double_z else z_channels, kernel_size=3, stride=1, padding=1) @@ -532,7 +489,7 @@ class Encoder(nn.Module): if len(self.down[i_level].attn) > 0: h = self.down[i_level].attn[i_block](h) hs.append(h) - if i_level != self.num_resolutions-1: + if i_level != self.num_resolutions - 1: hs.append(self.down[i_level].downsample(hs[-1])) # middle @@ -549,12 +506,27 @@ class Encoder(nn.Module): class Decoder(nn.Module): - def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks, - attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels, - resolution, z_channels, give_pre_end=False, tanh_out=False, use_linear_attn=False, - attn_type="vanilla", **ignorekwargs): + + def __init__(self, + *, + ch, + out_ch, + ch_mult=(1, 2, 4, 8), + num_res_blocks, + attn_resolutions, + dropout=0.0, + resamp_with_conv=True, + in_channels, + resolution, + z_channels, + give_pre_end=False, + tanh_out=False, + use_linear_attn=False, + attn_type="vanilla", + **ignorekwargs): super().__init__() - if use_linear_attn: attn_type = "linear" + if use_linear_attn: + attn_type = "linear" self.ch = ch self.temb_ch = 0 self.num_resolutions = len(ch_mult) @@ -565,19 +537,14 @@ class Decoder(nn.Module): self.tanh_out = tanh_out # compute in_ch_mult, block_in and curr_res at lowest res - in_ch_mult = (1,)+tuple(ch_mult) - block_in = ch*ch_mult[self.num_resolutions-1] - curr_res = resolution // 2**(self.num_resolutions-1) - self.z_shape = (1,z_channels,curr_res,curr_res) - rank_zero_info("Working with z of shape {} = {} dimensions.".format( - self.z_shape, np.prod(self.z_shape))) + in_ch_mult = (1,) + tuple(ch_mult) + block_in = ch * ch_mult[self.num_resolutions - 1] + curr_res = resolution // 2**(self.num_resolutions - 1) + self.z_shape = (1, z_channels, curr_res, curr_res) + rank_zero_info("Working with z of shape {} = {} dimensions.".format(self.z_shape, np.prod(self.z_shape))) # z to block_in - self.conv_in = torch.nn.Conv2d(z_channels, - block_in, - kernel_size=3, - stride=1, - padding=1) + self.conv_in = torch.nn.Conv2d(z_channels, block_in, kernel_size=3, stride=1, padding=1) # middle self.mid = nn.Module() @@ -596,12 +563,13 @@ class Decoder(nn.Module): for i_level in reversed(range(self.num_resolutions)): block = nn.ModuleList() attn = nn.ModuleList() - block_out = ch*ch_mult[i_level] - for i_block in range(self.num_res_blocks+1): - block.append(ResnetBlock(in_channels=block_in, - out_channels=block_out, - temb_channels=self.temb_ch, - dropout=dropout)) + block_out = ch * ch_mult[i_level] + for i_block in range(self.num_res_blocks + 1): + block.append( + ResnetBlock(in_channels=block_in, + out_channels=block_out, + temb_channels=self.temb_ch, + dropout=dropout)) block_in = block_out if curr_res in attn_resolutions: attn.append(make_attn(block_in, attn_type=attn_type)) @@ -611,15 +579,11 @@ class Decoder(nn.Module): if i_level != 0: up.upsample = Upsample(block_in, resamp_with_conv) curr_res = curr_res * 2 - self.up.insert(0, up) # prepend to get consistent order + self.up.insert(0, up) # prepend to get consistent order # end self.norm_out = Normalize(block_in) - self.conv_out = torch.nn.Conv2d(block_in, - out_ch, - kernel_size=3, - stride=1, - padding=1) + self.conv_out = torch.nn.Conv2d(block_in, out_ch, kernel_size=3, stride=1, padding=1) def forward(self, z): #assert z.shape[1:] == self.z_shape[1:] @@ -638,7 +602,7 @@ class Decoder(nn.Module): # upsampling for i_level in reversed(range(self.num_resolutions)): - for i_block in range(self.num_res_blocks+1): + for i_block in range(self.num_res_blocks + 1): h = self.up[i_level].block[i_block](h, temb) if len(self.up[i_level].attn) > 0: h = self.up[i_level].attn[i_block](h) @@ -658,31 +622,24 @@ class Decoder(nn.Module): class SimpleDecoder(nn.Module): + def __init__(self, in_channels, out_channels, *args, **kwargs): super().__init__() - self.model = nn.ModuleList([nn.Conv2d(in_channels, in_channels, 1), - ResnetBlock(in_channels=in_channels, - out_channels=2 * in_channels, - temb_channels=0, dropout=0.0), - ResnetBlock(in_channels=2 * in_channels, - out_channels=4 * in_channels, - temb_channels=0, dropout=0.0), - ResnetBlock(in_channels=4 * in_channels, - out_channels=2 * in_channels, - temb_channels=0, dropout=0.0), - nn.Conv2d(2*in_channels, in_channels, 1), - Upsample(in_channels, with_conv=True)]) + self.model = nn.ModuleList([ + nn.Conv2d(in_channels, in_channels, 1), + ResnetBlock(in_channels=in_channels, out_channels=2 * in_channels, temb_channels=0, dropout=0.0), + ResnetBlock(in_channels=2 * in_channels, out_channels=4 * in_channels, temb_channels=0, dropout=0.0), + ResnetBlock(in_channels=4 * in_channels, out_channels=2 * in_channels, temb_channels=0, dropout=0.0), + nn.Conv2d(2 * in_channels, in_channels, 1), + Upsample(in_channels, with_conv=True) + ]) # end self.norm_out = Normalize(in_channels) - self.conv_out = torch.nn.Conv2d(in_channels, - out_channels, - kernel_size=3, - stride=1, - padding=1) + self.conv_out = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1) def forward(self, x): for i, layer in enumerate(self.model): - if i in [1,2,3]: + if i in [1, 2, 3]: x = layer(x, None) else: x = layer(x) @@ -694,25 +651,26 @@ class SimpleDecoder(nn.Module): class UpsampleDecoder(nn.Module): - def __init__(self, in_channels, out_channels, ch, num_res_blocks, resolution, - ch_mult=(2,2), dropout=0.0): + + def __init__(self, in_channels, out_channels, ch, num_res_blocks, resolution, ch_mult=(2, 2), dropout=0.0): super().__init__() # upsampling self.temb_ch = 0 self.num_resolutions = len(ch_mult) self.num_res_blocks = num_res_blocks block_in = in_channels - curr_res = resolution // 2 ** (self.num_resolutions - 1) + curr_res = resolution // 2**(self.num_resolutions - 1) self.res_blocks = nn.ModuleList() self.upsample_blocks = nn.ModuleList() for i_level in range(self.num_resolutions): res_block = [] block_out = ch * ch_mult[i_level] for i_block in range(self.num_res_blocks + 1): - res_block.append(ResnetBlock(in_channels=block_in, - out_channels=block_out, - temb_channels=self.temb_ch, - dropout=dropout)) + res_block.append( + ResnetBlock(in_channels=block_in, + out_channels=block_out, + temb_channels=self.temb_ch, + dropout=dropout)) block_in = block_out self.res_blocks.append(nn.ModuleList(res_block)) if i_level != self.num_resolutions - 1: @@ -721,11 +679,7 @@ class UpsampleDecoder(nn.Module): # end self.norm_out = Normalize(block_in) - self.conv_out = torch.nn.Conv2d(block_in, - out_channels, - kernel_size=3, - stride=1, - padding=1) + self.conv_out = torch.nn.Conv2d(block_in, out_channels, kernel_size=3, stride=1, padding=1) def forward(self, x): # upsampling @@ -742,35 +696,35 @@ class UpsampleDecoder(nn.Module): class LatentRescaler(nn.Module): + def __init__(self, factor, in_channels, mid_channels, out_channels, depth=2): super().__init__() # residual block, interpolate, residual block self.factor = factor - self.conv_in = nn.Conv2d(in_channels, - mid_channels, - kernel_size=3, - stride=1, - padding=1) - self.res_block1 = nn.ModuleList([ResnetBlock(in_channels=mid_channels, - out_channels=mid_channels, - temb_channels=0, - dropout=0.0) for _ in range(depth)]) + self.conv_in = nn.Conv2d(in_channels, mid_channels, kernel_size=3, stride=1, padding=1) + self.res_block1 = nn.ModuleList([ + ResnetBlock(in_channels=mid_channels, out_channels=mid_channels, temb_channels=0, dropout=0.0) + for _ in range(depth) + ]) self.attn = AttnBlock(mid_channels) - self.res_block2 = nn.ModuleList([ResnetBlock(in_channels=mid_channels, - out_channels=mid_channels, - temb_channels=0, - dropout=0.0) for _ in range(depth)]) + self.res_block2 = nn.ModuleList([ + ResnetBlock(in_channels=mid_channels, out_channels=mid_channels, temb_channels=0, dropout=0.0) + for _ in range(depth) + ]) - self.conv_out = nn.Conv2d(mid_channels, - out_channels, - kernel_size=1, - ) + self.conv_out = nn.Conv2d( + mid_channels, + out_channels, + kernel_size=1, + ) def forward(self, x): x = self.conv_in(x) for block in self.res_block1: x = block(x, None) - x = torch.nn.functional.interpolate(x, size=(int(round(x.shape[2]*self.factor)), int(round(x.shape[3]*self.factor)))) + x = torch.nn.functional.interpolate(x, + size=(int(round(x.shape[2] * self.factor)), + int(round(x.shape[3] * self.factor)))) x = self.attn(x) for block in self.res_block2: x = block(x, None) @@ -779,17 +733,37 @@ class LatentRescaler(nn.Module): class MergedRescaleEncoder(nn.Module): - def __init__(self, in_channels, ch, resolution, out_ch, num_res_blocks, - attn_resolutions, dropout=0.0, resamp_with_conv=True, - ch_mult=(1,2,4,8), rescale_factor=1.0, rescale_module_depth=1): + + def __init__(self, + in_channels, + ch, + resolution, + out_ch, + num_res_blocks, + attn_resolutions, + dropout=0.0, + resamp_with_conv=True, + ch_mult=(1, 2, 4, 8), + rescale_factor=1.0, + rescale_module_depth=1): super().__init__() intermediate_chn = ch * ch_mult[-1] - self.encoder = Encoder(in_channels=in_channels, num_res_blocks=num_res_blocks, ch=ch, ch_mult=ch_mult, - z_channels=intermediate_chn, double_z=False, resolution=resolution, - attn_resolutions=attn_resolutions, dropout=dropout, resamp_with_conv=resamp_with_conv, + self.encoder = Encoder(in_channels=in_channels, + num_res_blocks=num_res_blocks, + ch=ch, + ch_mult=ch_mult, + z_channels=intermediate_chn, + double_z=False, + resolution=resolution, + attn_resolutions=attn_resolutions, + dropout=dropout, + resamp_with_conv=resamp_with_conv, out_ch=None) - self.rescaler = LatentRescaler(factor=rescale_factor, in_channels=intermediate_chn, - mid_channels=intermediate_chn, out_channels=out_ch, depth=rescale_module_depth) + self.rescaler = LatentRescaler(factor=rescale_factor, + in_channels=intermediate_chn, + mid_channels=intermediate_chn, + out_channels=out_ch, + depth=rescale_module_depth) def forward(self, x): x = self.encoder(x) @@ -798,15 +772,36 @@ class MergedRescaleEncoder(nn.Module): class MergedRescaleDecoder(nn.Module): - def __init__(self, z_channels, out_ch, resolution, num_res_blocks, attn_resolutions, ch, ch_mult=(1,2,4,8), - dropout=0.0, resamp_with_conv=True, rescale_factor=1.0, rescale_module_depth=1): + + def __init__(self, + z_channels, + out_ch, + resolution, + num_res_blocks, + attn_resolutions, + ch, + ch_mult=(1, 2, 4, 8), + dropout=0.0, + resamp_with_conv=True, + rescale_factor=1.0, + rescale_module_depth=1): super().__init__() - tmp_chn = z_channels*ch_mult[-1] - self.decoder = Decoder(out_ch=out_ch, z_channels=tmp_chn, attn_resolutions=attn_resolutions, dropout=dropout, - resamp_with_conv=resamp_with_conv, in_channels=None, num_res_blocks=num_res_blocks, - ch_mult=ch_mult, resolution=resolution, ch=ch) - self.rescaler = LatentRescaler(factor=rescale_factor, in_channels=z_channels, mid_channels=tmp_chn, - out_channels=tmp_chn, depth=rescale_module_depth) + tmp_chn = z_channels * ch_mult[-1] + self.decoder = Decoder(out_ch=out_ch, + z_channels=tmp_chn, + attn_resolutions=attn_resolutions, + dropout=dropout, + resamp_with_conv=resamp_with_conv, + in_channels=None, + num_res_blocks=num_res_blocks, + ch_mult=ch_mult, + resolution=resolution, + ch=ch) + self.rescaler = LatentRescaler(factor=rescale_factor, + in_channels=z_channels, + mid_channels=tmp_chn, + out_channels=tmp_chn, + depth=rescale_module_depth) def forward(self, x): x = self.rescaler(x) @@ -815,16 +810,26 @@ class MergedRescaleDecoder(nn.Module): class Upsampler(nn.Module): + def __init__(self, in_size, out_size, in_channels, out_channels, ch_mult=2): super().__init__() assert out_size >= in_size - num_blocks = int(np.log2(out_size//in_size))+1 - factor_up = 1.+ (out_size % in_size) - rank_zero_info(f"Building {self.__class__.__name__} with in_size: {in_size} --> out_size {out_size} and factor {factor_up}") - self.rescaler = LatentRescaler(factor=factor_up, in_channels=in_channels, mid_channels=2*in_channels, + num_blocks = int(np.log2(out_size // in_size)) + 1 + factor_up = 1. + (out_size % in_size) + rank_zero_info( + f"Building {self.__class__.__name__} with in_size: {in_size} --> out_size {out_size} and factor {factor_up}" + ) + self.rescaler = LatentRescaler(factor=factor_up, + in_channels=in_channels, + mid_channels=2 * in_channels, out_channels=in_channels) - self.decoder = Decoder(out_ch=out_channels, resolution=out_size, z_channels=in_channels, num_res_blocks=2, - attn_resolutions=[], in_channels=None, ch=in_channels, + self.decoder = Decoder(out_ch=out_channels, + resolution=out_size, + z_channels=in_channels, + num_res_blocks=2, + attn_resolutions=[], + in_channels=None, + ch=in_channels, ch_mult=[ch_mult for _ in range(num_blocks)]) def forward(self, x): @@ -834,23 +839,21 @@ class Upsampler(nn.Module): class Resize(nn.Module): + def __init__(self, in_channels=None, learned=False, mode="bilinear"): super().__init__() self.with_conv = learned self.mode = mode if self.with_conv: - rank_zero_info(f"Note: {self.__class__.__name} uses learned downsampling and will ignore the fixed {mode} mode") + rank_zero_info( + f"Note: {self.__class__.__name} uses learned downsampling and will ignore the fixed {mode} mode") raise NotImplementedError() assert in_channels is not None # no asymmetric padding in torch conv, must do it ourselves - self.conv = torch.nn.Conv2d(in_channels, - in_channels, - kernel_size=4, - stride=2, - padding=1) + self.conv = torch.nn.Conv2d(in_channels, in_channels, kernel_size=4, stride=2, padding=1) def forward(self, x, scale_factor=1.0): - if scale_factor==1.0: + if scale_factor == 1.0: return x else: x = torch.nn.functional.interpolate(x, mode=self.mode, align_corners=False, scale_factor=scale_factor) diff --git a/examples/images/diffusion/main.py b/examples/images/diffusion/main.py index 87d495123..5f166aa1f 100644 --- a/examples/images/diffusion/main.py +++ b/examples/images/diffusion/main.py @@ -106,7 +106,20 @@ def get_parser(**parser_kwargs): nargs="?", help="disable test", ) - parser.add_argument("-p", "--project", help="name of new or path to existing project") + parser.add_argument( + "-p", + "--project", + help="name of new or path to existing project", + ) + parser.add_argument( + "-c", + "--ckpt", + type=str, + const=True, + default="", + nargs="?", + help="load pretrained checkpoint from stable AI", + ) parser.add_argument( "-d", "--debug", @@ -145,22 +158,7 @@ def get_parser(**parser_kwargs): default=True, help="scale base-lr by ngpu * batch_size * n_accumulate", ) - parser.add_argument( - "--use_fp16", - type=str2bool, - nargs="?", - const=True, - default=True, - help="whether to use fp16", - ) - parser.add_argument( - "--flash", - type=str2bool, - const=True, - default=False, - nargs="?", - help="whether to use flash attention", - ) + return parser @@ -341,6 +339,12 @@ class SetupCallback(Callback): except FileNotFoundError: pass + # def on_fit_end(self, trainer, pl_module): + # if trainer.global_rank == 0: + # ckpt_path = os.path.join(self.ckptdir, "last.ckpt") + # rank_zero_info(f"Saving final checkpoint in {ckpt_path}.") + # trainer.save_checkpoint(ckpt_path) + class ImageLogger(Callback): @@ -536,6 +540,7 @@ if __name__ == "__main__": "If you want to resume training in a new log folder, " "use -n/--name in combination with --resume_from_checkpoint") if opt.resume: + rank_zero_info("Resuming from {}".format(opt.resume)) if not os.path.exists(opt.resume): raise ValueError("Cannot find {}".format(opt.resume)) if os.path.isfile(opt.resume): @@ -543,13 +548,13 @@ if __name__ == "__main__": # idx = len(paths)-paths[::-1].index("logs")+1 # logdir = "/".join(paths[:idx]) logdir = "/".join(paths[:-2]) + rank_zero_info("logdir: {}".format(logdir)) ckpt = opt.resume else: assert os.path.isdir(opt.resume), opt.resume logdir = opt.resume.rstrip("/") ckpt = os.path.join(logdir, "checkpoints", "last.ckpt") - opt.resume_from_checkpoint = ckpt base_configs = sorted(glob.glob(os.path.join(logdir, "configs/*.yaml"))) opt.base = base_configs + opt.base _tmp = logdir.split("/") @@ -558,6 +563,7 @@ if __name__ == "__main__": if opt.name: name = "_" + opt.name elif opt.base: + rank_zero_info("Using base config {}".format(opt.base)) cfg_fname = os.path.split(opt.base[0])[-1] cfg_name = os.path.splitext(cfg_fname)[0] name = "_" + cfg_name @@ -566,6 +572,9 @@ if __name__ == "__main__": nowname = now + name + opt.postfix logdir = os.path.join(opt.logdir, nowname) + if opt.ckpt: + ckpt = opt.ckpt + ckptdir = os.path.join(logdir, "checkpoints") cfgdir = os.path.join(logdir, "configs") seed_everything(opt.seed) @@ -582,14 +591,11 @@ if __name__ == "__main__": for k in nondefault_trainer_args(opt): trainer_config[k] = getattr(opt, k) - print(trainer_config) if not trainer_config["accelerator"] == "gpu": del trainer_config["accelerator"] cpu = True - print("Running on CPU") else: cpu = False - print("Running on GPU") trainer_opt = argparse.Namespace(**trainer_config) lightning_config.trainer = trainer_config @@ -597,10 +603,12 @@ if __name__ == "__main__": use_fp16 = trainer_config.get("precision", 32) == 16 if use_fp16: config.model["params"].update({"use_fp16": True}) - print("Using FP16 = {}".format(config.model["params"]["use_fp16"])) else: config.model["params"].update({"use_fp16": False}) - print("Using FP16 = {}".format(config.model["params"]["use_fp16"])) + + if ckpt is not None: + config.model["params"].update({"ckpt": ckpt}) + rank_zero_info("Using ckpt_path = {}".format(config.model["params"]["ckpt"])) model = instantiate_from_config(config.model) # trainer and callbacks @@ -639,7 +647,6 @@ if __name__ == "__main__": # config the strategy, defualt is ddp if "strategy" in trainer_config: strategy_cfg = trainer_config["strategy"] - print("Using strategy: {}".format(strategy_cfg["target"])) strategy_cfg["target"] = LIGHTNING_PACK_NAME + strategy_cfg["target"] else: strategy_cfg = { @@ -648,7 +655,6 @@ if __name__ == "__main__": "find_unused_parameters": False } } - print("Using strategy: DDPStrategy") trainer_kwargs["strategy"] = instantiate_from_config(strategy_cfg) @@ -664,7 +670,6 @@ if __name__ == "__main__": } } if hasattr(model, "monitor"): - print(f"Monitoring {model.monitor} as checkpoint metric.") default_modelckpt_cfg["params"]["monitor"] = model.monitor default_modelckpt_cfg["params"]["save_top_k"] = 3 @@ -673,7 +678,6 @@ if __name__ == "__main__": else: modelckpt_cfg = OmegaConf.create() modelckpt_cfg = OmegaConf.merge(default_modelckpt_cfg, modelckpt_cfg) - print(f"Merged modelckpt-cfg: \n{modelckpt_cfg}") if version.parse(pl.__version__) < version.parse('1.4.0'): trainer_kwargs["checkpoint_callback"] = instantiate_from_config(modelckpt_cfg) @@ -710,8 +714,6 @@ if __name__ == "__main__": "target": "main.CUDACallback" }, } - if version.parse(pl.__version__) >= version.parse('1.4.0'): - default_callbacks_cfg.update({'checkpoint_callback': modelckpt_cfg}) if "callbacks" in lightning_config: callbacks_cfg = lightning_config.callbacks @@ -737,15 +739,11 @@ if __name__ == "__main__": default_callbacks_cfg.update(default_metrics_over_trainsteps_ckpt_dict) callbacks_cfg = OmegaConf.merge(default_callbacks_cfg, callbacks_cfg) - if 'ignore_keys_callback' in callbacks_cfg and hasattr(trainer_opt, 'resume_from_checkpoint'): - callbacks_cfg.ignore_keys_callback.params['ckpt_path'] = trainer_opt.resume_from_checkpoint - elif 'ignore_keys_callback' in callbacks_cfg: - del callbacks_cfg['ignore_keys_callback'] trainer_kwargs["callbacks"] = [instantiate_from_config(callbacks_cfg[k]) for k in callbacks_cfg] trainer = Trainer.from_argparse_args(trainer_opt, **trainer_kwargs) - trainer.logdir = logdir ### + trainer.logdir = logdir # data data = instantiate_from_config(config.data) @@ -754,9 +752,9 @@ if __name__ == "__main__": # lightning still takes care of proper multiprocessing though data.prepare_data() data.setup() - print("#### Data #####") + for k in data.datasets: - print(f"{k}, {data.datasets[k].__class__.__name__}, {len(data.datasets[k])}") + rank_zero_info(f"{k}, {data.datasets[k].__class__.__name__}, {len(data.datasets[k])}") # configure learning rate bs, base_lr = config.data.params.batch_size, config.model.base_learning_rate @@ -768,17 +766,17 @@ if __name__ == "__main__": accumulate_grad_batches = lightning_config.trainer.accumulate_grad_batches else: accumulate_grad_batches = 1 - print(f"accumulate_grad_batches = {accumulate_grad_batches}") + rank_zero_info(f"accumulate_grad_batches = {accumulate_grad_batches}") lightning_config.trainer.accumulate_grad_batches = accumulate_grad_batches if opt.scale_lr: model.learning_rate = accumulate_grad_batches * ngpu * bs * base_lr - print( + rank_zero_info( "Setting learning rate to {:.2e} = {} (accumulate_grad_batches) * {} (num_gpus) * {} (batchsize) * {:.2e} (base_lr)" .format(model.learning_rate, accumulate_grad_batches, ngpu, bs, base_lr)) else: model.learning_rate = base_lr - print("++++ NOT USING LR SCALING ++++") - print(f"Setting learning rate to {model.learning_rate:.2e}") + rank_zero_info("++++ NOT USING LR SCALING ++++") + rank_zero_info(f"Setting learning rate to {model.learning_rate:.2e}") # allow checkpointing via USR1 def melk(*args, **kwargs): diff --git a/examples/images/diffusion/scripts/txt2img.sh b/examples/images/diffusion/scripts/txt2img.sh index 53041cb8d..bc6480b6b 100755 --- a/examples/images/diffusion/scripts/txt2img.sh +++ b/examples/images/diffusion/scripts/txt2img.sh @@ -1,5 +1,5 @@ -python scripts/txt2img.py --prompt "Teyvat, Name:Layla, Element: Cryo, Weapon:Sword, Region:Sumeru, Model type:Medium Female, Description:a woman in a blue outfit holding a sword" --plms \ +python scripts/txt2img.py --prompt "Teyvat, Medium Female, a woman in a blue outfit holding a sword" --plms \ --outdir ./output \ - --ckpt /tmp/2022-11-18T16-38-46_train_colossalai/checkpoints/last.ckpt \ - --config /tmp/2022-11-18T16-38-46_train_colossalai/configs/2022-11-18T16-38-46-project.yaml \ + --ckpt checkpoints/last.ckpt \ + --config configs/2023-02-02T18-06-14-project.yaml \ --n_samples 4 diff --git a/examples/images/diffusion/train_colossalai.sh b/examples/images/diffusion/train_colossalai.sh index dcaeeb0c6..c56ed7876 100755 --- a/examples/images/diffusion/train_colossalai.sh +++ b/examples/images/diffusion/train_colossalai.sh @@ -2,4 +2,4 @@ HF_DATASETS_OFFLINE=1 TRANSFORMERS_OFFLINE=1 DIFFUSERS_OFFLINE=1 -python main.py --logdir /tmp -t -b configs/train_colossalai.yaml +python main.py --logdir /tmp --train --base configs/Teyvat/train_colossalai_teyvat.yaml --ckpt diffuser_root_dir/512-base-ema.ckpt