From cea4292ae5ce7563f94bb7deef6f1b9bd155ecc5 Mon Sep 17 00:00:00 2001 From: Fazzie <1240419984@qq.com> Date: Mon, 12 Dec 2022 17:35:23 +0800 Subject: [PATCH] support stable diffusion v2 --- examples/images/diffusion/README.md | 45 +- .../configs/Inference/v2-inference-v.yaml | 68 + .../configs/Inference/v2-inference.yaml | 67 + .../Inference/v2-inpainting-inference.yaml | 158 +++ .../configs/Inference/v2-midas-inference.yaml | 72 + .../configs/Inference/x4-upscaling.yaml | 75 + .../images/diffusion/configs/Teyvat/README.md | 25 + .../{ => Teyvat}/train_colossalai_teyvat.yaml | 52 +- .../diffusion/configs/train_colossalai.yaml | 46 +- .../configs/train_colossalai_cifar10.yaml | 44 +- .../images/diffusion/configs/train_ddp.yaml | 60 +- .../diffusion/configs/train_pokemon.yaml | 57 +- examples/images/diffusion/environment.yaml | 25 +- .../diffusion/ldm/models/autoencoder.py | 431 +----- .../diffusion/ldm/models/diffusion/ddim.py | 128 +- .../diffusion/ldm/models/diffusion/ddpm.py | 1263 +++++++++++------ .../models/diffusion/dpm_solver/__init__.py | 1 + .../models/diffusion/dpm_solver/dpm_solver.py | 1154 +++++++++++++++ .../models/diffusion/dpm_solver/sampler.py | 87 ++ .../diffusion/ldm/models/diffusion/plms.py | 14 +- .../ldm/models/diffusion/sampling_util.py | 22 + .../images/diffusion/ldm/modules/attention.py | 233 +-- .../ldm/modules/diffusionmodules/model.py | 231 ++- .../modules/diffusionmodules/openaimodel.py | 527 ++----- .../ldm/modules/diffusionmodules/upscaling.py | 81 ++ .../ldm/modules/diffusionmodules/util.py | 25 +- examples/images/diffusion/ldm/modules/ema.py | 20 +- .../diffusion/ldm/modules/encoders/modules.py | 349 ++--- .../diffusion/ldm/modules/flash_attention.py | 50 - .../modules/image_degradation/bsrgan_light.py | 17 +- .../diffusion/ldm/modules/losses/__init__.py | 1 - .../ldm/modules/losses/contperceptual.py | 111 -- .../ldm/modules/losses/vqperceptual.py | 167 --- .../diffusion/ldm/modules/midas/__init__.py | 0 .../images/diffusion/ldm/modules/midas/api.py | 170 +++ .../ldm/modules/midas/midas/__init__.py | 0 .../ldm/modules/midas/midas/base_model.py | 16 + .../ldm/modules/midas/midas/blocks.py | 342 +++++ .../ldm/modules/midas/midas/dpt_depth.py | 109 ++ .../ldm/modules/midas/midas/midas_net.py | 76 + .../modules/midas/midas/midas_net_custom.py | 128 ++ .../ldm/modules/midas/midas/transforms.py | 234 +++ .../diffusion/ldm/modules/midas/midas/vit.py | 491 +++++++ .../diffusion/ldm/modules/midas/utils.py | 189 +++ .../diffusion/ldm/modules/x_transformer.py | 641 --------- examples/images/diffusion/ldm/util.py | 206 ++- examples/images/diffusion/main.py | 240 ++-- examples/images/diffusion/requirements.txt | 21 +- examples/images/diffusion/scripts/img2img.py | 97 +- examples/images/diffusion/scripts/txt2img.py | 226 ++- examples/images/diffusion/train.sh | 7 +- 51 files changed, 5597 insertions(+), 3302 deletions(-) create mode 100644 examples/images/diffusion/configs/Inference/v2-inference-v.yaml create mode 100644 examples/images/diffusion/configs/Inference/v2-inference.yaml create mode 100644 examples/images/diffusion/configs/Inference/v2-inpainting-inference.yaml create mode 100644 examples/images/diffusion/configs/Inference/v2-midas-inference.yaml create mode 100644 examples/images/diffusion/configs/Inference/x4-upscaling.yaml create mode 100644 examples/images/diffusion/configs/Teyvat/README.md rename examples/images/diffusion/configs/{ => Teyvat}/train_colossalai_teyvat.yaml (70%) create mode 100644 examples/images/diffusion/ldm/models/diffusion/dpm_solver/__init__.py create mode 100644 examples/images/diffusion/ldm/models/diffusion/dpm_solver/dpm_solver.py create mode 100644 examples/images/diffusion/ldm/models/diffusion/dpm_solver/sampler.py create mode 100644 examples/images/diffusion/ldm/models/diffusion/sampling_util.py create mode 100644 examples/images/diffusion/ldm/modules/diffusionmodules/upscaling.py delete mode 100644 examples/images/diffusion/ldm/modules/flash_attention.py delete mode 100644 examples/images/diffusion/ldm/modules/losses/__init__.py delete mode 100644 examples/images/diffusion/ldm/modules/losses/contperceptual.py delete mode 100644 examples/images/diffusion/ldm/modules/losses/vqperceptual.py create mode 100644 examples/images/diffusion/ldm/modules/midas/__init__.py create mode 100644 examples/images/diffusion/ldm/modules/midas/api.py create mode 100644 examples/images/diffusion/ldm/modules/midas/midas/__init__.py create mode 100644 examples/images/diffusion/ldm/modules/midas/midas/base_model.py create mode 100644 examples/images/diffusion/ldm/modules/midas/midas/blocks.py create mode 100644 examples/images/diffusion/ldm/modules/midas/midas/dpt_depth.py create mode 100644 examples/images/diffusion/ldm/modules/midas/midas/midas_net.py create mode 100644 examples/images/diffusion/ldm/modules/midas/midas/midas_net_custom.py create mode 100644 examples/images/diffusion/ldm/modules/midas/midas/transforms.py create mode 100644 examples/images/diffusion/ldm/modules/midas/midas/vit.py create mode 100644 examples/images/diffusion/ldm/modules/midas/utils.py delete mode 100644 examples/images/diffusion/ldm/modules/x_transformer.py diff --git a/examples/images/diffusion/README.md b/examples/images/diffusion/README.md index 77860211d..fa8cd28c2 100644 --- a/examples/images/diffusion/README.md +++ b/examples/images/diffusion/README.md @@ -1,4 +1,5 @@ -# Stable Diffusion with Colossal-AI +# ColoDiffusion: Stable Diffusion with Colossal-AI + *[Colosssal-AI](https://github.com/hpcaitech/ColossalAI) provides a faster and lower cost solution for pretraining and fine-tuning for AIGC (AI-Generated Content) applications such as the model [stable-diffusion](https://github.com/CompVis/stable-diffusion) from [Stability AI](https://stability.ai/).* @@ -6,6 +7,7 @@ We take advantage of [Colosssal-AI](https://github.com/hpcaitech/ColossalAI) to , e.g. data parallelism, tensor parallelism, mixed precision & ZeRO, to scale the training to multiple GPUs. ## Stable Diffusion + [Stable Diffusion](https://huggingface.co/CompVis/stable-diffusion) is a latent text-to-image diffusion model. Thanks to a generous compute donation from [Stability AI](https://stability.ai/) and support from [LAION](https://laion.ai/), we were able to train a Latent Diffusion Model on 512x512 images from a subset of the [LAION-5B](https://laion.ai/blog/laion-5b/) database. @@ -23,6 +25,7 @@ this model uses a frozen CLIP ViT-L/14 text encoder to condition the model on te

## Requirements + A suitable [conda](https://conda.io/) environment named `ldm` can be created and activated with: @@ -34,14 +37,24 @@ conda activate ldm You can also update an existing [latent diffusion](https://github.com/CompVis/latent-diffusion) environment by running ``` -conda install pytorch torchvision -c pytorch +conda install pytorch==1.12.1 torchvision==0.13.1 torchaudio==0.12.1 cudatoolkit=11.3 -c pytorch pip install transformers==4.19.2 diffusers invisible-watermark pip install -e . ``` -### Install [Colossal-AI v0.1.10](https://colossalai.org/download/) From Our Official Website +### install lightning + ``` -pip install colossalai==0.1.10+torch1.11cu11.3 -f https://release.colossalai.org +git clone https://github.com/1SAA/lightning.git +git checkout strategy/colossalai +export PACKAGE_NAME=pytorch +pip install . +``` + +### Install [Colossal-AI v0.1.10](https://colossalai.org/download/) From Our Official Website + +``` +pip install colossalai==0.1.12+torch1.12cu11.3 -f https://release.colossalai.org ``` > The specified version is due to the interface incompatibility caused by the latest update of [Lightning](https://github.com/Lightning-AI/lightning), which will be fixed in the near future. @@ -49,6 +62,7 @@ pip install colossalai==0.1.10+torch1.11cu11.3 -f https://release.colossalai.org ## Download the model checkpoint from pretrained ### stable-diffusion-v1-4 + Our default model config use the weight from [CompVis/stable-diffusion-v1-4](https://huggingface.co/CompVis/stable-diffusion-v1-4?text=A+mecha+robot+in+a+favela+in+expressionist+style) ``` @@ -57,6 +71,7 @@ git clone https://huggingface.co/CompVis/stable-diffusion-v1-4 ``` ### stable-diffusion-v1-5 from runway + If you want to useed the Last [stable-diffusion-v1-5](https://huggingface.co/runwayml/stable-diffusion-v1-5) wiegh from runwayml ``` @@ -64,23 +79,24 @@ git lfs install git clone https://huggingface.co/runwayml/stable-diffusion-v1-5 ``` - ## Dataset + The dataSet is from [LAION-5B](https://laion.ai/blog/laion-5b/), the subset of [LAION](https://laion.ai/), you should the change the `data.file_path` in the `config/train_colossalai.yaml` ## Training -We provide the script `train.sh` to run the training task , and two Stategy in `configs`:`train_colossalai.yaml` +We provide the script `train.sh` to run the training task , and two Stategy in `configs`:`train_colossalai.yaml` and `train_ddp.yaml` For example, you can run the training from colossalai by ``` -python main.py --logdir /tmp -t --postfix test -b configs/train_colossalai.yaml +python main.py --logdir /tmp/ -t -b configs/train_colossalai.yaml ``` - you can change the `--logdir` the save the log information and the last checkpoint ### Training config + You can change the trainging config in the yaml file - accelerator: acceleratortype, default 'gpu' @@ -88,27 +104,25 @@ You can change the trainging config in the yaml file - max_epochs: max training epochs - precision: usefp16 for training or not, default 16, you must use fp16 if you want to apply colossalai -## Example +## Finetone Example +### Training on Teyvat Datasets -### Training on cifar10 +We provide the finetuning example on [Teyvat](https://huggingface.co/datasets/Fazzie/Teyvat) dataset, which is create by BLIP generated captions. -We provide the finetuning example on CIFAR10 dataset - -You can run by config `train_colossalai_cifar10.yaml` +You can run by config `configs/Teyvat/train_colossalai_teyvat.yaml` ``` -python main.py --logdir /tmp -t --postfix test -b configs/train_colossalai_cifar10.yaml +python main.py --logdir /tmp/ -t -b configs/Teyvat/train_colossalai_teyvat.yaml ``` ## Inference you can get yout training last.ckpt and train config.yaml in your `--logdir`, and run by ``` -python scripts/txt2img.py --prompt "a photograph of an astronaut riding a horse" --plms +python scripts/txt2img.py --prompt "a photograph of an astronaut riding a horse" --plms --outdir ./output \ --config path/to/logdir/checkpoints/last.ckpt \ --ckpt /path/to/logdir/configs/project.yaml \ ``` - ```commandline usage: txt2img.py [-h] [--prompt [PROMPT]] [--outdir [OUTDIR]] [--skip_grid] [--skip_save] [--ddim_steps DDIM_STEPS] [--plms] [--laion400m] [--fixed_code] [--ddim_eta DDIM_ETA] [--n_iter N_ITER] [--H H] [--W W] [--C C] [--f F] [--n_samples N_SAMPLES] [--n_rows N_ROWS] [--scale SCALE] [--from-file FROM_FILE] [--config CONFIG] [--ckpt CKPT] @@ -144,7 +158,6 @@ optional arguments: evaluate at this precision ``` - ## Comments - Our codebase for the diffusion models builds heavily on [OpenAI's ADM codebase](https://github.com/openai/guided-diffusion) diff --git a/examples/images/diffusion/configs/Inference/v2-inference-v.yaml b/examples/images/diffusion/configs/Inference/v2-inference-v.yaml new file mode 100644 index 000000000..8ec8dfbfe --- /dev/null +++ b/examples/images/diffusion/configs/Inference/v2-inference-v.yaml @@ -0,0 +1,68 @@ +model: + base_learning_rate: 1.0e-4 + target: ldm.models.diffusion.ddpm.LatentDiffusion + params: + parameterization: "v" + linear_start: 0.00085 + linear_end: 0.0120 + num_timesteps_cond: 1 + log_every_t: 200 + timesteps: 1000 + first_stage_key: "jpg" + cond_stage_key: "txt" + image_size: 64 + channels: 4 + cond_stage_trainable: false + conditioning_key: crossattn + monitor: val/loss_simple_ema + scale_factor: 0.18215 + use_ema: False # we set this to false because this is an inference only config + + unet_config: + target: ldm.modules.diffusionmodules.openaimodel.UNetModel + params: + use_checkpoint: True + use_fp16: True + image_size: 32 # unused + in_channels: 4 + out_channels: 4 + model_channels: 320 + attention_resolutions: [ 4, 2, 1 ] + num_res_blocks: 2 + channel_mult: [ 1, 2, 4, 4 ] + num_head_channels: 64 # need to fix for flash-attn + use_spatial_transformer: True + use_linear_in_transformer: True + transformer_depth: 1 + context_dim: 1024 + legacy: False + + first_stage_config: + target: ldm.models.autoencoder.AutoencoderKL + params: + embed_dim: 4 + monitor: val/rec_loss + ddconfig: + #attn_type: "vanilla-xformers" + double_z: true + z_channels: 4 + resolution: 256 + in_channels: 3 + out_ch: 3 + ch: 128 + ch_mult: + - 1 + - 2 + - 4 + - 4 + num_res_blocks: 2 + attn_resolutions: [] + dropout: 0.0 + lossconfig: + target: torch.nn.Identity + + cond_stage_config: + target: ldm.modules.encoders.modules.FrozenOpenCLIPEmbedder + params: + freeze: True + layer: "penultimate" diff --git a/examples/images/diffusion/configs/Inference/v2-inference.yaml b/examples/images/diffusion/configs/Inference/v2-inference.yaml new file mode 100644 index 000000000..152c4f3c2 --- /dev/null +++ b/examples/images/diffusion/configs/Inference/v2-inference.yaml @@ -0,0 +1,67 @@ +model: + base_learning_rate: 1.0e-4 + target: ldm.models.diffusion.ddpm.LatentDiffusion + params: + linear_start: 0.00085 + linear_end: 0.0120 + num_timesteps_cond: 1 + log_every_t: 200 + timesteps: 1000 + first_stage_key: "jpg" + cond_stage_key: "txt" + image_size: 64 + channels: 4 + cond_stage_trainable: false + conditioning_key: crossattn + monitor: val/loss_simple_ema + scale_factor: 0.18215 + use_ema: False # we set this to false because this is an inference only config + + unet_config: + target: ldm.modules.diffusionmodules.openaimodel.UNetModel + params: + use_checkpoint: True + use_fp16: True + image_size: 32 # unused + in_channels: 4 + out_channels: 4 + model_channels: 320 + attention_resolutions: [ 4, 2, 1 ] + num_res_blocks: 2 + channel_mult: [ 1, 2, 4, 4 ] + num_head_channels: 64 # need to fix for flash-attn + use_spatial_transformer: True + use_linear_in_transformer: True + transformer_depth: 1 + context_dim: 1024 + legacy: False + + first_stage_config: + target: ldm.models.autoencoder.AutoencoderKL + params: + embed_dim: 4 + monitor: val/rec_loss + ddconfig: + #attn_type: "vanilla-xformers" + double_z: true + z_channels: 4 + resolution: 256 + in_channels: 3 + out_ch: 3 + ch: 128 + ch_mult: + - 1 + - 2 + - 4 + - 4 + num_res_blocks: 2 + attn_resolutions: [] + dropout: 0.0 + lossconfig: + target: torch.nn.Identity + + cond_stage_config: + target: ldm.modules.encoders.modules.FrozenOpenCLIPEmbedder + params: + freeze: True + layer: "penultimate" diff --git a/examples/images/diffusion/configs/Inference/v2-inpainting-inference.yaml b/examples/images/diffusion/configs/Inference/v2-inpainting-inference.yaml new file mode 100644 index 000000000..32a9471d7 --- /dev/null +++ b/examples/images/diffusion/configs/Inference/v2-inpainting-inference.yaml @@ -0,0 +1,158 @@ +model: + base_learning_rate: 5.0e-05 + target: ldm.models.diffusion.ddpm.LatentInpaintDiffusion + params: + linear_start: 0.00085 + linear_end: 0.0120 + num_timesteps_cond: 1 + log_every_t: 200 + timesteps: 1000 + first_stage_key: "jpg" + cond_stage_key: "txt" + image_size: 64 + channels: 4 + cond_stage_trainable: false + conditioning_key: hybrid + scale_factor: 0.18215 + monitor: val/loss_simple_ema + finetune_keys: null + use_ema: False + + unet_config: + target: ldm.modules.diffusionmodules.openaimodel.UNetModel + params: + use_checkpoint: True + image_size: 32 # unused + in_channels: 9 + out_channels: 4 + model_channels: 320 + attention_resolutions: [ 4, 2, 1 ] + num_res_blocks: 2 + channel_mult: [ 1, 2, 4, 4 ] + num_head_channels: 64 # need to fix for flash-attn + use_spatial_transformer: True + use_linear_in_transformer: True + transformer_depth: 1 + context_dim: 1024 + legacy: False + + first_stage_config: + target: ldm.models.autoencoder.AutoencoderKL + params: + embed_dim: 4 + monitor: val/rec_loss + ddconfig: + #attn_type: "vanilla-xformers" + double_z: true + z_channels: 4 + resolution: 256 + in_channels: 3 + out_ch: 3 + ch: 128 + ch_mult: + - 1 + - 2 + - 4 + - 4 + num_res_blocks: 2 + attn_resolutions: [ ] + dropout: 0.0 + lossconfig: + target: torch.nn.Identity + + cond_stage_config: + target: ldm.modules.encoders.modules.FrozenOpenCLIPEmbedder + params: + freeze: True + layer: "penultimate" + + +data: + target: ldm.data.laion.WebDataModuleFromConfig + params: + tar_base: null # for concat as in LAION-A + p_unsafe_threshold: 0.1 + filter_word_list: "data/filters.yaml" + max_pwatermark: 0.45 + batch_size: 8 + num_workers: 6 + multinode: True + min_size: 512 + train: + shards: + - "pipe:aws s3 cp s3://stability-aws/laion-a-native/part-0/{00000..18699}.tar -" + - "pipe:aws s3 cp s3://stability-aws/laion-a-native/part-1/{00000..18699}.tar -" + - "pipe:aws s3 cp s3://stability-aws/laion-a-native/part-2/{00000..18699}.tar -" + - "pipe:aws s3 cp s3://stability-aws/laion-a-native/part-3/{00000..18699}.tar -" + - "pipe:aws s3 cp s3://stability-aws/laion-a-native/part-4/{00000..18699}.tar -" #{00000-94333}.tar" + shuffle: 10000 + image_key: jpg + image_transforms: + - target: torchvision.transforms.Resize + params: + size: 512 + interpolation: 3 + - target: torchvision.transforms.RandomCrop + params: + size: 512 + postprocess: + target: ldm.data.laion.AddMask + params: + mode: "512train-large" + p_drop: 0.25 + # NOTE use enough shards to avoid empty validation loops in workers + validation: + shards: + - "pipe:aws s3 cp s3://deep-floyd-s3/datasets/laion_cleaned-part5/{93001..94333}.tar - " + shuffle: 0 + image_key: jpg + image_transforms: + - target: torchvision.transforms.Resize + params: + size: 512 + interpolation: 3 + - target: torchvision.transforms.CenterCrop + params: + size: 512 + postprocess: + target: ldm.data.laion.AddMask + params: + mode: "512train-large" + p_drop: 0.25 + +lightning: + find_unused_parameters: True + modelcheckpoint: + params: + every_n_train_steps: 5000 + + callbacks: + metrics_over_trainsteps_checkpoint: + params: + every_n_train_steps: 10000 + + image_logger: + target: main.ImageLogger + params: + enable_autocast: False + disabled: False + batch_frequency: 1000 + max_images: 4 + increase_log_steps: False + log_first_step: False + log_images_kwargs: + use_ema_scope: False + inpaint: False + plot_progressive_rows: False + plot_diffusion_rows: False + N: 4 + unconditional_guidance_scale: 5.0 + unconditional_guidance_label: [""] + ddim_steps: 50 # todo check these out for depth2img, + ddim_eta: 0.0 # todo check these out for depth2img, + + trainer: + benchmark: True + val_check_interval: 5000000 + num_sanity_val_steps: 0 + accumulate_grad_batches: 1 diff --git a/examples/images/diffusion/configs/Inference/v2-midas-inference.yaml b/examples/images/diffusion/configs/Inference/v2-midas-inference.yaml new file mode 100644 index 000000000..531199de4 --- /dev/null +++ b/examples/images/diffusion/configs/Inference/v2-midas-inference.yaml @@ -0,0 +1,72 @@ +model: + base_learning_rate: 5.0e-07 + target: ldm.models.diffusion.ddpm.LatentDepth2ImageDiffusion + params: + linear_start: 0.00085 + linear_end: 0.0120 + num_timesteps_cond: 1 + log_every_t: 200 + timesteps: 1000 + first_stage_key: "jpg" + cond_stage_key: "txt" + image_size: 64 + channels: 4 + cond_stage_trainable: false + conditioning_key: hybrid + scale_factor: 0.18215 + monitor: val/loss_simple_ema + finetune_keys: null + use_ema: False + + depth_stage_config: + target: ldm.modules.midas.api.MiDaSInference + params: + model_type: "dpt_hybrid" + + unet_config: + target: ldm.modules.diffusionmodules.openaimodel.UNetModel + params: + use_checkpoint: True + image_size: 32 # unused + in_channels: 5 + out_channels: 4 + model_channels: 320 + attention_resolutions: [ 4, 2, 1 ] + num_res_blocks: 2 + channel_mult: [ 1, 2, 4, 4 ] + num_head_channels: 64 # need to fix for flash-attn + use_spatial_transformer: True + use_linear_in_transformer: True + transformer_depth: 1 + context_dim: 1024 + legacy: False + + first_stage_config: + target: ldm.models.autoencoder.AutoencoderKL + params: + embed_dim: 4 + monitor: val/rec_loss + ddconfig: + #attn_type: "vanilla-xformers" + double_z: true + z_channels: 4 + resolution: 256 + in_channels: 3 + out_ch: 3 + ch: 128 + ch_mult: + - 1 + - 2 + - 4 + - 4 + num_res_blocks: 2 + attn_resolutions: [ ] + dropout: 0.0 + lossconfig: + target: torch.nn.Identity + + cond_stage_config: + target: ldm.modules.encoders.modules.FrozenOpenCLIPEmbedder + params: + freeze: True + layer: "penultimate" diff --git a/examples/images/diffusion/configs/Inference/x4-upscaling.yaml b/examples/images/diffusion/configs/Inference/x4-upscaling.yaml new file mode 100644 index 000000000..45ecbf9ad --- /dev/null +++ b/examples/images/diffusion/configs/Inference/x4-upscaling.yaml @@ -0,0 +1,75 @@ +model: + base_learning_rate: 1.0e-04 + target: ldm.models.diffusion.ddpm.LatentUpscaleDiffusion + params: + parameterization: "v" + low_scale_key: "lr" + linear_start: 0.0001 + linear_end: 0.02 + num_timesteps_cond: 1 + log_every_t: 200 + timesteps: 1000 + first_stage_key: "jpg" + cond_stage_key: "txt" + image_size: 128 + channels: 4 + cond_stage_trainable: false + conditioning_key: "hybrid-adm" + monitor: val/loss_simple_ema + scale_factor: 0.08333 + use_ema: False + + low_scale_config: + target: ldm.modules.diffusionmodules.upscaling.ImageConcatWithNoiseAugmentation + params: + noise_schedule_config: # image space + linear_start: 0.0001 + linear_end: 0.02 + max_noise_level: 350 + + unet_config: + target: ldm.modules.diffusionmodules.openaimodel.UNetModel + params: + use_checkpoint: True + num_classes: 1000 # timesteps for noise conditioning (here constant, just need one) + image_size: 128 + in_channels: 7 + out_channels: 4 + model_channels: 256 + attention_resolutions: [ 2,4,8] + num_res_blocks: 2 + channel_mult: [ 1, 2, 2, 4] + disable_self_attentions: [True, True, True, False] + disable_middle_self_attn: False + num_heads: 8 + use_spatial_transformer: True + transformer_depth: 1 + context_dim: 1024 + legacy: False + use_linear_in_transformer: True + + first_stage_config: + target: ldm.models.autoencoder.AutoencoderKL + params: + embed_dim: 4 + ddconfig: + # attn_type: "vanilla-xformers" this model needs efficient attention to be feasible on HR data, also the decoder seems to break in half precision (UNet is fine though) + double_z: True + z_channels: 4 + resolution: 256 + in_channels: 3 + out_ch: 3 + ch: 128 + ch_mult: [ 1,2,4 ] # num_down = len(ch_mult)-1 + num_res_blocks: 2 + attn_resolutions: [ ] + dropout: 0.0 + + lossconfig: + target: torch.nn.Identity + + cond_stage_config: + target: ldm.modules.encoders.modules.FrozenOpenCLIPEmbedder + params: + freeze: True + layer: "penultimate" diff --git a/examples/images/diffusion/configs/Teyvat/README.md b/examples/images/diffusion/configs/Teyvat/README.md new file mode 100644 index 000000000..6a7ee88e5 --- /dev/null +++ b/examples/images/diffusion/configs/Teyvat/README.md @@ -0,0 +1,25 @@ +# Dataset Card for Teyvat BLIP captions +Dataset used to train [Teyvat characters text to image model](https://github.com/hpcaitech/ColossalAI/tree/main/examples/images/diffusion). + +BLIP generated captions for characters images from [genshin-impact fandom wiki](https://genshin-impact.fandom.com/wiki/Character#Playable_Characters)and [biligame wiki for genshin impact](https://wiki.biligame.com/ys/%E8%A7%92%E8%89%B2). + +For each row the dataset contains `image` and `text` keys. `image` is a varying size PIL png, and `text` is the accompanying text caption. Only a train split is provided. + +The `text` include the tag `Teyvat`, `Name`,`Element`, `Weapon`, `Region`, `Model type`, and `Description`, the `Description` is captioned with the [pre-trained BLIP model](https://github.com/salesforce/BLIP). +## Examples + + + +> Teyvat, Name:Ganyu, Element:Cryo, Weapon:Bow, Region:Liyue, Model type:Medium Female, Description:an anime character with blue hair and blue eyes + + + +> Teyvat, Name:Ganyu, Element:Cryo, Weapon:Bow, Region:Liyue, Model type:Medium Female, Description:an anime character with blue hair and blue eyes + + + +> Teyvat, Name:Keqing, Element:Electro, Weapon:Sword, Region:Liyue, Model type:Medium Female, Description:a anime girl with long white hair and blue eyes + + + +> Teyvat, Name:Keqing, Element:Electro, Weapon:Sword, Region:Liyue, Model type:Medium Female, Description:an anime character wearing a purple dress and cat ears diff --git a/examples/images/diffusion/configs/train_colossalai_teyvat.yaml b/examples/images/diffusion/configs/Teyvat/train_colossalai_teyvat.yaml similarity index 70% rename from examples/images/diffusion/configs/train_colossalai_teyvat.yaml rename to examples/images/diffusion/configs/Teyvat/train_colossalai_teyvat.yaml index e25473004..9048b3f80 100644 --- a/examples/images/diffusion/configs/train_colossalai_teyvat.yaml +++ b/examples/images/diffusion/configs/Teyvat/train_colossalai_teyvat.yaml @@ -1,7 +1,8 @@ model: - base_learning_rate: 1.0e-04 + base_learning_rate: 1.0e-4 target: ldm.models.diffusion.ddpm.LatentDiffusion params: + parameterization: "v" linear_start: 0.00085 linear_end: 0.0120 num_timesteps_cond: 1 @@ -11,11 +12,11 @@ model: cond_stage_key: txt image_size: 64 channels: 4 - cond_stage_trainable: false # Note: different from the one we trained before + cond_stage_trainable: false conditioning_key: crossattn monitor: val/loss_simple_ema scale_factor: 0.18215 - use_ema: False + use_ema: False # we set this to false because this is an inference only config scheduler_config: # 10000 warmup steps target: ldm.lr_scheduler.LambdaLinearScheduler @@ -26,31 +27,33 @@ model: f_max: [ 1.e-4 ] f_min: [ 1.e-10 ] + unet_config: target: ldm.modules.diffusionmodules.openaimodel.UNetModel params: + use_checkpoint: True + use_fp16: True image_size: 32 # unused - from_pretrained: '/data/scratch/diffuser/stable-diffusion-v1-4/unet/diffusion_pytorch_model.bin' in_channels: 4 out_channels: 4 model_channels: 320 attention_resolutions: [ 4, 2, 1 ] num_res_blocks: 2 channel_mult: [ 1, 2, 4, 4 ] - num_heads: 8 + num_head_channels: 64 # need to fix for flash-attn use_spatial_transformer: True + use_linear_in_transformer: True transformer_depth: 1 - context_dim: 768 - use_checkpoint: False + context_dim: 1024 legacy: False first_stage_config: target: ldm.models.autoencoder.AutoencoderKL params: embed_dim: 4 - from_pretrained: '/data/scratch/diffuser/stable-diffusion-v1-4/vae/diffusion_pytorch_model.bin' monitor: val/rec_loss ddconfig: + #attn_type: "vanilla-xformers" double_z: true z_channels: 4 resolution: 256 @@ -69,9 +72,10 @@ model: target: torch.nn.Identity cond_stage_config: - target: ldm.modules.encoders.modules.FrozenCLIPEmbedder + target: ldm.modules.encoders.modules.FrozenOpenCLIPEmbedder params: - use_fp16: True + freeze: True + layer: "penultimate" data: target: main.DataModuleFromConfig @@ -86,37 +90,37 @@ data: - target: torchvision.transforms.Resize params: size: 512 - # - target: torchvision.transforms.RandomCrop - # params: - # size: 256 - # - target: torchvision.transforms.RandomHorizontalFlip + - target: torchvision.transforms.RandomCrop + params: + size: 512 + - target: torchvision.transforms.RandomHorizontalFlip lightning: trainer: - accelerator: 'gpu' + accelerator: 'gpu' devices: 2 log_gpu_memory: all - max_epochs: 10 + max_epochs: 2 precision: 16 auto_select_gpus: False strategy: - target: lightning.pytorch.strategies.ColossalAIStrategy + target: strategies.ColossalAIStrategy params: - use_chunk: False - enable_distributed_storage: True, - placement_policy: cuda - force_outputs_fp32: False + use_chunk: True + enable_distributed_storage: True + placement_policy: auto + force_outputs_fp32: true log_every_n_steps: 2 logger: True default_root_dir: "/tmp/diff_log/" - profiler: pytorch + # profiler: pytorch logger_config: wandb: - target: lightning.pytorch.loggers.WandbLogger + target: loggers.WandbLogger params: name: nowname save_dir: "/tmp/diff_log/" offline: opt.debug - id: nowname \ No newline at end of file + id: nowname diff --git a/examples/images/diffusion/configs/train_colossalai.yaml b/examples/images/diffusion/configs/train_colossalai.yaml index 41c998460..155b26dd4 100644 --- a/examples/images/diffusion/configs/train_colossalai.yaml +++ b/examples/images/diffusion/configs/train_colossalai.yaml @@ -1,21 +1,22 @@ model: - base_learning_rate: 1.0e-04 + base_learning_rate: 1.0e-4 target: ldm.models.diffusion.ddpm.LatentDiffusion params: + parameterization: "v" linear_start: 0.00085 linear_end: 0.0120 num_timesteps_cond: 1 log_every_t: 200 timesteps: 1000 first_stage_key: image - cond_stage_key: caption + cond_stage_key: txt image_size: 64 channels: 4 - cond_stage_trainable: false # Note: different from the one we trained before + cond_stage_trainable: false conditioning_key: crossattn monitor: val/loss_simple_ema scale_factor: 0.18215 - use_ema: False + use_ema: False # we set this to false because this is an inference only config scheduler_config: # 10000 warmup steps target: ldm.lr_scheduler.LambdaLinearScheduler @@ -26,31 +27,33 @@ model: f_max: [ 1.e-4 ] f_min: [ 1.e-10 ] + unet_config: target: ldm.modules.diffusionmodules.openaimodel.UNetModel params: + use_checkpoint: True + use_fp16: True image_size: 32 # unused - from_pretrained: '/data/scratch/diffuser/stable-diffusion-v1-4/unet/diffusion_pytorch_model.bin' in_channels: 4 out_channels: 4 model_channels: 320 attention_resolutions: [ 4, 2, 1 ] num_res_blocks: 2 channel_mult: [ 1, 2, 4, 4 ] - num_heads: 8 + num_head_channels: 64 # need to fix for flash-attn use_spatial_transformer: True + use_linear_in_transformer: True transformer_depth: 1 - context_dim: 768 - use_checkpoint: False + context_dim: 1024 legacy: False first_stage_config: target: ldm.models.autoencoder.AutoencoderKL params: embed_dim: 4 - from_pretrained: '/data/scratch/diffuser/stable-diffusion-v1-4/vae/diffusion_pytorch_model.bin' monitor: val/rec_loss ddconfig: + #attn_type: "vanilla-xformers" double_z: true z_channels: 4 resolution: 256 @@ -69,9 +72,10 @@ model: target: torch.nn.Identity cond_stage_config: - target: ldm.modules.encoders.modules.FrozenCLIPEmbedder + target: ldm.modules.encoders.modules.FrozenOpenCLIPEmbedder params: - use_fp16: True + freeze: True + layer: "penultimate" data: target: main.DataModuleFromConfig @@ -87,30 +91,30 @@ data: lightning: trainer: - accelerator: 'gpu' - devices: 4 + accelerator: 'gpu' + devices: 1 log_gpu_memory: all max_epochs: 2 precision: 16 auto_select_gpus: False strategy: - target: lightning.pytorch.strategies.ColossalAIStrategy + target: strategies.ColossalAIStrategy params: - use_chunk: False - enable_distributed_storage: True, - placement_policy: cuda - force_outputs_fp32: False + use_chunk: True + enable_distributed_storage: True + placement_policy: auto + force_outputs_fp32: true log_every_n_steps: 2 logger: True default_root_dir: "/tmp/diff_log/" - profiler: pytorch + # profiler: pytorch logger_config: wandb: - target: lightning.pytorch.loggers.WandbLogger + target: loggers.WandbLogger params: name: nowname save_dir: "/tmp/diff_log/" offline: opt.debug - id: nowname \ No newline at end of file + id: nowname diff --git a/examples/images/diffusion/configs/train_colossalai_cifar10.yaml b/examples/images/diffusion/configs/train_colossalai_cifar10.yaml index 4348870f5..5335bacbe 100644 --- a/examples/images/diffusion/configs/train_colossalai_cifar10.yaml +++ b/examples/images/diffusion/configs/train_colossalai_cifar10.yaml @@ -1,7 +1,8 @@ model: - base_learning_rate: 1.0e-04 + base_learning_rate: 1.0e-4 target: ldm.models.diffusion.ddpm.LatentDiffusion params: + parameterization: "v" linear_start: 0.00085 linear_end: 0.0120 num_timesteps_cond: 1 @@ -11,11 +12,11 @@ model: cond_stage_key: txt image_size: 64 channels: 4 - cond_stage_trainable: false # Note: different from the one we trained before + cond_stage_trainable: false conditioning_key: crossattn monitor: val/loss_simple_ema scale_factor: 0.18215 - use_ema: False + use_ema: False # we set this to false because this is an inference only config scheduler_config: # 10000 warmup steps target: ldm.lr_scheduler.LambdaLinearScheduler @@ -26,31 +27,33 @@ model: f_max: [ 1.e-4 ] f_min: [ 1.e-10 ] + unet_config: target: ldm.modules.diffusionmodules.openaimodel.UNetModel params: + use_checkpoint: True + use_fp16: True image_size: 32 # unused - from_pretrained: '/data/scratch/diffuser/stable-diffusion-v1-4/unet/diffusion_pytorch_model.bin' in_channels: 4 out_channels: 4 model_channels: 320 attention_resolutions: [ 4, 2, 1 ] num_res_blocks: 2 channel_mult: [ 1, 2, 4, 4 ] - num_heads: 8 + num_head_channels: 64 # need to fix for flash-attn use_spatial_transformer: True + use_linear_in_transformer: True transformer_depth: 1 - context_dim: 768 - use_checkpoint: False + context_dim: 1024 legacy: False first_stage_config: target: ldm.models.autoencoder.AutoencoderKL params: embed_dim: 4 - from_pretrained: '/data/scratch/diffuser/stable-diffusion-v1-4/vae/diffusion_pytorch_model.bin' monitor: val/rec_loss ddconfig: + #attn_type: "vanilla-xformers" double_z: true z_channels: 4 resolution: 256 @@ -69,9 +72,10 @@ model: target: torch.nn.Identity cond_stage_config: - target: ldm.modules.encoders.modules.FrozenCLIPEmbedder + target: ldm.modules.encoders.modules.FrozenOpenCLIPEmbedder params: - use_fp16: True + freeze: True + layer: "penultimate" data: target: main.DataModuleFromConfig @@ -94,30 +98,30 @@ data: lightning: trainer: - accelerator: 'gpu' - devices: 2 + accelerator: 'gpu' + devices: 1 log_gpu_memory: all max_epochs: 2 precision: 16 auto_select_gpus: False strategy: - target: lightning.pytorch.strategies.ColossalAIStrategy + target: strategies.ColossalAIStrategy params: - use_chunk: False - enable_distributed_storage: True, - placement_policy: cuda - force_outputs_fp32: False + use_chunk: True + enable_distributed_storage: True + placement_policy: auto + force_outputs_fp32: true log_every_n_steps: 2 logger: True default_root_dir: "/tmp/diff_log/" - profiler: pytorch + # profiler: pytorch logger_config: wandb: - target: lightning.pytorch.loggers.WandbLogger + target: loggers.WandbLogger params: name: nowname save_dir: "/tmp/diff_log/" offline: opt.debug - id: nowname \ No newline at end of file + id: nowname diff --git a/examples/images/diffusion/configs/train_ddp.yaml b/examples/images/diffusion/configs/train_ddp.yaml index a2e5982b3..4308998f4 100644 --- a/examples/images/diffusion/configs/train_ddp.yaml +++ b/examples/images/diffusion/configs/train_ddp.yaml @@ -1,56 +1,59 @@ model: - base_learning_rate: 1.0e-04 + base_learning_rate: 1.0e-4 target: ldm.models.diffusion.ddpm.LatentDiffusion params: + parameterization: "v" linear_start: 0.00085 linear_end: 0.0120 num_timesteps_cond: 1 log_every_t: 200 timesteps: 1000 first_stage_key: image - cond_stage_key: caption - image_size: 32 + cond_stage_key: txt + image_size: 64 channels: 4 - cond_stage_trainable: false # Note: different from the one we trained before + cond_stage_trainable: false conditioning_key: crossattn monitor: val/loss_simple_ema scale_factor: 0.18215 - use_ema: False + use_ema: False # we set this to false because this is an inference only config scheduler_config: # 10000 warmup steps target: ldm.lr_scheduler.LambdaLinearScheduler params: - warm_up_steps: [ 100 ] + warm_up_steps: [ 1 ] # NOTE for resuming. use 10000 if starting from scratch cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases f_start: [ 1.e-6 ] f_max: [ 1.e-4 ] - f_min: [ 1.e-10 ] + f_min: [ 1.e-10 ] + unet_config: target: ldm.modules.diffusionmodules.openaimodel.UNetModel params: + use_checkpoint: True + use_fp16: True image_size: 32 # unused - from_pretrained: '/data/scratch/diffuser/stable-diffusion-v1-4/unet/diffusion_pytorch_model.bin' in_channels: 4 out_channels: 4 model_channels: 320 attention_resolutions: [ 4, 2, 1 ] num_res_blocks: 2 channel_mult: [ 1, 2, 4, 4 ] - num_heads: 8 + num_head_channels: 64 # need to fix for flash-attn use_spatial_transformer: True + use_linear_in_transformer: True transformer_depth: 1 - context_dim: 768 - use_checkpoint: False + context_dim: 1024 legacy: False first_stage_config: target: ldm.models.autoencoder.AutoencoderKL params: embed_dim: 4 - from_pretrained: '/data/scratch/diffuser/stable-diffusion-v1-4/vae/diffusion_pytorch_model.bin' monitor: val/rec_loss ddconfig: + #attn_type: "vanilla-xformers" double_z: true z_channels: 4 resolution: 256 @@ -69,32 +72,39 @@ model: target: torch.nn.Identity cond_stage_config: - target: ldm.modules.encoders.modules.FrozenCLIPEmbedder + target: ldm.modules.encoders.modules.FrozenOpenCLIPEmbedder params: - use_fp16: True + freeze: True + layer: "penultimate" data: target: main.DataModuleFromConfig params: - batch_size: 64 - wrap: False + batch_size: 16 + num_workers: 4 train: - target: ldm.data.base.Txt2ImgIterableBaseDataset + target: ldm.data.teyvat.hf_dataset params: - file_path: "/data/scratch/diffuser/laion_part0/" - world_size: 1 - rank: 0 + path: Fazzie/Teyvat + image_transforms: + - target: torchvision.transforms.Resize + params: + size: 512 + - target: torchvision.transforms.RandomCrop + params: + size: 512 + - target: torchvision.transforms.RandomHorizontalFlip lightning: trainer: accelerator: 'gpu' - devices: 4 + devices: 2 log_gpu_memory: all max_epochs: 2 precision: 16 auto_select_gpus: False strategy: - target: lightning.pytorch.strategies.DDPStrategy + target: strategies.DDPStrategy params: find_unused_parameters: False log_every_n_steps: 2 @@ -105,9 +115,9 @@ lightning: logger_config: wandb: - target: lightning.pytorch.loggers.WandbLogger + target: loggers.WandbLogger params: name: nowname - save_dir: "/tmp/diff_log/" + save_dir: "/data2/tmp/diff_log/" offline: opt.debug - id: nowname \ No newline at end of file + id: nowname diff --git a/examples/images/diffusion/configs/train_pokemon.yaml b/examples/images/diffusion/configs/train_pokemon.yaml index 246cf002a..38e8485a3 100644 --- a/examples/images/diffusion/configs/train_pokemon.yaml +++ b/examples/images/diffusion/configs/train_pokemon.yaml @@ -1,57 +1,59 @@ model: - base_learning_rate: 1.0e-04 + base_learning_rate: 1.0e-4 target: ldm.models.diffusion.ddpm.LatentDiffusion params: + parameterization: "v" linear_start: 0.00085 linear_end: 0.0120 num_timesteps_cond: 1 log_every_t: 200 timesteps: 1000 first_stage_key: image - cond_stage_key: caption - image_size: 32 + cond_stage_key: txt + image_size: 64 channels: 4 - cond_stage_trainable: false # Note: different from the one we trained before + cond_stage_trainable: false conditioning_key: crossattn monitor: val/loss_simple_ema scale_factor: 0.18215 - use_ema: False - check_nan_inf: False + use_ema: False # we set this to false because this is an inference only config scheduler_config: # 10000 warmup steps target: ldm.lr_scheduler.LambdaLinearScheduler params: - warm_up_steps: [ 10000 ] + warm_up_steps: [ 1 ] # NOTE for resuming. use 10000 if starting from scratch cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases f_start: [ 1.e-6 ] f_max: [ 1.e-4 ] - f_min: [ 1.e-10 ] + f_min: [ 1.e-10 ] + unet_config: target: ldm.modules.diffusionmodules.openaimodel.UNetModel params: + use_checkpoint: True + use_fp16: True image_size: 32 # unused - from_pretrained: '/data/scratch/diffuser/stable-diffusion-v1-4/unet/diffusion_pytorch_model.bin' in_channels: 4 out_channels: 4 model_channels: 320 attention_resolutions: [ 4, 2, 1 ] num_res_blocks: 2 channel_mult: [ 1, 2, 4, 4 ] - num_heads: 8 + num_head_channels: 64 # need to fix for flash-attn use_spatial_transformer: True + use_linear_in_transformer: True transformer_depth: 1 - context_dim: 768 - use_checkpoint: False + context_dim: 1024 legacy: False first_stage_config: target: ldm.models.autoencoder.AutoencoderKL params: embed_dim: 4 - from_pretrained: '/data/scratch/diffuser/stable-diffusion-v1-4/vae/diffusion_pytorch_model.bin' monitor: val/rec_loss ddconfig: + #attn_type: "vanilla-xformers" double_z: true z_channels: 4 resolution: 256 @@ -70,9 +72,10 @@ model: target: torch.nn.Identity cond_stage_config: - target: ldm.modules.encoders.modules.FrozenCLIPEmbedder + target: ldm.modules.encoders.modules.FrozenOpenCLIPEmbedder params: - use_fp16: True + freeze: True + layer: "penultimate" data: target: main.DataModuleFromConfig @@ -88,34 +91,30 @@ data: lightning: trainer: - accelerator: 'gpu' - devices: 4 + accelerator: 'gpu' + devices: 1 log_gpu_memory: all max_epochs: 2 precision: 16 auto_select_gpus: False strategy: - target: lightning.pytorch.strategies.ColossalAIStrategy + target: strategies.ColossalAIStrategy params: - use_chunk: False - enable_distributed_storage: True, - placement_policy: cuda - force_outputs_fp32: False - initial_scale: 65536 - min_scale: 1 - max_scale: 65536 - # max_scale: 4294967296 + use_chunk: True + enable_distributed_storage: True + placement_policy: auto + force_outputs_fp32: true log_every_n_steps: 2 logger: True default_root_dir: "/tmp/diff_log/" - profiler: pytorch + # profiler: pytorch logger_config: wandb: - target: lightning.pytorch.loggers.WandbLogger + target: loggers.WandbLogger params: name: nowname save_dir: "/tmp/diff_log/" offline: opt.debug - id: nowname \ No newline at end of file + id: nowname diff --git a/examples/images/diffusion/environment.yaml b/examples/images/diffusion/environment.yaml index 6fcb10870..5b5579211 100644 --- a/examples/images/diffusion/environment.yaml +++ b/examples/images/diffusion/environment.yaml @@ -6,28 +6,25 @@ dependencies: - python=3.9.12 - pip=20.3 - cudatoolkit=11.3 - - pytorch=1.11.0 - - torchvision=0.12.0 - - numpy=1.19.2 + - pytorch=1.12.1 + - torchvision=0.13.1 + - numpy=1.23.1 - pip: - - albumentations==0.4.3 - - datasets - - diffusers + - albumentations==1.3.0 - opencv-python==4.6.0.66 - - pudb==2019.2 - - invisible-watermark - imageio==2.9.0 - imageio-ffmpeg==0.4.2 - - lightning==1.8.1 - omegaconf==2.1.1 - test-tube>=0.7.5 - - streamlit>=0.73.1 + - streamlit==1.12.1 - einops==0.3.0 - - torch-fidelity==0.3.0 - transformers==4.19.2 - - torchmetrics==0.7.0 + - webdataset==0.2.5 - kornia==0.6 + - open_clip_torch==2.0.2 + - invisible-watermark>=0.1.5 + - streamlit-drawable-canvas==0.8.0 + - torchmetrics==0.7.0 - prefetch_generator - - -e git+https://github.com/CompVis/taming-transformers.git@master#egg=taming-transformers - - -e git+https://github.com/openai/CLIP.git@main#egg=clip + - datasets - -e . diff --git a/examples/images/diffusion/ldm/models/autoencoder.py b/examples/images/diffusion/ldm/models/autoencoder.py index c69920ce5..b1bd83778 100644 --- a/examples/images/diffusion/ldm/models/autoencoder.py +++ b/examples/images/diffusion/ldm/models/autoencoder.py @@ -1,64 +1,68 @@ import torch -import lightning.pytorch as pl +try: + import lightning.pytorch as pl +except: + import pytorch_lightning as pl + import torch.nn.functional as F from contextlib import contextmanager -from taming.modules.vqvae.quantize import VectorQuantizer2 as VectorQuantizer - from ldm.modules.diffusionmodules.model import Encoder, Decoder from ldm.modules.distributions.distributions import DiagonalGaussianDistribution from ldm.util import instantiate_from_config +from ldm.modules.ema import LitEma -class VQModel(pl.LightningModule): +class AutoencoderKL(pl.LightningModule): def __init__(self, ddconfig, lossconfig, - n_embed, embed_dim, ckpt_path=None, ignore_keys=[], image_key="image", colorize_nlabels=None, monitor=None, - batch_resize_range=None, - scheduler_config=None, - lr_g_factor=1.0, - remap=None, - sane_index_shape=False, # tell vector quantizer to return indices as bhw - use_ema=False + ema_decay=None, + learn_logvar=False ): super().__init__() - self.embed_dim = embed_dim - self.n_embed = n_embed + self.learn_logvar = learn_logvar self.image_key = image_key self.encoder = Encoder(**ddconfig) self.decoder = Decoder(**ddconfig) self.loss = instantiate_from_config(lossconfig) - self.quantize = VectorQuantizer(n_embed, embed_dim, beta=0.25, - remap=remap, - sane_index_shape=sane_index_shape) - self.quant_conv = torch.nn.Conv2d(ddconfig["z_channels"], embed_dim, 1) + assert ddconfig["double_z"] + self.quant_conv = torch.nn.Conv2d(2*ddconfig["z_channels"], 2*embed_dim, 1) self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1) + self.embed_dim = embed_dim if colorize_nlabels is not None: assert type(colorize_nlabels)==int self.register_buffer("colorize", torch.randn(3, colorize_nlabels, 1, 1)) if monitor is not None: self.monitor = monitor - self.batch_resize_range = batch_resize_range - if self.batch_resize_range is not None: - print(f"{self.__class__.__name__}: Using per-batch resizing in range {batch_resize_range}.") - self.use_ema = use_ema + self.use_ema = ema_decay is not None if self.use_ema: - self.model_ema = LitEma(self) + self.ema_decay = ema_decay + assert 0. < ema_decay < 1. + self.model_ema = LitEma(self, decay=ema_decay) print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.") if ckpt_path is not None: self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys) - self.scheduler_config = scheduler_config - self.lr_g_factor = lr_g_factor + + def init_from_ckpt(self, path, ignore_keys=list()): + sd = torch.load(path, map_location="cpu")["state_dict"] + keys = list(sd.keys()) + for k in keys: + for ik in ignore_keys: + if k.startswith(ik): + print("Deleting key {} from state_dict.".format(k)) + del sd[k] + self.load_state_dict(sd, strict=False) + print(f"Restored from {path}") @contextmanager def ema_scope(self, context=None): @@ -75,353 +79,10 @@ class VQModel(pl.LightningModule): if context is not None: print(f"{context}: Restored training weights") - def init_from_ckpt(self, path, ignore_keys=list()): - sd = torch.load(path, map_location="cpu")["state_dict"] - keys = list(sd.keys()) - for k in keys: - for ik in ignore_keys: - if k.startswith(ik): - print("Deleting key {} from state_dict.".format(k)) - del sd[k] - missing, unexpected = self.load_state_dict(sd, strict=False) - print(f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys") - if len(missing) > 0: - print(f"Missing Keys: {missing}") - print(f"Unexpected Keys: {unexpected}") - def on_train_batch_end(self, *args, **kwargs): if self.use_ema: self.model_ema(self) - def encode(self, x): - h = self.encoder(x) - h = self.quant_conv(h) - quant, emb_loss, info = self.quantize(h) - return quant, emb_loss, info - - def encode_to_prequant(self, x): - h = self.encoder(x) - h = self.quant_conv(h) - return h - - def decode(self, quant): - quant = self.post_quant_conv(quant) - dec = self.decoder(quant) - return dec - - def decode_code(self, code_b): - quant_b = self.quantize.embed_code(code_b) - dec = self.decode(quant_b) - return dec - - def forward(self, input, return_pred_indices=False): - quant, diff, (_,_,ind) = self.encode(input) - dec = self.decode(quant) - if return_pred_indices: - return dec, diff, ind - return dec, diff - - def get_input(self, batch, k): - x = batch[k] - if len(x.shape) == 3: - x = x[..., None] - x = x.permute(0, 3, 1, 2).to(memory_format=torch.contiguous_format).float() - if self.batch_resize_range is not None: - lower_size = self.batch_resize_range[0] - upper_size = self.batch_resize_range[1] - if self.global_step <= 4: - # do the first few batches with max size to avoid later oom - new_resize = upper_size - else: - new_resize = np.random.choice(np.arange(lower_size, upper_size+16, 16)) - if new_resize != x.shape[2]: - x = F.interpolate(x, size=new_resize, mode="bicubic") - x = x.detach() - return x - - def training_step(self, batch, batch_idx, optimizer_idx): - # https://github.com/pytorch/pytorch/issues/37142 - # try not to fool the heuristics - x = self.get_input(batch, self.image_key) - xrec, qloss, ind = self(x, return_pred_indices=True) - - if optimizer_idx == 0: - # autoencode - aeloss, log_dict_ae = self.loss(qloss, x, xrec, optimizer_idx, self.global_step, - last_layer=self.get_last_layer(), split="train", - predicted_indices=ind) - - self.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=True) - return aeloss - - if optimizer_idx == 1: - # discriminator - discloss, log_dict_disc = self.loss(qloss, x, xrec, optimizer_idx, self.global_step, - last_layer=self.get_last_layer(), split="train") - self.log_dict(log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=True) - return discloss - - def validation_step(self, batch, batch_idx): - log_dict = self._validation_step(batch, batch_idx) - with self.ema_scope(): - log_dict_ema = self._validation_step(batch, batch_idx, suffix="_ema") - return log_dict - - def _validation_step(self, batch, batch_idx, suffix=""): - x = self.get_input(batch, self.image_key) - xrec, qloss, ind = self(x, return_pred_indices=True) - aeloss, log_dict_ae = self.loss(qloss, x, xrec, 0, - self.global_step, - last_layer=self.get_last_layer(), - split="val"+suffix, - predicted_indices=ind - ) - - discloss, log_dict_disc = self.loss(qloss, x, xrec, 1, - self.global_step, - last_layer=self.get_last_layer(), - split="val"+suffix, - predicted_indices=ind - ) - rec_loss = log_dict_ae[f"val{suffix}/rec_loss"] - self.log(f"val{suffix}/rec_loss", rec_loss, - prog_bar=True, logger=True, on_step=False, on_epoch=True, sync_dist=True) - self.log(f"val{suffix}/aeloss", aeloss, - prog_bar=True, logger=True, on_step=False, on_epoch=True, sync_dist=True) - if version.parse(pl.__version__) >= version.parse('1.4.0'): - del log_dict_ae[f"val{suffix}/rec_loss"] - self.log_dict(log_dict_ae) - self.log_dict(log_dict_disc) - return self.log_dict - - def configure_optimizers(self): - lr_d = self.learning_rate - lr_g = self.lr_g_factor*self.learning_rate - print("lr_d", lr_d) - print("lr_g", lr_g) - opt_ae = torch.optim.Adam(list(self.encoder.parameters())+ - list(self.decoder.parameters())+ - list(self.quantize.parameters())+ - list(self.quant_conv.parameters())+ - list(self.post_quant_conv.parameters()), - lr=lr_g, betas=(0.5, 0.9)) - opt_disc = torch.optim.Adam(self.loss.discriminator.parameters(), - lr=lr_d, betas=(0.5, 0.9)) - - if self.scheduler_config is not None: - scheduler = instantiate_from_config(self.scheduler_config) - - print("Setting up LambdaLR scheduler...") - scheduler = [ - { - 'scheduler': LambdaLR(opt_ae, lr_lambda=scheduler.schedule), - 'interval': 'step', - 'frequency': 1 - }, - { - 'scheduler': LambdaLR(opt_disc, lr_lambda=scheduler.schedule), - 'interval': 'step', - 'frequency': 1 - }, - ] - return [opt_ae, opt_disc], scheduler - return [opt_ae, opt_disc], [] - - def get_last_layer(self): - return self.decoder.conv_out.weight - - def log_images(self, batch, only_inputs=False, plot_ema=False, **kwargs): - log = dict() - x = self.get_input(batch, self.image_key) - x = x.to(self.device) - if only_inputs: - log["inputs"] = x - return log - xrec, _ = self(x) - if x.shape[1] > 3: - # colorize with random projection - assert xrec.shape[1] > 3 - x = self.to_rgb(x) - xrec = self.to_rgb(xrec) - log["inputs"] = x - log["reconstructions"] = xrec - if plot_ema: - with self.ema_scope(): - xrec_ema, _ = self(x) - if x.shape[1] > 3: xrec_ema = self.to_rgb(xrec_ema) - log["reconstructions_ema"] = xrec_ema - return log - - def to_rgb(self, x): - assert self.image_key == "segmentation" - if not hasattr(self, "colorize"): - self.register_buffer("colorize", torch.randn(3, x.shape[1], 1, 1).to(x)) - x = F.conv2d(x, weight=self.colorize) - x = 2.*(x-x.min())/(x.max()-x.min()) - 1. - return x - - -class VQModelInterface(VQModel): - def __init__(self, embed_dim, *args, **kwargs): - super().__init__(embed_dim=embed_dim, *args, **kwargs) - self.embed_dim = embed_dim - - def encode(self, x): - h = self.encoder(x) - h = self.quant_conv(h) - return h - - def decode(self, h, force_not_quantize=False): - # also go through quantization layer - if not force_not_quantize: - quant, emb_loss, info = self.quantize(h) - else: - quant = h - quant = self.post_quant_conv(quant) - dec = self.decoder(quant) - return dec - - -class AutoencoderKL(pl.LightningModule): - def __init__(self, - ddconfig, - lossconfig, - embed_dim, - ckpt_path=None, - ignore_keys=[], - image_key="image", - colorize_nlabels=None, - monitor=None, - from_pretrained: str=None - ): - super().__init__() - self.image_key = image_key - self.encoder = Encoder(**ddconfig) - self.decoder = Decoder(**ddconfig) - self.loss = instantiate_from_config(lossconfig) - assert ddconfig["double_z"] - self.quant_conv = torch.nn.Conv2d(2*ddconfig["z_channels"], 2*embed_dim, 1) - self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1) - self.embed_dim = embed_dim - if colorize_nlabels is not None: - assert type(colorize_nlabels)==int - self.register_buffer("colorize", torch.randn(3, colorize_nlabels, 1, 1)) - if monitor is not None: - self.monitor = monitor - if ckpt_path is not None: - self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys) - from diffusers.modeling_utils import load_state_dict - if from_pretrained is not None: - state_dict = load_state_dict(from_pretrained) - self._load_pretrained_model(state_dict) - - def _state_key_mapping(self, state_dict: dict): - import re - res_dict = {} - key_list = state_dict.keys() - key_str = " ".join(key_list) - up_block_pattern = re.compile('upsamplers') - p1 = re.compile('mid.block_[0-9]') - p2 = re.compile('decoder.up.[0-9]') - up_blocks_count = int(len(re.findall(up_block_pattern, key_str)) / 2 + 1) - for key_, val_ in state_dict.items(): - key_ = key_.replace("up_blocks", "up").replace("down_blocks", "down").replace('resnets', 'block')\ - .replace('mid_block', 'mid').replace("mid.block.", "mid.block_")\ - .replace('mid.attentions.0.key', 'mid.attn_1.k')\ - .replace('mid.attentions.0.query', 'mid.attn_1.q') \ - .replace('mid.attentions.0.value', 'mid.attn_1.v') \ - .replace('mid.attentions.0.group_norm', 'mid.attn_1.norm') \ - .replace('mid.attentions.0.proj_attn', 'mid.attn_1.proj_out')\ - .replace('upsamplers.0', 'upsample')\ - .replace('downsamplers.0', 'downsample')\ - .replace('conv_shortcut', 'nin_shortcut')\ - .replace('conv_norm_out', 'norm_out') - - mid_list = re.findall(p1, key_) - if len(mid_list) != 0: - mid_str = mid_list[0] - mid_id = int(mid_str[-1]) + 1 - key_ = key_.replace(mid_str, mid_str[:-1] + str(mid_id)) - - up_list = re.findall(p2, key_) - if len(up_list) != 0: - up_str = up_list[0] - up_id = up_blocks_count - 1 -int(up_str[-1]) - key_ = key_.replace(up_str, up_str[:-1] + str(up_id)) - res_dict[key_] = val_ - return res_dict - - def _load_pretrained_model(self, state_dict, ignore_mismatched_sizes=False): - state_dict = self._state_key_mapping(state_dict) - model_state_dict = self.state_dict() - loaded_keys = [k for k in state_dict.keys()] - expected_keys = list(model_state_dict.keys()) - original_loaded_keys = loaded_keys - missing_keys = list(set(expected_keys) - set(loaded_keys)) - unexpected_keys = list(set(loaded_keys) - set(expected_keys)) - - def _find_mismatched_keys( - state_dict, - model_state_dict, - loaded_keys, - ignore_mismatched_sizes, - ): - mismatched_keys = [] - if ignore_mismatched_sizes: - for checkpoint_key in loaded_keys: - model_key = checkpoint_key - - if ( - model_key in model_state_dict - and state_dict[checkpoint_key].shape != model_state_dict[model_key].shape - ): - mismatched_keys.append( - (checkpoint_key, state_dict[checkpoint_key].shape, model_state_dict[model_key].shape) - ) - del state_dict[checkpoint_key] - return mismatched_keys - if state_dict is not None: - # Whole checkpoint - mismatched_keys = _find_mismatched_keys( - state_dict, - model_state_dict, - original_loaded_keys, - ignore_mismatched_sizes, - ) - error_msgs = self._load_state_dict_into_model(state_dict) - return missing_keys, unexpected_keys, mismatched_keys, error_msgs - - def _load_state_dict_into_model(self, state_dict): - # Convert old format to new format if needed from a PyTorch state_dict - # copy state_dict so _load_from_state_dict can modify it - state_dict = state_dict.copy() - error_msgs = [] - - # PyTorch's `_load_from_state_dict` does not copy parameters in a module's descendants - # so we need to apply the function recursively. - def load(module: torch.nn.Module, prefix=""): - args = (state_dict, prefix, {}, True, [], [], error_msgs) - module._load_from_state_dict(*args) - - for name, child in module._modules.items(): - if child is not None: - load(child, prefix + name + ".") - - load(self) - - return error_msgs - - def init_from_ckpt(self, path, ignore_keys=list()): - sd = torch.load(path, map_location="cpu")["state_dict"] - keys = list(sd.keys()) - for k in keys: - for ik in ignore_keys: - if k.startswith(ik): - print("Deleting key {} from state_dict.".format(k)) - del sd[k] - self.load_state_dict(sd, strict=False) - print(f"Restored from {path}") - def encode(self, x): h = self.encoder(x) moments = self.quant_conv(h) @@ -471,25 +132,33 @@ class AutoencoderKL(pl.LightningModule): return discloss def validation_step(self, batch, batch_idx): + log_dict = self._validation_step(batch, batch_idx) + with self.ema_scope(): + log_dict_ema = self._validation_step(batch, batch_idx, postfix="_ema") + return log_dict + + def _validation_step(self, batch, batch_idx, postfix=""): inputs = self.get_input(batch, self.image_key) reconstructions, posterior = self(inputs) aeloss, log_dict_ae = self.loss(inputs, reconstructions, posterior, 0, self.global_step, - last_layer=self.get_last_layer(), split="val") + last_layer=self.get_last_layer(), split="val"+postfix) discloss, log_dict_disc = self.loss(inputs, reconstructions, posterior, 1, self.global_step, - last_layer=self.get_last_layer(), split="val") + last_layer=self.get_last_layer(), split="val"+postfix) - self.log("val/rec_loss", log_dict_ae["val/rec_loss"]) + self.log(f"val{postfix}/rec_loss", log_dict_ae[f"val{postfix}/rec_loss"]) self.log_dict(log_dict_ae) self.log_dict(log_dict_disc) return self.log_dict def configure_optimizers(self): lr = self.learning_rate - opt_ae = torch.optim.Adam(list(self.encoder.parameters())+ - list(self.decoder.parameters())+ - list(self.quant_conv.parameters())+ - list(self.post_quant_conv.parameters()), + ae_params_list = list(self.encoder.parameters()) + list(self.decoder.parameters()) + list( + self.quant_conv.parameters()) + list(self.post_quant_conv.parameters()) + if self.learn_logvar: + print(f"{self.__class__.__name__}: Learning logvar") + ae_params_list.append(self.loss.logvar) + opt_ae = torch.optim.Adam(ae_params_list, lr=lr, betas=(0.5, 0.9)) opt_disc = torch.optim.Adam(self.loss.discriminator.parameters(), lr=lr, betas=(0.5, 0.9)) @@ -499,7 +168,7 @@ class AutoencoderKL(pl.LightningModule): return self.decoder.conv_out.weight @torch.no_grad() - def log_images(self, batch, only_inputs=False, **kwargs): + def log_images(self, batch, only_inputs=False, log_ema=False, **kwargs): log = dict() x = self.get_input(batch, self.image_key) x = x.to(self.device) @@ -512,6 +181,15 @@ class AutoencoderKL(pl.LightningModule): xrec = self.to_rgb(xrec) log["samples"] = self.decode(torch.randn_like(posterior.sample())) log["reconstructions"] = xrec + if log_ema or self.use_ema: + with self.ema_scope(): + xrec_ema, posterior_ema = self(x) + if x.shape[1] > 3: + # colorize with random projection + assert xrec_ema.shape[1] > 3 + xrec_ema = self.to_rgb(xrec_ema) + log["samples_ema"] = self.decode(torch.randn_like(posterior_ema.sample())) + log["reconstructions_ema"] = xrec_ema log["inputs"] = x return log @@ -526,7 +204,7 @@ class AutoencoderKL(pl.LightningModule): class IdentityFirstStage(torch.nn.Module): def __init__(self, *args, vq_interface=False, **kwargs): - self.vq_interface = vq_interface # TODO: Should be true by default but check to not break older stuff + self.vq_interface = vq_interface super().__init__() def encode(self, x, *args, **kwargs): @@ -542,3 +220,4 @@ class IdentityFirstStage(torch.nn.Module): def forward(self, x, *args, **kwargs): return x + diff --git a/examples/images/diffusion/ldm/models/diffusion/ddim.py b/examples/images/diffusion/ldm/models/diffusion/ddim.py index 91335d637..27ead0ea9 100644 --- a/examples/images/diffusion/ldm/models/diffusion/ddim.py +++ b/examples/images/diffusion/ldm/models/diffusion/ddim.py @@ -3,10 +3,8 @@ import torch import numpy as np from tqdm import tqdm -from functools import partial -from ldm.modules.diffusionmodules.util import make_ddim_sampling_parameters, make_ddim_timesteps, noise_like, \ - extract_into_tensor +from ldm.modules.diffusionmodules.util import make_ddim_sampling_parameters, make_ddim_timesteps, noise_like, extract_into_tensor class DDIMSampler(object): @@ -74,15 +72,24 @@ class DDIMSampler(object): x_T=None, log_every_t=100, unconditional_guidance_scale=1., - unconditional_conditioning=None, - # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ... + unconditional_conditioning=None, # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ... + dynamic_threshold=None, + ucg_schedule=None, **kwargs ): if conditioning is not None: if isinstance(conditioning, dict): - cbs = conditioning[list(conditioning.keys())[0]].shape[0] + ctmp = conditioning[list(conditioning.keys())[0]] + while isinstance(ctmp, list): ctmp = ctmp[0] + cbs = ctmp.shape[0] if cbs != batch_size: print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}") + + elif isinstance(conditioning, list): + for ctmp in conditioning: + if ctmp.shape[0] != batch_size: + print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}") + else: if conditioning.shape[0] != batch_size: print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}") @@ -107,6 +114,8 @@ class DDIMSampler(object): log_every_t=log_every_t, unconditional_guidance_scale=unconditional_guidance_scale, unconditional_conditioning=unconditional_conditioning, + dynamic_threshold=dynamic_threshold, + ucg_schedule=ucg_schedule ) return samples, intermediates @@ -116,7 +125,8 @@ class DDIMSampler(object): callback=None, timesteps=None, quantize_denoised=False, mask=None, x0=None, img_callback=None, log_every_t=100, temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None, - unconditional_guidance_scale=1., unconditional_conditioning=None,): + unconditional_guidance_scale=1., unconditional_conditioning=None, dynamic_threshold=None, + ucg_schedule=None): device = self.model.betas.device b = shape[0] if x_T is None: @@ -145,12 +155,18 @@ class DDIMSampler(object): assert x0 is not None img_orig = self.model.q_sample(x0, ts) # TODO: deterministic forward pass? img = img_orig * mask + (1. - mask) * img + + if ucg_schedule is not None: + assert len(ucg_schedule) == len(time_range) + unconditional_guidance_scale = ucg_schedule[i] + outs = self.p_sample_ddim(img, cond, ts, index=index, use_original_steps=ddim_use_original_steps, quantize_denoised=quantize_denoised, temperature=temperature, noise_dropout=noise_dropout, score_corrector=score_corrector, corrector_kwargs=corrector_kwargs, unconditional_guidance_scale=unconditional_guidance_scale, - unconditional_conditioning=unconditional_conditioning) + unconditional_conditioning=unconditional_conditioning, + dynamic_threshold=dynamic_threshold) img, pred_x0 = outs if callback: callback(i) if img_callback: img_callback(pred_x0, i) @@ -164,20 +180,44 @@ class DDIMSampler(object): @torch.no_grad() def p_sample_ddim(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False, temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None, - unconditional_guidance_scale=1., unconditional_conditioning=None): + unconditional_guidance_scale=1., unconditional_conditioning=None, + dynamic_threshold=None): b, *_, device = *x.shape, x.device if unconditional_conditioning is None or unconditional_guidance_scale == 1.: - e_t = self.model.apply_model(x, t, c) + model_output = self.model.apply_model(x, t, c) else: x_in = torch.cat([x] * 2) t_in = torch.cat([t] * 2) - c_in = torch.cat([unconditional_conditioning, c]) - e_t_uncond, e_t = self.model.apply_model(x_in, t_in, c_in).chunk(2) - e_t = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond) + if isinstance(c, dict): + assert isinstance(unconditional_conditioning, dict) + c_in = dict() + for k in c: + if isinstance(c[k], list): + c_in[k] = [torch.cat([ + unconditional_conditioning[k][i], + c[k][i]]) for i in range(len(c[k]))] + else: + c_in[k] = torch.cat([ + unconditional_conditioning[k], + c[k]]) + elif isinstance(c, list): + c_in = list() + assert isinstance(unconditional_conditioning, list) + for i in range(len(c)): + c_in.append(torch.cat([unconditional_conditioning[i], c[i]])) + else: + c_in = torch.cat([unconditional_conditioning, c]) + model_uncond, model_t = self.model.apply_model(x_in, t_in, c_in).chunk(2) + model_output = model_uncond + unconditional_guidance_scale * (model_t - model_uncond) + + if self.model.parameterization == "v": + e_t = self.model.predict_eps_from_z_and_v(x, t, model_output) + else: + e_t = model_output if score_corrector is not None: - assert self.model.parameterization == "eps" + assert self.model.parameterization == "eps", 'not implemented' e_t = score_corrector.modify_score(self.model, e_t, x, t, c, **corrector_kwargs) alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas @@ -191,9 +231,17 @@ class DDIMSampler(object): sqrt_one_minus_at = torch.full((b, 1, 1, 1), sqrt_one_minus_alphas[index],device=device) # current prediction for x_0 - pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt() + if self.model.parameterization != "v": + pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt() + else: + pred_x0 = self.model.predict_start_from_z_and_v(x, t, model_output) + if quantize_denoised: pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0) + + if dynamic_threshold is not None: + raise NotImplementedError() + # direction pointing to x_t dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature @@ -202,6 +250,53 @@ class DDIMSampler(object): x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise return x_prev, pred_x0 + @torch.no_grad() + def encode(self, x0, c, t_enc, use_original_steps=False, return_intermediates=None, + unconditional_guidance_scale=1.0, unconditional_conditioning=None, callback=None): + num_reference_steps = self.ddpm_num_timesteps if use_original_steps else self.ddim_timesteps.shape[0] + + assert t_enc <= num_reference_steps + num_steps = t_enc + + if use_original_steps: + alphas_next = self.alphas_cumprod[:num_steps] + alphas = self.alphas_cumprod_prev[:num_steps] + else: + alphas_next = self.ddim_alphas[:num_steps] + alphas = torch.tensor(self.ddim_alphas_prev[:num_steps]) + + x_next = x0 + intermediates = [] + inter_steps = [] + for i in tqdm(range(num_steps), desc='Encoding Image'): + t = torch.full((x0.shape[0],), i, device=self.model.device, dtype=torch.long) + if unconditional_guidance_scale == 1.: + noise_pred = self.model.apply_model(x_next, t, c) + else: + assert unconditional_conditioning is not None + e_t_uncond, noise_pred = torch.chunk( + self.model.apply_model(torch.cat((x_next, x_next)), torch.cat((t, t)), + torch.cat((unconditional_conditioning, c))), 2) + noise_pred = e_t_uncond + unconditional_guidance_scale * (noise_pred - e_t_uncond) + + xt_weighted = (alphas_next[i] / alphas[i]).sqrt() * x_next + weighted_noise_pred = alphas_next[i].sqrt() * ( + (1 / alphas_next[i] - 1).sqrt() - (1 / alphas[i] - 1).sqrt()) * noise_pred + x_next = xt_weighted + weighted_noise_pred + if return_intermediates and i % ( + num_steps // return_intermediates) == 0 and i < num_steps - 1: + intermediates.append(x_next) + inter_steps.append(i) + elif return_intermediates and i >= num_steps - 2: + intermediates.append(x_next) + inter_steps.append(i) + if callback: callback(i) + + out = {'x_encoded': x_next, 'intermediate_steps': inter_steps} + if return_intermediates: + out.update({'intermediates': intermediates}) + return x_next, out + @torch.no_grad() def stochastic_encode(self, x0, t, use_original_steps=False, noise=None): # fast, but does not allow for exact reconstruction @@ -220,7 +315,7 @@ class DDIMSampler(object): @torch.no_grad() def decode(self, x_latent, cond, t_start, unconditional_guidance_scale=1.0, unconditional_conditioning=None, - use_original_steps=False): + use_original_steps=False, callback=None): timesteps = np.arange(self.ddpm_num_timesteps) if use_original_steps else self.ddim_timesteps timesteps = timesteps[:t_start] @@ -237,4 +332,5 @@ class DDIMSampler(object): x_dec, _ = self.p_sample_ddim(x_dec, cond, ts, index=index, use_original_steps=use_original_steps, unconditional_guidance_scale=unconditional_guidance_scale, unconditional_conditioning=unconditional_conditioning) + if callback: callback(i) return x_dec \ No newline at end of file diff --git a/examples/images/diffusion/ldm/models/diffusion/ddpm.py b/examples/images/diffusion/ldm/models/diffusion/ddpm.py index bd12a7510..f7ac0a735 100644 --- a/examples/images/diffusion/ldm/models/diffusion/ddpm.py +++ b/examples/images/diffusion/ldm/models/diffusion/ddpm.py @@ -1,25 +1,43 @@ +""" +wild mixture of +https://github.com/lucidrains/denoising-diffusion-pytorch/blob/7706bdfc6f527f58d33f84b7b522e61e6e3164b3/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py +https://github.com/openai/improved-diffusion/blob/e94489283bb876ac1477d5dd7709bbbd2d9902ce/improved_diffusion/gaussian_diffusion.py +https://github.com/CompVis/taming-transformers +-- merci +""" + import torch import torch.nn as nn import numpy as np -import lightning.pytorch as pl +try: + import lightning.pytorch as pl + from lightning.pytorch.utilities import rank_zero_only, rank_zero_info +except: + import pytorch_lightning as pl + from pytorch_lightning.utilities import rank_zero_only, rank_zero_info from torch.optim.lr_scheduler import LambdaLR from einops import rearrange, repeat -from contextlib import contextmanager +from contextlib import contextmanager, nullcontext from functools import partial +import itertools from tqdm import tqdm from torchvision.utils import make_grid -from lightning.pytorch.utilities.rank_zero import rank_zero_only -from lightning.pytorch.utilities import rank_zero_info +from omegaconf import ListConfig from ldm.util import log_txt_as_img, exists, default, ismap, isimage, mean_flat, count_params, instantiate_from_config from ldm.modules.ema import LitEma from ldm.modules.distributions.distributions import normal_kl, DiagonalGaussianDistribution -from ldm.models.autoencoder import VQModelInterface, IdentityFirstStage, AutoencoderKL +from ldm.models.autoencoder import IdentityFirstStage, AutoencoderKL + + +from ldm.modules.diffusionmodules.util import make_beta_schedule, extract_into_tensor, noise_like +from ldm.models.diffusion.ddim import DDIMSampler +from ldm.modules.diffusionmodules.openaimodel import * + from ldm.modules.diffusionmodules.util import make_beta_schedule, extract_into_tensor, noise_like from ldm.models.diffusion.ddim import DDIMSampler from ldm.modules.diffusionmodules.openaimodel import AttentionPool2d -from ldm.modules.x_transformer import * from ldm.modules.encoders.modules import * from ldm.modules.ema import LitEma @@ -34,10 +52,6 @@ from ldm.modules.diffusionmodules.model import Model, Encoder, Decoder from ldm.util import instantiate_from_config -from einops import rearrange, repeat - - - __conditioning_keys__ = {'concat': 'c_concat', 'crossattn': 'c_crossattn', @@ -85,9 +99,13 @@ class DDPM(pl.LightningModule): learn_logvar=False, logvar_init=0., use_fp16 = True, + make_it_fit=False, + ucg_training=None, + reset_ema=False, + reset_num_ema_updates=False, ): super().__init__() - assert parameterization in ["eps", "x0"], 'currently only supporting "eps" and "x0"' + assert parameterization in ["eps", "x0", "v"], 'currently only supporting "eps" and "x0" and "v"' self.parameterization = parameterization rank_zero_info(f"{self.__class__.__name__}: Running in {self.parameterization}-prediction mode") self.cond_stage_model = None @@ -97,6 +115,7 @@ class DDPM(pl.LightningModule): self.image_size = image_size # try conv? self.channels = channels self.use_positional_encodings = use_positional_encodings + self.unet_config = unet_config self.conditioning_key = conditioning_key self.model = DiffusionWrapper(unet_config, conditioning_key) @@ -116,37 +135,46 @@ class DDPM(pl.LightningModule): if monitor is not None: self.monitor = monitor + self.make_it_fit = make_it_fit self.ckpt_path = ckpt_path self.ignore_keys = ignore_keys self.load_only_unet = load_only_unet - self.given_betas = given_betas - self.beta_schedule = beta_schedule + self.reset_ema = reset_ema + self.reset_num_ema_updates = reset_num_ema_updates + + if reset_ema: assert exists(ckpt_path) + if ckpt_path is not None: + self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys, only_model=load_only_unet) + if reset_ema: + assert self.use_ema + print(f"Resetting ema to pure model weights. This is useful when restoring from an ema-only checkpoint.") + self.model_ema = LitEma(self.model) + if reset_num_ema_updates: + print(" +++++++++++ WARNING: RESETTING NUM_EMA UPDATES TO ZERO +++++++++++ ") + assert self.use_ema + self.model_ema.reset_num_updates() + self.timesteps = timesteps + self.beta_schedule = beta_schedule + self.given_betas = given_betas self.linear_start = linear_start self.linear_end = linear_end self.cosine_s = cosine_s - if ckpt_path is not None: - self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys, only_model=load_only_unet) self.register_schedule(given_betas=given_betas, beta_schedule=beta_schedule, timesteps=timesteps, linear_start=linear_start, linear_end=linear_end, cosine_s=cosine_s) self.loss_type = loss_type - self.learn_logvar = learn_logvar self.logvar_init = logvar_init + self.learn_logvar = learn_logvar self.logvar = torch.full(fill_value=logvar_init, size=(self.num_timesteps,)) if self.learn_logvar: self.logvar = nn.Parameter(self.logvar, requires_grad=True) - self.logvar = nn.Parameter(self.logvar, requires_grad=True) - self.use_fp16 = use_fp16 - if use_fp16: - self.unet_config["params"].update({"use_fp16": True}) - rank_zero_info("Using FP16 for UNet = {}".format(self.unet_config["params"]["use_fp16"])) - else: - self.unet_config["params"].update({"use_fp16": False}) - rank_zero_info("Using FP16 for UNet = {}".format(self.unet_config["params"]["use_fp16"])) + self.ucg_training = ucg_training or dict() + if self.ucg_training: + self.ucg_prng = np.random.RandomState() def register_schedule(self, given_betas=None, beta_schedule="linear", timesteps=1000, linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3): @@ -180,7 +208,7 @@ class DDPM(pl.LightningModule): # calculations for posterior q(x_{t-1} | x_t, x_0) posterior_variance = (1 - self.v_posterior) * betas * (1. - alphas_cumprod_prev) / ( - 1. - alphas_cumprod) + self.v_posterior * betas + 1. - alphas_cumprod) + self.v_posterior * betas # above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t) self.register_buffer('posterior_variance', to_torch(posterior_variance)) # below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain @@ -192,12 +220,14 @@ class DDPM(pl.LightningModule): if self.parameterization == "eps": lvlb_weights = self.betas ** 2 / ( - 2 * self.posterior_variance * to_torch(alphas) * (1 - self.alphas_cumprod)) + 2 * self.posterior_variance * to_torch(alphas) * (1 - self.alphas_cumprod)) elif self.parameterization == "x0": lvlb_weights = 0.5 * np.sqrt(torch.Tensor(alphas_cumprod)) / (2. * 1 - torch.Tensor(alphas_cumprod)) + elif self.parameterization == "v": + lvlb_weights = torch.ones_like(self.betas ** 2 / ( + 2 * self.posterior_variance * to_torch(alphas) * (1 - self.alphas_cumprod))) else: raise NotImplementedError("mu not supported") - # TODO how to choose this term lvlb_weights[0] = lvlb_weights[1] self.register_buffer('lvlb_weights', lvlb_weights, persistent=False) assert not torch.isnan(self.lvlb_weights).all() @@ -217,6 +247,7 @@ class DDPM(pl.LightningModule): if context is not None: print(f"{context}: Restored training weights") + @torch.no_grad() def init_from_ckpt(self, path, ignore_keys=list(), only_model=False): sd = torch.load(path, map_location="cpu") if "state_dict" in list(sd.keys()): @@ -227,13 +258,57 @@ class DDPM(pl.LightningModule): if k.startswith(ik): print("Deleting key {} from state_dict.".format(k)) del sd[k] + if self.make_it_fit: + n_params = len([name for name, _ in + itertools.chain(self.named_parameters(), + self.named_buffers())]) + for name, param in tqdm( + itertools.chain(self.named_parameters(), + self.named_buffers()), + desc="Fitting old weights to new weights", + total=n_params + ): + if not name in sd: + continue + old_shape = sd[name].shape + new_shape = param.shape + assert len(old_shape) == len(new_shape) + if len(new_shape) > 2: + # we only modify first two axes + assert new_shape[2:] == old_shape[2:] + # assumes first axis corresponds to output dim + if not new_shape == old_shape: + new_param = param.clone() + old_param = sd[name] + if len(new_shape) == 1: + for i in range(new_param.shape[0]): + new_param[i] = old_param[i % old_shape[0]] + elif len(new_shape) >= 2: + for i in range(new_param.shape[0]): + for j in range(new_param.shape[1]): + new_param[i, j] = old_param[i % old_shape[0], j % old_shape[1]] + + n_used_old = torch.ones(old_shape[1]) + for j in range(new_param.shape[1]): + n_used_old[j % old_shape[1]] += 1 + n_used_new = torch.zeros(new_shape[1]) + for j in range(new_param.shape[1]): + n_used_new[j] = n_used_old[j % old_shape[1]] + + n_used_new = n_used_new[None, :] + while len(n_used_new.shape) < len(new_shape): + n_used_new = n_used_new.unsqueeze(-1) + new_param /= n_used_new + + sd[name] = new_param + missing, unexpected = self.load_state_dict(sd, strict=False) if not only_model else self.model.load_state_dict( sd, strict=False) print(f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys") if len(missing) > 0: - print(f"Missing Keys: {missing}") + print(f"Missing Keys:\n {missing}") if len(unexpected) > 0: - print(f"Unexpected Keys: {unexpected}") + print(f"\nUnexpected Keys:\n {unexpected}") def q_mean_variance(self, x_start, t): """ @@ -253,6 +328,20 @@ class DDPM(pl.LightningModule): extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * noise ) + def predict_start_from_z_and_v(self, x_t, t, v): + # self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod))) + # self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod))) + return ( + extract_into_tensor(self.sqrt_alphas_cumprod, t, x_t.shape) * x_t - + extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_t.shape) * v + ) + + def predict_eps_from_z_and_v(self, x_t, t, v): + return ( + extract_into_tensor(self.sqrt_alphas_cumprod, t, x_t.shape) * v + + extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_t.shape) * x_t + ) + def q_posterior(self, x_start, x_t, t): posterior_mean = ( extract_into_tensor(self.posterior_mean_coef1, t, x_t.shape) * x_start + @@ -310,11 +399,13 @@ class DDPM(pl.LightningModule): return (extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start + extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise) + def get_v(self, x, noise, t): + return ( + extract_into_tensor(self.sqrt_alphas_cumprod, t, x.shape) * noise - + extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x.shape) * x + ) + def get_loss(self, pred, target, mean=True): - - if self.use_fp16: - target = target.half() - if self.loss_type == 'l1': loss = (target - pred).abs() if mean: @@ -339,6 +430,8 @@ class DDPM(pl.LightningModule): target = noise elif self.parameterization == "x0": target = x_start + elif self.parameterization == "v": + target = self.get_v(x_start, noise, t) else: raise NotImplementedError(f"Paramterization {self.parameterization} not yet supported") @@ -365,20 +458,14 @@ class DDPM(pl.LightningModule): return self.p_losses(x, t, *args, **kwargs) def get_input(self, batch, k): - - if not isinstance(batch, torch.Tensor): - x = batch[k] - else: - x = batch + x = batch[k] if len(x.shape) == 3: x = x[..., None] x = rearrange(x, 'b h w c -> b c h w') - if self.use_fp16: - x = x.to(memory_format=torch.contiguous_format).float().half() + x = x.to(memory_format=torch.contiguous_format).half() else: x = x.to(memory_format=torch.contiguous_format).float() - return x def shared_step(self, batch): @@ -387,6 +474,15 @@ class DDPM(pl.LightningModule): return loss, loss_dict def training_step(self, batch, batch_idx): + for k in self.ucg_training: + p = self.ucg_training[k]["p"] + val = self.ucg_training[k]["val"] + if val is None: + val = "" + for i in range(len(batch[k])): + if self.ucg_prng.choice(2, p=[1 - p, p]): + batch[k][i] = val + loss, loss_dict = self.shared_step(batch) self.log_dict(loss_dict, prog_bar=True, @@ -470,6 +566,7 @@ class DDPM(pl.LightningModule): class LatentDiffusion(DDPM): """main class""" + def __init__(self, first_stage_config, cond_stage_config, @@ -482,18 +579,19 @@ class LatentDiffusion(DDPM): scale_factor=1.0, scale_by_std=False, use_fp16=True, + force_null_conditioning=False, *args, **kwargs): + self.force_null_conditioning = force_null_conditioning self.num_timesteps_cond = default(num_timesteps_cond, 1) self.scale_by_std = scale_by_std assert self.num_timesteps_cond <= kwargs['timesteps'] # for backwards compatibility after implementation of DiffusionWrapper if conditioning_key is None: conditioning_key = 'concat' if concat_mode else 'crossattn' - if cond_stage_config == '__is_unconditional__': + if cond_stage_config == '__is_unconditional__' and not self.force_null_conditioning: conditioning_key = None - ckpt_path = kwargs.pop("ckpt_path", None) - ignore_keys = kwargs.pop("ignore_keys", []) - super().__init__(conditioning_key=conditioning_key, use_fp16=use_fp16, *args, **kwargs) + + super().__init__(conditioning_key=conditioning_key, *args, **kwargs) self.concat_mode = concat_mode self.cond_stage_trainable = cond_stage_trainable self.cond_stage_key = cond_stage_key @@ -501,38 +599,49 @@ class LatentDiffusion(DDPM): self.num_downs = len(first_stage_config.params.ddconfig.ch_mult) - 1 except: self.num_downs = 0 + if not scale_by_std: self.scale_factor = scale_factor else: self.register_buffer('scale_factor', torch.tensor(scale_factor)) self.first_stage_config = first_stage_config self.cond_stage_config = cond_stage_config - if self.use_fp16: - self.cond_stage_config["params"].update({"use_fp16": True}) - rank_zero_info("Using fp16 for conditioning stage = {}".format(self.cond_stage_config["params"]["use_fp16"])) - else: - self.cond_stage_config["params"].update({"use_fp16": False}) - rank_zero_info("Using fp16 for conditioning stage = {}".format(self.cond_stage_config["params"]["use_fp16"])) self.instantiate_first_stage(first_stage_config) self.instantiate_cond_stage(cond_stage_config) self.cond_stage_forward = cond_stage_forward self.clip_denoised = False - self.bbox_tokenizer = None + self.bbox_tokenizer = None self.restarted_from_ckpt = False - if ckpt_path is not None: - self.init_from_ckpt(ckpt_path, ignore_keys) + if self.ckpt_path is not None: + self.init_from_ckpt(self.ckpt_path, self.ignore_keys) self.restarted_from_ckpt = True - - + if self.reset_ema: + assert self.use_ema + print( + f"Resetting ema to pure model weights. This is useful when restoring from an ema-only checkpoint.") + self.model_ema = LitEma(self.model) + if self.reset_num_ema_updates: + print(" +++++++++++ WARNING: RESETTING NUM_EMA UPDATES TO ZERO +++++++++++ ") + assert self.use_ema + self.model_ema.reset_num_updates() def configure_sharded_model(self) -> None: + rank_zero_info("Configure sharded model for LatentDiffusion") self.model = DiffusionWrapper(self.unet_config, self.conditioning_key) - count_params(self.model, verbose=True) if self.use_ema: self.model_ema = LitEma(self.model) - print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.") + if self.ckpt_path is not None: + self.init_from_ckpt(self.ckpt_path, ignore_keys=self.ignore_keys, only_model=self.load_only_unet) + if self.reset_ema: + assert self.use_ema + print(f"Resetting ema to pure model weights. This is useful when restoring from an ema-only checkpoint.") + self.model_ema = LitEma(self.model) + if self.reset_num_ema_updates: + print(" +++++++++++ WARNING: RESETTING NUM_EMA UPDATES TO ZERO +++++++++++ ") + assert self.use_ema + self.model_ema.reset_num_updates() self.register_schedule(given_betas=self.given_betas, beta_schedule=self.beta_schedule, timesteps=self.timesteps, linear_start=self.linear_start, linear_end=self.linear_end, cosine_s=self.cosine_s) @@ -540,24 +649,31 @@ class LatentDiffusion(DDPM): self.logvar = torch.full(fill_value=self.logvar_init, size=(self.num_timesteps,)) if self.learn_logvar: self.logvar = nn.Parameter(self.logvar, requires_grad=True) - self.logvar = nn.Parameter(self.logvar, requires_grad=True) - if self.ckpt_path is not None: - self.init_from_ckpt(self.ckpt_path, self.ignore_keys) - self.restarted_from_ckpt = True + if self.ucg_training: + self.ucg_prng = np.random.RandomState() self.instantiate_first_stage(self.first_stage_config) self.instantiate_cond_stage(self.cond_stage_config) + if self.ckpt_path is not None: + self.init_from_ckpt(self.ckpt_path, self.ignore_keys) + self.restarted_from_ckpt = True + if self.reset_ema: + assert self.use_ema + print( + f"Resetting ema to pure model weights. This is useful when restoring from an ema-only checkpoint.") + self.model_ema = LitEma(self.model) + if self.reset_num_ema_updates: + print(" +++++++++++ WARNING: RESETTING NUM_EMA UPDATES TO ZERO +++++++++++ ") + assert self.use_ema + self.model_ema.reset_num_updates() def make_cond_schedule(self, ): self.cond_ids = torch.full(size=(self.num_timesteps,), fill_value=self.num_timesteps - 1, dtype=torch.long) ids = torch.round(torch.linspace(0, self.num_timesteps - 1, self.num_timesteps_cond)).long() self.cond_ids[:self.num_timesteps_cond] = ids - - @rank_zero_only @torch.no_grad() - # def on_train_batch_start(self, batch, batch_idx, dataloader_idx): def on_train_batch_start(self, batch, batch_idx): # only for very first batch if self.scale_by_std and self.current_epoch == 0 and self.global_step == 0 and batch_idx == 0 and not self.restarted_from_ckpt: @@ -614,7 +730,7 @@ class LatentDiffusion(DDPM): denoise_row = [] for zd in tqdm(samples, desc=desc): denoise_row.append(self.decode_first_stage(zd.to(self.device), - force_not_quantize=force_no_decoder_quantization)) + force_not_quantize=force_no_decoder_quantization)) n_imgs_per_row = len(denoise_row) denoise_row = torch.stack(denoise_row) # n_log_step, n_row, C, H, W denoise_grid = rearrange(denoise_row, 'n b c h w -> b n c h w') @@ -629,7 +745,7 @@ class LatentDiffusion(DDPM): z = encoder_posterior else: raise NotImplementedError(f"encoder_posterior of type '{type(encoder_posterior)}' not yet implemented") - return self.scale_factor * z + return self.scale_factor * z.half() if self.use_fp16 else self.scale_factor * z def get_learned_conditioning(self, c): if self.cond_stage_forward is None: @@ -735,7 +851,7 @@ class LatentDiffusion(DDPM): @torch.no_grad() def get_input(self, batch, k, return_first_stage_outputs=False, force_c_encode=False, - cond_key=None, return_original_cond=False, bs=None): + cond_key=None, return_original_cond=False, bs=None, return_x=False): x = super().get_input(batch, k) if bs is not None: x = x[:bs] @@ -743,13 +859,13 @@ class LatentDiffusion(DDPM): encoder_posterior = self.encode_first_stage(x) z = self.get_first_stage_encoding(encoder_posterior).detach() - if self.model.conditioning_key is not None: + if self.model.conditioning_key is not None and not self.force_null_conditioning: if cond_key is None: cond_key = self.cond_stage_key if cond_key != self.first_stage_key: - if cond_key in ['caption', 'coordinates_bbox', 'txt']: + if cond_key in ['caption', 'coordinates_bbox', "txt"]: xc = batch[cond_key] - elif cond_key == 'class_label': + elif cond_key in ['class_label', 'cls']: xc = batch else: xc = super().get_input(batch, cond_key).to(self.device) @@ -757,7 +873,6 @@ class LatentDiffusion(DDPM): xc = x if not self.cond_stage_trainable or force_c_encode: if isinstance(xc, dict) or isinstance(xc, list): - # import pudb; pudb.set_trace() c = self.get_learned_conditioning(xc) else: c = self.get_learned_conditioning(xc.to(self.device)) @@ -781,8 +896,11 @@ class LatentDiffusion(DDPM): if return_first_stage_outputs: xrec = self.decode_first_stage(z) out.extend([x, xrec]) + if return_x: + out.extend([x]) if return_original_cond: out.append(xc) + return out @torch.no_grad() @@ -794,156 +912,11 @@ class LatentDiffusion(DDPM): z = rearrange(z, 'b h w c -> b c h w').contiguous() z = 1. / self.scale_factor * z - - if hasattr(self, "split_input_params"): - if self.split_input_params["patch_distributed_vq"]: - ks = self.split_input_params["ks"] # eg. (128, 128) - stride = self.split_input_params["stride"] # eg. (64, 64) - uf = self.split_input_params["vqf"] - bs, nc, h, w = z.shape - if ks[0] > h or ks[1] > w: - ks = (min(ks[0], h), min(ks[1], w)) - print("reducing Kernel") - - if stride[0] > h or stride[1] > w: - stride = (min(stride[0], h), min(stride[1], w)) - print("reducing stride") - - fold, unfold, normalization, weighting = self.get_fold_unfold(z, ks, stride, uf=uf) - - z = unfold(z) # (bn, nc * prod(**ks), L) - # 1. Reshape to img shape - z = z.view((z.shape[0], -1, ks[0], ks[1], z.shape[-1])) # (bn, nc, ks[0], ks[1], L ) - - # 2. apply model loop over last dim - if isinstance(self.first_stage_model, VQModelInterface): - output_list = [self.first_stage_model.decode(z[:, :, :, :, i], - force_not_quantize=predict_cids or force_not_quantize) - for i in range(z.shape[-1])] - else: - - output_list = [self.first_stage_model.decode(z[:, :, :, :, i]) - for i in range(z.shape[-1])] - - o = torch.stack(output_list, axis=-1) # # (bn, nc, ks[0], ks[1], L) - o = o * weighting - # Reverse 1. reshape to img shape - o = o.view((o.shape[0], -1, o.shape[-1])) # (bn, nc * ks[0] * ks[1], L) - # stitch crops together - decoded = fold(o) - decoded = decoded / normalization # norm is shape (1, 1, h, w) - return decoded - else: - if isinstance(self.first_stage_model, VQModelInterface): - return self.first_stage_model.decode(z, force_not_quantize=predict_cids or force_not_quantize) - else: - return self.first_stage_model.decode(z) - - else: - if isinstance(self.first_stage_model, VQModelInterface): - return self.first_stage_model.decode(z, force_not_quantize=predict_cids or force_not_quantize) - else: - return self.first_stage_model.decode(z) - - # same as above but without decorator - def differentiable_decode_first_stage(self, z, predict_cids=False, force_not_quantize=False): - if predict_cids: - if z.dim() == 4: - z = torch.argmax(z.exp(), dim=1).long() - z = self.first_stage_model.quantize.get_codebook_entry(z, shape=None) - z = rearrange(z, 'b h w c -> b c h w').contiguous() - - z = 1. / self.scale_factor * z - - if hasattr(self, "split_input_params"): - if self.split_input_params["patch_distributed_vq"]: - ks = self.split_input_params["ks"] # eg. (128, 128) - stride = self.split_input_params["stride"] # eg. (64, 64) - uf = self.split_input_params["vqf"] - bs, nc, h, w = z.shape - if ks[0] > h or ks[1] > w: - ks = (min(ks[0], h), min(ks[1], w)) - print("reducing Kernel") - - if stride[0] > h or stride[1] > w: - stride = (min(stride[0], h), min(stride[1], w)) - print("reducing stride") - - fold, unfold, normalization, weighting = self.get_fold_unfold(z, ks, stride, uf=uf) - - z = unfold(z) # (bn, nc * prod(**ks), L) - # 1. Reshape to img shape - z = z.view((z.shape[0], -1, ks[0], ks[1], z.shape[-1])) # (bn, nc, ks[0], ks[1], L ) - - # 2. apply model loop over last dim - if isinstance(self.first_stage_model, VQModelInterface): - output_list = [self.first_stage_model.decode(z[:, :, :, :, i], - force_not_quantize=predict_cids or force_not_quantize) - for i in range(z.shape[-1])] - else: - - output_list = [self.first_stage_model.decode(z[:, :, :, :, i]) - for i in range(z.shape[-1])] - - o = torch.stack(output_list, axis=-1) # # (bn, nc, ks[0], ks[1], L) - o = o * weighting - # Reverse 1. reshape to img shape - o = o.view((o.shape[0], -1, o.shape[-1])) # (bn, nc * ks[0] * ks[1], L) - # stitch crops together - decoded = fold(o) - decoded = decoded / normalization # norm is shape (1, 1, h, w) - return decoded - else: - if isinstance(self.first_stage_model, VQModelInterface): - return self.first_stage_model.decode(z, force_not_quantize=predict_cids or force_not_quantize) - else: - return self.first_stage_model.decode(z) - - else: - if isinstance(self.first_stage_model, VQModelInterface): - return self.first_stage_model.decode(z, force_not_quantize=predict_cids or force_not_quantize) - else: - return self.first_stage_model.decode(z) + return self.first_stage_model.decode(z) @torch.no_grad() def encode_first_stage(self, x): - if hasattr(self, "split_input_params"): - if self.split_input_params["patch_distributed_vq"]: - ks = self.split_input_params["ks"] # eg. (128, 128) - stride = self.split_input_params["stride"] # eg. (64, 64) - df = self.split_input_params["vqf"] - self.split_input_params['original_image_size'] = x.shape[-2:] - bs, nc, h, w = x.shape - if ks[0] > h or ks[1] > w: - ks = (min(ks[0], h), min(ks[1], w)) - print("reducing Kernel") - - if stride[0] > h or stride[1] > w: - stride = (min(stride[0], h), min(stride[1], w)) - print("reducing stride") - - fold, unfold, normalization, weighting = self.get_fold_unfold(x, ks, stride, df=df) - z = unfold(x) # (bn, nc * prod(**ks), L) - # Reshape to img shape - z = z.view((z.shape[0], -1, ks[0], ks[1], z.shape[-1])) # (bn, nc, ks[0], ks[1], L ) - - output_list = [self.first_stage_model.encode(z[:, :, :, :, i]) - for i in range(z.shape[-1])] - - o = torch.stack(output_list, axis=-1) - o = o * weighting - - # Reverse reshape to img shape - o = o.view((o.shape[0], -1, o.shape[-1])) # (bn, nc * ks[0] * ks[1], L) - # stitch crops together - decoded = fold(o) - decoded = decoded / normalization - return decoded - - else: - return self.first_stage_model.encode(x) - else: - return self.first_stage_model.encode(x) + return self.first_stage_model.encode(x) def shared_step(self, batch, **kwargs): x, c = self.get_input(batch, self.first_stage_key) @@ -961,19 +934,9 @@ class LatentDiffusion(DDPM): c = self.q_sample(x_start=c, t=tc, noise=torch.randn_like(c.float())) return self.p_losses(x, c, t, *args, **kwargs) - def _rescale_annotations(self, bboxes, crop_coordinates): # TODO: move to dataset - def rescale_bbox(bbox): - x0 = clamp((bbox[0] - crop_coordinates[0]) / crop_coordinates[2]) - y0 = clamp((bbox[1] - crop_coordinates[1]) / crop_coordinates[3]) - w = min(bbox[2] / crop_coordinates[2], 1 - x0) - h = min(bbox[3] / crop_coordinates[3], 1 - y0) - return x0, y0, w, h - - return [rescale_bbox(b) for b in bboxes] - def apply_model(self, x_noisy, t, cond, return_ids=False): if isinstance(cond, dict): - # hybrid case, cond is exptected to be a dict + # hybrid case, cond is expected to be a dict pass else: if not isinstance(cond, list): @@ -981,91 +944,7 @@ class LatentDiffusion(DDPM): key = 'c_concat' if self.model.conditioning_key == 'concat' else 'c_crossattn' cond = {key: cond} - if hasattr(self, "split_input_params"): - assert len(cond) == 1 # todo can only deal with one conditioning atm - assert not return_ids - ks = self.split_input_params["ks"] # eg. (128, 128) - stride = self.split_input_params["stride"] # eg. (64, 64) - - h, w = x_noisy.shape[-2:] - - fold, unfold, normalization, weighting = self.get_fold_unfold(x_noisy, ks, stride) - - z = unfold(x_noisy) # (bn, nc * prod(**ks), L) - # Reshape to img shape - z = z.view((z.shape[0], -1, ks[0], ks[1], z.shape[-1])) # (bn, nc, ks[0], ks[1], L ) - z_list = [z[:, :, :, :, i] for i in range(z.shape[-1])] - if self.cond_stage_key in ["image", "LR_image", "segmentation", - 'bbox_img'] and self.model.conditioning_key: # todo check for completeness - c_key = next(iter(cond.keys())) # get key - c = next(iter(cond.values())) # get value - assert (len(c) == 1) # todo extend to list with more than one elem - c = c[0] # get element - - c = unfold(c) - c = c.view((c.shape[0], -1, ks[0], ks[1], c.shape[-1])) # (bn, nc, ks[0], ks[1], L ) - - cond_list = [{c_key: [c[:, :, :, :, i]]} for i in range(c.shape[-1])] - - elif self.cond_stage_key == 'coordinates_bbox': - assert 'original_image_size' in self.split_input_params, 'BoudingBoxRescaling is missing original_image_size' - - # assuming padding of unfold is always 0 and its dilation is always 1 - n_patches_per_row = int((w - ks[0]) / stride[0] + 1) - full_img_h, full_img_w = self.split_input_params['original_image_size'] - # as we are operating on latents, we need the factor from the original image size to the - # spatial latent size to properly rescale the crops for regenerating the bbox annotations - num_downs = self.first_stage_model.encoder.num_resolutions - 1 - rescale_latent = 2 ** (num_downs) - - # get top left postions of patches as conforming for the bbbox tokenizer, therefore we - # need to rescale the tl patch coordinates to be in between (0,1) - tl_patch_coordinates = [(rescale_latent * stride[0] * (patch_nr % n_patches_per_row) / full_img_w, - rescale_latent * stride[1] * (patch_nr // n_patches_per_row) / full_img_h) - for patch_nr in range(z.shape[-1])] - - # patch_limits are tl_coord, width and height coordinates as (x_tl, y_tl, h, w) - patch_limits = [(x_tl, y_tl, - rescale_latent * ks[0] / full_img_w, - rescale_latent * ks[1] / full_img_h) for x_tl, y_tl in tl_patch_coordinates] - # patch_values = [(np.arange(x_tl,min(x_tl+ks, 1.)),np.arange(y_tl,min(y_tl+ks, 1.))) for x_tl, y_tl in tl_patch_coordinates] - - # tokenize crop coordinates for the bounding boxes of the respective patches - patch_limits_tknzd = [torch.LongTensor(self.bbox_tokenizer._crop_encoder(bbox))[None].to(self.device) - for bbox in patch_limits] # list of length l with tensors of shape (1, 2) - print(patch_limits_tknzd[0].shape) - # cut tknzd crop position from conditioning - assert isinstance(cond, dict), 'cond must be dict to be fed into model' - cut_cond = cond['c_crossattn'][0][..., :-2].to(self.device) - print(cut_cond.shape) - - adapted_cond = torch.stack([torch.cat([cut_cond, p], dim=1) for p in patch_limits_tknzd]) - adapted_cond = rearrange(adapted_cond, 'l b n -> (l b) n') - print(adapted_cond.shape) - adapted_cond = self.get_learned_conditioning(adapted_cond) - print(adapted_cond.shape) - adapted_cond = rearrange(adapted_cond, '(l b) n d -> l b n d', l=z.shape[-1]) - print(adapted_cond.shape) - - cond_list = [{'c_crossattn': [e]} for e in adapted_cond] - - else: - cond_list = [cond for i in range(z.shape[-1])] # Todo make this more efficient - - # apply model by loop over crops - output_list = [self.model(z_list[i], t, **cond_list[i]) for i in range(z.shape[-1])] - assert not isinstance(output_list[0], - tuple) # todo cant deal with multiple model outputs check this never happens - - o = torch.stack(output_list, axis=-1) - o = o * weighting - # Reverse reshape to img shape - o = o.view((o.shape[0], -1, o.shape[-1])) # (bn, nc * ks[0] * ks[1], L) - # stitch crops together - x_recon = fold(o) / normalization - - else: - x_recon = self.model(x_noisy, t, **cond) + x_recon = self.model(x_noisy, t, **cond) if isinstance(x_recon, tuple) and not return_ids: return x_recon[0] @@ -1102,6 +981,8 @@ class LatentDiffusion(DDPM): target = x_start elif self.parameterization == "eps": target = noise + elif self.parameterization == "v": + target = self.get_v(x_start, noise, t) else: raise NotImplementedError() @@ -1109,6 +990,7 @@ class LatentDiffusion(DDPM): loss_dict.update({f'{prefix}/loss_simple': loss_simple.mean()}) logvar_t = self.logvar[t].to(self.device) + loss = loss_simple / torch.exp(logvar_t) + logvar_t # loss = loss_simple / torch.exp(self.logvar) + self.logvar if self.learn_logvar: @@ -1297,7 +1179,7 @@ class LatentDiffusion(DDPM): @torch.no_grad() def sample(self, cond, batch_size=16, return_intermediates=False, x_T=None, verbose=True, timesteps=None, quantize_denoised=False, - mask=None, x0=None, shape=None,**kwargs): + mask=None, x0=None, shape=None, **kwargs): if shape is None: shape = (batch_size, self.channels, self.image_size, self.image_size) if cond is not None: @@ -1313,26 +1195,51 @@ class LatentDiffusion(DDPM): mask=mask, x0=x0) @torch.no_grad() - def sample_log(self,cond,batch_size,ddim, ddim_steps,**kwargs): - + def sample_log(self, cond, batch_size, ddim, ddim_steps, **kwargs): if ddim: ddim_sampler = DDIMSampler(self) shape = (self.channels, self.image_size, self.image_size) - samples, intermediates =ddim_sampler.sample(ddim_steps,batch_size, - shape,cond,verbose=False,**kwargs) + samples, intermediates = ddim_sampler.sample(ddim_steps, batch_size, + shape, cond, verbose=False, **kwargs) else: samples, intermediates = self.sample(cond=cond, batch_size=batch_size, - return_intermediates=True,**kwargs) + return_intermediates=True, **kwargs) return samples, intermediates + @torch.no_grad() + def get_unconditional_conditioning(self, batch_size, null_label=None): + if null_label is not None: + xc = null_label + if isinstance(xc, ListConfig): + xc = list(xc) + if isinstance(xc, dict) or isinstance(xc, list): + c = self.get_learned_conditioning(xc) + else: + if hasattr(xc, "to"): + xc = xc.to(self.device) + c = self.get_learned_conditioning(xc) + else: + if self.cond_stage_key in ["class_label", "cls"]: + xc = self.cond_stage_model.get_unconditional_conditioning(batch_size, device=self.device) + return self.get_learned_conditioning(xc) + else: + raise NotImplementedError("todo") + if isinstance(c, list): # in case the encoder gives us a list + for i in range(len(c)): + c[i] = repeat(c[i], '1 ... -> b ...', b=batch_size).to(self.device) + else: + c = repeat(c, '1 ... -> b ...', b=batch_size).to(self.device) + return c @torch.no_grad() - def log_images(self, batch, N=8, n_row=4, sample=True, ddim_steps=200, ddim_eta=1., return_keys=None, + def log_images(self, batch, N=8, n_row=4, sample=True, ddim_steps=50, ddim_eta=0., return_keys=None, quantize_denoised=True, inpaint=True, plot_denoise_rows=False, plot_progressive_rows=True, - plot_diffusion_rows=True, **kwargs): - + plot_diffusion_rows=True, unconditional_guidance_scale=1., unconditional_guidance_label=None, + use_ema_scope=True, + **kwargs): + ema_scope = self.ema_scope if use_ema_scope else nullcontext use_ddim = ddim_steps is not None log = dict() @@ -1349,11 +1256,260 @@ class LatentDiffusion(DDPM): if hasattr(self.cond_stage_model, "decode"): xc = self.cond_stage_model.decode(c) log["conditioning"] = xc - elif self.cond_stage_key in ["caption"]: - xc = log_txt_as_img((x.shape[2], x.shape[3]), batch["caption"]) + elif self.cond_stage_key in ["caption", "txt"]: + xc = log_txt_as_img((x.shape[2], x.shape[3]), batch[self.cond_stage_key], size=x.shape[2] // 25) log["conditioning"] = xc - elif self.cond_stage_key == 'class_label': - xc = log_txt_as_img((x.shape[2], x.shape[3]), batch["human_label"]) + elif self.cond_stage_key in ['class_label', "cls"]: + try: + xc = log_txt_as_img((x.shape[2], x.shape[3]), batch["human_label"], size=x.shape[2] // 25) + log['conditioning'] = xc + except KeyError: + # probably no "human_label" in batch + pass + elif isimage(xc): + log["conditioning"] = xc + if ismap(xc): + log["original_conditioning"] = self.to_rgb(xc) + + if plot_diffusion_rows: + # get diffusion row + diffusion_row = list() + z_start = z[:n_row] + for t in range(self.num_timesteps): + if t % self.log_every_t == 0 or t == self.num_timesteps - 1: + t = repeat(torch.tensor([t]), '1 -> b', b=n_row) + t = t.to(self.device).long() + noise = torch.randn_like(z_start) + z_noisy = self.q_sample(x_start=z_start, t=t, noise=noise) + diffusion_row.append(self.decode_first_stage(z_noisy)) + + diffusion_row = torch.stack(diffusion_row) # n_log_step, n_row, C, H, W + diffusion_grid = rearrange(diffusion_row, 'n b c h w -> b n c h w') + diffusion_grid = rearrange(diffusion_grid, 'b n c h w -> (b n) c h w') + diffusion_grid = make_grid(diffusion_grid, nrow=diffusion_row.shape[0]) + log["diffusion_row"] = diffusion_grid + + if sample: + # get denoise row + with ema_scope("Sampling"): + samples, z_denoise_row = self.sample_log(cond=c, batch_size=N, ddim=use_ddim, + ddim_steps=ddim_steps, eta=ddim_eta) + # samples, z_denoise_row = self.sample(cond=c, batch_size=N, return_intermediates=True) + x_samples = self.decode_first_stage(samples) + log["samples"] = x_samples + if plot_denoise_rows: + denoise_grid = self._get_denoise_row_from_list(z_denoise_row) + log["denoise_row"] = denoise_grid + + if quantize_denoised and not isinstance(self.first_stage_model, AutoencoderKL) and not isinstance( + self.first_stage_model, IdentityFirstStage): + # also display when quantizing x0 while sampling + with ema_scope("Plotting Quantized Denoised"): + samples, z_denoise_row = self.sample_log(cond=c, batch_size=N, ddim=use_ddim, + ddim_steps=ddim_steps, eta=ddim_eta, + quantize_denoised=True) + # samples, z_denoise_row = self.sample(cond=c, batch_size=N, return_intermediates=True, + # quantize_denoised=True) + x_samples = self.decode_first_stage(samples.to(self.device)) + log["samples_x0_quantized"] = x_samples + + if unconditional_guidance_scale > 1.0: + uc = self.get_unconditional_conditioning(N, unconditional_guidance_label) + if self.model.conditioning_key == "crossattn-adm": + uc = {"c_crossattn": [uc], "c_adm": c["c_adm"]} + with ema_scope("Sampling with classifier-free guidance"): + samples_cfg, _ = self.sample_log(cond=c, batch_size=N, ddim=use_ddim, + ddim_steps=ddim_steps, eta=ddim_eta, + unconditional_guidance_scale=unconditional_guidance_scale, + unconditional_conditioning=uc, + ) + x_samples_cfg = self.decode_first_stage(samples_cfg) + log[f"samples_cfg_scale_{unconditional_guidance_scale:.2f}"] = x_samples_cfg + + if inpaint: + # make a simple center square + b, h, w = z.shape[0], z.shape[2], z.shape[3] + mask = torch.ones(N, h, w).to(self.device) + # zeros will be filled in + mask[:, h // 4:3 * h // 4, w // 4:3 * w // 4] = 0. + mask = mask[:, None, ...] + with ema_scope("Plotting Inpaint"): + samples, _ = self.sample_log(cond=c, batch_size=N, ddim=use_ddim, eta=ddim_eta, + ddim_steps=ddim_steps, x0=z[:N], mask=mask) + x_samples = self.decode_first_stage(samples.to(self.device)) + log["samples_inpainting"] = x_samples + log["mask"] = mask + + # outpaint + mask = 1. - mask + with ema_scope("Plotting Outpaint"): + samples, _ = self.sample_log(cond=c, batch_size=N, ddim=use_ddim, eta=ddim_eta, + ddim_steps=ddim_steps, x0=z[:N], mask=mask) + x_samples = self.decode_first_stage(samples.to(self.device)) + log["samples_outpainting"] = x_samples + + if plot_progressive_rows: + with ema_scope("Plotting Progressives"): + img, progressives = self.progressive_denoising(c, + shape=(self.channels, self.image_size, self.image_size), + batch_size=N) + prog_row = self._get_denoise_row_from_list(progressives, desc="Progressive Generation") + log["progressive_row"] = prog_row + + if return_keys: + if np.intersect1d(list(log.keys()), return_keys).shape[0] == 0: + return log + else: + return {key: log[key] for key in return_keys} + return log + + def configure_optimizers(self): + lr = self.learning_rate + params = list(self.model.parameters()) + if self.cond_stage_trainable: + print(f"{self.__class__.__name__}: Also optimizing conditioner params!") + params = params + list(self.cond_stage_model.parameters()) + if self.learn_logvar: + print('Diffusion model optimizing logvar') + params.append(self.logvar) + + from colossalai.nn.optimizer import HybridAdam + opt = HybridAdam(params, lr=lr) + + # opt = torch.optim.AdamW(params, lr=lr) + if self.use_scheduler: + assert 'target' in self.scheduler_config + scheduler = instantiate_from_config(self.scheduler_config) + + print("Setting up LambdaLR scheduler...") + scheduler = [ + { + 'scheduler': LambdaLR(opt, lr_lambda=scheduler.schedule), + 'interval': 'step', + 'frequency': 1 + }] + return [opt], scheduler + return opt + + @torch.no_grad() + def to_rgb(self, x): + x = x.float() + if not hasattr(self, "colorize"): + self.colorize = torch.randn(3, x.shape[1], 1, 1).to(x) + x = nn.functional.conv2d(x, weight=self.colorize) + x = 2. * (x - x.min()) / (x.max() - x.min()) - 1. + return x + + +class DiffusionWrapper(pl.LightningModule): + def __init__(self, diff_model_config, conditioning_key): + super().__init__() + self.sequential_cross_attn = diff_model_config.pop("sequential_crossattn", False) + self.diffusion_model = instantiate_from_config(diff_model_config) + self.conditioning_key = conditioning_key + assert self.conditioning_key in [None, 'concat', 'crossattn', 'hybrid', 'adm', 'hybrid-adm', 'crossattn-adm'] + + def forward(self, x, t, c_concat: list = None, c_crossattn: list = None, c_adm=None): + if self.conditioning_key is None: + out = self.diffusion_model(x, t) + elif self.conditioning_key == 'concat': + xc = torch.cat([x] + c_concat, dim=1) + out = self.diffusion_model(xc, t) + elif self.conditioning_key == 'crossattn': + if not self.sequential_cross_attn: + cc = torch.cat(c_crossattn, 1) + else: + cc = c_crossattn + out = self.diffusion_model(x, t, context=cc) + elif self.conditioning_key == 'hybrid': + xc = torch.cat([x] + c_concat, dim=1) + cc = torch.cat(c_crossattn, 1) + out = self.diffusion_model(xc, t, context=cc) + elif self.conditioning_key == 'hybrid-adm': + assert c_adm is not None + xc = torch.cat([x] + c_concat, dim=1) + cc = torch.cat(c_crossattn, 1) + out = self.diffusion_model(xc, t, context=cc, y=c_adm) + elif self.conditioning_key == 'crossattn-adm': + assert c_adm is not None + cc = torch.cat(c_crossattn, 1) + out = self.diffusion_model(x, t, context=cc, y=c_adm) + elif self.conditioning_key == 'adm': + cc = c_crossattn[0] + out = self.diffusion_model(x, t, y=cc) + else: + raise NotImplementedError() + + return out + + +class LatentUpscaleDiffusion(LatentDiffusion): + def __init__(self, *args, low_scale_config, low_scale_key="LR", noise_level_key=None, **kwargs): + super().__init__(*args, **kwargs) + # assumes that neither the cond_stage nor the low_scale_model contain trainable params + assert not self.cond_stage_trainable + self.instantiate_low_stage(low_scale_config) + self.low_scale_key = low_scale_key + self.noise_level_key = noise_level_key + + def instantiate_low_stage(self, config): + model = instantiate_from_config(config) + self.low_scale_model = model.eval() + self.low_scale_model.train = disabled_train + for param in self.low_scale_model.parameters(): + param.requires_grad = False + + @torch.no_grad() + def get_input(self, batch, k, cond_key=None, bs=None, log_mode=False): + if not log_mode: + z, c = super().get_input(batch, k, force_c_encode=True, bs=bs) + else: + z, c, x, xrec, xc = super().get_input(batch, self.first_stage_key, return_first_stage_outputs=True, + force_c_encode=True, return_original_cond=True, bs=bs) + x_low = batch[self.low_scale_key][:bs] + x_low = rearrange(x_low, 'b h w c -> b c h w') + if self.use_fp16: + x_low = x_low.to(memory_format=torch.contiguous_format).half() + else: + x_low = x_low.to(memory_format=torch.contiguous_format).float() + zx, noise_level = self.low_scale_model(x_low) + if self.noise_level_key is not None: + # get noise level from batch instead, e.g. when extracting a custom noise level for bsr + raise NotImplementedError('TODO') + + all_conds = {"c_concat": [zx], "c_crossattn": [c], "c_adm": noise_level} + if log_mode: + # TODO: maybe disable if too expensive + x_low_rec = self.low_scale_model.decode(zx) + return z, all_conds, x, xrec, xc, x_low, x_low_rec, noise_level + return z, all_conds + + @torch.no_grad() + def log_images(self, batch, N=8, n_row=4, sample=True, ddim_steps=200, ddim_eta=1., return_keys=None, + plot_denoise_rows=False, plot_progressive_rows=True, plot_diffusion_rows=True, + unconditional_guidance_scale=1., unconditional_guidance_label=None, use_ema_scope=True, + **kwargs): + ema_scope = self.ema_scope if use_ema_scope else nullcontext + use_ddim = ddim_steps is not None + + log = dict() + z, c, x, xrec, xc, x_low, x_low_rec, noise_level = self.get_input(batch, self.first_stage_key, bs=N, + log_mode=True) + N = min(x.shape[0], N) + n_row = min(x.shape[0], n_row) + log["inputs"] = x + log["reconstruction"] = xrec + log["x_lr"] = x_low + log[f"x_lr_rec_@noise_levels{'-'.join(map(lambda x: str(x), list(noise_level.cpu().numpy())))}"] = x_low_rec + if self.model.conditioning_key is not None: + if hasattr(self.cond_stage_model, "decode"): + xc = self.cond_stage_model.decode(c) + log["conditioning"] = xc + elif self.cond_stage_key in ["caption", "txt"]: + xc = log_txt_as_img((x.shape[2], x.shape[3]), batch[self.cond_stage_key], size=x.shape[2] // 25) + log["conditioning"] = xc + elif self.cond_stage_key in ['class_label', 'cls']: + xc = log_txt_as_img((x.shape[2], x.shape[3]), batch["human_label"], size=x.shape[2] // 25) log['conditioning'] = xc elif isimage(xc): log["conditioning"] = xc @@ -1380,9 +1536,9 @@ class LatentDiffusion(DDPM): if sample: # get denoise row - with self.ema_scope("Plotting"): - samples, z_denoise_row = self.sample_log(cond=c,batch_size=N,ddim=use_ddim, - ddim_steps=ddim_steps,eta=ddim_eta) + with ema_scope("Sampling"): + samples, z_denoise_row = self.sample_log(cond=c, batch_size=N, ddim=use_ddim, + ddim_steps=ddim_steps, eta=ddim_eta) # samples, z_denoise_row = self.sample(cond=c, batch_size=N, return_intermediates=True) x_samples = self.decode_first_stage(samples) log["samples"] = x_samples @@ -1390,139 +1546,350 @@ class LatentDiffusion(DDPM): denoise_grid = self._get_denoise_row_from_list(z_denoise_row) log["denoise_row"] = denoise_grid - if quantize_denoised and not isinstance(self.first_stage_model, AutoencoderKL) and not isinstance( - self.first_stage_model, IdentityFirstStage): - # also display when quantizing x0 while sampling - with self.ema_scope("Plotting Quantized Denoised"): - samples, z_denoise_row = self.sample_log(cond=c,batch_size=N,ddim=use_ddim, - ddim_steps=ddim_steps,eta=ddim_eta, - quantize_denoised=True) - # samples, z_denoise_row = self.sample(cond=c, batch_size=N, return_intermediates=True, - # quantize_denoised=True) - x_samples = self.decode_first_stage(samples.to(self.device)) - log["samples_x0_quantized"] = x_samples + if unconditional_guidance_scale > 1.0: + uc_tmp = self.get_unconditional_conditioning(N, unconditional_guidance_label) + # TODO explore better "unconditional" choices for the other keys + # maybe guide away from empty text label and highest noise level and maximally degraded zx? + uc = dict() + for k in c: + if k == "c_crossattn": + assert isinstance(c[k], list) and len(c[k]) == 1 + uc[k] = [uc_tmp] + elif k == "c_adm": # todo: only run with text-based guidance? + assert isinstance(c[k], torch.Tensor) + #uc[k] = torch.ones_like(c[k]) * self.low_scale_model.max_noise_level + uc[k] = c[k] + elif isinstance(c[k], list): + uc[k] = [c[k][i] for i in range(len(c[k]))] + else: + uc[k] = c[k] - if inpaint: - # make a simple center square - b, h, w = z.shape[0], z.shape[2], z.shape[3] - mask = torch.ones(N, h, w).to(self.device) - # zeros will be filled in - mask[:, h // 4:3 * h // 4, w // 4:3 * w // 4] = 0. - mask = mask[:, None, ...] - with self.ema_scope("Plotting Inpaint"): - - samples, _ = self.sample_log(cond=c,batch_size=N,ddim=use_ddim, eta=ddim_eta, - ddim_steps=ddim_steps, x0=z[:N], mask=mask) - x_samples = self.decode_first_stage(samples.to(self.device)) - log["samples_inpainting"] = x_samples - log["mask"] = mask - - # outpaint - with self.ema_scope("Plotting Outpaint"): - samples, _ = self.sample_log(cond=c, batch_size=N, ddim=use_ddim,eta=ddim_eta, - ddim_steps=ddim_steps, x0=z[:N], mask=mask) - x_samples = self.decode_first_stage(samples.to(self.device)) - log["samples_outpainting"] = x_samples + with ema_scope("Sampling with classifier-free guidance"): + samples_cfg, _ = self.sample_log(cond=c, batch_size=N, ddim=use_ddim, + ddim_steps=ddim_steps, eta=ddim_eta, + unconditional_guidance_scale=unconditional_guidance_scale, + unconditional_conditioning=uc, + ) + x_samples_cfg = self.decode_first_stage(samples_cfg) + log[f"samples_cfg_scale_{unconditional_guidance_scale:.2f}"] = x_samples_cfg if plot_progressive_rows: - with self.ema_scope("Plotting Progressives"): + with ema_scope("Plotting Progressives"): img, progressives = self.progressive_denoising(c, shape=(self.channels, self.image_size, self.image_size), batch_size=N) prog_row = self._get_denoise_row_from_list(progressives, desc="Progressive Generation") log["progressive_row"] = prog_row - if return_keys: - if np.intersect1d(list(log.keys()), return_keys).shape[0] == 0: - return log - else: - return {key: log[key] for key in return_keys} return log - def configure_optimizers(self): - lr = self.learning_rate - params = list(self.model.parameters()) - if self.cond_stage_trainable: - print(f"{self.__class__.__name__}: Also optimizing conditioner params!") - params = params + list(self.cond_stage_model.parameters()) - if self.learn_logvar: - print('Diffusion model optimizing logvar') - params.append(self.logvar) - from colossalai.nn.optimizer import HybridAdam - opt = HybridAdam(params, lr=lr) - # opt = torch.optim.AdamW(params, lr=lr) - if self.use_scheduler: - assert 'target' in self.scheduler_config - scheduler = instantiate_from_config(self.scheduler_config) - rank_zero_info("Setting up LambdaLR scheduler...") - scheduler = [ - { - 'scheduler': LambdaLR(opt, lr_lambda=scheduler.schedule), - 'interval': 'step', - 'frequency': 1 - }] - return [opt], scheduler - return opt +class LatentFinetuneDiffusion(LatentDiffusion): + """ + Basis for different finetunas, such as inpainting or depth2image + To disable finetuning mode, set finetune_keys to None + """ + + def __init__(self, + concat_keys: tuple, + finetune_keys=("model.diffusion_model.input_blocks.0.0.weight", + "model_ema.diffusion_modelinput_blocks00weight" + ), + keep_finetune_dims=4, + # if model was trained without concat mode before and we would like to keep these channels + c_concat_log_start=None, # to log reconstruction of c_concat codes + c_concat_log_end=None, + *args, **kwargs + ): + ckpt_path = kwargs.pop("ckpt_path", None) + ignore_keys = kwargs.pop("ignore_keys", list()) + super().__init__(*args, **kwargs) + self.finetune_keys = finetune_keys + self.concat_keys = concat_keys + self.keep_dims = keep_finetune_dims + self.c_concat_log_start = c_concat_log_start + self.c_concat_log_end = c_concat_log_end + if exists(self.finetune_keys): assert exists(ckpt_path), 'can only finetune from a given checkpoint' + if exists(ckpt_path): + self.init_from_ckpt(ckpt_path, ignore_keys) + + def init_from_ckpt(self, path, ignore_keys=list(), only_model=False): + sd = torch.load(path, map_location="cpu") + if "state_dict" in list(sd.keys()): + sd = sd["state_dict"] + keys = list(sd.keys()) + for k in keys: + for ik in ignore_keys: + if k.startswith(ik): + print("Deleting key {} from state_dict.".format(k)) + del sd[k] + + # make it explicit, finetune by including extra input channels + if exists(self.finetune_keys) and k in self.finetune_keys: + new_entry = None + for name, param in self.named_parameters(): + if name in self.finetune_keys: + print( + f"modifying key '{name}' and keeping its original {self.keep_dims} (channels) dimensions only") + new_entry = torch.zeros_like(param) # zero init + assert exists(new_entry), 'did not find matching parameter to modify' + new_entry[:, :self.keep_dims, ...] = sd[k] + sd[k] = new_entry + + missing, unexpected = self.load_state_dict(sd, strict=False) if not only_model else self.model.load_state_dict( + sd, strict=False) + print(f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys") + if len(missing) > 0: + print(f"Missing Keys: {missing}") + if len(unexpected) > 0: + print(f"Unexpected Keys: {unexpected}") @torch.no_grad() - def to_rgb(self, x): - x = x.float() - if not hasattr(self, "colorize"): - self.colorize = torch.randn(3, x.shape[1], 1, 1).to(x) - x = nn.functional.conv2d(x, weight=self.colorize) - x = 2. * (x - x.min()) / (x.max() - x.min()) - 1. - return x + def log_images(self, batch, N=8, n_row=4, sample=True, ddim_steps=200, ddim_eta=1., return_keys=None, + quantize_denoised=True, inpaint=True, plot_denoise_rows=False, plot_progressive_rows=True, + plot_diffusion_rows=True, unconditional_guidance_scale=1., unconditional_guidance_label=None, + use_ema_scope=True, + **kwargs): + ema_scope = self.ema_scope if use_ema_scope else nullcontext + use_ddim = ddim_steps is not None + + log = dict() + z, c, x, xrec, xc = self.get_input(batch, self.first_stage_key, bs=N, return_first_stage_outputs=True) + c_cat, c = c["c_concat"][0], c["c_crossattn"][0] + N = min(x.shape[0], N) + n_row = min(x.shape[0], n_row) + log["inputs"] = x + log["reconstruction"] = xrec + if self.model.conditioning_key is not None: + if hasattr(self.cond_stage_model, "decode"): + xc = self.cond_stage_model.decode(c) + log["conditioning"] = xc + elif self.cond_stage_key in ["caption", "txt"]: + xc = log_txt_as_img((x.shape[2], x.shape[3]), batch[self.cond_stage_key], size=x.shape[2] // 25) + log["conditioning"] = xc + elif self.cond_stage_key in ['class_label', 'cls']: + xc = log_txt_as_img((x.shape[2], x.shape[3]), batch["human_label"], size=x.shape[2] // 25) + log['conditioning'] = xc + elif isimage(xc): + log["conditioning"] = xc + if ismap(xc): + log["original_conditioning"] = self.to_rgb(xc) + + if not (self.c_concat_log_start is None and self.c_concat_log_end is None): + log["c_concat_decoded"] = self.decode_first_stage(c_cat[:, self.c_concat_log_start:self.c_concat_log_end]) + + if plot_diffusion_rows: + # get diffusion row + diffusion_row = list() + z_start = z[:n_row] + for t in range(self.num_timesteps): + if t % self.log_every_t == 0 or t == self.num_timesteps - 1: + t = repeat(torch.tensor([t]), '1 -> b', b=n_row) + t = t.to(self.device).long() + noise = torch.randn_like(z_start) + z_noisy = self.q_sample(x_start=z_start, t=t, noise=noise) + diffusion_row.append(self.decode_first_stage(z_noisy)) + + diffusion_row = torch.stack(diffusion_row) # n_log_step, n_row, C, H, W + diffusion_grid = rearrange(diffusion_row, 'n b c h w -> b n c h w') + diffusion_grid = rearrange(diffusion_grid, 'b n c h w -> (b n) c h w') + diffusion_grid = make_grid(diffusion_grid, nrow=diffusion_row.shape[0]) + log["diffusion_row"] = diffusion_grid + + if sample: + # get denoise row + with ema_scope("Sampling"): + samples, z_denoise_row = self.sample_log(cond={"c_concat": [c_cat], "c_crossattn": [c]}, + batch_size=N, ddim=use_ddim, + ddim_steps=ddim_steps, eta=ddim_eta) + # samples, z_denoise_row = self.sample(cond=c, batch_size=N, return_intermediates=True) + x_samples = self.decode_first_stage(samples) + log["samples"] = x_samples + if plot_denoise_rows: + denoise_grid = self._get_denoise_row_from_list(z_denoise_row) + log["denoise_row"] = denoise_grid + + if unconditional_guidance_scale > 1.0: + uc_cross = self.get_unconditional_conditioning(N, unconditional_guidance_label) + uc_cat = c_cat + uc_full = {"c_concat": [uc_cat], "c_crossattn": [uc_cross]} + with ema_scope("Sampling with classifier-free guidance"): + samples_cfg, _ = self.sample_log(cond={"c_concat": [c_cat], "c_crossattn": [c]}, + batch_size=N, ddim=use_ddim, + ddim_steps=ddim_steps, eta=ddim_eta, + unconditional_guidance_scale=unconditional_guidance_scale, + unconditional_conditioning=uc_full, + ) + x_samples_cfg = self.decode_first_stage(samples_cfg) + log[f"samples_cfg_scale_{unconditional_guidance_scale:.2f}"] = x_samples_cfg + + return log -class DiffusionWrapper(pl.LightningModule): - def __init__(self, diff_model_config, conditioning_key): - super().__init__() - self.diffusion_model = instantiate_from_config(diff_model_config) - self.conditioning_key = conditioning_key - assert self.conditioning_key in [None, 'concat', 'crossattn', 'hybrid', 'adm'] +class LatentInpaintDiffusion(LatentFinetuneDiffusion): + """ + can either run as pure inpainting model (only concat mode) or with mixed conditionings, + e.g. mask as concat and text via cross-attn. + To disable finetuning mode, set finetune_keys to None + """ - def forward(self, x, t, c_concat: list = None, c_crossattn: list = None): - if self.conditioning_key is None: - out = self.diffusion_model(x, t) - elif self.conditioning_key == 'concat': - xc = torch.cat([x] + c_concat, dim=1) - out = self.diffusion_model(xc, t) - elif self.conditioning_key == 'crossattn': - cc = torch.cat(c_crossattn, 1) - out = self.diffusion_model(x, t, context=cc) - elif self.conditioning_key == 'hybrid': - xc = torch.cat([x] + c_concat, dim=1) - cc = torch.cat(c_crossattn, 1) - out = self.diffusion_model(xc, t, context=cc) - elif self.conditioning_key == 'adm': - cc = c_crossattn[0] - out = self.diffusion_model(x, t, y=cc) + def __init__(self, + concat_keys=("mask", "masked_image"), + masked_image_key="masked_image", + *args, **kwargs + ): + super().__init__(concat_keys, *args, **kwargs) + self.masked_image_key = masked_image_key + assert self.masked_image_key in concat_keys + + @torch.no_grad() + def get_input(self, batch, k, cond_key=None, bs=None, return_first_stage_outputs=False): + # note: restricted to non-trainable encoders currently + assert not self.cond_stage_trainable, 'trainable cond stages not yet supported for inpainting' + z, c, x, xrec, xc = super().get_input(batch, self.first_stage_key, return_first_stage_outputs=True, + force_c_encode=True, return_original_cond=True, bs=bs) + + assert exists(self.concat_keys) + c_cat = list() + for ck in self.concat_keys: + if self.use_fp16: + cc = rearrange(batch[ck], 'b h w c -> b c h w').to(memory_format=torch.contiguous_format).half() + else: + cc = rearrange(batch[ck], 'b h w c -> b c h w').to(memory_format=torch.contiguous_format).float() + if bs is not None: + cc = cc[:bs] + cc = cc.to(self.device) + bchw = z.shape + if ck != self.masked_image_key: + cc = torch.nn.functional.interpolate(cc, size=bchw[-2:]) + else: + cc = self.get_first_stage_encoding(self.encode_first_stage(cc)) + c_cat.append(cc) + c_cat = torch.cat(c_cat, dim=1) + all_conds = {"c_concat": [c_cat], "c_crossattn": [c]} + if return_first_stage_outputs: + return z, all_conds, x, xrec, xc + return z, all_conds + + @torch.no_grad() + def log_images(self, *args, **kwargs): + log = super(LatentInpaintDiffusion, self).log_images(*args, **kwargs) + log["masked_image"] = rearrange(args[0]["masked_image"], + 'b h w c -> b c h w').to(memory_format=torch.contiguous_format).float() + return log + + +class LatentDepth2ImageDiffusion(LatentFinetuneDiffusion): + """ + condition on monocular depth estimation + """ + + def __init__(self, depth_stage_config, concat_keys=("midas_in",), *args, **kwargs): + super().__init__(concat_keys=concat_keys, *args, **kwargs) + self.depth_model = instantiate_from_config(depth_stage_config) + self.depth_stage_key = concat_keys[0] + + @torch.no_grad() + def get_input(self, batch, k, cond_key=None, bs=None, return_first_stage_outputs=False): + # note: restricted to non-trainable encoders currently + assert not self.cond_stage_trainable, 'trainable cond stages not yet supported for depth2img' + z, c, x, xrec, xc = super().get_input(batch, self.first_stage_key, return_first_stage_outputs=True, + force_c_encode=True, return_original_cond=True, bs=bs) + + assert exists(self.concat_keys) + assert len(self.concat_keys) == 1 + c_cat = list() + for ck in self.concat_keys: + cc = batch[ck] + if bs is not None: + cc = cc[:bs] + cc = cc.to(self.device) + cc = self.depth_model(cc) + cc = torch.nn.functional.interpolate( + cc, + size=z.shape[2:], + mode="bicubic", + align_corners=False, + ) + + depth_min, depth_max = torch.amin(cc, dim=[1, 2, 3], keepdim=True), torch.amax(cc, dim=[1, 2, 3], + keepdim=True) + cc = 2. * (cc - depth_min) / (depth_max - depth_min + 0.001) - 1. + c_cat.append(cc) + c_cat = torch.cat(c_cat, dim=1) + all_conds = {"c_concat": [c_cat], "c_crossattn": [c]} + if return_first_stage_outputs: + return z, all_conds, x, xrec, xc + return z, all_conds + + @torch.no_grad() + def log_images(self, *args, **kwargs): + log = super().log_images(*args, **kwargs) + depth = self.depth_model(args[0][self.depth_stage_key]) + depth_min, depth_max = torch.amin(depth, dim=[1, 2, 3], keepdim=True), \ + torch.amax(depth, dim=[1, 2, 3], keepdim=True) + log["depth"] = 2. * (depth - depth_min) / (depth_max - depth_min) - 1. + return log + + +class LatentUpscaleFinetuneDiffusion(LatentFinetuneDiffusion): + """ + condition on low-res image (and optionally on some spatial noise augmentation) + """ + def __init__(self, concat_keys=("lr",), reshuffle_patch_size=None, + low_scale_config=None, low_scale_key=None, *args, **kwargs): + super().__init__(concat_keys=concat_keys, *args, **kwargs) + self.reshuffle_patch_size = reshuffle_patch_size + self.low_scale_model = None + if low_scale_config is not None: + print("Initializing a low-scale model") + assert exists(low_scale_key) + self.instantiate_low_stage(low_scale_config) + self.low_scale_key = low_scale_key + + def instantiate_low_stage(self, config): + model = instantiate_from_config(config) + self.low_scale_model = model.eval() + self.low_scale_model.train = disabled_train + for param in self.low_scale_model.parameters(): + param.requires_grad = False + + @torch.no_grad() + def get_input(self, batch, k, cond_key=None, bs=None, return_first_stage_outputs=False): + # note: restricted to non-trainable encoders currently + assert not self.cond_stage_trainable, 'trainable cond stages not yet supported for upscaling-ft' + z, c, x, xrec, xc = super().get_input(batch, self.first_stage_key, return_first_stage_outputs=True, + force_c_encode=True, return_original_cond=True, bs=bs) + + assert exists(self.concat_keys) + assert len(self.concat_keys) == 1 + # optionally make spatial noise_level here + c_cat = list() + noise_level = None + for ck in self.concat_keys: + cc = batch[ck] + cc = rearrange(cc, 'b h w c -> b c h w') + if exists(self.reshuffle_patch_size): + assert isinstance(self.reshuffle_patch_size, int) + cc = rearrange(cc, 'b c (p1 h) (p2 w) -> b (p1 p2 c) h w', + p1=self.reshuffle_patch_size, p2=self.reshuffle_patch_size) + if bs is not None: + cc = cc[:bs] + cc = cc.to(self.device) + if exists(self.low_scale_model) and ck == self.low_scale_key: + cc, noise_level = self.low_scale_model(cc) + c_cat.append(cc) + c_cat = torch.cat(c_cat, dim=1) + if exists(noise_level): + all_conds = {"c_concat": [c_cat], "c_crossattn": [c], "c_adm": noise_level} else: - raise NotImplementedError() + all_conds = {"c_concat": [c_cat], "c_crossattn": [c]} + if return_first_stage_outputs: + return z, all_conds, x, xrec, xc + return z, all_conds - return out - - -class Layout2ImgDiffusion(LatentDiffusion): - # TODO: move all layout-specific hacks to this class - def __init__(self, cond_stage_key, *args, **kwargs): - assert cond_stage_key == 'coordinates_bbox', 'Layout2ImgDiffusion only for cond_stage_key="coordinates_bbox"' - super().__init__(cond_stage_key=cond_stage_key, *args, **kwargs) - - def log_images(self, batch, N=8, *args, **kwargs): - logs = super().log_images(batch=batch, N=N, *args, **kwargs) - - key = 'train' if self.training else 'validation' - dset = self.trainer.datamodule.datasets[key] - mapper = dset.conditional_builders[self.cond_stage_key] - - bbox_imgs = [] - map_fn = lambda catno: dset.get_textual_label(dset.get_category_id(catno)) - for tknzd_bbox in batch[self.cond_stage_key][:N]: - bboximg = mapper.plot(tknzd_bbox.detach().cpu(), map_fn, (256, 256)) - bbox_imgs.append(bboximg) - - cond_img = torch.stack(bbox_imgs, dim=0) - logs['bbox_image'] = cond_img - return logs + @torch.no_grad() + def log_images(self, *args, **kwargs): + log = super().log_images(*args, **kwargs) + log["lr"] = rearrange(args[0]["lr"], 'b h w c -> b c h w') + return log diff --git a/examples/images/diffusion/ldm/models/diffusion/dpm_solver/__init__.py b/examples/images/diffusion/ldm/models/diffusion/dpm_solver/__init__.py new file mode 100644 index 000000000..7427f38c0 --- /dev/null +++ b/examples/images/diffusion/ldm/models/diffusion/dpm_solver/__init__.py @@ -0,0 +1 @@ +from .sampler import DPMSolverSampler \ No newline at end of file diff --git a/examples/images/diffusion/ldm/models/diffusion/dpm_solver/dpm_solver.py b/examples/images/diffusion/ldm/models/diffusion/dpm_solver/dpm_solver.py new file mode 100644 index 000000000..095e5ba3c --- /dev/null +++ b/examples/images/diffusion/ldm/models/diffusion/dpm_solver/dpm_solver.py @@ -0,0 +1,1154 @@ +import torch +import torch.nn.functional as F +import math +from tqdm import tqdm + + +class NoiseScheduleVP: + def __init__( + self, + schedule='discrete', + betas=None, + alphas_cumprod=None, + continuous_beta_0=0.1, + continuous_beta_1=20., + ): + """Create a wrapper class for the forward SDE (VP type). + *** + Update: We support discrete-time diffusion models by implementing a picewise linear interpolation for log_alpha_t. + We recommend to use schedule='discrete' for the discrete-time diffusion models, especially for high-resolution images. + *** + The forward SDE ensures that the condition distribution q_{t|0}(x_t | x_0) = N ( alpha_t * x_0, sigma_t^2 * I ). + We further define lambda_t = log(alpha_t) - log(sigma_t), which is the half-logSNR (described in the DPM-Solver paper). + Therefore, we implement the functions for computing alpha_t, sigma_t and lambda_t. For t in [0, T], we have: + log_alpha_t = self.marginal_log_mean_coeff(t) + sigma_t = self.marginal_std(t) + lambda_t = self.marginal_lambda(t) + Moreover, as lambda(t) is an invertible function, we also support its inverse function: + t = self.inverse_lambda(lambda_t) + =============================================================== + We support both discrete-time DPMs (trained on n = 0, 1, ..., N-1) and continuous-time DPMs (trained on t in [t_0, T]). + 1. For discrete-time DPMs: + For discrete-time DPMs trained on n = 0, 1, ..., N-1, we convert the discrete steps to continuous time steps by: + t_i = (i + 1) / N + e.g. for N = 1000, we have t_0 = 1e-3 and T = t_{N-1} = 1. + We solve the corresponding diffusion ODE from time T = 1 to time t_0 = 1e-3. + Args: + betas: A `torch.Tensor`. The beta array for the discrete-time DPM. (See the original DDPM paper for details) + alphas_cumprod: A `torch.Tensor`. The cumprod alphas for the discrete-time DPM. (See the original DDPM paper for details) + Note that we always have alphas_cumprod = cumprod(betas). Therefore, we only need to set one of `betas` and `alphas_cumprod`. + **Important**: Please pay special attention for the args for `alphas_cumprod`: + The `alphas_cumprod` is the \hat{alpha_n} arrays in the notations of DDPM. Specifically, DDPMs assume that + q_{t_n | 0}(x_{t_n} | x_0) = N ( \sqrt{\hat{alpha_n}} * x_0, (1 - \hat{alpha_n}) * I ). + Therefore, the notation \hat{alpha_n} is different from the notation alpha_t in DPM-Solver. In fact, we have + alpha_{t_n} = \sqrt{\hat{alpha_n}}, + and + log(alpha_{t_n}) = 0.5 * log(\hat{alpha_n}). + 2. For continuous-time DPMs: + We support two types of VPSDEs: linear (DDPM) and cosine (improved-DDPM). The hyperparameters for the noise + schedule are the default settings in DDPM and improved-DDPM: + Args: + beta_min: A `float` number. The smallest beta for the linear schedule. + beta_max: A `float` number. The largest beta for the linear schedule. + cosine_s: A `float` number. The hyperparameter in the cosine schedule. + cosine_beta_max: A `float` number. The hyperparameter in the cosine schedule. + T: A `float` number. The ending time of the forward process. + =============================================================== + Args: + schedule: A `str`. The noise schedule of the forward SDE. 'discrete' for discrete-time DPMs, + 'linear' or 'cosine' for continuous-time DPMs. + Returns: + A wrapper object of the forward SDE (VP type). + + =============================================================== + Example: + # For discrete-time DPMs, given betas (the beta array for n = 0, 1, ..., N - 1): + >>> ns = NoiseScheduleVP('discrete', betas=betas) + # For discrete-time DPMs, given alphas_cumprod (the \hat{alpha_n} array for n = 0, 1, ..., N - 1): + >>> ns = NoiseScheduleVP('discrete', alphas_cumprod=alphas_cumprod) + # For continuous-time DPMs (VPSDE), linear schedule: + >>> ns = NoiseScheduleVP('linear', continuous_beta_0=0.1, continuous_beta_1=20.) + """ + + if schedule not in ['discrete', 'linear', 'cosine']: + raise ValueError( + "Unsupported noise schedule {}. The schedule needs to be 'discrete' or 'linear' or 'cosine'".format( + schedule)) + + self.schedule = schedule + if schedule == 'discrete': + if betas is not None: + log_alphas = 0.5 * torch.log(1 - betas).cumsum(dim=0) + else: + assert alphas_cumprod is not None + log_alphas = 0.5 * torch.log(alphas_cumprod) + self.total_N = len(log_alphas) + self.T = 1. + self.t_array = torch.linspace(0., 1., self.total_N + 1)[1:].reshape((1, -1)) + self.log_alpha_array = log_alphas.reshape((1, -1,)) + else: + self.total_N = 1000 + self.beta_0 = continuous_beta_0 + self.beta_1 = continuous_beta_1 + self.cosine_s = 0.008 + self.cosine_beta_max = 999. + self.cosine_t_max = math.atan(self.cosine_beta_max * (1. + self.cosine_s) / math.pi) * 2. * ( + 1. + self.cosine_s) / math.pi - self.cosine_s + self.cosine_log_alpha_0 = math.log(math.cos(self.cosine_s / (1. + self.cosine_s) * math.pi / 2.)) + self.schedule = schedule + if schedule == 'cosine': + # For the cosine schedule, T = 1 will have numerical issues. So we manually set the ending time T. + # Note that T = 0.9946 may be not the optimal setting. However, we find it works well. + self.T = 0.9946 + else: + self.T = 1. + + def marginal_log_mean_coeff(self, t): + """ + Compute log(alpha_t) of a given continuous-time label t in [0, T]. + """ + if self.schedule == 'discrete': + return interpolate_fn(t.reshape((-1, 1)), self.t_array.to(t.device), + self.log_alpha_array.to(t.device)).reshape((-1)) + elif self.schedule == 'linear': + return -0.25 * t ** 2 * (self.beta_1 - self.beta_0) - 0.5 * t * self.beta_0 + elif self.schedule == 'cosine': + log_alpha_fn = lambda s: torch.log(torch.cos((s + self.cosine_s) / (1. + self.cosine_s) * math.pi / 2.)) + log_alpha_t = log_alpha_fn(t) - self.cosine_log_alpha_0 + return log_alpha_t + + def marginal_alpha(self, t): + """ + Compute alpha_t of a given continuous-time label t in [0, T]. + """ + return torch.exp(self.marginal_log_mean_coeff(t)) + + def marginal_std(self, t): + """ + Compute sigma_t of a given continuous-time label t in [0, T]. + """ + return torch.sqrt(1. - torch.exp(2. * self.marginal_log_mean_coeff(t))) + + def marginal_lambda(self, t): + """ + Compute lambda_t = log(alpha_t) - log(sigma_t) of a given continuous-time label t in [0, T]. + """ + log_mean_coeff = self.marginal_log_mean_coeff(t) + log_std = 0.5 * torch.log(1. - torch.exp(2. * log_mean_coeff)) + return log_mean_coeff - log_std + + def inverse_lambda(self, lamb): + """ + Compute the continuous-time label t in [0, T] of a given half-logSNR lambda_t. + """ + if self.schedule == 'linear': + tmp = 2. * (self.beta_1 - self.beta_0) * torch.logaddexp(-2. * lamb, torch.zeros((1,)).to(lamb)) + Delta = self.beta_0 ** 2 + tmp + return tmp / (torch.sqrt(Delta) + self.beta_0) / (self.beta_1 - self.beta_0) + elif self.schedule == 'discrete': + log_alpha = -0.5 * torch.logaddexp(torch.zeros((1,)).to(lamb.device), -2. * lamb) + t = interpolate_fn(log_alpha.reshape((-1, 1)), torch.flip(self.log_alpha_array.to(lamb.device), [1]), + torch.flip(self.t_array.to(lamb.device), [1])) + return t.reshape((-1,)) + else: + log_alpha = -0.5 * torch.logaddexp(-2. * lamb, torch.zeros((1,)).to(lamb)) + t_fn = lambda log_alpha_t: torch.arccos(torch.exp(log_alpha_t + self.cosine_log_alpha_0)) * 2. * ( + 1. + self.cosine_s) / math.pi - self.cosine_s + t = t_fn(log_alpha) + return t + + +def model_wrapper( + model, + noise_schedule, + model_type="noise", + model_kwargs={}, + guidance_type="uncond", + condition=None, + unconditional_condition=None, + guidance_scale=1., + classifier_fn=None, + classifier_kwargs={}, +): + """Create a wrapper function for the noise prediction model. + DPM-Solver needs to solve the continuous-time diffusion ODEs. For DPMs trained on discrete-time labels, we need to + firstly wrap the model function to a noise prediction model that accepts the continuous time as the input. + We support four types of the diffusion model by setting `model_type`: + 1. "noise": noise prediction model. (Trained by predicting noise). + 2. "x_start": data prediction model. (Trained by predicting the data x_0 at time 0). + 3. "v": velocity prediction model. (Trained by predicting the velocity). + The "v" prediction is derivation detailed in Appendix D of [1], and is used in Imagen-Video [2]. + [1] Salimans, Tim, and Jonathan Ho. "Progressive distillation for fast sampling of diffusion models." + arXiv preprint arXiv:2202.00512 (2022). + [2] Ho, Jonathan, et al. "Imagen Video: High Definition Video Generation with Diffusion Models." + arXiv preprint arXiv:2210.02303 (2022). + + 4. "score": marginal score function. (Trained by denoising score matching). + Note that the score function and the noise prediction model follows a simple relationship: + ``` + noise(x_t, t) = -sigma_t * score(x_t, t) + ``` + We support three types of guided sampling by DPMs by setting `guidance_type`: + 1. "uncond": unconditional sampling by DPMs. + The input `model` has the following format: + `` + model(x, t_input, **model_kwargs) -> noise | x_start | v | score + `` + 2. "classifier": classifier guidance sampling [3] by DPMs and another classifier. + The input `model` has the following format: + `` + model(x, t_input, **model_kwargs) -> noise | x_start | v | score + `` + The input `classifier_fn` has the following format: + `` + classifier_fn(x, t_input, cond, **classifier_kwargs) -> logits(x, t_input, cond) + `` + [3] P. Dhariwal and A. Q. Nichol, "Diffusion models beat GANs on image synthesis," + in Advances in Neural Information Processing Systems, vol. 34, 2021, pp. 8780-8794. + 3. "classifier-free": classifier-free guidance sampling by conditional DPMs. + The input `model` has the following format: + `` + model(x, t_input, cond, **model_kwargs) -> noise | x_start | v | score + `` + And if cond == `unconditional_condition`, the model output is the unconditional DPM output. + [4] Ho, Jonathan, and Tim Salimans. "Classifier-free diffusion guidance." + arXiv preprint arXiv:2207.12598 (2022). + + The `t_input` is the time label of the model, which may be discrete-time labels (i.e. 0 to 999) + or continuous-time labels (i.e. epsilon to T). + We wrap the model function to accept only `x` and `t_continuous` as inputs, and outputs the predicted noise: + `` + def model_fn(x, t_continuous) -> noise: + t_input = get_model_input_time(t_continuous) + return noise_pred(model, x, t_input, **model_kwargs) + `` + where `t_continuous` is the continuous time labels (i.e. epsilon to T). And we use `model_fn` for DPM-Solver. + =============================================================== + Args: + model: A diffusion model with the corresponding format described above. + noise_schedule: A noise schedule object, such as NoiseScheduleVP. + model_type: A `str`. The parameterization type of the diffusion model. + "noise" or "x_start" or "v" or "score". + model_kwargs: A `dict`. A dict for the other inputs of the model function. + guidance_type: A `str`. The type of the guidance for sampling. + "uncond" or "classifier" or "classifier-free". + condition: A pytorch tensor. The condition for the guided sampling. + Only used for "classifier" or "classifier-free" guidance type. + unconditional_condition: A pytorch tensor. The condition for the unconditional sampling. + Only used for "classifier-free" guidance type. + guidance_scale: A `float`. The scale for the guided sampling. + classifier_fn: A classifier function. Only used for the classifier guidance. + classifier_kwargs: A `dict`. A dict for the other inputs of the classifier function. + Returns: + A noise prediction model that accepts the noised data and the continuous time as the inputs. + """ + + def get_model_input_time(t_continuous): + """ + Convert the continuous-time `t_continuous` (in [epsilon, T]) to the model input time. + For discrete-time DPMs, we convert `t_continuous` in [1 / N, 1] to `t_input` in [0, 1000 * (N - 1) / N]. + For continuous-time DPMs, we just use `t_continuous`. + """ + if noise_schedule.schedule == 'discrete': + return (t_continuous - 1. / noise_schedule.total_N) * 1000. + else: + return t_continuous + + def noise_pred_fn(x, t_continuous, cond=None): + if t_continuous.reshape((-1,)).shape[0] == 1: + t_continuous = t_continuous.expand((x.shape[0])) + t_input = get_model_input_time(t_continuous) + if cond is None: + output = model(x, t_input, **model_kwargs) + else: + output = model(x, t_input, cond, **model_kwargs) + if model_type == "noise": + return output + elif model_type == "x_start": + alpha_t, sigma_t = noise_schedule.marginal_alpha(t_continuous), noise_schedule.marginal_std(t_continuous) + dims = x.dim() + return (x - expand_dims(alpha_t, dims) * output) / expand_dims(sigma_t, dims) + elif model_type == "v": + alpha_t, sigma_t = noise_schedule.marginal_alpha(t_continuous), noise_schedule.marginal_std(t_continuous) + dims = x.dim() + return expand_dims(alpha_t, dims) * output + expand_dims(sigma_t, dims) * x + elif model_type == "score": + sigma_t = noise_schedule.marginal_std(t_continuous) + dims = x.dim() + return -expand_dims(sigma_t, dims) * output + + def cond_grad_fn(x, t_input): + """ + Compute the gradient of the classifier, i.e. nabla_{x} log p_t(cond | x_t). + """ + with torch.enable_grad(): + x_in = x.detach().requires_grad_(True) + log_prob = classifier_fn(x_in, t_input, condition, **classifier_kwargs) + return torch.autograd.grad(log_prob.sum(), x_in)[0] + + def model_fn(x, t_continuous): + """ + The noise predicition model function that is used for DPM-Solver. + """ + if t_continuous.reshape((-1,)).shape[0] == 1: + t_continuous = t_continuous.expand((x.shape[0])) + if guidance_type == "uncond": + return noise_pred_fn(x, t_continuous) + elif guidance_type == "classifier": + assert classifier_fn is not None + t_input = get_model_input_time(t_continuous) + cond_grad = cond_grad_fn(x, t_input) + sigma_t = noise_schedule.marginal_std(t_continuous) + noise = noise_pred_fn(x, t_continuous) + return noise - guidance_scale * expand_dims(sigma_t, dims=cond_grad.dim()) * cond_grad + elif guidance_type == "classifier-free": + if guidance_scale == 1. or unconditional_condition is None: + return noise_pred_fn(x, t_continuous, cond=condition) + else: + x_in = torch.cat([x] * 2) + t_in = torch.cat([t_continuous] * 2) + c_in = torch.cat([unconditional_condition, condition]) + noise_uncond, noise = noise_pred_fn(x_in, t_in, cond=c_in).chunk(2) + return noise_uncond + guidance_scale * (noise - noise_uncond) + + assert model_type in ["noise", "x_start", "v"] + assert guidance_type in ["uncond", "classifier", "classifier-free"] + return model_fn + + +class DPM_Solver: + def __init__(self, model_fn, noise_schedule, predict_x0=False, thresholding=False, max_val=1.): + """Construct a DPM-Solver. + We support both the noise prediction model ("predicting epsilon") and the data prediction model ("predicting x0"). + If `predict_x0` is False, we use the solver for the noise prediction model (DPM-Solver). + If `predict_x0` is True, we use the solver for the data prediction model (DPM-Solver++). + In such case, we further support the "dynamic thresholding" in [1] when `thresholding` is True. + The "dynamic thresholding" can greatly improve the sample quality for pixel-space DPMs with large guidance scales. + Args: + model_fn: A noise prediction model function which accepts the continuous-time input (t in [epsilon, T]): + `` + def model_fn(x, t_continuous): + return noise + `` + noise_schedule: A noise schedule object, such as NoiseScheduleVP. + predict_x0: A `bool`. If true, use the data prediction model; else, use the noise prediction model. + thresholding: A `bool`. Valid when `predict_x0` is True. Whether to use the "dynamic thresholding" in [1]. + max_val: A `float`. Valid when both `predict_x0` and `thresholding` are True. The max value for thresholding. + + [1] Chitwan Saharia, William Chan, Saurabh Saxena, Lala Li, Jay Whang, Emily Denton, Seyed Kamyar Seyed Ghasemipour, Burcu Karagol Ayan, S Sara Mahdavi, Rapha Gontijo Lopes, et al. Photorealistic text-to-image diffusion models with deep language understanding. arXiv preprint arXiv:2205.11487, 2022b. + """ + self.model = model_fn + self.noise_schedule = noise_schedule + self.predict_x0 = predict_x0 + self.thresholding = thresholding + self.max_val = max_val + + def noise_prediction_fn(self, x, t): + """ + Return the noise prediction model. + """ + return self.model(x, t) + + def data_prediction_fn(self, x, t): + """ + Return the data prediction model (with thresholding). + """ + noise = self.noise_prediction_fn(x, t) + dims = x.dim() + alpha_t, sigma_t = self.noise_schedule.marginal_alpha(t), self.noise_schedule.marginal_std(t) + x0 = (x - expand_dims(sigma_t, dims) * noise) / expand_dims(alpha_t, dims) + if self.thresholding: + p = 0.995 # A hyperparameter in the paper of "Imagen" [1]. + s = torch.quantile(torch.abs(x0).reshape((x0.shape[0], -1)), p, dim=1) + s = expand_dims(torch.maximum(s, self.max_val * torch.ones_like(s).to(s.device)), dims) + x0 = torch.clamp(x0, -s, s) / s + return x0 + + def model_fn(self, x, t): + """ + Convert the model to the noise prediction model or the data prediction model. + """ + if self.predict_x0: + return self.data_prediction_fn(x, t) + else: + return self.noise_prediction_fn(x, t) + + def get_time_steps(self, skip_type, t_T, t_0, N, device): + """Compute the intermediate time steps for sampling. + Args: + skip_type: A `str`. The type for the spacing of the time steps. We support three types: + - 'logSNR': uniform logSNR for the time steps. + - 'time_uniform': uniform time for the time steps. (**Recommended for high-resolutional data**.) + - 'time_quadratic': quadratic time for the time steps. (Used in DDIM for low-resolutional data.) + t_T: A `float`. The starting time of the sampling (default is T). + t_0: A `float`. The ending time of the sampling (default is epsilon). + N: A `int`. The total number of the spacing of the time steps. + device: A torch device. + Returns: + A pytorch tensor of the time steps, with the shape (N + 1,). + """ + if skip_type == 'logSNR': + lambda_T = self.noise_schedule.marginal_lambda(torch.tensor(t_T).to(device)) + lambda_0 = self.noise_schedule.marginal_lambda(torch.tensor(t_0).to(device)) + logSNR_steps = torch.linspace(lambda_T.cpu().item(), lambda_0.cpu().item(), N + 1).to(device) + return self.noise_schedule.inverse_lambda(logSNR_steps) + elif skip_type == 'time_uniform': + return torch.linspace(t_T, t_0, N + 1).to(device) + elif skip_type == 'time_quadratic': + t_order = 2 + t = torch.linspace(t_T ** (1. / t_order), t_0 ** (1. / t_order), N + 1).pow(t_order).to(device) + return t + else: + raise ValueError( + "Unsupported skip_type {}, need to be 'logSNR' or 'time_uniform' or 'time_quadratic'".format(skip_type)) + + def get_orders_and_timesteps_for_singlestep_solver(self, steps, order, skip_type, t_T, t_0, device): + """ + Get the order of each step for sampling by the singlestep DPM-Solver. + We combine both DPM-Solver-1,2,3 to use all the function evaluations, which is named as "DPM-Solver-fast". + Given a fixed number of function evaluations by `steps`, the sampling procedure by DPM-Solver-fast is: + - If order == 1: + We take `steps` of DPM-Solver-1 (i.e. DDIM). + - If order == 2: + - Denote K = (steps // 2). We take K or (K + 1) intermediate time steps for sampling. + - If steps % 2 == 0, we use K steps of DPM-Solver-2. + - If steps % 2 == 1, we use K steps of DPM-Solver-2 and 1 step of DPM-Solver-1. + - If order == 3: + - Denote K = (steps // 3 + 1). We take K intermediate time steps for sampling. + - If steps % 3 == 0, we use (K - 2) steps of DPM-Solver-3, and 1 step of DPM-Solver-2 and 1 step of DPM-Solver-1. + - If steps % 3 == 1, we use (K - 1) steps of DPM-Solver-3 and 1 step of DPM-Solver-1. + - If steps % 3 == 2, we use (K - 1) steps of DPM-Solver-3 and 1 step of DPM-Solver-2. + ============================================ + Args: + order: A `int`. The max order for the solver (2 or 3). + steps: A `int`. The total number of function evaluations (NFE). + skip_type: A `str`. The type for the spacing of the time steps. We support three types: + - 'logSNR': uniform logSNR for the time steps. + - 'time_uniform': uniform time for the time steps. (**Recommended for high-resolutional data**.) + - 'time_quadratic': quadratic time for the time steps. (Used in DDIM for low-resolutional data.) + t_T: A `float`. The starting time of the sampling (default is T). + t_0: A `float`. The ending time of the sampling (default is epsilon). + device: A torch device. + Returns: + orders: A list of the solver order of each step. + """ + if order == 3: + K = steps // 3 + 1 + if steps % 3 == 0: + orders = [3, ] * (K - 2) + [2, 1] + elif steps % 3 == 1: + orders = [3, ] * (K - 1) + [1] + else: + orders = [3, ] * (K - 1) + [2] + elif order == 2: + if steps % 2 == 0: + K = steps // 2 + orders = [2, ] * K + else: + K = steps // 2 + 1 + orders = [2, ] * (K - 1) + [1] + elif order == 1: + K = 1 + orders = [1, ] * steps + else: + raise ValueError("'order' must be '1' or '2' or '3'.") + if skip_type == 'logSNR': + # To reproduce the results in DPM-Solver paper + timesteps_outer = self.get_time_steps(skip_type, t_T, t_0, K, device) + else: + timesteps_outer = self.get_time_steps(skip_type, t_T, t_0, steps, device)[ + torch.cumsum(torch.tensor([0, ] + orders)).to(device)] + return timesteps_outer, orders + + def denoise_to_zero_fn(self, x, s): + """ + Denoise at the final step, which is equivalent to solve the ODE from lambda_s to infty by first-order discretization. + """ + return self.data_prediction_fn(x, s) + + def dpm_solver_first_update(self, x, s, t, model_s=None, return_intermediate=False): + """ + DPM-Solver-1 (equivalent to DDIM) from time `s` to time `t`. + Args: + x: A pytorch tensor. The initial value at time `s`. + s: A pytorch tensor. The starting time, with the shape (x.shape[0],). + t: A pytorch tensor. The ending time, with the shape (x.shape[0],). + model_s: A pytorch tensor. The model function evaluated at time `s`. + If `model_s` is None, we evaluate the model by `x` and `s`; otherwise we directly use it. + return_intermediate: A `bool`. If true, also return the model value at time `s`. + Returns: + x_t: A pytorch tensor. The approximated solution at time `t`. + """ + ns = self.noise_schedule + dims = x.dim() + lambda_s, lambda_t = ns.marginal_lambda(s), ns.marginal_lambda(t) + h = lambda_t - lambda_s + log_alpha_s, log_alpha_t = ns.marginal_log_mean_coeff(s), ns.marginal_log_mean_coeff(t) + sigma_s, sigma_t = ns.marginal_std(s), ns.marginal_std(t) + alpha_t = torch.exp(log_alpha_t) + + if self.predict_x0: + phi_1 = torch.expm1(-h) + if model_s is None: + model_s = self.model_fn(x, s) + x_t = ( + expand_dims(sigma_t / sigma_s, dims) * x + - expand_dims(alpha_t * phi_1, dims) * model_s + ) + if return_intermediate: + return x_t, {'model_s': model_s} + else: + return x_t + else: + phi_1 = torch.expm1(h) + if model_s is None: + model_s = self.model_fn(x, s) + x_t = ( + expand_dims(torch.exp(log_alpha_t - log_alpha_s), dims) * x + - expand_dims(sigma_t * phi_1, dims) * model_s + ) + if return_intermediate: + return x_t, {'model_s': model_s} + else: + return x_t + + def singlestep_dpm_solver_second_update(self, x, s, t, r1=0.5, model_s=None, return_intermediate=False, + solver_type='dpm_solver'): + """ + Singlestep solver DPM-Solver-2 from time `s` to time `t`. + Args: + x: A pytorch tensor. The initial value at time `s`. + s: A pytorch tensor. The starting time, with the shape (x.shape[0],). + t: A pytorch tensor. The ending time, with the shape (x.shape[0],). + r1: A `float`. The hyperparameter of the second-order solver. + model_s: A pytorch tensor. The model function evaluated at time `s`. + If `model_s` is None, we evaluate the model by `x` and `s`; otherwise we directly use it. + return_intermediate: A `bool`. If true, also return the model value at time `s` and `s1` (the intermediate time). + solver_type: either 'dpm_solver' or 'taylor'. The type for the high-order solvers. + The type slightly impacts the performance. We recommend to use 'dpm_solver' type. + Returns: + x_t: A pytorch tensor. The approximated solution at time `t`. + """ + if solver_type not in ['dpm_solver', 'taylor']: + raise ValueError("'solver_type' must be either 'dpm_solver' or 'taylor', got {}".format(solver_type)) + if r1 is None: + r1 = 0.5 + ns = self.noise_schedule + dims = x.dim() + lambda_s, lambda_t = ns.marginal_lambda(s), ns.marginal_lambda(t) + h = lambda_t - lambda_s + lambda_s1 = lambda_s + r1 * h + s1 = ns.inverse_lambda(lambda_s1) + log_alpha_s, log_alpha_s1, log_alpha_t = ns.marginal_log_mean_coeff(s), ns.marginal_log_mean_coeff( + s1), ns.marginal_log_mean_coeff(t) + sigma_s, sigma_s1, sigma_t = ns.marginal_std(s), ns.marginal_std(s1), ns.marginal_std(t) + alpha_s1, alpha_t = torch.exp(log_alpha_s1), torch.exp(log_alpha_t) + + if self.predict_x0: + phi_11 = torch.expm1(-r1 * h) + phi_1 = torch.expm1(-h) + + if model_s is None: + model_s = self.model_fn(x, s) + x_s1 = ( + expand_dims(sigma_s1 / sigma_s, dims) * x + - expand_dims(alpha_s1 * phi_11, dims) * model_s + ) + model_s1 = self.model_fn(x_s1, s1) + if solver_type == 'dpm_solver': + x_t = ( + expand_dims(sigma_t / sigma_s, dims) * x + - expand_dims(alpha_t * phi_1, dims) * model_s + - (0.5 / r1) * expand_dims(alpha_t * phi_1, dims) * (model_s1 - model_s) + ) + elif solver_type == 'taylor': + x_t = ( + expand_dims(sigma_t / sigma_s, dims) * x + - expand_dims(alpha_t * phi_1, dims) * model_s + + (1. / r1) * expand_dims(alpha_t * ((torch.exp(-h) - 1.) / h + 1.), dims) * ( + model_s1 - model_s) + ) + else: + phi_11 = torch.expm1(r1 * h) + phi_1 = torch.expm1(h) + + if model_s is None: + model_s = self.model_fn(x, s) + x_s1 = ( + expand_dims(torch.exp(log_alpha_s1 - log_alpha_s), dims) * x + - expand_dims(sigma_s1 * phi_11, dims) * model_s + ) + model_s1 = self.model_fn(x_s1, s1) + if solver_type == 'dpm_solver': + x_t = ( + expand_dims(torch.exp(log_alpha_t - log_alpha_s), dims) * x + - expand_dims(sigma_t * phi_1, dims) * model_s + - (0.5 / r1) * expand_dims(sigma_t * phi_1, dims) * (model_s1 - model_s) + ) + elif solver_type == 'taylor': + x_t = ( + expand_dims(torch.exp(log_alpha_t - log_alpha_s), dims) * x + - expand_dims(sigma_t * phi_1, dims) * model_s + - (1. / r1) * expand_dims(sigma_t * ((torch.exp(h) - 1.) / h - 1.), dims) * (model_s1 - model_s) + ) + if return_intermediate: + return x_t, {'model_s': model_s, 'model_s1': model_s1} + else: + return x_t + + def singlestep_dpm_solver_third_update(self, x, s, t, r1=1. / 3., r2=2. / 3., model_s=None, model_s1=None, + return_intermediate=False, solver_type='dpm_solver'): + """ + Singlestep solver DPM-Solver-3 from time `s` to time `t`. + Args: + x: A pytorch tensor. The initial value at time `s`. + s: A pytorch tensor. The starting time, with the shape (x.shape[0],). + t: A pytorch tensor. The ending time, with the shape (x.shape[0],). + r1: A `float`. The hyperparameter of the third-order solver. + r2: A `float`. The hyperparameter of the third-order solver. + model_s: A pytorch tensor. The model function evaluated at time `s`. + If `model_s` is None, we evaluate the model by `x` and `s`; otherwise we directly use it. + model_s1: A pytorch tensor. The model function evaluated at time `s1` (the intermediate time given by `r1`). + If `model_s1` is None, we evaluate the model at `s1`; otherwise we directly use it. + return_intermediate: A `bool`. If true, also return the model value at time `s`, `s1` and `s2` (the intermediate times). + solver_type: either 'dpm_solver' or 'taylor'. The type for the high-order solvers. + The type slightly impacts the performance. We recommend to use 'dpm_solver' type. + Returns: + x_t: A pytorch tensor. The approximated solution at time `t`. + """ + if solver_type not in ['dpm_solver', 'taylor']: + raise ValueError("'solver_type' must be either 'dpm_solver' or 'taylor', got {}".format(solver_type)) + if r1 is None: + r1 = 1. / 3. + if r2 is None: + r2 = 2. / 3. + ns = self.noise_schedule + dims = x.dim() + lambda_s, lambda_t = ns.marginal_lambda(s), ns.marginal_lambda(t) + h = lambda_t - lambda_s + lambda_s1 = lambda_s + r1 * h + lambda_s2 = lambda_s + r2 * h + s1 = ns.inverse_lambda(lambda_s1) + s2 = ns.inverse_lambda(lambda_s2) + log_alpha_s, log_alpha_s1, log_alpha_s2, log_alpha_t = ns.marginal_log_mean_coeff( + s), ns.marginal_log_mean_coeff(s1), ns.marginal_log_mean_coeff(s2), ns.marginal_log_mean_coeff(t) + sigma_s, sigma_s1, sigma_s2, sigma_t = ns.marginal_std(s), ns.marginal_std(s1), ns.marginal_std( + s2), ns.marginal_std(t) + alpha_s1, alpha_s2, alpha_t = torch.exp(log_alpha_s1), torch.exp(log_alpha_s2), torch.exp(log_alpha_t) + + if self.predict_x0: + phi_11 = torch.expm1(-r1 * h) + phi_12 = torch.expm1(-r2 * h) + phi_1 = torch.expm1(-h) + phi_22 = torch.expm1(-r2 * h) / (r2 * h) + 1. + phi_2 = phi_1 / h + 1. + phi_3 = phi_2 / h - 0.5 + + if model_s is None: + model_s = self.model_fn(x, s) + if model_s1 is None: + x_s1 = ( + expand_dims(sigma_s1 / sigma_s, dims) * x + - expand_dims(alpha_s1 * phi_11, dims) * model_s + ) + model_s1 = self.model_fn(x_s1, s1) + x_s2 = ( + expand_dims(sigma_s2 / sigma_s, dims) * x + - expand_dims(alpha_s2 * phi_12, dims) * model_s + + r2 / r1 * expand_dims(alpha_s2 * phi_22, dims) * (model_s1 - model_s) + ) + model_s2 = self.model_fn(x_s2, s2) + if solver_type == 'dpm_solver': + x_t = ( + expand_dims(sigma_t / sigma_s, dims) * x + - expand_dims(alpha_t * phi_1, dims) * model_s + + (1. / r2) * expand_dims(alpha_t * phi_2, dims) * (model_s2 - model_s) + ) + elif solver_type == 'taylor': + D1_0 = (1. / r1) * (model_s1 - model_s) + D1_1 = (1. / r2) * (model_s2 - model_s) + D1 = (r2 * D1_0 - r1 * D1_1) / (r2 - r1) + D2 = 2. * (D1_1 - D1_0) / (r2 - r1) + x_t = ( + expand_dims(sigma_t / sigma_s, dims) * x + - expand_dims(alpha_t * phi_1, dims) * model_s + + expand_dims(alpha_t * phi_2, dims) * D1 + - expand_dims(alpha_t * phi_3, dims) * D2 + ) + else: + phi_11 = torch.expm1(r1 * h) + phi_12 = torch.expm1(r2 * h) + phi_1 = torch.expm1(h) + phi_22 = torch.expm1(r2 * h) / (r2 * h) - 1. + phi_2 = phi_1 / h - 1. + phi_3 = phi_2 / h - 0.5 + + if model_s is None: + model_s = self.model_fn(x, s) + if model_s1 is None: + x_s1 = ( + expand_dims(torch.exp(log_alpha_s1 - log_alpha_s), dims) * x + - expand_dims(sigma_s1 * phi_11, dims) * model_s + ) + model_s1 = self.model_fn(x_s1, s1) + x_s2 = ( + expand_dims(torch.exp(log_alpha_s2 - log_alpha_s), dims) * x + - expand_dims(sigma_s2 * phi_12, dims) * model_s + - r2 / r1 * expand_dims(sigma_s2 * phi_22, dims) * (model_s1 - model_s) + ) + model_s2 = self.model_fn(x_s2, s2) + if solver_type == 'dpm_solver': + x_t = ( + expand_dims(torch.exp(log_alpha_t - log_alpha_s), dims) * x + - expand_dims(sigma_t * phi_1, dims) * model_s + - (1. / r2) * expand_dims(sigma_t * phi_2, dims) * (model_s2 - model_s) + ) + elif solver_type == 'taylor': + D1_0 = (1. / r1) * (model_s1 - model_s) + D1_1 = (1. / r2) * (model_s2 - model_s) + D1 = (r2 * D1_0 - r1 * D1_1) / (r2 - r1) + D2 = 2. * (D1_1 - D1_0) / (r2 - r1) + x_t = ( + expand_dims(torch.exp(log_alpha_t - log_alpha_s), dims) * x + - expand_dims(sigma_t * phi_1, dims) * model_s + - expand_dims(sigma_t * phi_2, dims) * D1 + - expand_dims(sigma_t * phi_3, dims) * D2 + ) + + if return_intermediate: + return x_t, {'model_s': model_s, 'model_s1': model_s1, 'model_s2': model_s2} + else: + return x_t + + def multistep_dpm_solver_second_update(self, x, model_prev_list, t_prev_list, t, solver_type="dpm_solver"): + """ + Multistep solver DPM-Solver-2 from time `t_prev_list[-1]` to time `t`. + Args: + x: A pytorch tensor. The initial value at time `s`. + model_prev_list: A list of pytorch tensor. The previous computed model values. + t_prev_list: A list of pytorch tensor. The previous times, each time has the shape (x.shape[0],) + t: A pytorch tensor. The ending time, with the shape (x.shape[0],). + solver_type: either 'dpm_solver' or 'taylor'. The type for the high-order solvers. + The type slightly impacts the performance. We recommend to use 'dpm_solver' type. + Returns: + x_t: A pytorch tensor. The approximated solution at time `t`. + """ + if solver_type not in ['dpm_solver', 'taylor']: + raise ValueError("'solver_type' must be either 'dpm_solver' or 'taylor', got {}".format(solver_type)) + ns = self.noise_schedule + dims = x.dim() + model_prev_1, model_prev_0 = model_prev_list + t_prev_1, t_prev_0 = t_prev_list + lambda_prev_1, lambda_prev_0, lambda_t = ns.marginal_lambda(t_prev_1), ns.marginal_lambda( + t_prev_0), ns.marginal_lambda(t) + log_alpha_prev_0, log_alpha_t = ns.marginal_log_mean_coeff(t_prev_0), ns.marginal_log_mean_coeff(t) + sigma_prev_0, sigma_t = ns.marginal_std(t_prev_0), ns.marginal_std(t) + alpha_t = torch.exp(log_alpha_t) + + h_0 = lambda_prev_0 - lambda_prev_1 + h = lambda_t - lambda_prev_0 + r0 = h_0 / h + D1_0 = expand_dims(1. / r0, dims) * (model_prev_0 - model_prev_1) + if self.predict_x0: + if solver_type == 'dpm_solver': + x_t = ( + expand_dims(sigma_t / sigma_prev_0, dims) * x + - expand_dims(alpha_t * (torch.exp(-h) - 1.), dims) * model_prev_0 + - 0.5 * expand_dims(alpha_t * (torch.exp(-h) - 1.), dims) * D1_0 + ) + elif solver_type == 'taylor': + x_t = ( + expand_dims(sigma_t / sigma_prev_0, dims) * x + - expand_dims(alpha_t * (torch.exp(-h) - 1.), dims) * model_prev_0 + + expand_dims(alpha_t * ((torch.exp(-h) - 1.) / h + 1.), dims) * D1_0 + ) + else: + if solver_type == 'dpm_solver': + x_t = ( + expand_dims(torch.exp(log_alpha_t - log_alpha_prev_0), dims) * x + - expand_dims(sigma_t * (torch.exp(h) - 1.), dims) * model_prev_0 + - 0.5 * expand_dims(sigma_t * (torch.exp(h) - 1.), dims) * D1_0 + ) + elif solver_type == 'taylor': + x_t = ( + expand_dims(torch.exp(log_alpha_t - log_alpha_prev_0), dims) * x + - expand_dims(sigma_t * (torch.exp(h) - 1.), dims) * model_prev_0 + - expand_dims(sigma_t * ((torch.exp(h) - 1.) / h - 1.), dims) * D1_0 + ) + return x_t + + def multistep_dpm_solver_third_update(self, x, model_prev_list, t_prev_list, t, solver_type='dpm_solver'): + """ + Multistep solver DPM-Solver-3 from time `t_prev_list[-1]` to time `t`. + Args: + x: A pytorch tensor. The initial value at time `s`. + model_prev_list: A list of pytorch tensor. The previous computed model values. + t_prev_list: A list of pytorch tensor. The previous times, each time has the shape (x.shape[0],) + t: A pytorch tensor. The ending time, with the shape (x.shape[0],). + solver_type: either 'dpm_solver' or 'taylor'. The type for the high-order solvers. + The type slightly impacts the performance. We recommend to use 'dpm_solver' type. + Returns: + x_t: A pytorch tensor. The approximated solution at time `t`. + """ + ns = self.noise_schedule + dims = x.dim() + model_prev_2, model_prev_1, model_prev_0 = model_prev_list + t_prev_2, t_prev_1, t_prev_0 = t_prev_list + lambda_prev_2, lambda_prev_1, lambda_prev_0, lambda_t = ns.marginal_lambda(t_prev_2), ns.marginal_lambda( + t_prev_1), ns.marginal_lambda(t_prev_0), ns.marginal_lambda(t) + log_alpha_prev_0, log_alpha_t = ns.marginal_log_mean_coeff(t_prev_0), ns.marginal_log_mean_coeff(t) + sigma_prev_0, sigma_t = ns.marginal_std(t_prev_0), ns.marginal_std(t) + alpha_t = torch.exp(log_alpha_t) + + h_1 = lambda_prev_1 - lambda_prev_2 + h_0 = lambda_prev_0 - lambda_prev_1 + h = lambda_t - lambda_prev_0 + r0, r1 = h_0 / h, h_1 / h + D1_0 = expand_dims(1. / r0, dims) * (model_prev_0 - model_prev_1) + D1_1 = expand_dims(1. / r1, dims) * (model_prev_1 - model_prev_2) + D1 = D1_0 + expand_dims(r0 / (r0 + r1), dims) * (D1_0 - D1_1) + D2 = expand_dims(1. / (r0 + r1), dims) * (D1_0 - D1_1) + if self.predict_x0: + x_t = ( + expand_dims(sigma_t / sigma_prev_0, dims) * x + - expand_dims(alpha_t * (torch.exp(-h) - 1.), dims) * model_prev_0 + + expand_dims(alpha_t * ((torch.exp(-h) - 1.) / h + 1.), dims) * D1 + - expand_dims(alpha_t * ((torch.exp(-h) - 1. + h) / h ** 2 - 0.5), dims) * D2 + ) + else: + x_t = ( + expand_dims(torch.exp(log_alpha_t - log_alpha_prev_0), dims) * x + - expand_dims(sigma_t * (torch.exp(h) - 1.), dims) * model_prev_0 + - expand_dims(sigma_t * ((torch.exp(h) - 1.) / h - 1.), dims) * D1 + - expand_dims(sigma_t * ((torch.exp(h) - 1. - h) / h ** 2 - 0.5), dims) * D2 + ) + return x_t + + def singlestep_dpm_solver_update(self, x, s, t, order, return_intermediate=False, solver_type='dpm_solver', r1=None, + r2=None): + """ + Singlestep DPM-Solver with the order `order` from time `s` to time `t`. + Args: + x: A pytorch tensor. The initial value at time `s`. + s: A pytorch tensor. The starting time, with the shape (x.shape[0],). + t: A pytorch tensor. The ending time, with the shape (x.shape[0],). + order: A `int`. The order of DPM-Solver. We only support order == 1 or 2 or 3. + return_intermediate: A `bool`. If true, also return the model value at time `s`, `s1` and `s2` (the intermediate times). + solver_type: either 'dpm_solver' or 'taylor'. The type for the high-order solvers. + The type slightly impacts the performance. We recommend to use 'dpm_solver' type. + r1: A `float`. The hyperparameter of the second-order or third-order solver. + r2: A `float`. The hyperparameter of the third-order solver. + Returns: + x_t: A pytorch tensor. The approximated solution at time `t`. + """ + if order == 1: + return self.dpm_solver_first_update(x, s, t, return_intermediate=return_intermediate) + elif order == 2: + return self.singlestep_dpm_solver_second_update(x, s, t, return_intermediate=return_intermediate, + solver_type=solver_type, r1=r1) + elif order == 3: + return self.singlestep_dpm_solver_third_update(x, s, t, return_intermediate=return_intermediate, + solver_type=solver_type, r1=r1, r2=r2) + else: + raise ValueError("Solver order must be 1 or 2 or 3, got {}".format(order)) + + def multistep_dpm_solver_update(self, x, model_prev_list, t_prev_list, t, order, solver_type='dpm_solver'): + """ + Multistep DPM-Solver with the order `order` from time `t_prev_list[-1]` to time `t`. + Args: + x: A pytorch tensor. The initial value at time `s`. + model_prev_list: A list of pytorch tensor. The previous computed model values. + t_prev_list: A list of pytorch tensor. The previous times, each time has the shape (x.shape[0],) + t: A pytorch tensor. The ending time, with the shape (x.shape[0],). + order: A `int`. The order of DPM-Solver. We only support order == 1 or 2 or 3. + solver_type: either 'dpm_solver' or 'taylor'. The type for the high-order solvers. + The type slightly impacts the performance. We recommend to use 'dpm_solver' type. + Returns: + x_t: A pytorch tensor. The approximated solution at time `t`. + """ + if order == 1: + return self.dpm_solver_first_update(x, t_prev_list[-1], t, model_s=model_prev_list[-1]) + elif order == 2: + return self.multistep_dpm_solver_second_update(x, model_prev_list, t_prev_list, t, solver_type=solver_type) + elif order == 3: + return self.multistep_dpm_solver_third_update(x, model_prev_list, t_prev_list, t, solver_type=solver_type) + else: + raise ValueError("Solver order must be 1 or 2 or 3, got {}".format(order)) + + def dpm_solver_adaptive(self, x, order, t_T, t_0, h_init=0.05, atol=0.0078, rtol=0.05, theta=0.9, t_err=1e-5, + solver_type='dpm_solver'): + """ + The adaptive step size solver based on singlestep DPM-Solver. + Args: + x: A pytorch tensor. The initial value at time `t_T`. + order: A `int`. The (higher) order of the solver. We only support order == 2 or 3. + t_T: A `float`. The starting time of the sampling (default is T). + t_0: A `float`. The ending time of the sampling (default is epsilon). + h_init: A `float`. The initial step size (for logSNR). + atol: A `float`. The absolute tolerance of the solver. For image data, the default setting is 0.0078, followed [1]. + rtol: A `float`. The relative tolerance of the solver. The default setting is 0.05. + theta: A `float`. The safety hyperparameter for adapting the step size. The default setting is 0.9, followed [1]. + t_err: A `float`. The tolerance for the time. We solve the diffusion ODE until the absolute error between the + current time and `t_0` is less than `t_err`. The default setting is 1e-5. + solver_type: either 'dpm_solver' or 'taylor'. The type for the high-order solvers. + The type slightly impacts the performance. We recommend to use 'dpm_solver' type. + Returns: + x_0: A pytorch tensor. The approximated solution at time `t_0`. + [1] A. Jolicoeur-Martineau, K. Li, R. Piché-Taillefer, T. Kachman, and I. Mitliagkas, "Gotta go fast when generating data with score-based models," arXiv preprint arXiv:2105.14080, 2021. + """ + ns = self.noise_schedule + s = t_T * torch.ones((x.shape[0],)).to(x) + lambda_s = ns.marginal_lambda(s) + lambda_0 = ns.marginal_lambda(t_0 * torch.ones_like(s).to(x)) + h = h_init * torch.ones_like(s).to(x) + x_prev = x + nfe = 0 + if order == 2: + r1 = 0.5 + lower_update = lambda x, s, t: self.dpm_solver_first_update(x, s, t, return_intermediate=True) + higher_update = lambda x, s, t, **kwargs: self.singlestep_dpm_solver_second_update(x, s, t, r1=r1, + solver_type=solver_type, + **kwargs) + elif order == 3: + r1, r2 = 1. / 3., 2. / 3. + lower_update = lambda x, s, t: self.singlestep_dpm_solver_second_update(x, s, t, r1=r1, + return_intermediate=True, + solver_type=solver_type) + higher_update = lambda x, s, t, **kwargs: self.singlestep_dpm_solver_third_update(x, s, t, r1=r1, r2=r2, + solver_type=solver_type, + **kwargs) + else: + raise ValueError("For adaptive step size solver, order must be 2 or 3, got {}".format(order)) + while torch.abs((s - t_0)).mean() > t_err: + t = ns.inverse_lambda(lambda_s + h) + x_lower, lower_noise_kwargs = lower_update(x, s, t) + x_higher = higher_update(x, s, t, **lower_noise_kwargs) + delta = torch.max(torch.ones_like(x).to(x) * atol, rtol * torch.max(torch.abs(x_lower), torch.abs(x_prev))) + norm_fn = lambda v: torch.sqrt(torch.square(v.reshape((v.shape[0], -1))).mean(dim=-1, keepdim=True)) + E = norm_fn((x_higher - x_lower) / delta).max() + if torch.all(E <= 1.): + x = x_higher + s = t + x_prev = x_lower + lambda_s = ns.marginal_lambda(s) + h = torch.min(theta * h * torch.float_power(E, -1. / order).float(), lambda_0 - lambda_s) + nfe += order + print('adaptive solver nfe', nfe) + return x + + def sample(self, x, steps=20, t_start=None, t_end=None, order=3, skip_type='time_uniform', + method='singlestep', lower_order_final=True, denoise_to_zero=False, solver_type='dpm_solver', + atol=0.0078, rtol=0.05, + ): + """ + Compute the sample at time `t_end` by DPM-Solver, given the initial `x` at time `t_start`. + ===================================================== + We support the following algorithms for both noise prediction model and data prediction model: + - 'singlestep': + Singlestep DPM-Solver (i.e. "DPM-Solver-fast" in the paper), which combines different orders of singlestep DPM-Solver. + We combine all the singlestep solvers with order <= `order` to use up all the function evaluations (steps). + The total number of function evaluations (NFE) == `steps`. + Given a fixed NFE == `steps`, the sampling procedure is: + - If `order` == 1: + - Denote K = steps. We use K steps of DPM-Solver-1 (i.e. DDIM). + - If `order` == 2: + - Denote K = (steps // 2) + (steps % 2). We take K intermediate time steps for sampling. + - If steps % 2 == 0, we use K steps of singlestep DPM-Solver-2. + - If steps % 2 == 1, we use (K - 1) steps of singlestep DPM-Solver-2 and 1 step of DPM-Solver-1. + - If `order` == 3: + - Denote K = (steps // 3 + 1). We take K intermediate time steps for sampling. + - If steps % 3 == 0, we use (K - 2) steps of singlestep DPM-Solver-3, and 1 step of singlestep DPM-Solver-2 and 1 step of DPM-Solver-1. + - If steps % 3 == 1, we use (K - 1) steps of singlestep DPM-Solver-3 and 1 step of DPM-Solver-1. + - If steps % 3 == 2, we use (K - 1) steps of singlestep DPM-Solver-3 and 1 step of singlestep DPM-Solver-2. + - 'multistep': + Multistep DPM-Solver with the order of `order`. The total number of function evaluations (NFE) == `steps`. + We initialize the first `order` values by lower order multistep solvers. + Given a fixed NFE == `steps`, the sampling procedure is: + Denote K = steps. + - If `order` == 1: + - We use K steps of DPM-Solver-1 (i.e. DDIM). + - If `order` == 2: + - We firstly use 1 step of DPM-Solver-1, then use (K - 1) step of multistep DPM-Solver-2. + - If `order` == 3: + - We firstly use 1 step of DPM-Solver-1, then 1 step of multistep DPM-Solver-2, then (K - 2) step of multistep DPM-Solver-3. + - 'singlestep_fixed': + Fixed order singlestep DPM-Solver (i.e. DPM-Solver-1 or singlestep DPM-Solver-2 or singlestep DPM-Solver-3). + We use singlestep DPM-Solver-`order` for `order`=1 or 2 or 3, with total [`steps` // `order`] * `order` NFE. + - 'adaptive': + Adaptive step size DPM-Solver (i.e. "DPM-Solver-12" and "DPM-Solver-23" in the paper). + We ignore `steps` and use adaptive step size DPM-Solver with a higher order of `order`. + You can adjust the absolute tolerance `atol` and the relative tolerance `rtol` to balance the computatation costs + (NFE) and the sample quality. + - If `order` == 2, we use DPM-Solver-12 which combines DPM-Solver-1 and singlestep DPM-Solver-2. + - If `order` == 3, we use DPM-Solver-23 which combines singlestep DPM-Solver-2 and singlestep DPM-Solver-3. + ===================================================== + Some advices for choosing the algorithm: + - For **unconditional sampling** or **guided sampling with small guidance scale** by DPMs: + Use singlestep DPM-Solver ("DPM-Solver-fast" in the paper) with `order = 3`. + e.g. + >>> dpm_solver = DPM_Solver(model_fn, noise_schedule, predict_x0=False) + >>> x_sample = dpm_solver.sample(x, steps=steps, t_start=t_start, t_end=t_end, order=3, + skip_type='time_uniform', method='singlestep') + - For **guided sampling with large guidance scale** by DPMs: + Use multistep DPM-Solver with `predict_x0 = True` and `order = 2`. + e.g. + >>> dpm_solver = DPM_Solver(model_fn, noise_schedule, predict_x0=True) + >>> x_sample = dpm_solver.sample(x, steps=steps, t_start=t_start, t_end=t_end, order=2, + skip_type='time_uniform', method='multistep') + We support three types of `skip_type`: + - 'logSNR': uniform logSNR for the time steps. **Recommended for low-resolutional images** + - 'time_uniform': uniform time for the time steps. **Recommended for high-resolutional images**. + - 'time_quadratic': quadratic time for the time steps. + ===================================================== + Args: + x: A pytorch tensor. The initial value at time `t_start` + e.g. if `t_start` == T, then `x` is a sample from the standard normal distribution. + steps: A `int`. The total number of function evaluations (NFE). + t_start: A `float`. The starting time of the sampling. + If `T` is None, we use self.noise_schedule.T (default is 1.0). + t_end: A `float`. The ending time of the sampling. + If `t_end` is None, we use 1. / self.noise_schedule.total_N. + e.g. if total_N == 1000, we have `t_end` == 1e-3. + For discrete-time DPMs: + - We recommend `t_end` == 1. / self.noise_schedule.total_N. + For continuous-time DPMs: + - We recommend `t_end` == 1e-3 when `steps` <= 15; and `t_end` == 1e-4 when `steps` > 15. + order: A `int`. The order of DPM-Solver. + skip_type: A `str`. The type for the spacing of the time steps. 'time_uniform' or 'logSNR' or 'time_quadratic'. + method: A `str`. The method for sampling. 'singlestep' or 'multistep' or 'singlestep_fixed' or 'adaptive'. + denoise_to_zero: A `bool`. Whether to denoise to time 0 at the final step. + Default is `False`. If `denoise_to_zero` is `True`, the total NFE is (`steps` + 1). + This trick is firstly proposed by DDPM (https://arxiv.org/abs/2006.11239) and + score_sde (https://arxiv.org/abs/2011.13456). Such trick can improve the FID + for diffusion models sampling by diffusion SDEs for low-resolutional images + (such as CIFAR-10). However, we observed that such trick does not matter for + high-resolutional images. As it needs an additional NFE, we do not recommend + it for high-resolutional images. + lower_order_final: A `bool`. Whether to use lower order solvers at the final steps. + Only valid for `method=multistep` and `steps < 15`. We empirically find that + this trick is a key to stabilizing the sampling by DPM-Solver with very few steps + (especially for steps <= 10). So we recommend to set it to be `True`. + solver_type: A `str`. The taylor expansion type for the solver. `dpm_solver` or `taylor`. We recommend `dpm_solver`. + atol: A `float`. The absolute tolerance of the adaptive step size solver. Valid when `method` == 'adaptive'. + rtol: A `float`. The relative tolerance of the adaptive step size solver. Valid when `method` == 'adaptive'. + Returns: + x_end: A pytorch tensor. The approximated solution at time `t_end`. + """ + t_0 = 1. / self.noise_schedule.total_N if t_end is None else t_end + t_T = self.noise_schedule.T if t_start is None else t_start + device = x.device + if method == 'adaptive': + with torch.no_grad(): + x = self.dpm_solver_adaptive(x, order=order, t_T=t_T, t_0=t_0, atol=atol, rtol=rtol, + solver_type=solver_type) + elif method == 'multistep': + assert steps >= order + timesteps = self.get_time_steps(skip_type=skip_type, t_T=t_T, t_0=t_0, N=steps, device=device) + assert timesteps.shape[0] - 1 == steps + with torch.no_grad(): + vec_t = timesteps[0].expand((x.shape[0])) + model_prev_list = [self.model_fn(x, vec_t)] + t_prev_list = [vec_t] + # Init the first `order` values by lower order multistep DPM-Solver. + for init_order in tqdm(range(1, order), desc="DPM init order"): + vec_t = timesteps[init_order].expand(x.shape[0]) + x = self.multistep_dpm_solver_update(x, model_prev_list, t_prev_list, vec_t, init_order, + solver_type=solver_type) + model_prev_list.append(self.model_fn(x, vec_t)) + t_prev_list.append(vec_t) + # Compute the remaining values by `order`-th order multistep DPM-Solver. + for step in tqdm(range(order, steps + 1), desc="DPM multistep"): + vec_t = timesteps[step].expand(x.shape[0]) + if lower_order_final and steps < 15: + step_order = min(order, steps + 1 - step) + else: + step_order = order + x = self.multistep_dpm_solver_update(x, model_prev_list, t_prev_list, vec_t, step_order, + solver_type=solver_type) + for i in range(order - 1): + t_prev_list[i] = t_prev_list[i + 1] + model_prev_list[i] = model_prev_list[i + 1] + t_prev_list[-1] = vec_t + # We do not need to evaluate the final model value. + if step < steps: + model_prev_list[-1] = self.model_fn(x, vec_t) + elif method in ['singlestep', 'singlestep_fixed']: + if method == 'singlestep': + timesteps_outer, orders = self.get_orders_and_timesteps_for_singlestep_solver(steps=steps, order=order, + skip_type=skip_type, + t_T=t_T, t_0=t_0, + device=device) + elif method == 'singlestep_fixed': + K = steps // order + orders = [order, ] * K + timesteps_outer = self.get_time_steps(skip_type=skip_type, t_T=t_T, t_0=t_0, N=K, device=device) + for i, order in enumerate(orders): + t_T_inner, t_0_inner = timesteps_outer[i], timesteps_outer[i + 1] + timesteps_inner = self.get_time_steps(skip_type=skip_type, t_T=t_T_inner.item(), t_0=t_0_inner.item(), + N=order, device=device) + lambda_inner = self.noise_schedule.marginal_lambda(timesteps_inner) + vec_s, vec_t = t_T_inner.tile(x.shape[0]), t_0_inner.tile(x.shape[0]) + h = lambda_inner[-1] - lambda_inner[0] + r1 = None if order <= 1 else (lambda_inner[1] - lambda_inner[0]) / h + r2 = None if order <= 2 else (lambda_inner[2] - lambda_inner[0]) / h + x = self.singlestep_dpm_solver_update(x, vec_s, vec_t, order, solver_type=solver_type, r1=r1, r2=r2) + if denoise_to_zero: + x = self.denoise_to_zero_fn(x, torch.ones((x.shape[0],)).to(device) * t_0) + return x + + +############################################################# +# other utility functions +############################################################# + +def interpolate_fn(x, xp, yp): + """ + A piecewise linear function y = f(x), using xp and yp as keypoints. + We implement f(x) in a differentiable way (i.e. applicable for autograd). + The function f(x) is well-defined for all x-axis. (For x beyond the bounds of xp, we use the outmost points of xp to define the linear function.) + Args: + x: PyTorch tensor with shape [N, C], where N is the batch size, C is the number of channels (we use C = 1 for DPM-Solver). + xp: PyTorch tensor with shape [C, K], where K is the number of keypoints. + yp: PyTorch tensor with shape [C, K]. + Returns: + The function values f(x), with shape [N, C]. + """ + N, K = x.shape[0], xp.shape[1] + all_x = torch.cat([x.unsqueeze(2), xp.unsqueeze(0).repeat((N, 1, 1))], dim=2) + sorted_all_x, x_indices = torch.sort(all_x, dim=2) + x_idx = torch.argmin(x_indices, dim=2) + cand_start_idx = x_idx - 1 + start_idx = torch.where( + torch.eq(x_idx, 0), + torch.tensor(1, device=x.device), + torch.where( + torch.eq(x_idx, K), torch.tensor(K - 2, device=x.device), cand_start_idx, + ), + ) + end_idx = torch.where(torch.eq(start_idx, cand_start_idx), start_idx + 2, start_idx + 1) + start_x = torch.gather(sorted_all_x, dim=2, index=start_idx.unsqueeze(2)).squeeze(2) + end_x = torch.gather(sorted_all_x, dim=2, index=end_idx.unsqueeze(2)).squeeze(2) + start_idx2 = torch.where( + torch.eq(x_idx, 0), + torch.tensor(0, device=x.device), + torch.where( + torch.eq(x_idx, K), torch.tensor(K - 2, device=x.device), cand_start_idx, + ), + ) + y_positions_expanded = yp.unsqueeze(0).expand(N, -1, -1) + start_y = torch.gather(y_positions_expanded, dim=2, index=start_idx2.unsqueeze(2)).squeeze(2) + end_y = torch.gather(y_positions_expanded, dim=2, index=(start_idx2 + 1).unsqueeze(2)).squeeze(2) + cand = start_y + (x - start_x) * (end_y - start_y) / (end_x - start_x) + return cand + + +def expand_dims(v, dims): + """ + Expand the tensor `v` to the dim `dims`. + Args: + `v`: a PyTorch tensor with shape [N]. + `dim`: a `int`. + Returns: + a PyTorch tensor with shape [N, 1, 1, ..., 1] and the total dimension is `dims`. + """ + return v[(...,) + (None,) * (dims - 1)] \ No newline at end of file diff --git a/examples/images/diffusion/ldm/models/diffusion/dpm_solver/sampler.py b/examples/images/diffusion/ldm/models/diffusion/dpm_solver/sampler.py new file mode 100644 index 000000000..7d137b8cf --- /dev/null +++ b/examples/images/diffusion/ldm/models/diffusion/dpm_solver/sampler.py @@ -0,0 +1,87 @@ +"""SAMPLING ONLY.""" +import torch + +from .dpm_solver import NoiseScheduleVP, model_wrapper, DPM_Solver + + +MODEL_TYPES = { + "eps": "noise", + "v": "v" +} + + +class DPMSolverSampler(object): + def __init__(self, model, **kwargs): + super().__init__() + self.model = model + to_torch = lambda x: x.clone().detach().to(torch.float32).to(model.device) + self.register_buffer('alphas_cumprod', to_torch(model.alphas_cumprod)) + + def register_buffer(self, name, attr): + if type(attr) == torch.Tensor: + if attr.device != torch.device("cuda"): + attr = attr.to(torch.device("cuda")) + setattr(self, name, attr) + + @torch.no_grad() + def sample(self, + S, + batch_size, + shape, + conditioning=None, + callback=None, + normals_sequence=None, + img_callback=None, + quantize_x0=False, + eta=0., + mask=None, + x0=None, + temperature=1., + noise_dropout=0., + score_corrector=None, + corrector_kwargs=None, + verbose=True, + x_T=None, + log_every_t=100, + unconditional_guidance_scale=1., + unconditional_conditioning=None, + # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ... + **kwargs + ): + if conditioning is not None: + if isinstance(conditioning, dict): + cbs = conditioning[list(conditioning.keys())[0]].shape[0] + if cbs != batch_size: + print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}") + else: + if conditioning.shape[0] != batch_size: + print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}") + + # sampling + C, H, W = shape + size = (batch_size, C, H, W) + + print(f'Data shape for DPM-Solver sampling is {size}, sampling steps {S}') + + device = self.model.betas.device + if x_T is None: + img = torch.randn(size, device=device) + else: + img = x_T + + ns = NoiseScheduleVP('discrete', alphas_cumprod=self.alphas_cumprod) + + model_fn = model_wrapper( + lambda x, t, c: self.model.apply_model(x, t, c), + ns, + model_type=MODEL_TYPES[self.model.parameterization], + guidance_type="classifier-free", + condition=conditioning, + unconditional_condition=unconditional_conditioning, + guidance_scale=unconditional_guidance_scale, + ) + + dpm_solver = DPM_Solver(model_fn, ns, predict_x0=True, thresholding=False) + x = dpm_solver.sample(img, steps=S, skip_type="time_uniform", method="multistep", order=2, lower_order_final=True) + + return x.to(device), None \ No newline at end of file diff --git a/examples/images/diffusion/ldm/models/diffusion/plms.py b/examples/images/diffusion/ldm/models/diffusion/plms.py index 78eeb1003..7002a365d 100644 --- a/examples/images/diffusion/ldm/models/diffusion/plms.py +++ b/examples/images/diffusion/ldm/models/diffusion/plms.py @@ -6,6 +6,7 @@ from tqdm import tqdm from functools import partial from ldm.modules.diffusionmodules.util import make_ddim_sampling_parameters, make_ddim_timesteps, noise_like +from ldm.models.diffusion.sampling_util import norm_thresholding class PLMSSampler(object): @@ -77,6 +78,7 @@ class PLMSSampler(object): unconditional_guidance_scale=1., unconditional_conditioning=None, # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ... + dynamic_threshold=None, **kwargs ): if conditioning is not None: @@ -108,6 +110,7 @@ class PLMSSampler(object): log_every_t=log_every_t, unconditional_guidance_scale=unconditional_guidance_scale, unconditional_conditioning=unconditional_conditioning, + dynamic_threshold=dynamic_threshold, ) return samples, intermediates @@ -117,7 +120,8 @@ class PLMSSampler(object): callback=None, timesteps=None, quantize_denoised=False, mask=None, x0=None, img_callback=None, log_every_t=100, temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None, - unconditional_guidance_scale=1., unconditional_conditioning=None,): + unconditional_guidance_scale=1., unconditional_conditioning=None, + dynamic_threshold=None): device = self.model.betas.device b = shape[0] if x_T is None: @@ -155,7 +159,8 @@ class PLMSSampler(object): corrector_kwargs=corrector_kwargs, unconditional_guidance_scale=unconditional_guidance_scale, unconditional_conditioning=unconditional_conditioning, - old_eps=old_eps, t_next=ts_next) + old_eps=old_eps, t_next=ts_next, + dynamic_threshold=dynamic_threshold) img, pred_x0, e_t = outs old_eps.append(e_t) if len(old_eps) >= 4: @@ -172,7 +177,8 @@ class PLMSSampler(object): @torch.no_grad() def p_sample_plms(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False, temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None, - unconditional_guidance_scale=1., unconditional_conditioning=None, old_eps=None, t_next=None): + unconditional_guidance_scale=1., unconditional_conditioning=None, old_eps=None, t_next=None, + dynamic_threshold=None): b, *_, device = *x.shape, x.device def get_model_output(x, t): @@ -207,6 +213,8 @@ class PLMSSampler(object): pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt() if quantize_denoised: pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0) + if dynamic_threshold is not None: + pred_x0 = norm_thresholding(pred_x0, dynamic_threshold) # direction pointing to x_t dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature diff --git a/examples/images/diffusion/ldm/models/diffusion/sampling_util.py b/examples/images/diffusion/ldm/models/diffusion/sampling_util.py new file mode 100644 index 000000000..7eff02be6 --- /dev/null +++ b/examples/images/diffusion/ldm/models/diffusion/sampling_util.py @@ -0,0 +1,22 @@ +import torch +import numpy as np + + +def append_dims(x, target_dims): + """Appends dimensions to the end of a tensor until it has target_dims dimensions. + From https://github.com/crowsonkb/k-diffusion/blob/master/k_diffusion/utils.py""" + dims_to_append = target_dims - x.ndim + if dims_to_append < 0: + raise ValueError(f'input has {x.ndim} dims but target_dims is {target_dims}, which is less') + return x[(...,) + (None,) * dims_to_append] + + +def norm_thresholding(x0, value): + s = append_dims(x0.pow(2).flatten(1).mean(1).sqrt().clamp(min=value), x0.ndim) + return x0 * (value / s) + + +def spatial_norm_thresholding(x0, value): + # b c h w + s = x0.pow(2).mean(1, keepdim=True).sqrt().clamp(min=value) + return x0 * (value / s) \ No newline at end of file diff --git a/examples/images/diffusion/ldm/modules/attention.py b/examples/images/diffusion/ldm/modules/attention.py index 3401ceafd..d504d939f 100644 --- a/examples/images/diffusion/ldm/modules/attention.py +++ b/examples/images/diffusion/ldm/modules/attention.py @@ -4,24 +4,17 @@ import torch import torch.nn.functional as F from torch import nn, einsum from einops import rearrange, repeat +from typing import Optional, Any + +from ldm.modules.diffusionmodules.util import checkpoint -from torch.utils import checkpoint try: - from ldm.modules.flash_attention import flash_attention_qkv, flash_attention_q_kv - FlASH_AVAILABLE = True + import xformers + import xformers.ops + XFORMERS_IS_AVAILBLE = True except: - FlASH_AVAILABLE = False - -USE_FLASH = False - - -def enable_flash_attention(): - global USE_FLASH - USE_FLASH = True - if FlASH_AVAILABLE is False: - print("Please install flash attention to activate new attention kernel.\n" + - "Use \'pip install git+https://github.com/HazyResearch/flash-attention.git@c422fee3776eb3ea24e011ef641fd5fbeb212623#egg=flash_attn\'") + XFORMERS_IS_AVAILBLE = False def exists(val): @@ -93,25 +86,6 @@ def Normalize(in_channels): return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True) -class LinearAttention(nn.Module): - def __init__(self, dim, heads=4, dim_head=32): - super().__init__() - self.heads = heads - hidden_dim = dim_head * heads - self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias = False) - self.to_out = nn.Conv2d(hidden_dim, dim, 1) - - def forward(self, x): - b, c, h, w = x.shape - qkv = self.to_qkv(x) - q, k, v = rearrange(qkv, 'b (qkv heads c) h w -> qkv b heads c (h w)', heads = self.heads, qkv=3) - k = k.softmax(dim=-1) - context = torch.einsum('bhdn,bhen->bhde', k, v) - out = torch.einsum('bhde,bhdn->bhen', context, q) - out = rearrange(out, 'b heads c (h w) -> b (heads c) h w', heads=self.heads, h=h, w=w) - return self.to_out(out) - - class SpatialSelfAttention(nn.Module): def __init__(self, in_channels): super().__init__() @@ -184,85 +158,111 @@ class CrossAttention(nn.Module): ) def forward(self, x, context=None, mask=None): + h = self.heads + q = self.to_q(x) context = default(context, x) k = self.to_k(context) v = self.to_v(context) - dim_head = q.shape[-1] / self.heads - if USE_FLASH and FlASH_AVAILABLE and q.dtype in (torch.float16, torch.bfloat16) and \ - dim_head <= 128 and (dim_head % 8) == 0: - # print("in flash") - if q.shape[1] == k.shape[1]: - out = self._flash_attention_qkv(q, k, v) - else: - out = self._flash_attention_q_kv(q, k, v) - else: - out = self._native_attention(q, k, v, self.heads, mask) - - return self.to_out(out) - - def _native_attention(self, q, k, v, h, mask): q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v)) + sim = einsum('b i d, b j d -> b i j', q, k) * self.scale + del q, k + if exists(mask): mask = rearrange(mask, 'b ... -> b (...)') max_neg_value = -torch.finfo(sim.dtype).max mask = repeat(mask, 'b j -> (b h) () j', h=h) sim.masked_fill_(~mask, max_neg_value) - # attention, what we cannot get enough of - out = sim.softmax(dim=-1) - out = einsum('b i j, b j d -> b i d', out, v) - out = rearrange(out, '(b h) n d -> b n (h d)', h=h) - return out - def _flash_attention_qkv(self, q, k, v): - qkv = torch.stack([q, k, v], dim=2) - b = qkv.shape[0] - n = qkv.shape[1] - qkv = rearrange(qkv, 'b n t (h d) -> (b n) t h d', h=self.heads) - out = flash_attention_qkv(qkv, self.scale, b, n) - out = rearrange(out, '(b n) h d -> b n (h d)', b=b, h=self.heads) - return out - - def _flash_attention_q_kv(self, q, k, v): - kv = torch.stack([k, v], dim=2) - b = q.shape[0] - q_seqlen = q.shape[1] - kv_seqlen = kv.shape[1] - q = rearrange(q, 'b n (h d) -> (b n) h d', h=self.heads) - kv = rearrange(kv, 'b n t (h d) -> (b n) t h d', h=self.heads) - out = flash_attention_q_kv(q, kv, self.scale, b, q_seqlen, kv_seqlen) - out = rearrange(out, '(b n) h d -> b n (h d)', b=b, h=self.heads) - return out + # attention, what we cannot get enough of + sim = sim.softmax(dim=-1) + + out = einsum('b i j, b j d -> b i d', sim, v) + out = rearrange(out, '(b h) n d -> b n (h d)', h=h) + return self.to_out(out) + + +class MemoryEfficientCrossAttention(nn.Module): + # https://github.com/MatthieuTPHR/diffusers/blob/d80b531ff8060ec1ea982b65a1b8df70f73aa67c/src/diffusers/models/attention.py#L223 + def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.0): + super().__init__() + print(f"Setting up {self.__class__.__name__}. Query dim is {query_dim}, context_dim is {context_dim} and using " + f"{heads} heads.") + inner_dim = dim_head * heads + context_dim = default(context_dim, query_dim) + + self.heads = heads + self.dim_head = dim_head + + self.to_q = nn.Linear(query_dim, inner_dim, bias=False) + self.to_k = nn.Linear(context_dim, inner_dim, bias=False) + self.to_v = nn.Linear(context_dim, inner_dim, bias=False) + + self.to_out = nn.Sequential(nn.Linear(inner_dim, query_dim), nn.Dropout(dropout)) + self.attention_op: Optional[Any] = None + + def forward(self, x, context=None, mask=None): + q = self.to_q(x) + context = default(context, x) + k = self.to_k(context) + v = self.to_v(context) + + b, _, _ = q.shape + q, k, v = map( + lambda t: t.unsqueeze(3) + .reshape(b, t.shape[1], self.heads, self.dim_head) + .permute(0, 2, 1, 3) + .reshape(b * self.heads, t.shape[1], self.dim_head) + .contiguous(), + (q, k, v), + ) + + # actually compute the attention, what we cannot get enough of + out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None, op=self.attention_op) + + if exists(mask): + raise NotImplementedError + out = ( + out.unsqueeze(0) + .reshape(b, self.heads, out.shape[1], self.dim_head) + .permute(0, 2, 1, 3) + .reshape(b, out.shape[1], self.heads * self.dim_head) + ) + return self.to_out(out) class BasicTransformerBlock(nn.Module): - def __init__(self, dim, n_heads, d_head, dropout=0., context_dim=None, gated_ff=True, use_checkpoint=False): + ATTENTION_MODES = { + "softmax": CrossAttention, # vanilla attention + "softmax-xformers": MemoryEfficientCrossAttention + } + def __init__(self, dim, n_heads, d_head, dropout=0., context_dim=None, gated_ff=True, checkpoint=True, + disable_self_attn=False): super().__init__() - self.attn1 = CrossAttention(query_dim=dim, heads=n_heads, dim_head=d_head, dropout=dropout) # is a self-attention + attn_mode = "softmax-xformers" if XFORMERS_IS_AVAILBLE else "softmax" + assert attn_mode in self.ATTENTION_MODES + attn_cls = self.ATTENTION_MODES[attn_mode] + self.disable_self_attn = disable_self_attn + self.attn1 = attn_cls(query_dim=dim, heads=n_heads, dim_head=d_head, dropout=dropout, + context_dim=context_dim if self.disable_self_attn else None) # is a self-attention if not self.disable_self_attn self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff) - self.attn2 = CrossAttention(query_dim=dim, context_dim=context_dim, - heads=n_heads, dim_head=d_head, dropout=dropout) # is self-attn if context is none + self.attn2 = attn_cls(query_dim=dim, context_dim=context_dim, + heads=n_heads, dim_head=d_head, dropout=dropout) # is self-attn if context is none self.norm1 = nn.LayerNorm(dim) self.norm2 = nn.LayerNorm(dim) self.norm3 = nn.LayerNorm(dim) - self.use_checkpoint = use_checkpoint + self.checkpoint = checkpoint def forward(self, x, context=None): - - - if self.use_checkpoint: - return checkpoint(self._forward, x, context) - else: - return self._forward(x, context) + return checkpoint(self._forward, (x, context), self.parameters(), self.checkpoint) def _forward(self, x, context=None): - x = self.attn1(self.norm1(x)) + x + x = self.attn1(self.norm1(x), context=context if self.disable_self_attn else None) + x x = self.attn2(self.norm2(x), context=context) + x x = self.ff(self.norm3(x)) + x return x - class SpatialTransformer(nn.Module): @@ -272,43 +272,60 @@ class SpatialTransformer(nn.Module): and reshape to b, t, d. Then apply standard transformer action. Finally, reshape to image + NEW: use_linear for more efficiency instead of the 1x1 convs """ def __init__(self, in_channels, n_heads, d_head, - depth=1, dropout=0., context_dim=None, use_checkpoint=False): + depth=1, dropout=0., context_dim=None, + disable_self_attn=False, use_linear=False, + use_checkpoint=True): super().__init__() + if exists(context_dim) and not isinstance(context_dim, list): + context_dim = [context_dim] self.in_channels = in_channels inner_dim = n_heads * d_head self.norm = Normalize(in_channels) - - self.proj_in = nn.Conv2d(in_channels, - inner_dim, - kernel_size=1, - stride=1, - padding=0) + if not use_linear: + self.proj_in = nn.Conv2d(in_channels, + inner_dim, + kernel_size=1, + stride=1, + padding=0) + else: + self.proj_in = nn.Linear(in_channels, inner_dim) self.transformer_blocks = nn.ModuleList( - [BasicTransformerBlock(inner_dim, n_heads, d_head, dropout=dropout, context_dim=context_dim, use_checkpoint=use_checkpoint) + [BasicTransformerBlock(inner_dim, n_heads, d_head, dropout=dropout, context_dim=context_dim[d], + disable_self_attn=disable_self_attn, checkpoint=use_checkpoint) for d in range(depth)] ) - - self.proj_out = zero_module(nn.Conv2d(inner_dim, - in_channels, - kernel_size=1, - stride=1, - padding=0)) - + if not use_linear: + self.proj_out = zero_module(nn.Conv2d(inner_dim, + in_channels, + kernel_size=1, + stride=1, + padding=0)) + else: + self.proj_out = zero_module(nn.Linear(in_channels, inner_dim)) + self.use_linear = use_linear def forward(self, x, context=None): # note: if no context is given, cross-attention defaults to self-attention + if not isinstance(context, list): + context = [context] b, c, h, w = x.shape x_in = x x = self.norm(x) - x = self.proj_in(x) - x = rearrange(x, 'b c h w -> b (h w) c') - x = x.contiguous() - for block in self.transformer_blocks: - x = block(x, context=context) - x = rearrange(x, 'b (h w) c -> b c h w', h=h, w=w) - x = x.contiguous() - x = self.proj_out(x) - return x + x_in \ No newline at end of file + if not self.use_linear: + x = self.proj_in(x) + x = rearrange(x, 'b c h w -> b (h w) c').contiguous() + if self.use_linear: + x = self.proj_in(x) + for i, block in enumerate(self.transformer_blocks): + x = block(x, context=context[i]) + if self.use_linear: + x = self.proj_out(x) + x = rearrange(x, 'b (h w) c -> b c h w', h=h, w=w).contiguous() + if not self.use_linear: + x = self.proj_out(x) + return x + x_in + diff --git a/examples/images/diffusion/ldm/modules/diffusionmodules/model.py b/examples/images/diffusion/ldm/modules/diffusionmodules/model.py index 3c28492c5..57b9a4b80 100644 --- a/examples/images/diffusion/ldm/modules/diffusionmodules/model.py +++ b/examples/images/diffusion/ldm/modules/diffusionmodules/model.py @@ -4,9 +4,22 @@ import torch import torch.nn as nn import numpy as np from einops import rearrange +from typing import Optional, Any -from ldm.util import instantiate_from_config -from ldm.modules.attention import LinearAttention +try: + from lightning.pytorch.utilities import rank_zero_info +except: + from pytorch_lightning.utilities import rank_zero_info + +from ldm.modules.attention import MemoryEfficientCrossAttention + +try: + import xformers + import xformers.ops + XFORMERS_IS_AVAILBLE = True +except: + XFORMERS_IS_AVAILBLE = False + print("No module 'xformers'. Proceeding without it.") def get_timestep_embedding(timesteps, embedding_dim): @@ -141,12 +154,6 @@ class ResnetBlock(nn.Module): return x+h -class LinAttnBlock(LinearAttention): - """to match AttnBlock usage""" - def __init__(self, in_channels): - super().__init__(dim=in_channels, heads=1, dim_head=in_channels) - - class AttnBlock(nn.Module): def __init__(self, in_channels): super().__init__() @@ -174,7 +181,6 @@ class AttnBlock(nn.Module): stride=1, padding=0) - def forward(self, x): h_ = x h_ = self.norm(h_) @@ -201,21 +207,100 @@ class AttnBlock(nn.Module): return x+h_ +class MemoryEfficientAttnBlock(nn.Module): + """ + Uses xformers efficient implementation, + see https://github.com/MatthieuTPHR/diffusers/blob/d80b531ff8060ec1ea982b65a1b8df70f73aa67c/src/diffusers/models/attention.py#L223 + Note: this is a single-head self-attention operation + """ + # + def __init__(self, in_channels): + super().__init__() + self.in_channels = in_channels -def make_attn(in_channels, attn_type="vanilla"): - assert attn_type in ["vanilla", "linear", "none"], f'attn_type {attn_type} unknown' - print(f"making attention of type '{attn_type}' with {in_channels} in_channels") + self.norm = Normalize(in_channels) + self.q = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=1, + stride=1, + padding=0) + self.k = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=1, + stride=1, + padding=0) + self.v = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=1, + stride=1, + padding=0) + self.proj_out = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=1, + stride=1, + padding=0) + self.attention_op: Optional[Any] = None + + def forward(self, x): + h_ = x + h_ = self.norm(h_) + q = self.q(h_) + k = self.k(h_) + v = self.v(h_) + + # compute attention + B, C, H, W = q.shape + q, k, v = map(lambda x: rearrange(x, 'b c h w -> b (h w) c'), (q, k, v)) + + q, k, v = map( + lambda t: t.unsqueeze(3) + .reshape(B, t.shape[1], 1, C) + .permute(0, 2, 1, 3) + .reshape(B * 1, t.shape[1], C) + .contiguous(), + (q, k, v), + ) + out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None, op=self.attention_op) + + out = ( + out.unsqueeze(0) + .reshape(B, 1, out.shape[1], C) + .permute(0, 2, 1, 3) + .reshape(B, out.shape[1], C) + ) + out = rearrange(out, 'b (h w) c -> b c h w', b=B, h=H, w=W, c=C) + out = self.proj_out(out) + return x+out + + +class MemoryEfficientCrossAttentionWrapper(MemoryEfficientCrossAttention): + def forward(self, x, context=None, mask=None): + b, c, h, w = x.shape + x = rearrange(x, 'b c h w -> b (h w) c') + out = super().forward(x, context=context, mask=mask) + out = rearrange(out, 'b (h w) c -> b c h w', h=h, w=w, c=c) + return x + out + + +def make_attn(in_channels, attn_type="vanilla", attn_kwargs=None): + assert attn_type in ["vanilla", "vanilla-xformers", "memory-efficient-cross-attn", "linear", "none"], f'attn_type {attn_type} unknown' + if XFORMERS_IS_AVAILBLE and attn_type == "vanilla": + attn_type = "vanilla-xformers" + rank_zero_info(f"making attention of type '{attn_type}' with {in_channels} in_channels") if attn_type == "vanilla": + assert attn_kwargs is None return AttnBlock(in_channels) + elif attn_type == "vanilla-xformers": + rank_zero_info(f"building MemoryEfficientAttnBlock with {in_channels} in_channels...") + return MemoryEfficientAttnBlock(in_channels) + elif type == "memory-efficient-cross-attn": + attn_kwargs["query_dim"] = in_channels + return MemoryEfficientCrossAttentionWrapper(**attn_kwargs) elif attn_type == "none": return nn.Identity(in_channels) else: - return LinAttnBlock(in_channels) + raise NotImplementedError() -class temb_module(nn.Module): - def __init__(self): - super().__init__() - pass class Model(nn.Module): def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks, @@ -233,8 +318,7 @@ class Model(nn.Module): self.use_timestep = use_timestep if self.use_timestep: # timestep embedding - # self.temb = nn.Module() - self.temb = temb_module() + self.temb = nn.Module() self.temb.dense = nn.ModuleList([ torch.nn.Linear(self.ch, self.temb_ch), @@ -265,8 +349,7 @@ class Model(nn.Module): block_in = block_out if curr_res in attn_resolutions: attn.append(make_attn(block_in, attn_type=attn_type)) - # down = nn.Module() - down = Down_module() + down = nn.Module() down.block = block down.attn = attn if i_level != self.num_resolutions-1: @@ -275,8 +358,7 @@ class Model(nn.Module): self.down.append(down) # middle - # self.mid = nn.Module() - self.mid = Mid_module() + self.mid = nn.Module() self.mid.block_1 = ResnetBlock(in_channels=block_in, out_channels=block_in, temb_channels=self.temb_ch, @@ -304,8 +386,7 @@ class Model(nn.Module): block_in = block_out if curr_res in attn_resolutions: attn.append(make_attn(block_in, attn_type=attn_type)) - # up = nn.Module() - up = Up_module() + up = nn.Module() up.block = block up.attn = attn if i_level != 0: @@ -372,21 +453,6 @@ class Model(nn.Module): def get_last_layer(self): return self.conv_out.weight -class Down_module(nn.Module): - def __init__(self): - super().__init__() - pass - -class Up_module(nn.Module): - def __init__(self): - super().__init__() - pass - -class Mid_module(nn.Module): - def __init__(self): - super().__init__() - pass - class Encoder(nn.Module): def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks, @@ -426,8 +492,7 @@ class Encoder(nn.Module): block_in = block_out if curr_res in attn_resolutions: attn.append(make_attn(block_in, attn_type=attn_type)) - # down = nn.Module() - down = Down_module() + down = nn.Module() down.block = block down.attn = attn if i_level != self.num_resolutions-1: @@ -436,8 +501,7 @@ class Encoder(nn.Module): self.down.append(down) # middle - # self.mid = nn.Module() - self.mid = Mid_module() + self.mid = nn.Module() self.mid.block_1 = ResnetBlock(in_channels=block_in, out_channels=block_in, temb_channels=self.temb_ch, @@ -505,7 +569,7 @@ class Decoder(nn.Module): block_in = ch*ch_mult[self.num_resolutions-1] curr_res = resolution // 2**(self.num_resolutions-1) self.z_shape = (1,z_channels,curr_res,curr_res) - print("Working with z of shape {} = {} dimensions.".format( + rank_zero_info("Working with z of shape {} = {} dimensions.".format( self.z_shape, np.prod(self.z_shape))) # z to block_in @@ -516,8 +580,7 @@ class Decoder(nn.Module): padding=1) # middle - # self.mid = nn.Module() - self.mid = Mid_module() + self.mid = nn.Module() self.mid.block_1 = ResnetBlock(in_channels=block_in, out_channels=block_in, temb_channels=self.temb_ch, @@ -542,8 +605,7 @@ class Decoder(nn.Module): block_in = block_out if curr_res in attn_resolutions: attn.append(make_attn(block_in, attn_type=attn_type)) - # up = nn.Module() - up = Up_module() + up = nn.Module() up.block = block up.attn = attn if i_level != 0: @@ -758,7 +820,7 @@ class Upsampler(nn.Module): assert out_size >= in_size num_blocks = int(np.log2(out_size//in_size))+1 factor_up = 1.+ (out_size % in_size) - print(f"Building {self.__class__.__name__} with in_size: {in_size} --> out_size {out_size} and factor {factor_up}") + rank_zero_info(f"Building {self.__class__.__name__} with in_size: {in_size} --> out_size {out_size} and factor {factor_up}") self.rescaler = LatentRescaler(factor=factor_up, in_channels=in_channels, mid_channels=2*in_channels, out_channels=in_channels) self.decoder = Decoder(out_ch=out_channels, resolution=out_size, z_channels=in_channels, num_res_blocks=2, @@ -777,7 +839,7 @@ class Resize(nn.Module): self.with_conv = learned self.mode = mode if self.with_conv: - print(f"Note: {self.__class__.__name} uses learned downsampling and will ignore the fixed {mode} mode") + rank_zero_info(f"Note: {self.__class__.__name} uses learned downsampling and will ignore the fixed {mode} mode") raise NotImplementedError() assert in_channels is not None # no asymmetric padding in torch conv, must do it ourselves @@ -793,70 +855,3 @@ class Resize(nn.Module): else: x = torch.nn.functional.interpolate(x, mode=self.mode, align_corners=False, scale_factor=scale_factor) return x - -class FirstStagePostProcessor(nn.Module): - - def __init__(self, ch_mult:list, in_channels, - pretrained_model:nn.Module=None, - reshape=False, - n_channels=None, - dropout=0., - pretrained_config=None): - super().__init__() - if pretrained_config is None: - assert pretrained_model is not None, 'Either "pretrained_model" or "pretrained_config" must not be None' - self.pretrained_model = pretrained_model - else: - assert pretrained_config is not None, 'Either "pretrained_model" or "pretrained_config" must not be None' - self.instantiate_pretrained(pretrained_config) - - self.do_reshape = reshape - - if n_channels is None: - n_channels = self.pretrained_model.encoder.ch - - self.proj_norm = Normalize(in_channels,num_groups=in_channels//2) - self.proj = nn.Conv2d(in_channels,n_channels,kernel_size=3, - stride=1,padding=1) - - blocks = [] - downs = [] - ch_in = n_channels - for m in ch_mult: - blocks.append(ResnetBlock(in_channels=ch_in,out_channels=m*n_channels,dropout=dropout)) - ch_in = m * n_channels - downs.append(Downsample(ch_in, with_conv=False)) - - self.model = nn.ModuleList(blocks) - self.downsampler = nn.ModuleList(downs) - - - def instantiate_pretrained(self, config): - model = instantiate_from_config(config) - self.pretrained_model = model.eval() - # self.pretrained_model.train = False - for param in self.pretrained_model.parameters(): - param.requires_grad = False - - - @torch.no_grad() - def encode_with_pretrained(self,x): - c = self.pretrained_model.encode(x) - if isinstance(c, DiagonalGaussianDistribution): - c = c.mode() - return c - - def forward(self,x): - z_fs = self.encode_with_pretrained(x) - z = self.proj_norm(z_fs) - z = self.proj(z) - z = nonlinearity(z) - - for submodel, downmodel in zip(self.model,self.downsampler): - z = submodel(z,temb=None) - z = downmodel(z) - - if self.do_reshape: - z = rearrange(z,'b c h w -> b (h w) c') - return z - diff --git a/examples/images/diffusion/ldm/modules/diffusionmodules/openaimodel.py b/examples/images/diffusion/ldm/modules/diffusionmodules/openaimodel.py index 3aedc2205..cd639d936 100644 --- a/examples/images/diffusion/ldm/modules/diffusionmodules/openaimodel.py +++ b/examples/images/diffusion/ldm/modules/diffusionmodules/openaimodel.py @@ -1,16 +1,13 @@ from abc import abstractmethod -from functools import partial import math -from typing import Iterable import numpy as np -import torch import torch as th import torch.nn as nn import torch.nn.functional as F -from torch.utils import checkpoint from ldm.modules.diffusionmodules.util import ( + checkpoint, conv_nd, linear, avg_pool_nd, @@ -19,13 +16,11 @@ from ldm.modules.diffusionmodules.util import ( timestep_embedding, ) from ldm.modules.attention import SpatialTransformer +from ldm.util import exists # dummy replace def convert_module_to_f16(x): - # for n,p in x.named_parameter(): - # print(f"convert module {n} to_f16") - # p.data = p.data.half() pass def convert_module_to_f32(x): @@ -251,10 +246,9 @@ class ResBlock(TimestepBlock): :param emb: an [N x emb_channels] Tensor of timestep embeddings. :return: an [N x C x ...] Tensor of outputs. """ - if self.use_checkpoint: - return checkpoint(self._forward, x, emb) - else: - return self._forward(x, emb) + return checkpoint( + self._forward, (x, emb), self.parameters(), self.use_checkpoint + ) def _forward(self, x, emb): @@ -317,11 +311,8 @@ class AttentionBlock(nn.Module): self.proj_out = zero_module(conv_nd(1, channels, channels, 1)) def forward(self, x): - if self.use_checkpoint: - return checkpoint(self._forward, x) # TODO: check checkpoint usage, is True # TODO: fix the .half call!!! + return checkpoint(self._forward, (x,), self.parameters(), True) # TODO: check checkpoint usage, is True # TODO: fix the .half call!!! #return pt_checkpoint(self._forward, x) # pytorch - else: - return self._forward(x) def _forward(self, x): b, c, *spatial = x.shape @@ -474,7 +465,10 @@ class UNetModel(nn.Module): context_dim=None, # custom transformer support n_embed=None, # custom support for prediction of discrete ids into codebook of first stage vq model legacy=True, - from_pretrained: str=None + disable_self_attentions=None, + num_attention_blocks=None, + disable_middle_self_attn=False, + use_linear_in_transformer=False, ): super().__init__() if use_spatial_transformer: @@ -499,7 +493,24 @@ class UNetModel(nn.Module): self.in_channels = in_channels self.model_channels = model_channels self.out_channels = out_channels - self.num_res_blocks = num_res_blocks + if isinstance(num_res_blocks, int): + self.num_res_blocks = len(channel_mult) * [num_res_blocks] + else: + if len(num_res_blocks) != len(channel_mult): + raise ValueError("provide num_res_blocks either as an int (globally constant) or " + "as a list/tuple (per-level) with the same length as channel_mult") + self.num_res_blocks = num_res_blocks + if disable_self_attentions is not None: + # should be a list of booleans, indicating whether to disable self-attention in TransformerBlocks or not + assert len(disable_self_attentions) == len(channel_mult) + if num_attention_blocks is not None: + assert len(num_attention_blocks) == len(self.num_res_blocks) + assert all(map(lambda i: self.num_res_blocks[i] >= num_attention_blocks[i], range(len(num_attention_blocks)))) + print(f"Constructor of UNetModel received num_attention_blocks={num_attention_blocks}. " + f"This option has LESS priority than attention_resolutions {attention_resolutions}, " + f"i.e., in cases where num_attention_blocks[i] > 0 but 2**i not in attention_resolutions, " + f"attention will still not be set.") + self.attention_resolutions = attention_resolutions self.dropout = dropout self.channel_mult = channel_mult @@ -520,7 +531,13 @@ class UNetModel(nn.Module): ) if self.num_classes is not None: - self.label_emb = nn.Embedding(num_classes, time_embed_dim) + if isinstance(self.num_classes, int): + self.label_emb = nn.Embedding(num_classes, time_embed_dim) + elif self.num_classes == "continuous": + print("setting up linear c_adm embedding layer") + self.label_emb = nn.Linear(1, time_embed_dim) + else: + raise ValueError() self.input_blocks = nn.ModuleList( [ @@ -534,7 +551,7 @@ class UNetModel(nn.Module): ch = model_channels ds = 1 for level, mult in enumerate(channel_mult): - for _ in range(num_res_blocks): + for nr in range(self.num_res_blocks[level]): layers = [ ResBlock( ch, @@ -556,17 +573,25 @@ class UNetModel(nn.Module): if legacy: #num_heads = 1 dim_head = ch // num_heads if use_spatial_transformer else num_head_channels - layers.append( - AttentionBlock( - ch, - use_checkpoint=use_checkpoint, - num_heads=num_heads, - num_head_channels=dim_head, - use_new_attention_order=use_new_attention_order, - ) if not use_spatial_transformer else SpatialTransformer( - ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim, use_checkpoint=use_checkpoint, + if exists(disable_self_attentions): + disabled_sa = disable_self_attentions[level] + else: + disabled_sa = False + + if not exists(num_attention_blocks) or nr < num_attention_blocks[level]: + layers.append( + AttentionBlock( + ch, + use_checkpoint=use_checkpoint, + num_heads=num_heads, + num_head_channels=dim_head, + use_new_attention_order=use_new_attention_order, + ) if not use_spatial_transformer else SpatialTransformer( + ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim, + disable_self_attn=disabled_sa, use_linear=use_linear_in_transformer, + use_checkpoint=use_checkpoint + ) ) - ) self.input_blocks.append(TimestepEmbedSequential(*layers)) self._feature_size += ch input_block_chans.append(ch) @@ -618,8 +643,10 @@ class UNetModel(nn.Module): num_heads=num_heads, num_head_channels=dim_head, use_new_attention_order=use_new_attention_order, - ) if not use_spatial_transformer else SpatialTransformer( - ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim + ) if not use_spatial_transformer else SpatialTransformer( # always uses a self-attn + ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim, + disable_self_attn=disable_middle_self_attn, use_linear=use_linear_in_transformer, + use_checkpoint=use_checkpoint ), ResBlock( ch, @@ -634,7 +661,7 @@ class UNetModel(nn.Module): self.output_blocks = nn.ModuleList([]) for level, mult in list(enumerate(channel_mult))[::-1]: - for i in range(num_res_blocks + 1): + for i in range(self.num_res_blocks[level] + 1): ich = input_block_chans.pop() layers = [ ResBlock( @@ -657,18 +684,26 @@ class UNetModel(nn.Module): if legacy: #num_heads = 1 dim_head = ch // num_heads if use_spatial_transformer else num_head_channels - layers.append( - AttentionBlock( - ch, - use_checkpoint=use_checkpoint, - num_heads=num_heads_upsample, - num_head_channels=dim_head, - use_new_attention_order=use_new_attention_order, - ) if not use_spatial_transformer else SpatialTransformer( - ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim + if exists(disable_self_attentions): + disabled_sa = disable_self_attentions[level] + else: + disabled_sa = False + + if not exists(num_attention_blocks) or i < num_attention_blocks[level]: + layers.append( + AttentionBlock( + ch, + use_checkpoint=use_checkpoint, + num_heads=num_heads_upsample, + num_head_channels=dim_head, + use_new_attention_order=use_new_attention_order, + ) if not use_spatial_transformer else SpatialTransformer( + ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim, + disable_self_attn=disabled_sa, use_linear=use_linear_in_transformer, + use_checkpoint=use_checkpoint + ) ) - ) - if level and i == num_res_blocks: + if level and i == self.num_res_blocks[level]: out_ch = ch layers.append( ResBlock( @@ -699,188 +734,6 @@ class UNetModel(nn.Module): conv_nd(dims, model_channels, n_embed, 1), #nn.LogSoftmax(dim=1) # change to cross_entropy and produce non-normalized logits ) - # if use_fp16: - # self.convert_to_fp16() - from diffusers.modeling_utils import load_state_dict - if from_pretrained is not None: - state_dict = load_state_dict(from_pretrained) - self._load_pretrained_model(state_dict) - - def _input_blocks_mapping(self, input_dict): - res_dict = {} - for key_, value_ in input_dict.items(): - id_0 = int(key_[13]) - if "resnets" in key_: - id_1 = int(key_[23]) - target_id = 3 * id_0 + 1 + id_1 - post_fix = key_[25:].replace('time_emb_proj', 'emb_layers.1')\ - .replace('norm1', 'in_layers.0')\ - .replace('norm2', 'out_layers.0')\ - .replace('conv1', 'in_layers.2')\ - .replace('conv2', 'out_layers.3')\ - .replace('conv_shortcut', 'skip_connection') - res_dict["input_blocks." + str(target_id) + '.0.' + post_fix] = value_ - elif "attentions" in key_: - id_1 = int(key_[26]) - target_id = 3 * id_0 + 1 + id_1 - post_fix = key_[28:] - res_dict["input_blocks." + str(target_id) + '.1.' + post_fix] = value_ - elif "downsamplers" in key_: - post_fix = key_[35:] - target_id = 3 * (id_0 + 1) - res_dict["input_blocks." + str(target_id) + '.0.op.' + post_fix] = value_ - return res_dict - - - def _mid_blocks_mapping(self, mid_dict): - res_dict = {} - for key_, value_ in mid_dict.items(): - if "resnets" in key_: - temp_key_ =key_.replace('time_emb_proj', 'emb_layers.1') \ - .replace('norm1', 'in_layers.0') \ - .replace('norm2', 'out_layers.0') \ - .replace('conv1', 'in_layers.2') \ - .replace('conv2', 'out_layers.3') \ - .replace('conv_shortcut', 'skip_connection')\ - .replace('middle_block.resnets.0', 'middle_block.0')\ - .replace('middle_block.resnets.1', 'middle_block.2') - res_dict[temp_key_] = value_ - elif "attentions" in key_: - res_dict[key_.replace('attentions.0', '1')] = value_ - return res_dict - - def _other_blocks_mapping(self, other_dict): - res_dict = {} - for key_, value_ in other_dict.items(): - tmp_key = key_.replace('conv_in', 'input_blocks.0.0')\ - .replace('time_embedding.linear_1', 'time_embed.0')\ - .replace('time_embedding.linear_2', 'time_embed.2')\ - .replace('conv_norm_out', 'out.0')\ - .replace('conv_out', 'out.2') - res_dict[tmp_key] = value_ - return res_dict - - - def _output_blocks_mapping(self, output_dict): - res_dict = {} - for key_, value_ in output_dict.items(): - id_0 = int(key_[14]) - if "resnets" in key_: - id_1 = int(key_[24]) - target_id = 3 * id_0 + id_1 - post_fix = key_[26:].replace('time_emb_proj', 'emb_layers.1') \ - .replace('norm1', 'in_layers.0') \ - .replace('norm2', 'out_layers.0') \ - .replace('conv1', 'in_layers.2') \ - .replace('conv2', 'out_layers.3') \ - .replace('conv_shortcut', 'skip_connection') - res_dict["output_blocks." + str(target_id) + '.0.' + post_fix] = value_ - elif "attentions" in key_: - id_1 = int(key_[27]) - target_id = 3 * id_0 + id_1 - post_fix = key_[29:] - res_dict["output_blocks." + str(target_id) + '.1.' + post_fix] = value_ - elif "upsamplers" in key_: - post_fix = key_[34:] - target_id = 3 * (id_0 + 1) - 1 - mid_str = '.2.conv.' if target_id != 2 else '.1.conv.' - res_dict["output_blocks." + str(target_id) + mid_str + post_fix] = value_ - return res_dict - - def _state_key_mapping(self, state_dict: dict): - import re - res_dict = {} - input_dict = {} - mid_dict = {} - output_dict = {} - other_dict = {} - for key_, value_ in state_dict.items(): - if "down_blocks" in key_: - input_dict[key_.replace('down_blocks', 'input_blocks')] = value_ - elif "up_blocks" in key_: - output_dict[key_.replace('up_blocks', 'output_blocks')] = value_ - elif "mid_block" in key_: - mid_dict[key_.replace('mid_block', 'middle_block')] = value_ - else: - other_dict[key_] = value_ - - input_dict = self._input_blocks_mapping(input_dict) - output_dict = self._output_blocks_mapping(output_dict) - mid_dict = self._mid_blocks_mapping(mid_dict) - other_dict = self._other_blocks_mapping(other_dict) - # key_list = state_dict.keys() - # key_str = " ".join(key_list) - - # for key_, val_ in state_dict.items(): - # key_ = key_.replace("down_blocks", "input_blocks")\ - # .replace("up_blocks", 'output_blocks') - # res_dict[key_] = val_ - res_dict.update(input_dict) - res_dict.update(output_dict) - res_dict.update(mid_dict) - res_dict.update(other_dict) - - return res_dict - - def _load_pretrained_model(self, state_dict, ignore_mismatched_sizes=False): - state_dict = self._state_key_mapping(state_dict) - model_state_dict = self.state_dict() - loaded_keys = [k for k in state_dict.keys()] - expected_keys = list(model_state_dict.keys()) - original_loaded_keys = loaded_keys - missing_keys = list(set(expected_keys) - set(loaded_keys)) - unexpected_keys = list(set(loaded_keys) - set(expected_keys)) - - def _find_mismatched_keys( - state_dict, - model_state_dict, - loaded_keys, - ignore_mismatched_sizes, - ): - mismatched_keys = [] - if ignore_mismatched_sizes: - for checkpoint_key in loaded_keys: - model_key = checkpoint_key - - if ( - model_key in model_state_dict - and state_dict[checkpoint_key].shape != model_state_dict[model_key].shape - ): - mismatched_keys.append( - (checkpoint_key, state_dict[checkpoint_key].shape, model_state_dict[model_key].shape) - ) - del state_dict[checkpoint_key] - return mismatched_keys - if state_dict is not None: - # Whole checkpoint - mismatched_keys = _find_mismatched_keys( - state_dict, - model_state_dict, - original_loaded_keys, - ignore_mismatched_sizes, - ) - error_msgs = self._load_state_dict_into_model(state_dict) - return missing_keys, unexpected_keys, mismatched_keys, error_msgs - - def _load_state_dict_into_model(self, state_dict): - # Convert old format to new format if needed from a PyTorch state_dict - # copy state_dict so _load_from_state_dict can modify it - state_dict = state_dict.copy() - error_msgs = [] - - # PyTorch's `_load_from_state_dict` does not copy parameters in a module's descendants - # so we need to apply the function recursively. - def load(module: torch.nn.Module, prefix=""): - args = (state_dict, prefix, {}, True, [], [], error_msgs) - module._load_from_state_dict(*args) - - for name, child in module._modules.items(): - if child is not None: - load(child, prefix + name + ".") - - load(self) - - return error_msgs def convert_to_fp16(self): """ @@ -912,10 +765,11 @@ class UNetModel(nn.Module): ), "must specify y if and only if the model is class-conditional" hs = [] t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False) + t_emb = t_emb.type(self.dtype) emb = self.time_embed(t_emb) if self.num_classes is not None: - assert y.shape == (x.shape[0],) + assert y.shape[0] == x.shape[0] emb = emb + self.label_emb(y) h = x.type(self.dtype) @@ -926,227 +780,8 @@ class UNetModel(nn.Module): for module in self.output_blocks: h = th.cat([h, hs.pop()], dim=1) h = module(h, emb, context) - h = h.type(self.dtype) + h = h.type(x.dtype) if self.predict_codebook_ids: return self.id_predictor(h) else: return self.out(h) - - -class EncoderUNetModel(nn.Module): - """ - The half UNet model with attention and timestep embedding. - For usage, see UNet. - """ - - def __init__( - self, - image_size, - in_channels, - model_channels, - out_channels, - num_res_blocks, - attention_resolutions, - dropout=0, - channel_mult=(1, 2, 4, 8), - conv_resample=True, - dims=2, - use_checkpoint=False, - use_fp16=False, - num_heads=1, - num_head_channels=-1, - num_heads_upsample=-1, - use_scale_shift_norm=False, - resblock_updown=False, - use_new_attention_order=False, - pool="adaptive", - *args, - **kwargs - ): - super().__init__() - - if num_heads_upsample == -1: - num_heads_upsample = num_heads - - self.in_channels = in_channels - self.model_channels = model_channels - self.out_channels = out_channels - self.num_res_blocks = num_res_blocks - self.attention_resolutions = attention_resolutions - self.dropout = dropout - self.channel_mult = channel_mult - self.conv_resample = conv_resample - self.use_checkpoint = use_checkpoint - self.dtype = th.float16 if use_fp16 else th.float32 - self.num_heads = num_heads - self.num_head_channels = num_head_channels - self.num_heads_upsample = num_heads_upsample - - time_embed_dim = model_channels * 4 - self.time_embed = nn.Sequential( - linear(model_channels, time_embed_dim), - nn.SiLU(), - linear(time_embed_dim, time_embed_dim), - ) - - self.input_blocks = nn.ModuleList( - [ - TimestepEmbedSequential( - conv_nd(dims, in_channels, model_channels, 3, padding=1) - ) - ] - ) - self._feature_size = model_channels - input_block_chans = [model_channels] - ch = model_channels - ds = 1 - for level, mult in enumerate(channel_mult): - for _ in range(num_res_blocks): - layers = [ - ResBlock( - ch, - time_embed_dim, - dropout, - out_channels=mult * model_channels, - dims=dims, - use_checkpoint=use_checkpoint, - use_scale_shift_norm=use_scale_shift_norm, - ) - ] - ch = mult * model_channels - if ds in attention_resolutions: - layers.append( - AttentionBlock( - ch, - use_checkpoint=use_checkpoint, - num_heads=num_heads, - num_head_channels=num_head_channels, - use_new_attention_order=use_new_attention_order, - ) - ) - self.input_blocks.append(TimestepEmbedSequential(*layers)) - self._feature_size += ch - input_block_chans.append(ch) - if level != len(channel_mult) - 1: - out_ch = ch - self.input_blocks.append( - TimestepEmbedSequential( - ResBlock( - ch, - time_embed_dim, - dropout, - out_channels=out_ch, - dims=dims, - use_checkpoint=use_checkpoint, - use_scale_shift_norm=use_scale_shift_norm, - down=True, - ) - if resblock_updown - else Downsample( - ch, conv_resample, dims=dims, out_channels=out_ch - ) - ) - ) - ch = out_ch - input_block_chans.append(ch) - ds *= 2 - self._feature_size += ch - - self.middle_block = TimestepEmbedSequential( - ResBlock( - ch, - time_embed_dim, - dropout, - dims=dims, - use_checkpoint=use_checkpoint, - use_scale_shift_norm=use_scale_shift_norm, - ), - AttentionBlock( - ch, - use_checkpoint=use_checkpoint, - num_heads=num_heads, - num_head_channels=num_head_channels, - use_new_attention_order=use_new_attention_order, - ), - ResBlock( - ch, - time_embed_dim, - dropout, - dims=dims, - use_checkpoint=use_checkpoint, - use_scale_shift_norm=use_scale_shift_norm, - ), - ) - self._feature_size += ch - self.pool = pool - if pool == "adaptive": - self.out = nn.Sequential( - normalization(ch), - nn.SiLU(), - nn.AdaptiveAvgPool2d((1, 1)), - zero_module(conv_nd(dims, ch, out_channels, 1)), - nn.Flatten(), - ) - elif pool == "attention": - assert num_head_channels != -1 - self.out = nn.Sequential( - normalization(ch), - nn.SiLU(), - AttentionPool2d( - (image_size // ds), ch, num_head_channels, out_channels - ), - ) - elif pool == "spatial": - self.out = nn.Sequential( - nn.Linear(self._feature_size, 2048), - nn.ReLU(), - nn.Linear(2048, self.out_channels), - ) - elif pool == "spatial_v2": - self.out = nn.Sequential( - nn.Linear(self._feature_size, 2048), - normalization(2048), - nn.SiLU(), - nn.Linear(2048, self.out_channels), - ) - else: - raise NotImplementedError(f"Unexpected {pool} pooling") - - def convert_to_fp16(self): - """ - Convert the torso of the model to float16. - """ - self.input_blocks.apply(convert_module_to_f16) - self.middle_block.apply(convert_module_to_f16) - - def convert_to_fp32(self): - """ - Convert the torso of the model to float32. - """ - self.input_blocks.apply(convert_module_to_f32) - self.middle_block.apply(convert_module_to_f32) - - def forward(self, x, timesteps): - """ - Apply the model to an input batch. - :param x: an [N x C x ...] Tensor of inputs. - :param timesteps: a 1-D batch of timesteps. - :return: an [N x K] Tensor of outputs. - """ - emb = self.time_embed(timestep_embedding(timesteps, self.model_channels)) - - results = [] - h = x.type(self.dtype) - for module in self.input_blocks: - h = module(h, emb) - if self.pool.startswith("spatial"): - results.append(h.type(x.dtype).mean(dim=(2, 3))) - h = self.middle_block(h, emb) - if self.pool.startswith("spatial"): - results.append(h.type(x.dtype).mean(dim=(2, 3))) - h = th.cat(results, axis=-1) - return self.out(h) - else: - h = h.type(self.dtype) - return self.out(h) - diff --git a/examples/images/diffusion/ldm/modules/diffusionmodules/upscaling.py b/examples/images/diffusion/ldm/modules/diffusionmodules/upscaling.py new file mode 100644 index 000000000..038166620 --- /dev/null +++ b/examples/images/diffusion/ldm/modules/diffusionmodules/upscaling.py @@ -0,0 +1,81 @@ +import torch +import torch.nn as nn +import numpy as np +from functools import partial + +from ldm.modules.diffusionmodules.util import extract_into_tensor, make_beta_schedule +from ldm.util import default + + +class AbstractLowScaleModel(nn.Module): + # for concatenating a downsampled image to the latent representation + def __init__(self, noise_schedule_config=None): + super(AbstractLowScaleModel, self).__init__() + if noise_schedule_config is not None: + self.register_schedule(**noise_schedule_config) + + def register_schedule(self, beta_schedule="linear", timesteps=1000, + linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3): + betas = make_beta_schedule(beta_schedule, timesteps, linear_start=linear_start, linear_end=linear_end, + cosine_s=cosine_s) + alphas = 1. - betas + alphas_cumprod = np.cumprod(alphas, axis=0) + alphas_cumprod_prev = np.append(1., alphas_cumprod[:-1]) + + timesteps, = betas.shape + self.num_timesteps = int(timesteps) + self.linear_start = linear_start + self.linear_end = linear_end + assert alphas_cumprod.shape[0] == self.num_timesteps, 'alphas have to be defined for each timestep' + + to_torch = partial(torch.tensor, dtype=torch.float32) + + self.register_buffer('betas', to_torch(betas)) + self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod)) + self.register_buffer('alphas_cumprod_prev', to_torch(alphas_cumprod_prev)) + + # calculations for diffusion q(x_t | x_{t-1}) and others + self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod))) + self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod))) + self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod))) + self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod))) + self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod - 1))) + + def q_sample(self, x_start, t, noise=None): + noise = default(noise, lambda: torch.randn_like(x_start)) + return (extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start + + extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise) + + def forward(self, x): + return x, None + + def decode(self, x): + return x + + +class SimpleImageConcat(AbstractLowScaleModel): + # no noise level conditioning + def __init__(self): + super(SimpleImageConcat, self).__init__(noise_schedule_config=None) + self.max_noise_level = 0 + + def forward(self, x): + # fix to constant noise level + return x, torch.zeros(x.shape[0], device=x.device).long() + + +class ImageConcatWithNoiseAugmentation(AbstractLowScaleModel): + def __init__(self, noise_schedule_config, max_noise_level=1000, to_cuda=False): + super().__init__(noise_schedule_config=noise_schedule_config) + self.max_noise_level = max_noise_level + + def forward(self, x, noise_level=None): + if noise_level is None: + noise_level = torch.randint(0, self.max_noise_level, (x.shape[0],), device=x.device).long() + else: + assert isinstance(noise_level, torch.Tensor) + z = self.q_sample(x, noise_level) + return z, noise_level + + + diff --git a/examples/images/diffusion/ldm/modules/diffusionmodules/util.py b/examples/images/diffusion/ldm/modules/diffusionmodules/util.py index a7db9369c..e0621032d 100644 --- a/examples/images/diffusion/ldm/modules/diffusionmodules/util.py +++ b/examples/images/diffusion/ldm/modules/diffusionmodules/util.py @@ -122,7 +122,9 @@ class CheckpointFunction(torch.autograd.Function): ctx.run_function = run_function ctx.input_tensors = list(args[:length]) ctx.input_params = list(args[length:]) - + ctx.gpu_autocast_kwargs = {"enabled": torch.is_autocast_enabled(), + "dtype": torch.get_autocast_gpu_dtype(), + "cache_enabled": torch.is_autocast_cache_enabled()} with torch.no_grad(): output_tensors = ctx.run_function(*ctx.input_tensors) return output_tensors @@ -130,7 +132,8 @@ class CheckpointFunction(torch.autograd.Function): @staticmethod def backward(ctx, *output_grads): ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors] - with torch.enable_grad(): + with torch.enable_grad(), \ + torch.cuda.amp.autocast(**ctx.gpu_autocast_kwargs): # Fixes a bug where the first op in run_function modifies the # Tensor storage in place, which is not allowed for detach()'d # Tensors. @@ -148,7 +151,7 @@ class CheckpointFunction(torch.autograd.Function): return (None, None) + input_grads -def timestep_embedding(timesteps, dim, max_period=10000, repeat_only=False, use_fp16=True): +def timestep_embedding(timesteps, dim, max_period=10000, repeat_only=False): """ Create sinusoidal timestep embeddings. :param timesteps: a 1-D Tensor of N indices, one per batch element. @@ -168,10 +171,7 @@ def timestep_embedding(timesteps, dim, max_period=10000, repeat_only=False, use_ embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) else: embedding = repeat(timesteps, 'b -> b d', d=dim) - if use_fp16: - return embedding.half() - else: - return embedding + return embedding def zero_module(module): @@ -199,16 +199,14 @@ def mean_flat(tensor): return tensor.mean(dim=list(range(1, len(tensor.shape)))) -def normalization(channels, precision=16): +def normalization(channels): """ Make a standard normalization layer. :param channels: number of input channels. :return: an nn.Module for normalization. """ - if precision == 16: - return GroupNorm16(16, channels) - else: - return GroupNorm32(32, channels) + return nn.GroupNorm(16, channels) + # return GroupNorm32(32, channels) # PyTorch 1.7 has SiLU, but we support PyTorch 1.5. @@ -216,9 +214,6 @@ class SiLU(nn.Module): def forward(self, x): return x * torch.sigmoid(x) -class GroupNorm16(nn.GroupNorm): - def forward(self, x): - return super().forward(x.half()).type(x.dtype) class GroupNorm32(nn.GroupNorm): def forward(self, x): diff --git a/examples/images/diffusion/ldm/modules/ema.py b/examples/images/diffusion/ldm/modules/ema.py index c8c75af43..bded25019 100644 --- a/examples/images/diffusion/ldm/modules/ema.py +++ b/examples/images/diffusion/ldm/modules/ema.py @@ -10,24 +10,28 @@ class LitEma(nn.Module): self.m_name2s_name = {} self.register_buffer('decay', torch.tensor(decay, dtype=torch.float32)) - self.register_buffer('num_updates', torch.tensor(0,dtype=torch.int) if use_num_upates - else torch.tensor(-1,dtype=torch.int)) + self.register_buffer('num_updates', torch.tensor(0, dtype=torch.int) if use_num_upates + else torch.tensor(-1, dtype=torch.int)) for name, p in model.named_parameters(): if p.requires_grad: - #remove as '.'-character is not allowed in buffers - s_name = name.replace('.','') - self.m_name2s_name.update({name:s_name}) - self.register_buffer(s_name,p.clone().detach().data) + # remove as '.'-character is not allowed in buffers + s_name = name.replace('.', '') + self.m_name2s_name.update({name: s_name}) + self.register_buffer(s_name, p.clone().detach().data) self.collected_params = [] - def forward(self,model): + def reset_num_updates(self): + del self.num_updates + self.register_buffer('num_updates', torch.tensor(0, dtype=torch.int)) + + def forward(self, model): decay = self.decay if self.num_updates >= 0: self.num_updates += 1 - decay = min(self.decay,(1 + self.num_updates) / (10 + self.num_updates)) + decay = min(self.decay, (1 + self.num_updates) / (10 + self.num_updates)) one_minus_decay = 1.0 - decay diff --git a/examples/images/diffusion/ldm/modules/encoders/modules.py b/examples/images/diffusion/ldm/modules/encoders/modules.py index 8cfc01e5d..4edd5496b 100644 --- a/examples/images/diffusion/ldm/modules/encoders/modules.py +++ b/examples/images/diffusion/ldm/modules/encoders/modules.py @@ -1,15 +1,11 @@ -import types - import torch import torch.nn as nn -from functools import partial -import clip -from einops import rearrange, repeat -from transformers import CLIPTokenizer, CLIPTextModel, CLIPTextConfig -import kornia -from transformers.models.clip.modeling_clip import CLIPTextTransformer +from torch.utils.checkpoint import checkpoint -from ldm.modules.x_transformer import Encoder, TransformerWrapper # TODO: can we directly rely on lucidrains code and simply add this as a reuirement? --> test +from transformers import T5Tokenizer, T5EncoderModel, CLIPTokenizer, CLIPTextModel + +import open_clip +from ldm.util import default, count_params class AbstractEncoder(nn.Module): @@ -20,169 +16,66 @@ class AbstractEncoder(nn.Module): raise NotImplementedError +class IdentityEncoder(AbstractEncoder): + + def encode(self, x): + return x + class ClassEmbedder(nn.Module): - def __init__(self, embed_dim, n_classes=1000, key='class'): + def __init__(self, embed_dim, n_classes=1000, key='class', ucg_rate=0.1): super().__init__() self.key = key self.embedding = nn.Embedding(n_classes, embed_dim) + self.n_classes = n_classes + self.ucg_rate = ucg_rate - def forward(self, batch, key=None): + def forward(self, batch, key=None, disable_dropout=False): if key is None: key = self.key # this is for use in crossattn c = batch[key][:, None] + if self.ucg_rate > 0. and not disable_dropout: + mask = 1. - torch.bernoulli(torch.ones_like(c) * self.ucg_rate) + c = mask * c + (1-mask) * torch.ones_like(c)*(self.n_classes-1) + c = c.long() c = self.embedding(c) return c + def get_unconditional_conditioning(self, bs, device="cuda"): + uc_class = self.n_classes - 1 # 1000 classes --> 0 ... 999, one extra class for ucg (class 1000) + uc = torch.ones((bs,), device=device) * uc_class + uc = {self.key: uc} + return uc -class TransformerEmbedder(AbstractEncoder): - """Some transformer encoder layers""" - def __init__(self, n_embed, n_layer, vocab_size, max_seq_len=77, device="cuda"): + +def disabled_train(self, mode=True): + """Overwrite model.train with this function to make sure train/eval mode + does not change anymore.""" + return self + + +class FrozenT5Embedder(AbstractEncoder): + """Uses the T5 transformer encoder for text""" + def __init__(self, version="google/t5-v1_1-large", device="cuda", max_length=77, freeze=True): # others are google/t5-v1_1-xl and google/t5-v1_1-xxl super().__init__() + self.tokenizer = T5Tokenizer.from_pretrained(version) + self.transformer = T5EncoderModel.from_pretrained(version) self.device = device - self.transformer = TransformerWrapper(num_tokens=vocab_size, max_seq_len=max_seq_len, - attn_layers=Encoder(dim=n_embed, depth=n_layer)) - - def forward(self, tokens): - tokens = tokens.to(self.device) # meh - z = self.transformer(tokens, return_embeddings=True) - return z - - def encode(self, x): - return self(x) - - -class BERTTokenizer(AbstractEncoder): - """ Uses a pretrained BERT tokenizer by huggingface. Vocab size: 30522 (?)""" - def __init__(self, device="cuda", vq_interface=True, max_length=77): - super().__init__() - from transformers import BertTokenizerFast # TODO: add to reuquirements - self.tokenizer = BertTokenizerFast.from_pretrained("bert-base-uncased") - self.device = device - self.vq_interface = vq_interface - self.max_length = max_length - - def forward(self, text): - batch_encoding = self.tokenizer(text, truncation=True, max_length=self.max_length, return_length=True, - return_overflowing_tokens=False, padding="max_length", return_tensors="pt") - tokens = batch_encoding["input_ids"].to(self.device) - return tokens - - @torch.no_grad() - def encode(self, text): - tokens = self(text) - if not self.vq_interface: - return tokens - return None, None, [None, None, tokens] - - def decode(self, text): - return text - - -class BERTEmbedder(AbstractEncoder): - """Uses the BERT tokenizr model and add some transformer encoder layers""" - def __init__(self, n_embed, n_layer, vocab_size=30522, max_seq_len=77, - device="cuda",use_tokenizer=True, embedding_dropout=0.0): - super().__init__() - self.use_tknz_fn = use_tokenizer - if self.use_tknz_fn: - self.tknz_fn = BERTTokenizer(vq_interface=False, max_length=max_seq_len) - self.device = device - self.transformer = TransformerWrapper(num_tokens=vocab_size, max_seq_len=max_seq_len, - attn_layers=Encoder(dim=n_embed, depth=n_layer), - emb_dropout=embedding_dropout) - - def forward(self, text): - if self.use_tknz_fn: - tokens = self.tknz_fn(text)#.to(self.device) - else: - tokens = text - z = self.transformer(tokens, return_embeddings=True) - return z - - def encode(self, text): - # output of length 77 - return self(text) - - -class SpatialRescaler(nn.Module): - def __init__(self, - n_stages=1, - method='bilinear', - multiplier=0.5, - in_channels=3, - out_channels=None, - bias=False): - super().__init__() - self.n_stages = n_stages - assert self.n_stages >= 0 - assert method in ['nearest','linear','bilinear','trilinear','bicubic','area'] - self.multiplier = multiplier - self.interpolator = partial(torch.nn.functional.interpolate, mode=method) - self.remap_output = out_channels is not None - if self.remap_output: - print(f'Spatial Rescaler mapping from {in_channels} to {out_channels} channels after resizing.') - self.channel_mapper = nn.Conv2d(in_channels,out_channels,1,bias=bias) - - def forward(self,x): - for stage in range(self.n_stages): - x = self.interpolator(x, scale_factor=self.multiplier) - - - if self.remap_output: - x = self.channel_mapper(x) - return x - - def encode(self, x): - return self(x) - - -class CLIPTextModelZero(CLIPTextModel): - config_class = CLIPTextConfig - - def __init__(self, config: CLIPTextConfig): - super().__init__(config) - self.text_model = CLIPTextTransformerZero(config) - -class CLIPTextTransformerZero(CLIPTextTransformer): - def _build_causal_attention_mask(self, bsz, seq_len): - # lazily create causal attention mask, with full attention between the vision tokens - # pytorch uses additive attention mask; fill with -inf - mask = torch.empty(bsz, seq_len, seq_len) - mask.fill_(float("-inf")) - mask.triu_(1) # zero out the lower diagonal - mask = mask.unsqueeze(1) # expand mask - return mask.half() - -class FrozenCLIPEmbedder(AbstractEncoder): - """Uses the CLIP transformer encoder for text (from Hugging Face)""" - def __init__(self, version="openai/clip-vit-large-patch14", device="cuda", max_length=77, use_fp16=True): - super().__init__() - self.tokenizer = CLIPTokenizer.from_pretrained(version) - - if use_fp16: - self.transformer = CLIPTextModelZero.from_pretrained(version) - else: - self.transformer = CLIPTextModel.from_pretrained(version) - - # print(self.transformer.modules()) - # print("check model dtyoe: {}, {}".format(self.tokenizer.dtype, self.transformer.dtype)) - self.device = device - self.max_length = max_length - self.freeze() + self.max_length = max_length # TODO: typical value? + if freeze: + self.freeze() def freeze(self): self.transformer = self.transformer.eval() + #self.train = disabled_train for param in self.parameters(): param.requires_grad = False def forward(self, text): batch_encoding = self.tokenizer(text, truncation=True, max_length=self.max_length, return_length=True, return_overflowing_tokens=False, padding="max_length", return_tensors="pt") - # tokens = batch_encoding["input_ids"].to(self.device) tokens = batch_encoding["input_ids"].to(self.device) - # print("token type: {}".format(tokens.dtype)) outputs = self.transformer(input_ids=tokens) z = outputs.last_hidden_state @@ -192,17 +85,80 @@ class FrozenCLIPEmbedder(AbstractEncoder): return self(text) -class FrozenCLIPTextEmbedder(nn.Module): - """ - Uses the CLIP transformer encoder for text. - """ - def __init__(self, version='ViT-L/14', device="cuda", max_length=77, n_repeat=1, normalize=True): +class FrozenCLIPEmbedder(AbstractEncoder): + """Uses the CLIP transformer encoder for text (from huggingface)""" + LAYERS = [ + "last", + "pooled", + "hidden" + ] + def __init__(self, version="openai/clip-vit-large-patch14", device="cuda", max_length=77, + freeze=True, layer="last", layer_idx=None): # clip-vit-base-patch32 super().__init__() - self.model, _ = clip.load(version, jit=False, device="cpu") + assert layer in self.LAYERS + self.tokenizer = CLIPTokenizer.from_pretrained(version) + self.transformer = CLIPTextModel.from_pretrained(version) self.device = device self.max_length = max_length - self.n_repeat = n_repeat - self.normalize = normalize + if freeze: + self.freeze() + self.layer = layer + self.layer_idx = layer_idx + if layer == "hidden": + assert layer_idx is not None + assert 0 <= abs(layer_idx) <= 12 + + def freeze(self): + self.transformer = self.transformer.eval() + #self.train = disabled_train + for param in self.parameters(): + param.requires_grad = False + + def forward(self, text): + batch_encoding = self.tokenizer(text, truncation=True, max_length=self.max_length, return_length=True, + return_overflowing_tokens=False, padding="max_length", return_tensors="pt") + tokens = batch_encoding["input_ids"].to(self.device) + outputs = self.transformer(input_ids=tokens, output_hidden_states=self.layer=="hidden") + if self.layer == "last": + z = outputs.last_hidden_state + elif self.layer == "pooled": + z = outputs.pooler_output[:, None, :] + else: + z = outputs.hidden_states[self.layer_idx] + return z + + def encode(self, text): + return self(text) + + +class FrozenOpenCLIPEmbedder(AbstractEncoder): + """ + Uses the OpenCLIP transformer encoder for text + """ + LAYERS = [ + #"pooled", + "last", + "penultimate" + ] + def __init__(self, arch="ViT-H-14", version="laion2b_s32b_b79k", device="cuda", max_length=77, + freeze=True, layer="last"): + super().__init__() + assert layer in self.LAYERS + model, _, _ = open_clip.create_model_and_transforms(arch, device=torch.device('cpu'), pretrained=version) + del model.visual + self.model = model + + self.device = device + self.max_length = max_length + if freeze: + self.freeze() + self.layer = layer + if self.layer == "last": + self.layer_idx = 0 + elif self.layer == "penultimate": + self.layer_idx = 1 + else: + raise NotImplementedError() def freeze(self): self.model = self.model.eval() @@ -210,55 +166,48 @@ class FrozenCLIPTextEmbedder(nn.Module): param.requires_grad = False def forward(self, text): - tokens = clip.tokenize(text).to(self.device) - z = self.model.encode_text(tokens) - if self.normalize: - z = z / torch.linalg.norm(z, dim=1, keepdim=True) + tokens = open_clip.tokenize(text) + z = self.encode_with_transformer(tokens.to(self.device)) return z - def encode(self, text): - z = self(text) - if z.ndim==2: - z = z[:, None, :] - z = repeat(z, 'b 1 d -> b k d', k=self.n_repeat) - return z - - -class FrozenClipImageEmbedder(nn.Module): - """ - Uses the CLIP image encoder. - """ - def __init__( - self, - model, - jit=False, - device='cuda' if torch.cuda.is_available() else 'cpu', - antialias=False, - ): - super().__init__() - self.model, _ = clip.load(name=model, device=device, jit=jit) - - self.antialias = antialias - - self.register_buffer('mean', torch.Tensor([0.48145466, 0.4578275, 0.40821073]), persistent=False) - self.register_buffer('std', torch.Tensor([0.26862954, 0.26130258, 0.27577711]), persistent=False) - - def preprocess(self, x): - # normalize to [0,1] - x = kornia.geometry.resize(x, (224, 224), - interpolation='bicubic',align_corners=True, - antialias=self.antialias) - x = (x + 1.) / 2. - # renormalize according to clip - x = kornia.enhance.normalize(x, self.mean, self.std) + def encode_with_transformer(self, text): + x = self.model.token_embedding(text) # [batch_size, n_ctx, d_model] + x = x + self.model.positional_embedding + x = x.permute(1, 0, 2) # NLD -> LND + x = self.text_transformer_forward(x, attn_mask=self.model.attn_mask) + x = x.permute(1, 0, 2) # LND -> NLD + x = self.model.ln_final(x) return x - def forward(self, x): - # x is assumed to be in range [-1,1] - return self.model.encode_image(self.preprocess(x)) + def text_transformer_forward(self, x: torch.Tensor, attn_mask = None): + for i, r in enumerate(self.model.transformer.resblocks): + if i == len(self.model.transformer.resblocks) - self.layer_idx: + break + if self.model.transformer.grad_checkpointing and not torch.jit.is_scripting(): + x = checkpoint(r, x, attn_mask) + else: + x = r(x, attn_mask=attn_mask) + return x + + def encode(self, text): + return self(text) + + +class FrozenCLIPT5Encoder(AbstractEncoder): + def __init__(self, clip_version="openai/clip-vit-large-patch14", t5_version="google/t5-v1_1-xl", device="cuda", + clip_max_length=77, t5_max_length=77): + super().__init__() + self.clip_encoder = FrozenCLIPEmbedder(clip_version, device, max_length=clip_max_length) + self.t5_encoder = FrozenT5Embedder(t5_version, device, max_length=t5_max_length) + print(f"{self.clip_encoder.__class__.__name__} has {count_params(self.clip_encoder)*1.e-6:.2f} M parameters, " + f"{self.t5_encoder.__class__.__name__} comes with {count_params(self.t5_encoder)*1.e-6:.2f} M params.") + + def encode(self, text): + return self(text) + + def forward(self, text): + clip_z = self.clip_encoder.encode(text) + t5_z = self.t5_encoder.encode(text) + return [clip_z, t5_z] -if __name__ == "__main__": - from ldm.util import count_params - model = FrozenCLIPEmbedder() - count_params(model, verbose=True) \ No newline at end of file diff --git a/examples/images/diffusion/ldm/modules/flash_attention.py b/examples/images/diffusion/ldm/modules/flash_attention.py deleted file mode 100644 index 2a7a73879..000000000 --- a/examples/images/diffusion/ldm/modules/flash_attention.py +++ /dev/null @@ -1,50 +0,0 @@ -""" -Fused Attention -=============== -This is a Triton implementation of the Flash Attention algorithm -(see: Dao et al., https://arxiv.org/pdf/2205.14135v2.pdf; Rabe and Staats https://arxiv.org/pdf/2112.05682v2.pdf; Triton https://github.com/openai/triton) -""" - -import torch -try: - from flash_attn.flash_attn_interface import flash_attn_unpadded_qkvpacked_func, flash_attn_unpadded_kvpacked_func -except ImportError: - raise ImportError('please install flash_attn from https://github.com/HazyResearch/flash-attention') - - - -def flash_attention_qkv(qkv, sm_scale, batch_size, seq_len): - """ - Arguments: - qkv: (batch*seq, 3, nheads, headdim) - batch_size: int. - seq_len: int. - sm_scale: float. The scaling of QK^T before applying softmax. - Return: - out: (total, nheads, headdim). - """ - max_s = seq_len - cu_seqlens = torch.arange(0, (batch_size + 1) * seq_len, step=seq_len, dtype=torch.int32, - device=qkv.device) - out = flash_attn_unpadded_qkvpacked_func( - qkv, cu_seqlens, max_s, 0.0, - softmax_scale=sm_scale, causal=False - ) - return out - - -def flash_attention_q_kv(q, kv, sm_scale, batch_size, q_seqlen, kv_seqlen): - """ - Arguments: - q: (batch*seq, nheads, headdim) - kv: (batch*seq, 2, nheads, headdim) - batch_size: int. - seq_len: int. - sm_scale: float. The scaling of QK^T before applying softmax. - Return: - out: (total, nheads, headdim). - """ - cu_seqlens_q = torch.arange(0, (batch_size + 1) * q_seqlen, step=q_seqlen, dtype=torch.int32, device=q.device) - cu_seqlens_k = torch.arange(0, (batch_size + 1) * kv_seqlen, step=kv_seqlen, dtype=torch.int32, device=kv.device) - out = flash_attn_unpadded_kvpacked_func(q, kv, cu_seqlens_q, cu_seqlens_k, q_seqlen, kv_seqlen, 0.0, sm_scale) - return out diff --git a/examples/images/diffusion/ldm/modules/image_degradation/bsrgan_light.py b/examples/images/diffusion/ldm/modules/image_degradation/bsrgan_light.py index 9e1f82399..808c7f882 100644 --- a/examples/images/diffusion/ldm/modules/image_degradation/bsrgan_light.py +++ b/examples/images/diffusion/ldm/modules/image_degradation/bsrgan_light.py @@ -25,7 +25,6 @@ import ldm.modules.image_degradation.utils_image as util # -------------------------------------------- """ - def modcrop_np(img, sf): ''' Args: @@ -254,7 +253,7 @@ def srmd_degradation(x, k, sf=3): year={2018} } ''' - x = ndimage.filters.convolve(x, np.expand_dims(k, axis=2), mode='wrap') # 'nearest' | 'mirror' + x = ndimage.convolve(x, np.expand_dims(k, axis=2), mode='wrap') # 'nearest' | 'mirror' x = bicubic_degradation(x, sf=sf) return x @@ -277,7 +276,7 @@ def dpsr_degradation(x, k, sf=3): } ''' x = bicubic_degradation(x, sf=sf) - x = ndimage.filters.convolve(x, np.expand_dims(k, axis=2), mode='wrap') + x = ndimage.convolve(x, np.expand_dims(k, axis=2), mode='wrap') return x @@ -290,7 +289,7 @@ def classical_degradation(x, k, sf=3): Return: downsampled LR image ''' - x = ndimage.filters.convolve(x, np.expand_dims(k, axis=2), mode='wrap') + x = ndimage.convolve(x, np.expand_dims(k, axis=2), mode='wrap') # x = filters.correlate(x, np.expand_dims(np.flip(k), axis=2)) st = 0 return x[st::sf, st::sf, ...] @@ -335,7 +334,7 @@ def add_blur(img, sf=4): k = anisotropic_Gaussian(ksize=random.randint(2, 11) + 3, theta=random.random() * np.pi, l1=l1, l2=l2) else: k = fspecial('gaussian', random.randint(2, 4) + 3, wd * random.random()) - img = ndimage.filters.convolve(img, np.expand_dims(k, axis=2), mode='mirror') + img = ndimage.convolve(img, np.expand_dims(k, axis=2), mode='mirror') return img @@ -497,7 +496,7 @@ def degradation_bsrgan(img, sf=4, lq_patchsize=72, isp_model=None): k = fspecial('gaussian', 25, random.uniform(0.1, 0.6 * sf)) k_shifted = shift_pixel(k, sf) k_shifted = k_shifted / k_shifted.sum() # blur with shifted kernel - img = ndimage.filters.convolve(img, np.expand_dims(k_shifted, axis=2), mode='mirror') + img = ndimage.convolve(img, np.expand_dims(k_shifted, axis=2), mode='mirror') img = img[0::sf, 0::sf, ...] # nearest downsampling img = np.clip(img, 0.0, 1.0) @@ -531,7 +530,7 @@ def degradation_bsrgan(img, sf=4, lq_patchsize=72, isp_model=None): # todo no isp_model? -def degradation_bsrgan_variant(image, sf=4, isp_model=None): +def degradation_bsrgan_variant(image, sf=4, isp_model=None, up=False): """ This is the degradation model of BSRGAN from the paper "Designing a Practical Degradation Model for Deep Blind Image Super-Resolution" @@ -589,7 +588,7 @@ def degradation_bsrgan_variant(image, sf=4, isp_model=None): k = fspecial('gaussian', 25, random.uniform(0.1, 0.6 * sf)) k_shifted = shift_pixel(k, sf) k_shifted = k_shifted / k_shifted.sum() # blur with shifted kernel - image = ndimage.filters.convolve(image, np.expand_dims(k_shifted, axis=2), mode='mirror') + image = ndimage.convolve(image, np.expand_dims(k_shifted, axis=2), mode='mirror') image = image[0::sf, 0::sf, ...] # nearest downsampling image = np.clip(image, 0.0, 1.0) @@ -617,6 +616,8 @@ def degradation_bsrgan_variant(image, sf=4, isp_model=None): # add final JPEG compression noise image = add_JPEG_noise(image) image = util.single2uint(image) + if up: + image = cv2.resize(image, (w1, h1), interpolation=cv2.INTER_CUBIC) # todo: random, as above? want to condition on it then example = {"image": image} return example diff --git a/examples/images/diffusion/ldm/modules/losses/__init__.py b/examples/images/diffusion/ldm/modules/losses/__init__.py deleted file mode 100644 index 876d7c5bd..000000000 --- a/examples/images/diffusion/ldm/modules/losses/__init__.py +++ /dev/null @@ -1 +0,0 @@ -from ldm.modules.losses.contperceptual import LPIPSWithDiscriminator \ No newline at end of file diff --git a/examples/images/diffusion/ldm/modules/losses/contperceptual.py b/examples/images/diffusion/ldm/modules/losses/contperceptual.py deleted file mode 100644 index 672c1e32a..000000000 --- a/examples/images/diffusion/ldm/modules/losses/contperceptual.py +++ /dev/null @@ -1,111 +0,0 @@ -import torch -import torch.nn as nn - -from taming.modules.losses.vqperceptual import * # TODO: taming dependency yes/no? - - -class LPIPSWithDiscriminator(nn.Module): - def __init__(self, disc_start, logvar_init=0.0, kl_weight=1.0, pixelloss_weight=1.0, - disc_num_layers=3, disc_in_channels=3, disc_factor=1.0, disc_weight=1.0, - perceptual_weight=1.0, use_actnorm=False, disc_conditional=False, - disc_loss="hinge"): - - super().__init__() - assert disc_loss in ["hinge", "vanilla"] - self.kl_weight = kl_weight - self.pixel_weight = pixelloss_weight - self.perceptual_loss = LPIPS().eval() - self.perceptual_weight = perceptual_weight - # output log variance - self.logvar = nn.Parameter(torch.ones(size=()) * logvar_init) - - self.discriminator = NLayerDiscriminator(input_nc=disc_in_channels, - n_layers=disc_num_layers, - use_actnorm=use_actnorm - ).apply(weights_init) - self.discriminator_iter_start = disc_start - self.disc_loss = hinge_d_loss if disc_loss == "hinge" else vanilla_d_loss - self.disc_factor = disc_factor - self.discriminator_weight = disc_weight - self.disc_conditional = disc_conditional - - def calculate_adaptive_weight(self, nll_loss, g_loss, last_layer=None): - if last_layer is not None: - nll_grads = torch.autograd.grad(nll_loss, last_layer, retain_graph=True)[0] - g_grads = torch.autograd.grad(g_loss, last_layer, retain_graph=True)[0] - else: - nll_grads = torch.autograd.grad(nll_loss, self.last_layer[0], retain_graph=True)[0] - g_grads = torch.autograd.grad(g_loss, self.last_layer[0], retain_graph=True)[0] - - d_weight = torch.norm(nll_grads) / (torch.norm(g_grads) + 1e-4) - d_weight = torch.clamp(d_weight, 0.0, 1e4).detach() - d_weight = d_weight * self.discriminator_weight - return d_weight - - def forward(self, inputs, reconstructions, posteriors, optimizer_idx, - global_step, last_layer=None, cond=None, split="train", - weights=None): - rec_loss = torch.abs(inputs.contiguous() - reconstructions.contiguous()) - if self.perceptual_weight > 0: - p_loss = self.perceptual_loss(inputs.contiguous(), reconstructions.contiguous()) - rec_loss = rec_loss + self.perceptual_weight * p_loss - - nll_loss = rec_loss / torch.exp(self.logvar) + self.logvar - weighted_nll_loss = nll_loss - if weights is not None: - weighted_nll_loss = weights*nll_loss - weighted_nll_loss = torch.sum(weighted_nll_loss) / weighted_nll_loss.shape[0] - nll_loss = torch.sum(nll_loss) / nll_loss.shape[0] - kl_loss = posteriors.kl() - kl_loss = torch.sum(kl_loss) / kl_loss.shape[0] - - # now the GAN part - if optimizer_idx == 0: - # generator update - if cond is None: - assert not self.disc_conditional - logits_fake = self.discriminator(reconstructions.contiguous()) - else: - assert self.disc_conditional - logits_fake = self.discriminator(torch.cat((reconstructions.contiguous(), cond), dim=1)) - g_loss = -torch.mean(logits_fake) - - if self.disc_factor > 0.0: - try: - d_weight = self.calculate_adaptive_weight(nll_loss, g_loss, last_layer=last_layer) - except RuntimeError: - assert not self.training - d_weight = torch.tensor(0.0) - else: - d_weight = torch.tensor(0.0) - - disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start) - loss = weighted_nll_loss + self.kl_weight * kl_loss + d_weight * disc_factor * g_loss - - log = {"{}/total_loss".format(split): loss.clone().detach().mean(), "{}/logvar".format(split): self.logvar.detach(), - "{}/kl_loss".format(split): kl_loss.detach().mean(), "{}/nll_loss".format(split): nll_loss.detach().mean(), - "{}/rec_loss".format(split): rec_loss.detach().mean(), - "{}/d_weight".format(split): d_weight.detach(), - "{}/disc_factor".format(split): torch.tensor(disc_factor), - "{}/g_loss".format(split): g_loss.detach().mean(), - } - return loss, log - - if optimizer_idx == 1: - # second pass for discriminator update - if cond is None: - logits_real = self.discriminator(inputs.contiguous().detach()) - logits_fake = self.discriminator(reconstructions.contiguous().detach()) - else: - logits_real = self.discriminator(torch.cat((inputs.contiguous().detach(), cond), dim=1)) - logits_fake = self.discriminator(torch.cat((reconstructions.contiguous().detach(), cond), dim=1)) - - disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start) - d_loss = disc_factor * self.disc_loss(logits_real, logits_fake) - - log = {"{}/disc_loss".format(split): d_loss.clone().detach().mean(), - "{}/logits_real".format(split): logits_real.detach().mean(), - "{}/logits_fake".format(split): logits_fake.detach().mean() - } - return d_loss, log - diff --git a/examples/images/diffusion/ldm/modules/losses/vqperceptual.py b/examples/images/diffusion/ldm/modules/losses/vqperceptual.py deleted file mode 100644 index f69981769..000000000 --- a/examples/images/diffusion/ldm/modules/losses/vqperceptual.py +++ /dev/null @@ -1,167 +0,0 @@ -import torch -from torch import nn -import torch.nn.functional as F -from einops import repeat - -from taming.modules.discriminator.model import NLayerDiscriminator, weights_init -from taming.modules.losses.lpips import LPIPS -from taming.modules.losses.vqperceptual import hinge_d_loss, vanilla_d_loss - - -def hinge_d_loss_with_exemplar_weights(logits_real, logits_fake, weights): - assert weights.shape[0] == logits_real.shape[0] == logits_fake.shape[0] - loss_real = torch.mean(F.relu(1. - logits_real), dim=[1,2,3]) - loss_fake = torch.mean(F.relu(1. + logits_fake), dim=[1,2,3]) - loss_real = (weights * loss_real).sum() / weights.sum() - loss_fake = (weights * loss_fake).sum() / weights.sum() - d_loss = 0.5 * (loss_real + loss_fake) - return d_loss - -def adopt_weight(weight, global_step, threshold=0, value=0.): - if global_step < threshold: - weight = value - return weight - - -def measure_perplexity(predicted_indices, n_embed): - # src: https://github.com/karpathy/deep-vector-quantization/blob/main/model.py - # eval cluster perplexity. when perplexity == num_embeddings then all clusters are used exactly equally - encodings = F.one_hot(predicted_indices, n_embed).float().reshape(-1, n_embed) - avg_probs = encodings.mean(0) - perplexity = (-(avg_probs * torch.log(avg_probs + 1e-10)).sum()).exp() - cluster_use = torch.sum(avg_probs > 0) - return perplexity, cluster_use - -def l1(x, y): - return torch.abs(x-y) - - -def l2(x, y): - return torch.pow((x-y), 2) - - -class VQLPIPSWithDiscriminator(nn.Module): - def __init__(self, disc_start, codebook_weight=1.0, pixelloss_weight=1.0, - disc_num_layers=3, disc_in_channels=3, disc_factor=1.0, disc_weight=1.0, - perceptual_weight=1.0, use_actnorm=False, disc_conditional=False, - disc_ndf=64, disc_loss="hinge", n_classes=None, perceptual_loss="lpips", - pixel_loss="l1"): - super().__init__() - assert disc_loss in ["hinge", "vanilla"] - assert perceptual_loss in ["lpips", "clips", "dists"] - assert pixel_loss in ["l1", "l2"] - self.codebook_weight = codebook_weight - self.pixel_weight = pixelloss_weight - if perceptual_loss == "lpips": - print(f"{self.__class__.__name__}: Running with LPIPS.") - self.perceptual_loss = LPIPS().eval() - else: - raise ValueError(f"Unknown perceptual loss: >> {perceptual_loss} <<") - self.perceptual_weight = perceptual_weight - - if pixel_loss == "l1": - self.pixel_loss = l1 - else: - self.pixel_loss = l2 - - self.discriminator = NLayerDiscriminator(input_nc=disc_in_channels, - n_layers=disc_num_layers, - use_actnorm=use_actnorm, - ndf=disc_ndf - ).apply(weights_init) - self.discriminator_iter_start = disc_start - if disc_loss == "hinge": - self.disc_loss = hinge_d_loss - elif disc_loss == "vanilla": - self.disc_loss = vanilla_d_loss - else: - raise ValueError(f"Unknown GAN loss '{disc_loss}'.") - print(f"VQLPIPSWithDiscriminator running with {disc_loss} loss.") - self.disc_factor = disc_factor - self.discriminator_weight = disc_weight - self.disc_conditional = disc_conditional - self.n_classes = n_classes - - def calculate_adaptive_weight(self, nll_loss, g_loss, last_layer=None): - if last_layer is not None: - nll_grads = torch.autograd.grad(nll_loss, last_layer, retain_graph=True)[0] - g_grads = torch.autograd.grad(g_loss, last_layer, retain_graph=True)[0] - else: - nll_grads = torch.autograd.grad(nll_loss, self.last_layer[0], retain_graph=True)[0] - g_grads = torch.autograd.grad(g_loss, self.last_layer[0], retain_graph=True)[0] - - d_weight = torch.norm(nll_grads) / (torch.norm(g_grads) + 1e-4) - d_weight = torch.clamp(d_weight, 0.0, 1e4).detach() - d_weight = d_weight * self.discriminator_weight - return d_weight - - def forward(self, codebook_loss, inputs, reconstructions, optimizer_idx, - global_step, last_layer=None, cond=None, split="train", predicted_indices=None): - if not exists(codebook_loss): - codebook_loss = torch.tensor([0.]).to(inputs.device) - #rec_loss = torch.abs(inputs.contiguous() - reconstructions.contiguous()) - rec_loss = self.pixel_loss(inputs.contiguous(), reconstructions.contiguous()) - if self.perceptual_weight > 0: - p_loss = self.perceptual_loss(inputs.contiguous(), reconstructions.contiguous()) - rec_loss = rec_loss + self.perceptual_weight * p_loss - else: - p_loss = torch.tensor([0.0]) - - nll_loss = rec_loss - #nll_loss = torch.sum(nll_loss) / nll_loss.shape[0] - nll_loss = torch.mean(nll_loss) - - # now the GAN part - if optimizer_idx == 0: - # generator update - if cond is None: - assert not self.disc_conditional - logits_fake = self.discriminator(reconstructions.contiguous()) - else: - assert self.disc_conditional - logits_fake = self.discriminator(torch.cat((reconstructions.contiguous(), cond), dim=1)) - g_loss = -torch.mean(logits_fake) - - try: - d_weight = self.calculate_adaptive_weight(nll_loss, g_loss, last_layer=last_layer) - except RuntimeError: - assert not self.training - d_weight = torch.tensor(0.0) - - disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start) - loss = nll_loss + d_weight * disc_factor * g_loss + self.codebook_weight * codebook_loss.mean() - - log = {"{}/total_loss".format(split): loss.clone().detach().mean(), - "{}/quant_loss".format(split): codebook_loss.detach().mean(), - "{}/nll_loss".format(split): nll_loss.detach().mean(), - "{}/rec_loss".format(split): rec_loss.detach().mean(), - "{}/p_loss".format(split): p_loss.detach().mean(), - "{}/d_weight".format(split): d_weight.detach(), - "{}/disc_factor".format(split): torch.tensor(disc_factor), - "{}/g_loss".format(split): g_loss.detach().mean(), - } - if predicted_indices is not None: - assert self.n_classes is not None - with torch.no_grad(): - perplexity, cluster_usage = measure_perplexity(predicted_indices, self.n_classes) - log[f"{split}/perplexity"] = perplexity - log[f"{split}/cluster_usage"] = cluster_usage - return loss, log - - if optimizer_idx == 1: - # second pass for discriminator update - if cond is None: - logits_real = self.discriminator(inputs.contiguous().detach()) - logits_fake = self.discriminator(reconstructions.contiguous().detach()) - else: - logits_real = self.discriminator(torch.cat((inputs.contiguous().detach(), cond), dim=1)) - logits_fake = self.discriminator(torch.cat((reconstructions.contiguous().detach(), cond), dim=1)) - - disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start) - d_loss = disc_factor * self.disc_loss(logits_real, logits_fake) - - log = {"{}/disc_loss".format(split): d_loss.clone().detach().mean(), - "{}/logits_real".format(split): logits_real.detach().mean(), - "{}/logits_fake".format(split): logits_fake.detach().mean() - } - return d_loss, log diff --git a/examples/images/diffusion/ldm/modules/midas/__init__.py b/examples/images/diffusion/ldm/modules/midas/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/examples/images/diffusion/ldm/modules/midas/api.py b/examples/images/diffusion/ldm/modules/midas/api.py new file mode 100644 index 000000000..b58ebbffd --- /dev/null +++ b/examples/images/diffusion/ldm/modules/midas/api.py @@ -0,0 +1,170 @@ +# based on https://github.com/isl-org/MiDaS + +import cv2 +import torch +import torch.nn as nn +from torchvision.transforms import Compose + +from ldm.modules.midas.midas.dpt_depth import DPTDepthModel +from ldm.modules.midas.midas.midas_net import MidasNet +from ldm.modules.midas.midas.midas_net_custom import MidasNet_small +from ldm.modules.midas.midas.transforms import Resize, NormalizeImage, PrepareForNet + + +ISL_PATHS = { + "dpt_large": "midas_models/dpt_large-midas-2f21e586.pt", + "dpt_hybrid": "midas_models/dpt_hybrid-midas-501f0c75.pt", + "midas_v21": "", + "midas_v21_small": "", +} + + +def disabled_train(self, mode=True): + """Overwrite model.train with this function to make sure train/eval mode + does not change anymore.""" + return self + + +def load_midas_transform(model_type): + # https://github.com/isl-org/MiDaS/blob/master/run.py + # load transform only + if model_type == "dpt_large": # DPT-Large + net_w, net_h = 384, 384 + resize_mode = "minimal" + normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) + + elif model_type == "dpt_hybrid": # DPT-Hybrid + net_w, net_h = 384, 384 + resize_mode = "minimal" + normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) + + elif model_type == "midas_v21": + net_w, net_h = 384, 384 + resize_mode = "upper_bound" + normalization = NormalizeImage(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) + + elif model_type == "midas_v21_small": + net_w, net_h = 256, 256 + resize_mode = "upper_bound" + normalization = NormalizeImage(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) + + else: + assert False, f"model_type '{model_type}' not implemented, use: --model_type large" + + transform = Compose( + [ + Resize( + net_w, + net_h, + resize_target=None, + keep_aspect_ratio=True, + ensure_multiple_of=32, + resize_method=resize_mode, + image_interpolation_method=cv2.INTER_CUBIC, + ), + normalization, + PrepareForNet(), + ] + ) + + return transform + + +def load_model(model_type): + # https://github.com/isl-org/MiDaS/blob/master/run.py + # load network + model_path = ISL_PATHS[model_type] + if model_type == "dpt_large": # DPT-Large + model = DPTDepthModel( + path=model_path, + backbone="vitl16_384", + non_negative=True, + ) + net_w, net_h = 384, 384 + resize_mode = "minimal" + normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) + + elif model_type == "dpt_hybrid": # DPT-Hybrid + model = DPTDepthModel( + path=model_path, + backbone="vitb_rn50_384", + non_negative=True, + ) + net_w, net_h = 384, 384 + resize_mode = "minimal" + normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) + + elif model_type == "midas_v21": + model = MidasNet(model_path, non_negative=True) + net_w, net_h = 384, 384 + resize_mode = "upper_bound" + normalization = NormalizeImage( + mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] + ) + + elif model_type == "midas_v21_small": + model = MidasNet_small(model_path, features=64, backbone="efficientnet_lite3", exportable=True, + non_negative=True, blocks={'expand': True}) + net_w, net_h = 256, 256 + resize_mode = "upper_bound" + normalization = NormalizeImage( + mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] + ) + + else: + print(f"model_type '{model_type}' not implemented, use: --model_type large") + assert False + + transform = Compose( + [ + Resize( + net_w, + net_h, + resize_target=None, + keep_aspect_ratio=True, + ensure_multiple_of=32, + resize_method=resize_mode, + image_interpolation_method=cv2.INTER_CUBIC, + ), + normalization, + PrepareForNet(), + ] + ) + + return model.eval(), transform + + +class MiDaSInference(nn.Module): + MODEL_TYPES_TORCH_HUB = [ + "DPT_Large", + "DPT_Hybrid", + "MiDaS_small" + ] + MODEL_TYPES_ISL = [ + "dpt_large", + "dpt_hybrid", + "midas_v21", + "midas_v21_small", + ] + + def __init__(self, model_type): + super().__init__() + assert (model_type in self.MODEL_TYPES_ISL) + model, _ = load_model(model_type) + self.model = model + self.model.train = disabled_train + + def forward(self, x): + # x in 0..1 as produced by calling self.transform on a 0..1 float64 numpy array + # NOTE: we expect that the correct transform has been called during dataloading. + with torch.no_grad(): + prediction = self.model(x) + prediction = torch.nn.functional.interpolate( + prediction.unsqueeze(1), + size=x.shape[2:], + mode="bicubic", + align_corners=False, + ) + assert prediction.shape == (x.shape[0], 1, x.shape[2], x.shape[3]) + return prediction + diff --git a/examples/images/diffusion/ldm/modules/midas/midas/__init__.py b/examples/images/diffusion/ldm/modules/midas/midas/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/examples/images/diffusion/ldm/modules/midas/midas/base_model.py b/examples/images/diffusion/ldm/modules/midas/midas/base_model.py new file mode 100644 index 000000000..5cf430239 --- /dev/null +++ b/examples/images/diffusion/ldm/modules/midas/midas/base_model.py @@ -0,0 +1,16 @@ +import torch + + +class BaseModel(torch.nn.Module): + def load(self, path): + """Load model from file. + + Args: + path (str): file path + """ + parameters = torch.load(path, map_location=torch.device('cpu')) + + if "optimizer" in parameters: + parameters = parameters["model"] + + self.load_state_dict(parameters) diff --git a/examples/images/diffusion/ldm/modules/midas/midas/blocks.py b/examples/images/diffusion/ldm/modules/midas/midas/blocks.py new file mode 100644 index 000000000..2145d18fa --- /dev/null +++ b/examples/images/diffusion/ldm/modules/midas/midas/blocks.py @@ -0,0 +1,342 @@ +import torch +import torch.nn as nn + +from .vit import ( + _make_pretrained_vitb_rn50_384, + _make_pretrained_vitl16_384, + _make_pretrained_vitb16_384, + forward_vit, +) + +def _make_encoder(backbone, features, use_pretrained, groups=1, expand=False, exportable=True, hooks=None, use_vit_only=False, use_readout="ignore",): + if backbone == "vitl16_384": + pretrained = _make_pretrained_vitl16_384( + use_pretrained, hooks=hooks, use_readout=use_readout + ) + scratch = _make_scratch( + [256, 512, 1024, 1024], features, groups=groups, expand=expand + ) # ViT-L/16 - 85.0% Top1 (backbone) + elif backbone == "vitb_rn50_384": + pretrained = _make_pretrained_vitb_rn50_384( + use_pretrained, + hooks=hooks, + use_vit_only=use_vit_only, + use_readout=use_readout, + ) + scratch = _make_scratch( + [256, 512, 768, 768], features, groups=groups, expand=expand + ) # ViT-H/16 - 85.0% Top1 (backbone) + elif backbone == "vitb16_384": + pretrained = _make_pretrained_vitb16_384( + use_pretrained, hooks=hooks, use_readout=use_readout + ) + scratch = _make_scratch( + [96, 192, 384, 768], features, groups=groups, expand=expand + ) # ViT-B/16 - 84.6% Top1 (backbone) + elif backbone == "resnext101_wsl": + pretrained = _make_pretrained_resnext101_wsl(use_pretrained) + scratch = _make_scratch([256, 512, 1024, 2048], features, groups=groups, expand=expand) # efficientnet_lite3 + elif backbone == "efficientnet_lite3": + pretrained = _make_pretrained_efficientnet_lite3(use_pretrained, exportable=exportable) + scratch = _make_scratch([32, 48, 136, 384], features, groups=groups, expand=expand) # efficientnet_lite3 + else: + print(f"Backbone '{backbone}' not implemented") + assert False + + return pretrained, scratch + + +def _make_scratch(in_shape, out_shape, groups=1, expand=False): + scratch = nn.Module() + + out_shape1 = out_shape + out_shape2 = out_shape + out_shape3 = out_shape + out_shape4 = out_shape + if expand==True: + out_shape1 = out_shape + out_shape2 = out_shape*2 + out_shape3 = out_shape*4 + out_shape4 = out_shape*8 + + scratch.layer1_rn = nn.Conv2d( + in_shape[0], out_shape1, kernel_size=3, stride=1, padding=1, bias=False, groups=groups + ) + scratch.layer2_rn = nn.Conv2d( + in_shape[1], out_shape2, kernel_size=3, stride=1, padding=1, bias=False, groups=groups + ) + scratch.layer3_rn = nn.Conv2d( + in_shape[2], out_shape3, kernel_size=3, stride=1, padding=1, bias=False, groups=groups + ) + scratch.layer4_rn = nn.Conv2d( + in_shape[3], out_shape4, kernel_size=3, stride=1, padding=1, bias=False, groups=groups + ) + + return scratch + + +def _make_pretrained_efficientnet_lite3(use_pretrained, exportable=False): + efficientnet = torch.hub.load( + "rwightman/gen-efficientnet-pytorch", + "tf_efficientnet_lite3", + pretrained=use_pretrained, + exportable=exportable + ) + return _make_efficientnet_backbone(efficientnet) + + +def _make_efficientnet_backbone(effnet): + pretrained = nn.Module() + + pretrained.layer1 = nn.Sequential( + effnet.conv_stem, effnet.bn1, effnet.act1, *effnet.blocks[0:2] + ) + pretrained.layer2 = nn.Sequential(*effnet.blocks[2:3]) + pretrained.layer3 = nn.Sequential(*effnet.blocks[3:5]) + pretrained.layer4 = nn.Sequential(*effnet.blocks[5:9]) + + return pretrained + + +def _make_resnet_backbone(resnet): + pretrained = nn.Module() + pretrained.layer1 = nn.Sequential( + resnet.conv1, resnet.bn1, resnet.relu, resnet.maxpool, resnet.layer1 + ) + + pretrained.layer2 = resnet.layer2 + pretrained.layer3 = resnet.layer3 + pretrained.layer4 = resnet.layer4 + + return pretrained + + +def _make_pretrained_resnext101_wsl(use_pretrained): + resnet = torch.hub.load("facebookresearch/WSL-Images", "resnext101_32x8d_wsl") + return _make_resnet_backbone(resnet) + + + +class Interpolate(nn.Module): + """Interpolation module. + """ + + def __init__(self, scale_factor, mode, align_corners=False): + """Init. + + Args: + scale_factor (float): scaling + mode (str): interpolation mode + """ + super(Interpolate, self).__init__() + + self.interp = nn.functional.interpolate + self.scale_factor = scale_factor + self.mode = mode + self.align_corners = align_corners + + def forward(self, x): + """Forward pass. + + Args: + x (tensor): input + + Returns: + tensor: interpolated data + """ + + x = self.interp( + x, scale_factor=self.scale_factor, mode=self.mode, align_corners=self.align_corners + ) + + return x + + +class ResidualConvUnit(nn.Module): + """Residual convolution module. + """ + + def __init__(self, features): + """Init. + + Args: + features (int): number of features + """ + super().__init__() + + self.conv1 = nn.Conv2d( + features, features, kernel_size=3, stride=1, padding=1, bias=True + ) + + self.conv2 = nn.Conv2d( + features, features, kernel_size=3, stride=1, padding=1, bias=True + ) + + self.relu = nn.ReLU(inplace=True) + + def forward(self, x): + """Forward pass. + + Args: + x (tensor): input + + Returns: + tensor: output + """ + out = self.relu(x) + out = self.conv1(out) + out = self.relu(out) + out = self.conv2(out) + + return out + x + + +class FeatureFusionBlock(nn.Module): + """Feature fusion block. + """ + + def __init__(self, features): + """Init. + + Args: + features (int): number of features + """ + super(FeatureFusionBlock, self).__init__() + + self.resConfUnit1 = ResidualConvUnit(features) + self.resConfUnit2 = ResidualConvUnit(features) + + def forward(self, *xs): + """Forward pass. + + Returns: + tensor: output + """ + output = xs[0] + + if len(xs) == 2: + output += self.resConfUnit1(xs[1]) + + output = self.resConfUnit2(output) + + output = nn.functional.interpolate( + output, scale_factor=2, mode="bilinear", align_corners=True + ) + + return output + + + + +class ResidualConvUnit_custom(nn.Module): + """Residual convolution module. + """ + + def __init__(self, features, activation, bn): + """Init. + + Args: + features (int): number of features + """ + super().__init__() + + self.bn = bn + + self.groups=1 + + self.conv1 = nn.Conv2d( + features, features, kernel_size=3, stride=1, padding=1, bias=True, groups=self.groups + ) + + self.conv2 = nn.Conv2d( + features, features, kernel_size=3, stride=1, padding=1, bias=True, groups=self.groups + ) + + if self.bn==True: + self.bn1 = nn.BatchNorm2d(features) + self.bn2 = nn.BatchNorm2d(features) + + self.activation = activation + + self.skip_add = nn.quantized.FloatFunctional() + + def forward(self, x): + """Forward pass. + + Args: + x (tensor): input + + Returns: + tensor: output + """ + + out = self.activation(x) + out = self.conv1(out) + if self.bn==True: + out = self.bn1(out) + + out = self.activation(out) + out = self.conv2(out) + if self.bn==True: + out = self.bn2(out) + + if self.groups > 1: + out = self.conv_merge(out) + + return self.skip_add.add(out, x) + + # return out + x + + +class FeatureFusionBlock_custom(nn.Module): + """Feature fusion block. + """ + + def __init__(self, features, activation, deconv=False, bn=False, expand=False, align_corners=True): + """Init. + + Args: + features (int): number of features + """ + super(FeatureFusionBlock_custom, self).__init__() + + self.deconv = deconv + self.align_corners = align_corners + + self.groups=1 + + self.expand = expand + out_features = features + if self.expand==True: + out_features = features//2 + + self.out_conv = nn.Conv2d(features, out_features, kernel_size=1, stride=1, padding=0, bias=True, groups=1) + + self.resConfUnit1 = ResidualConvUnit_custom(features, activation, bn) + self.resConfUnit2 = ResidualConvUnit_custom(features, activation, bn) + + self.skip_add = nn.quantized.FloatFunctional() + + def forward(self, *xs): + """Forward pass. + + Returns: + tensor: output + """ + output = xs[0] + + if len(xs) == 2: + res = self.resConfUnit1(xs[1]) + output = self.skip_add.add(output, res) + # output += res + + output = self.resConfUnit2(output) + + output = nn.functional.interpolate( + output, scale_factor=2, mode="bilinear", align_corners=self.align_corners + ) + + output = self.out_conv(output) + + return output + diff --git a/examples/images/diffusion/ldm/modules/midas/midas/dpt_depth.py b/examples/images/diffusion/ldm/modules/midas/midas/dpt_depth.py new file mode 100644 index 000000000..4e9aab5d2 --- /dev/null +++ b/examples/images/diffusion/ldm/modules/midas/midas/dpt_depth.py @@ -0,0 +1,109 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + +from .base_model import BaseModel +from .blocks import ( + FeatureFusionBlock, + FeatureFusionBlock_custom, + Interpolate, + _make_encoder, + forward_vit, +) + + +def _make_fusion_block(features, use_bn): + return FeatureFusionBlock_custom( + features, + nn.ReLU(False), + deconv=False, + bn=use_bn, + expand=False, + align_corners=True, + ) + + +class DPT(BaseModel): + def __init__( + self, + head, + features=256, + backbone="vitb_rn50_384", + readout="project", + channels_last=False, + use_bn=False, + ): + + super(DPT, self).__init__() + + self.channels_last = channels_last + + hooks = { + "vitb_rn50_384": [0, 1, 8, 11], + "vitb16_384": [2, 5, 8, 11], + "vitl16_384": [5, 11, 17, 23], + } + + # Instantiate backbone and reassemble blocks + self.pretrained, self.scratch = _make_encoder( + backbone, + features, + False, # Set to true of you want to train from scratch, uses ImageNet weights + groups=1, + expand=False, + exportable=False, + hooks=hooks[backbone], + use_readout=readout, + ) + + self.scratch.refinenet1 = _make_fusion_block(features, use_bn) + self.scratch.refinenet2 = _make_fusion_block(features, use_bn) + self.scratch.refinenet3 = _make_fusion_block(features, use_bn) + self.scratch.refinenet4 = _make_fusion_block(features, use_bn) + + self.scratch.output_conv = head + + + def forward(self, x): + if self.channels_last == True: + x.contiguous(memory_format=torch.channels_last) + + layer_1, layer_2, layer_3, layer_4 = forward_vit(self.pretrained, x) + + layer_1_rn = self.scratch.layer1_rn(layer_1) + layer_2_rn = self.scratch.layer2_rn(layer_2) + layer_3_rn = self.scratch.layer3_rn(layer_3) + layer_4_rn = self.scratch.layer4_rn(layer_4) + + path_4 = self.scratch.refinenet4(layer_4_rn) + path_3 = self.scratch.refinenet3(path_4, layer_3_rn) + path_2 = self.scratch.refinenet2(path_3, layer_2_rn) + path_1 = self.scratch.refinenet1(path_2, layer_1_rn) + + out = self.scratch.output_conv(path_1) + + return out + + +class DPTDepthModel(DPT): + def __init__(self, path=None, non_negative=True, **kwargs): + features = kwargs["features"] if "features" in kwargs else 256 + + head = nn.Sequential( + nn.Conv2d(features, features // 2, kernel_size=3, stride=1, padding=1), + Interpolate(scale_factor=2, mode="bilinear", align_corners=True), + nn.Conv2d(features // 2, 32, kernel_size=3, stride=1, padding=1), + nn.ReLU(True), + nn.Conv2d(32, 1, kernel_size=1, stride=1, padding=0), + nn.ReLU(True) if non_negative else nn.Identity(), + nn.Identity(), + ) + + super().__init__(head, **kwargs) + + if path is not None: + self.load(path) + + def forward(self, x): + return super().forward(x).squeeze(dim=1) + diff --git a/examples/images/diffusion/ldm/modules/midas/midas/midas_net.py b/examples/images/diffusion/ldm/modules/midas/midas/midas_net.py new file mode 100644 index 000000000..8a9549778 --- /dev/null +++ b/examples/images/diffusion/ldm/modules/midas/midas/midas_net.py @@ -0,0 +1,76 @@ +"""MidashNet: Network for monocular depth estimation trained by mixing several datasets. +This file contains code that is adapted from +https://github.com/thomasjpfan/pytorch_refinenet/blob/master/pytorch_refinenet/refinenet/refinenet_4cascade.py +""" +import torch +import torch.nn as nn + +from .base_model import BaseModel +from .blocks import FeatureFusionBlock, Interpolate, _make_encoder + + +class MidasNet(BaseModel): + """Network for monocular depth estimation. + """ + + def __init__(self, path=None, features=256, non_negative=True): + """Init. + + Args: + path (str, optional): Path to saved model. Defaults to None. + features (int, optional): Number of features. Defaults to 256. + backbone (str, optional): Backbone network for encoder. Defaults to resnet50 + """ + print("Loading weights: ", path) + + super(MidasNet, self).__init__() + + use_pretrained = False if path is None else True + + self.pretrained, self.scratch = _make_encoder(backbone="resnext101_wsl", features=features, use_pretrained=use_pretrained) + + self.scratch.refinenet4 = FeatureFusionBlock(features) + self.scratch.refinenet3 = FeatureFusionBlock(features) + self.scratch.refinenet2 = FeatureFusionBlock(features) + self.scratch.refinenet1 = FeatureFusionBlock(features) + + self.scratch.output_conv = nn.Sequential( + nn.Conv2d(features, 128, kernel_size=3, stride=1, padding=1), + Interpolate(scale_factor=2, mode="bilinear"), + nn.Conv2d(128, 32, kernel_size=3, stride=1, padding=1), + nn.ReLU(True), + nn.Conv2d(32, 1, kernel_size=1, stride=1, padding=0), + nn.ReLU(True) if non_negative else nn.Identity(), + ) + + if path: + self.load(path) + + def forward(self, x): + """Forward pass. + + Args: + x (tensor): input data (image) + + Returns: + tensor: depth + """ + + layer_1 = self.pretrained.layer1(x) + layer_2 = self.pretrained.layer2(layer_1) + layer_3 = self.pretrained.layer3(layer_2) + layer_4 = self.pretrained.layer4(layer_3) + + layer_1_rn = self.scratch.layer1_rn(layer_1) + layer_2_rn = self.scratch.layer2_rn(layer_2) + layer_3_rn = self.scratch.layer3_rn(layer_3) + layer_4_rn = self.scratch.layer4_rn(layer_4) + + path_4 = self.scratch.refinenet4(layer_4_rn) + path_3 = self.scratch.refinenet3(path_4, layer_3_rn) + path_2 = self.scratch.refinenet2(path_3, layer_2_rn) + path_1 = self.scratch.refinenet1(path_2, layer_1_rn) + + out = self.scratch.output_conv(path_1) + + return torch.squeeze(out, dim=1) diff --git a/examples/images/diffusion/ldm/modules/midas/midas/midas_net_custom.py b/examples/images/diffusion/ldm/modules/midas/midas/midas_net_custom.py new file mode 100644 index 000000000..50e4acb5e --- /dev/null +++ b/examples/images/diffusion/ldm/modules/midas/midas/midas_net_custom.py @@ -0,0 +1,128 @@ +"""MidashNet: Network for monocular depth estimation trained by mixing several datasets. +This file contains code that is adapted from +https://github.com/thomasjpfan/pytorch_refinenet/blob/master/pytorch_refinenet/refinenet/refinenet_4cascade.py +""" +import torch +import torch.nn as nn + +from .base_model import BaseModel +from .blocks import FeatureFusionBlock, FeatureFusionBlock_custom, Interpolate, _make_encoder + + +class MidasNet_small(BaseModel): + """Network for monocular depth estimation. + """ + + def __init__(self, path=None, features=64, backbone="efficientnet_lite3", non_negative=True, exportable=True, channels_last=False, align_corners=True, + blocks={'expand': True}): + """Init. + + Args: + path (str, optional): Path to saved model. Defaults to None. + features (int, optional): Number of features. Defaults to 256. + backbone (str, optional): Backbone network for encoder. Defaults to resnet50 + """ + print("Loading weights: ", path) + + super(MidasNet_small, self).__init__() + + use_pretrained = False if path else True + + self.channels_last = channels_last + self.blocks = blocks + self.backbone = backbone + + self.groups = 1 + + features1=features + features2=features + features3=features + features4=features + self.expand = False + if "expand" in self.blocks and self.blocks['expand'] == True: + self.expand = True + features1=features + features2=features*2 + features3=features*4 + features4=features*8 + + self.pretrained, self.scratch = _make_encoder(self.backbone, features, use_pretrained, groups=self.groups, expand=self.expand, exportable=exportable) + + self.scratch.activation = nn.ReLU(False) + + self.scratch.refinenet4 = FeatureFusionBlock_custom(features4, self.scratch.activation, deconv=False, bn=False, expand=self.expand, align_corners=align_corners) + self.scratch.refinenet3 = FeatureFusionBlock_custom(features3, self.scratch.activation, deconv=False, bn=False, expand=self.expand, align_corners=align_corners) + self.scratch.refinenet2 = FeatureFusionBlock_custom(features2, self.scratch.activation, deconv=False, bn=False, expand=self.expand, align_corners=align_corners) + self.scratch.refinenet1 = FeatureFusionBlock_custom(features1, self.scratch.activation, deconv=False, bn=False, align_corners=align_corners) + + + self.scratch.output_conv = nn.Sequential( + nn.Conv2d(features, features//2, kernel_size=3, stride=1, padding=1, groups=self.groups), + Interpolate(scale_factor=2, mode="bilinear"), + nn.Conv2d(features//2, 32, kernel_size=3, stride=1, padding=1), + self.scratch.activation, + nn.Conv2d(32, 1, kernel_size=1, stride=1, padding=0), + nn.ReLU(True) if non_negative else nn.Identity(), + nn.Identity(), + ) + + if path: + self.load(path) + + + def forward(self, x): + """Forward pass. + + Args: + x (tensor): input data (image) + + Returns: + tensor: depth + """ + if self.channels_last==True: + print("self.channels_last = ", self.channels_last) + x.contiguous(memory_format=torch.channels_last) + + + layer_1 = self.pretrained.layer1(x) + layer_2 = self.pretrained.layer2(layer_1) + layer_3 = self.pretrained.layer3(layer_2) + layer_4 = self.pretrained.layer4(layer_3) + + layer_1_rn = self.scratch.layer1_rn(layer_1) + layer_2_rn = self.scratch.layer2_rn(layer_2) + layer_3_rn = self.scratch.layer3_rn(layer_3) + layer_4_rn = self.scratch.layer4_rn(layer_4) + + + path_4 = self.scratch.refinenet4(layer_4_rn) + path_3 = self.scratch.refinenet3(path_4, layer_3_rn) + path_2 = self.scratch.refinenet2(path_3, layer_2_rn) + path_1 = self.scratch.refinenet1(path_2, layer_1_rn) + + out = self.scratch.output_conv(path_1) + + return torch.squeeze(out, dim=1) + + + +def fuse_model(m): + prev_previous_type = nn.Identity() + prev_previous_name = '' + previous_type = nn.Identity() + previous_name = '' + for name, module in m.named_modules(): + if prev_previous_type == nn.Conv2d and previous_type == nn.BatchNorm2d and type(module) == nn.ReLU: + # print("FUSED ", prev_previous_name, previous_name, name) + torch.quantization.fuse_modules(m, [prev_previous_name, previous_name, name], inplace=True) + elif prev_previous_type == nn.Conv2d and previous_type == nn.BatchNorm2d: + # print("FUSED ", prev_previous_name, previous_name) + torch.quantization.fuse_modules(m, [prev_previous_name, previous_name], inplace=True) + # elif previous_type == nn.Conv2d and type(module) == nn.ReLU: + # print("FUSED ", previous_name, name) + # torch.quantization.fuse_modules(m, [previous_name, name], inplace=True) + + prev_previous_type = previous_type + prev_previous_name = previous_name + previous_type = type(module) + previous_name = name \ No newline at end of file diff --git a/examples/images/diffusion/ldm/modules/midas/midas/transforms.py b/examples/images/diffusion/ldm/modules/midas/midas/transforms.py new file mode 100644 index 000000000..350cbc116 --- /dev/null +++ b/examples/images/diffusion/ldm/modules/midas/midas/transforms.py @@ -0,0 +1,234 @@ +import numpy as np +import cv2 +import math + + +def apply_min_size(sample, size, image_interpolation_method=cv2.INTER_AREA): + """Rezise the sample to ensure the given size. Keeps aspect ratio. + + Args: + sample (dict): sample + size (tuple): image size + + Returns: + tuple: new size + """ + shape = list(sample["disparity"].shape) + + if shape[0] >= size[0] and shape[1] >= size[1]: + return sample + + scale = [0, 0] + scale[0] = size[0] / shape[0] + scale[1] = size[1] / shape[1] + + scale = max(scale) + + shape[0] = math.ceil(scale * shape[0]) + shape[1] = math.ceil(scale * shape[1]) + + # resize + sample["image"] = cv2.resize( + sample["image"], tuple(shape[::-1]), interpolation=image_interpolation_method + ) + + sample["disparity"] = cv2.resize( + sample["disparity"], tuple(shape[::-1]), interpolation=cv2.INTER_NEAREST + ) + sample["mask"] = cv2.resize( + sample["mask"].astype(np.float32), + tuple(shape[::-1]), + interpolation=cv2.INTER_NEAREST, + ) + sample["mask"] = sample["mask"].astype(bool) + + return tuple(shape) + + +class Resize(object): + """Resize sample to given size (width, height). + """ + + def __init__( + self, + width, + height, + resize_target=True, + keep_aspect_ratio=False, + ensure_multiple_of=1, + resize_method="lower_bound", + image_interpolation_method=cv2.INTER_AREA, + ): + """Init. + + Args: + width (int): desired output width + height (int): desired output height + resize_target (bool, optional): + True: Resize the full sample (image, mask, target). + False: Resize image only. + Defaults to True. + keep_aspect_ratio (bool, optional): + True: Keep the aspect ratio of the input sample. + Output sample might not have the given width and height, and + resize behaviour depends on the parameter 'resize_method'. + Defaults to False. + ensure_multiple_of (int, optional): + Output width and height is constrained to be multiple of this parameter. + Defaults to 1. + resize_method (str, optional): + "lower_bound": Output will be at least as large as the given size. + "upper_bound": Output will be at max as large as the given size. (Output size might be smaller than given size.) + "minimal": Scale as least as possible. (Output size might be smaller than given size.) + Defaults to "lower_bound". + """ + self.__width = width + self.__height = height + + self.__resize_target = resize_target + self.__keep_aspect_ratio = keep_aspect_ratio + self.__multiple_of = ensure_multiple_of + self.__resize_method = resize_method + self.__image_interpolation_method = image_interpolation_method + + def constrain_to_multiple_of(self, x, min_val=0, max_val=None): + y = (np.round(x / self.__multiple_of) * self.__multiple_of).astype(int) + + if max_val is not None and y > max_val: + y = (np.floor(x / self.__multiple_of) * self.__multiple_of).astype(int) + + if y < min_val: + y = (np.ceil(x / self.__multiple_of) * self.__multiple_of).astype(int) + + return y + + def get_size(self, width, height): + # determine new height and width + scale_height = self.__height / height + scale_width = self.__width / width + + if self.__keep_aspect_ratio: + if self.__resize_method == "lower_bound": + # scale such that output size is lower bound + if scale_width > scale_height: + # fit width + scale_height = scale_width + else: + # fit height + scale_width = scale_height + elif self.__resize_method == "upper_bound": + # scale such that output size is upper bound + if scale_width < scale_height: + # fit width + scale_height = scale_width + else: + # fit height + scale_width = scale_height + elif self.__resize_method == "minimal": + # scale as least as possbile + if abs(1 - scale_width) < abs(1 - scale_height): + # fit width + scale_height = scale_width + else: + # fit height + scale_width = scale_height + else: + raise ValueError( + f"resize_method {self.__resize_method} not implemented" + ) + + if self.__resize_method == "lower_bound": + new_height = self.constrain_to_multiple_of( + scale_height * height, min_val=self.__height + ) + new_width = self.constrain_to_multiple_of( + scale_width * width, min_val=self.__width + ) + elif self.__resize_method == "upper_bound": + new_height = self.constrain_to_multiple_of( + scale_height * height, max_val=self.__height + ) + new_width = self.constrain_to_multiple_of( + scale_width * width, max_val=self.__width + ) + elif self.__resize_method == "minimal": + new_height = self.constrain_to_multiple_of(scale_height * height) + new_width = self.constrain_to_multiple_of(scale_width * width) + else: + raise ValueError(f"resize_method {self.__resize_method} not implemented") + + return (new_width, new_height) + + def __call__(self, sample): + width, height = self.get_size( + sample["image"].shape[1], sample["image"].shape[0] + ) + + # resize sample + sample["image"] = cv2.resize( + sample["image"], + (width, height), + interpolation=self.__image_interpolation_method, + ) + + if self.__resize_target: + if "disparity" in sample: + sample["disparity"] = cv2.resize( + sample["disparity"], + (width, height), + interpolation=cv2.INTER_NEAREST, + ) + + if "depth" in sample: + sample["depth"] = cv2.resize( + sample["depth"], (width, height), interpolation=cv2.INTER_NEAREST + ) + + sample["mask"] = cv2.resize( + sample["mask"].astype(np.float32), + (width, height), + interpolation=cv2.INTER_NEAREST, + ) + sample["mask"] = sample["mask"].astype(bool) + + return sample + + +class NormalizeImage(object): + """Normlize image by given mean and std. + """ + + def __init__(self, mean, std): + self.__mean = mean + self.__std = std + + def __call__(self, sample): + sample["image"] = (sample["image"] - self.__mean) / self.__std + + return sample + + +class PrepareForNet(object): + """Prepare sample for usage as network input. + """ + + def __init__(self): + pass + + def __call__(self, sample): + image = np.transpose(sample["image"], (2, 0, 1)) + sample["image"] = np.ascontiguousarray(image).astype(np.float32) + + if "mask" in sample: + sample["mask"] = sample["mask"].astype(np.float32) + sample["mask"] = np.ascontiguousarray(sample["mask"]) + + if "disparity" in sample: + disparity = sample["disparity"].astype(np.float32) + sample["disparity"] = np.ascontiguousarray(disparity) + + if "depth" in sample: + depth = sample["depth"].astype(np.float32) + sample["depth"] = np.ascontiguousarray(depth) + + return sample diff --git a/examples/images/diffusion/ldm/modules/midas/midas/vit.py b/examples/images/diffusion/ldm/modules/midas/midas/vit.py new file mode 100644 index 000000000..ea46b1be8 --- /dev/null +++ b/examples/images/diffusion/ldm/modules/midas/midas/vit.py @@ -0,0 +1,491 @@ +import torch +import torch.nn as nn +import timm +import types +import math +import torch.nn.functional as F + + +class Slice(nn.Module): + def __init__(self, start_index=1): + super(Slice, self).__init__() + self.start_index = start_index + + def forward(self, x): + return x[:, self.start_index :] + + +class AddReadout(nn.Module): + def __init__(self, start_index=1): + super(AddReadout, self).__init__() + self.start_index = start_index + + def forward(self, x): + if self.start_index == 2: + readout = (x[:, 0] + x[:, 1]) / 2 + else: + readout = x[:, 0] + return x[:, self.start_index :] + readout.unsqueeze(1) + + +class ProjectReadout(nn.Module): + def __init__(self, in_features, start_index=1): + super(ProjectReadout, self).__init__() + self.start_index = start_index + + self.project = nn.Sequential(nn.Linear(2 * in_features, in_features), nn.GELU()) + + def forward(self, x): + readout = x[:, 0].unsqueeze(1).expand_as(x[:, self.start_index :]) + features = torch.cat((x[:, self.start_index :], readout), -1) + + return self.project(features) + + +class Transpose(nn.Module): + def __init__(self, dim0, dim1): + super(Transpose, self).__init__() + self.dim0 = dim0 + self.dim1 = dim1 + + def forward(self, x): + x = x.transpose(self.dim0, self.dim1) + return x + + +def forward_vit(pretrained, x): + b, c, h, w = x.shape + + glob = pretrained.model.forward_flex(x) + + layer_1 = pretrained.activations["1"] + layer_2 = pretrained.activations["2"] + layer_3 = pretrained.activations["3"] + layer_4 = pretrained.activations["4"] + + layer_1 = pretrained.act_postprocess1[0:2](layer_1) + layer_2 = pretrained.act_postprocess2[0:2](layer_2) + layer_3 = pretrained.act_postprocess3[0:2](layer_3) + layer_4 = pretrained.act_postprocess4[0:2](layer_4) + + unflatten = nn.Sequential( + nn.Unflatten( + 2, + torch.Size( + [ + h // pretrained.model.patch_size[1], + w // pretrained.model.patch_size[0], + ] + ), + ) + ) + + if layer_1.ndim == 3: + layer_1 = unflatten(layer_1) + if layer_2.ndim == 3: + layer_2 = unflatten(layer_2) + if layer_3.ndim == 3: + layer_3 = unflatten(layer_3) + if layer_4.ndim == 3: + layer_4 = unflatten(layer_4) + + layer_1 = pretrained.act_postprocess1[3 : len(pretrained.act_postprocess1)](layer_1) + layer_2 = pretrained.act_postprocess2[3 : len(pretrained.act_postprocess2)](layer_2) + layer_3 = pretrained.act_postprocess3[3 : len(pretrained.act_postprocess3)](layer_3) + layer_4 = pretrained.act_postprocess4[3 : len(pretrained.act_postprocess4)](layer_4) + + return layer_1, layer_2, layer_3, layer_4 + + +def _resize_pos_embed(self, posemb, gs_h, gs_w): + posemb_tok, posemb_grid = ( + posemb[:, : self.start_index], + posemb[0, self.start_index :], + ) + + gs_old = int(math.sqrt(len(posemb_grid))) + + posemb_grid = posemb_grid.reshape(1, gs_old, gs_old, -1).permute(0, 3, 1, 2) + posemb_grid = F.interpolate(posemb_grid, size=(gs_h, gs_w), mode="bilinear") + posemb_grid = posemb_grid.permute(0, 2, 3, 1).reshape(1, gs_h * gs_w, -1) + + posemb = torch.cat([posemb_tok, posemb_grid], dim=1) + + return posemb + + +def forward_flex(self, x): + b, c, h, w = x.shape + + pos_embed = self._resize_pos_embed( + self.pos_embed, h // self.patch_size[1], w // self.patch_size[0] + ) + + B = x.shape[0] + + if hasattr(self.patch_embed, "backbone"): + x = self.patch_embed.backbone(x) + if isinstance(x, (list, tuple)): + x = x[-1] # last feature if backbone outputs list/tuple of features + + x = self.patch_embed.proj(x).flatten(2).transpose(1, 2) + + if getattr(self, "dist_token", None) is not None: + cls_tokens = self.cls_token.expand( + B, -1, -1 + ) # stole cls_tokens impl from Phil Wang, thanks + dist_token = self.dist_token.expand(B, -1, -1) + x = torch.cat((cls_tokens, dist_token, x), dim=1) + else: + cls_tokens = self.cls_token.expand( + B, -1, -1 + ) # stole cls_tokens impl from Phil Wang, thanks + x = torch.cat((cls_tokens, x), dim=1) + + x = x + pos_embed + x = self.pos_drop(x) + + for blk in self.blocks: + x = blk(x) + + x = self.norm(x) + + return x + + +activations = {} + + +def get_activation(name): + def hook(model, input, output): + activations[name] = output + + return hook + + +def get_readout_oper(vit_features, features, use_readout, start_index=1): + if use_readout == "ignore": + readout_oper = [Slice(start_index)] * len(features) + elif use_readout == "add": + readout_oper = [AddReadout(start_index)] * len(features) + elif use_readout == "project": + readout_oper = [ + ProjectReadout(vit_features, start_index) for out_feat in features + ] + else: + assert ( + False + ), "wrong operation for readout token, use_readout can be 'ignore', 'add', or 'project'" + + return readout_oper + + +def _make_vit_b16_backbone( + model, + features=[96, 192, 384, 768], + size=[384, 384], + hooks=[2, 5, 8, 11], + vit_features=768, + use_readout="ignore", + start_index=1, +): + pretrained = nn.Module() + + pretrained.model = model + pretrained.model.blocks[hooks[0]].register_forward_hook(get_activation("1")) + pretrained.model.blocks[hooks[1]].register_forward_hook(get_activation("2")) + pretrained.model.blocks[hooks[2]].register_forward_hook(get_activation("3")) + pretrained.model.blocks[hooks[3]].register_forward_hook(get_activation("4")) + + pretrained.activations = activations + + readout_oper = get_readout_oper(vit_features, features, use_readout, start_index) + + # 32, 48, 136, 384 + pretrained.act_postprocess1 = nn.Sequential( + readout_oper[0], + Transpose(1, 2), + nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])), + nn.Conv2d( + in_channels=vit_features, + out_channels=features[0], + kernel_size=1, + stride=1, + padding=0, + ), + nn.ConvTranspose2d( + in_channels=features[0], + out_channels=features[0], + kernel_size=4, + stride=4, + padding=0, + bias=True, + dilation=1, + groups=1, + ), + ) + + pretrained.act_postprocess2 = nn.Sequential( + readout_oper[1], + Transpose(1, 2), + nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])), + nn.Conv2d( + in_channels=vit_features, + out_channels=features[1], + kernel_size=1, + stride=1, + padding=0, + ), + nn.ConvTranspose2d( + in_channels=features[1], + out_channels=features[1], + kernel_size=2, + stride=2, + padding=0, + bias=True, + dilation=1, + groups=1, + ), + ) + + pretrained.act_postprocess3 = nn.Sequential( + readout_oper[2], + Transpose(1, 2), + nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])), + nn.Conv2d( + in_channels=vit_features, + out_channels=features[2], + kernel_size=1, + stride=1, + padding=0, + ), + ) + + pretrained.act_postprocess4 = nn.Sequential( + readout_oper[3], + Transpose(1, 2), + nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])), + nn.Conv2d( + in_channels=vit_features, + out_channels=features[3], + kernel_size=1, + stride=1, + padding=0, + ), + nn.Conv2d( + in_channels=features[3], + out_channels=features[3], + kernel_size=3, + stride=2, + padding=1, + ), + ) + + pretrained.model.start_index = start_index + pretrained.model.patch_size = [16, 16] + + # We inject this function into the VisionTransformer instances so that + # we can use it with interpolated position embeddings without modifying the library source. + pretrained.model.forward_flex = types.MethodType(forward_flex, pretrained.model) + pretrained.model._resize_pos_embed = types.MethodType( + _resize_pos_embed, pretrained.model + ) + + return pretrained + + +def _make_pretrained_vitl16_384(pretrained, use_readout="ignore", hooks=None): + model = timm.create_model("vit_large_patch16_384", pretrained=pretrained) + + hooks = [5, 11, 17, 23] if hooks == None else hooks + return _make_vit_b16_backbone( + model, + features=[256, 512, 1024, 1024], + hooks=hooks, + vit_features=1024, + use_readout=use_readout, + ) + + +def _make_pretrained_vitb16_384(pretrained, use_readout="ignore", hooks=None): + model = timm.create_model("vit_base_patch16_384", pretrained=pretrained) + + hooks = [2, 5, 8, 11] if hooks == None else hooks + return _make_vit_b16_backbone( + model, features=[96, 192, 384, 768], hooks=hooks, use_readout=use_readout + ) + + +def _make_pretrained_deitb16_384(pretrained, use_readout="ignore", hooks=None): + model = timm.create_model("vit_deit_base_patch16_384", pretrained=pretrained) + + hooks = [2, 5, 8, 11] if hooks == None else hooks + return _make_vit_b16_backbone( + model, features=[96, 192, 384, 768], hooks=hooks, use_readout=use_readout + ) + + +def _make_pretrained_deitb16_distil_384(pretrained, use_readout="ignore", hooks=None): + model = timm.create_model( + "vit_deit_base_distilled_patch16_384", pretrained=pretrained + ) + + hooks = [2, 5, 8, 11] if hooks == None else hooks + return _make_vit_b16_backbone( + model, + features=[96, 192, 384, 768], + hooks=hooks, + use_readout=use_readout, + start_index=2, + ) + + +def _make_vit_b_rn50_backbone( + model, + features=[256, 512, 768, 768], + size=[384, 384], + hooks=[0, 1, 8, 11], + vit_features=768, + use_vit_only=False, + use_readout="ignore", + start_index=1, +): + pretrained = nn.Module() + + pretrained.model = model + + if use_vit_only == True: + pretrained.model.blocks[hooks[0]].register_forward_hook(get_activation("1")) + pretrained.model.blocks[hooks[1]].register_forward_hook(get_activation("2")) + else: + pretrained.model.patch_embed.backbone.stages[0].register_forward_hook( + get_activation("1") + ) + pretrained.model.patch_embed.backbone.stages[1].register_forward_hook( + get_activation("2") + ) + + pretrained.model.blocks[hooks[2]].register_forward_hook(get_activation("3")) + pretrained.model.blocks[hooks[3]].register_forward_hook(get_activation("4")) + + pretrained.activations = activations + + readout_oper = get_readout_oper(vit_features, features, use_readout, start_index) + + if use_vit_only == True: + pretrained.act_postprocess1 = nn.Sequential( + readout_oper[0], + Transpose(1, 2), + nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])), + nn.Conv2d( + in_channels=vit_features, + out_channels=features[0], + kernel_size=1, + stride=1, + padding=0, + ), + nn.ConvTranspose2d( + in_channels=features[0], + out_channels=features[0], + kernel_size=4, + stride=4, + padding=0, + bias=True, + dilation=1, + groups=1, + ), + ) + + pretrained.act_postprocess2 = nn.Sequential( + readout_oper[1], + Transpose(1, 2), + nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])), + nn.Conv2d( + in_channels=vit_features, + out_channels=features[1], + kernel_size=1, + stride=1, + padding=0, + ), + nn.ConvTranspose2d( + in_channels=features[1], + out_channels=features[1], + kernel_size=2, + stride=2, + padding=0, + bias=True, + dilation=1, + groups=1, + ), + ) + else: + pretrained.act_postprocess1 = nn.Sequential( + nn.Identity(), nn.Identity(), nn.Identity() + ) + pretrained.act_postprocess2 = nn.Sequential( + nn.Identity(), nn.Identity(), nn.Identity() + ) + + pretrained.act_postprocess3 = nn.Sequential( + readout_oper[2], + Transpose(1, 2), + nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])), + nn.Conv2d( + in_channels=vit_features, + out_channels=features[2], + kernel_size=1, + stride=1, + padding=0, + ), + ) + + pretrained.act_postprocess4 = nn.Sequential( + readout_oper[3], + Transpose(1, 2), + nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])), + nn.Conv2d( + in_channels=vit_features, + out_channels=features[3], + kernel_size=1, + stride=1, + padding=0, + ), + nn.Conv2d( + in_channels=features[3], + out_channels=features[3], + kernel_size=3, + stride=2, + padding=1, + ), + ) + + pretrained.model.start_index = start_index + pretrained.model.patch_size = [16, 16] + + # We inject this function into the VisionTransformer instances so that + # we can use it with interpolated position embeddings without modifying the library source. + pretrained.model.forward_flex = types.MethodType(forward_flex, pretrained.model) + + # We inject this function into the VisionTransformer instances so that + # we can use it with interpolated position embeddings without modifying the library source. + pretrained.model._resize_pos_embed = types.MethodType( + _resize_pos_embed, pretrained.model + ) + + return pretrained + + +def _make_pretrained_vitb_rn50_384( + pretrained, use_readout="ignore", hooks=None, use_vit_only=False +): + model = timm.create_model("vit_base_resnet50_384", pretrained=pretrained) + + hooks = [0, 1, 8, 11] if hooks == None else hooks + return _make_vit_b_rn50_backbone( + model, + features=[256, 512, 768, 768], + size=[384, 384], + hooks=hooks, + use_vit_only=use_vit_only, + use_readout=use_readout, + ) diff --git a/examples/images/diffusion/ldm/modules/midas/utils.py b/examples/images/diffusion/ldm/modules/midas/utils.py new file mode 100644 index 000000000..9a9d3b5b6 --- /dev/null +++ b/examples/images/diffusion/ldm/modules/midas/utils.py @@ -0,0 +1,189 @@ +"""Utils for monoDepth.""" +import sys +import re +import numpy as np +import cv2 +import torch + + +def read_pfm(path): + """Read pfm file. + + Args: + path (str): path to file + + Returns: + tuple: (data, scale) + """ + with open(path, "rb") as file: + + color = None + width = None + height = None + scale = None + endian = None + + header = file.readline().rstrip() + if header.decode("ascii") == "PF": + color = True + elif header.decode("ascii") == "Pf": + color = False + else: + raise Exception("Not a PFM file: " + path) + + dim_match = re.match(r"^(\d+)\s(\d+)\s$", file.readline().decode("ascii")) + if dim_match: + width, height = list(map(int, dim_match.groups())) + else: + raise Exception("Malformed PFM header.") + + scale = float(file.readline().decode("ascii").rstrip()) + if scale < 0: + # little-endian + endian = "<" + scale = -scale + else: + # big-endian + endian = ">" + + data = np.fromfile(file, endian + "f") + shape = (height, width, 3) if color else (height, width) + + data = np.reshape(data, shape) + data = np.flipud(data) + + return data, scale + + +def write_pfm(path, image, scale=1): + """Write pfm file. + + Args: + path (str): pathto file + image (array): data + scale (int, optional): Scale. Defaults to 1. + """ + + with open(path, "wb") as file: + color = None + + if image.dtype.name != "float32": + raise Exception("Image dtype must be float32.") + + image = np.flipud(image) + + if len(image.shape) == 3 and image.shape[2] == 3: # color image + color = True + elif ( + len(image.shape) == 2 or len(image.shape) == 3 and image.shape[2] == 1 + ): # greyscale + color = False + else: + raise Exception("Image must have H x W x 3, H x W x 1 or H x W dimensions.") + + file.write("PF\n" if color else "Pf\n".encode()) + file.write("%d %d\n".encode() % (image.shape[1], image.shape[0])) + + endian = image.dtype.byteorder + + if endian == "<" or endian == "=" and sys.byteorder == "little": + scale = -scale + + file.write("%f\n".encode() % scale) + + image.tofile(file) + + +def read_image(path): + """Read image and output RGB image (0-1). + + Args: + path (str): path to file + + Returns: + array: RGB image (0-1) + """ + img = cv2.imread(path) + + if img.ndim == 2: + img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR) + + img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) / 255.0 + + return img + + +def resize_image(img): + """Resize image and make it fit for network. + + Args: + img (array): image + + Returns: + tensor: data ready for network + """ + height_orig = img.shape[0] + width_orig = img.shape[1] + + if width_orig > height_orig: + scale = width_orig / 384 + else: + scale = height_orig / 384 + + height = (np.ceil(height_orig / scale / 32) * 32).astype(int) + width = (np.ceil(width_orig / scale / 32) * 32).astype(int) + + img_resized = cv2.resize(img, (width, height), interpolation=cv2.INTER_AREA) + + img_resized = ( + torch.from_numpy(np.transpose(img_resized, (2, 0, 1))).contiguous().float() + ) + img_resized = img_resized.unsqueeze(0) + + return img_resized + + +def resize_depth(depth, width, height): + """Resize depth map and bring to CPU (numpy). + + Args: + depth (tensor): depth + width (int): image width + height (int): image height + + Returns: + array: processed depth + """ + depth = torch.squeeze(depth[0, :, :, :]).to("cpu") + + depth_resized = cv2.resize( + depth.numpy(), (width, height), interpolation=cv2.INTER_CUBIC + ) + + return depth_resized + +def write_depth(path, depth, bits=1): + """Write depth map to pfm and png file. + + Args: + path (str): filepath without extension + depth (array): depth + """ + write_pfm(path + ".pfm", depth.astype(np.float32)) + + depth_min = depth.min() + depth_max = depth.max() + + max_val = (2**(8*bits))-1 + + if depth_max - depth_min > np.finfo("float").eps: + out = max_val * (depth - depth_min) / (depth_max - depth_min) + else: + out = np.zeros(depth.shape, dtype=depth.type) + + if bits == 1: + cv2.imwrite(path + ".png", out.astype("uint8")) + elif bits == 2: + cv2.imwrite(path + ".png", out.astype("uint16")) + + return diff --git a/examples/images/diffusion/ldm/modules/x_transformer.py b/examples/images/diffusion/ldm/modules/x_transformer.py deleted file mode 100644 index 5fc15bf9c..000000000 --- a/examples/images/diffusion/ldm/modules/x_transformer.py +++ /dev/null @@ -1,641 +0,0 @@ -"""shout-out to https://github.com/lucidrains/x-transformers/tree/main/x_transformers""" -import torch -from torch import nn, einsum -import torch.nn.functional as F -from functools import partial -from inspect import isfunction -from collections import namedtuple -from einops import rearrange, repeat, reduce - -# constants - -DEFAULT_DIM_HEAD = 64 - -Intermediates = namedtuple('Intermediates', [ - 'pre_softmax_attn', - 'post_softmax_attn' -]) - -LayerIntermediates = namedtuple('Intermediates', [ - 'hiddens', - 'attn_intermediates' -]) - - -class AbsolutePositionalEmbedding(nn.Module): - def __init__(self, dim, max_seq_len): - super().__init__() - self.emb = nn.Embedding(max_seq_len, dim) - self.init_() - - def init_(self): - nn.init.normal_(self.emb.weight, std=0.02) - - def forward(self, x): - n = torch.arange(x.shape[1], device=x.device) - return self.emb(n)[None, :, :] - - -class FixedPositionalEmbedding(nn.Module): - def __init__(self, dim): - super().__init__() - inv_freq = 1. / (10000 ** (torch.arange(0, dim, 2).float() / dim)) - self.register_buffer('inv_freq', inv_freq) - - def forward(self, x, seq_dim=1, offset=0): - t = torch.arange(x.shape[seq_dim], device=x.device).type_as(self.inv_freq) + offset - sinusoid_inp = torch.einsum('i , j -> i j', t, self.inv_freq) - emb = torch.cat((sinusoid_inp.sin(), sinusoid_inp.cos()), dim=-1) - return emb[None, :, :] - - -# helpers - -def exists(val): - return val is not None - - -def default(val, d): - if exists(val): - return val - return d() if isfunction(d) else d - - -def always(val): - def inner(*args, **kwargs): - return val - return inner - - -def not_equals(val): - def inner(x): - return x != val - return inner - - -def equals(val): - def inner(x): - return x == val - return inner - - -def max_neg_value(tensor): - return -torch.finfo(tensor.dtype).max - - -# keyword argument helpers - -def pick_and_pop(keys, d): - values = list(map(lambda key: d.pop(key), keys)) - return dict(zip(keys, values)) - - -def group_dict_by_key(cond, d): - return_val = [dict(), dict()] - for key in d.keys(): - match = bool(cond(key)) - ind = int(not match) - return_val[ind][key] = d[key] - return (*return_val,) - - -def string_begins_with(prefix, str): - return str.startswith(prefix) - - -def group_by_key_prefix(prefix, d): - return group_dict_by_key(partial(string_begins_with, prefix), d) - - -def groupby_prefix_and_trim(prefix, d): - kwargs_with_prefix, kwargs = group_dict_by_key(partial(string_begins_with, prefix), d) - kwargs_without_prefix = dict(map(lambda x: (x[0][len(prefix):], x[1]), tuple(kwargs_with_prefix.items()))) - return kwargs_without_prefix, kwargs - - -# classes -class Scale(nn.Module): - def __init__(self, value, fn): - super().__init__() - self.value = value - self.fn = fn - - def forward(self, x, **kwargs): - x, *rest = self.fn(x, **kwargs) - return (x * self.value, *rest) - - -class Rezero(nn.Module): - def __init__(self, fn): - super().__init__() - self.fn = fn - self.g = nn.Parameter(torch.zeros(1)) - - def forward(self, x, **kwargs): - x, *rest = self.fn(x, **kwargs) - return (x * self.g, *rest) - - -class ScaleNorm(nn.Module): - def __init__(self, dim, eps=1e-5): - super().__init__() - self.scale = dim ** -0.5 - self.eps = eps - self.g = nn.Parameter(torch.ones(1)) - - def forward(self, x): - norm = torch.norm(x, dim=-1, keepdim=True) * self.scale - return x / norm.clamp(min=self.eps) * self.g - - -class RMSNorm(nn.Module): - def __init__(self, dim, eps=1e-8): - super().__init__() - self.scale = dim ** -0.5 - self.eps = eps - self.g = nn.Parameter(torch.ones(dim)) - - def forward(self, x): - norm = torch.norm(x, dim=-1, keepdim=True) * self.scale - return x / norm.clamp(min=self.eps) * self.g - - -class Residual(nn.Module): - def forward(self, x, residual): - return x + residual - - -class GRUGating(nn.Module): - def __init__(self, dim): - super().__init__() - self.gru = nn.GRUCell(dim, dim) - - def forward(self, x, residual): - gated_output = self.gru( - rearrange(x, 'b n d -> (b n) d'), - rearrange(residual, 'b n d -> (b n) d') - ) - - return gated_output.reshape_as(x) - - -# feedforward - -class GEGLU(nn.Module): - def __init__(self, dim_in, dim_out): - super().__init__() - self.proj = nn.Linear(dim_in, dim_out * 2) - - def forward(self, x): - x, gate = self.proj(x).chunk(2, dim=-1) - return x * F.gelu(gate) - - -class FeedForward(nn.Module): - def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.): - super().__init__() - inner_dim = int(dim * mult) - dim_out = default(dim_out, dim) - project_in = nn.Sequential( - nn.Linear(dim, inner_dim), - nn.GELU() - ) if not glu else GEGLU(dim, inner_dim) - - self.net = nn.Sequential( - project_in, - nn.Dropout(dropout), - nn.Linear(inner_dim, dim_out) - ) - - def forward(self, x): - return self.net(x) - - -# attention. -class Attention(nn.Module): - def __init__( - self, - dim, - dim_head=DEFAULT_DIM_HEAD, - heads=8, - causal=False, - mask=None, - talking_heads=False, - sparse_topk=None, - use_entmax15=False, - num_mem_kv=0, - dropout=0., - on_attn=False - ): - super().__init__() - if use_entmax15: - raise NotImplementedError("Check out entmax activation instead of softmax activation!") - self.scale = dim_head ** -0.5 - self.heads = heads - self.causal = causal - self.mask = mask - - inner_dim = dim_head * heads - - self.to_q = nn.Linear(dim, inner_dim, bias=False) - self.to_k = nn.Linear(dim, inner_dim, bias=False) - self.to_v = nn.Linear(dim, inner_dim, bias=False) - self.dropout = nn.Dropout(dropout) - - # talking heads - self.talking_heads = talking_heads - if talking_heads: - self.pre_softmax_proj = nn.Parameter(torch.randn(heads, heads)) - self.post_softmax_proj = nn.Parameter(torch.randn(heads, heads)) - - # explicit topk sparse attention - self.sparse_topk = sparse_topk - - # entmax - #self.attn_fn = entmax15 if use_entmax15 else F.softmax - self.attn_fn = F.softmax - - # add memory key / values - self.num_mem_kv = num_mem_kv - if num_mem_kv > 0: - self.mem_k = nn.Parameter(torch.randn(heads, num_mem_kv, dim_head)) - self.mem_v = nn.Parameter(torch.randn(heads, num_mem_kv, dim_head)) - - # attention on attention - self.attn_on_attn = on_attn - self.to_out = nn.Sequential(nn.Linear(inner_dim, dim * 2), nn.GLU()) if on_attn else nn.Linear(inner_dim, dim) - - def forward( - self, - x, - context=None, - mask=None, - context_mask=None, - rel_pos=None, - sinusoidal_emb=None, - prev_attn=None, - mem=None - ): - b, n, _, h, talking_heads, device = *x.shape, self.heads, self.talking_heads, x.device - kv_input = default(context, x) - - q_input = x - k_input = kv_input - v_input = kv_input - - if exists(mem): - k_input = torch.cat((mem, k_input), dim=-2) - v_input = torch.cat((mem, v_input), dim=-2) - - if exists(sinusoidal_emb): - # in shortformer, the query would start at a position offset depending on the past cached memory - offset = k_input.shape[-2] - q_input.shape[-2] - q_input = q_input + sinusoidal_emb(q_input, offset=offset) - k_input = k_input + sinusoidal_emb(k_input) - - q = self.to_q(q_input) - k = self.to_k(k_input) - v = self.to_v(v_input) - - q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h=h), (q, k, v)) - - input_mask = None - if any(map(exists, (mask, context_mask))): - q_mask = default(mask, lambda: torch.ones((b, n), device=device).bool()) - k_mask = q_mask if not exists(context) else context_mask - k_mask = default(k_mask, lambda: torch.ones((b, k.shape[-2]), device=device).bool()) - q_mask = rearrange(q_mask, 'b i -> b () i ()') - k_mask = rearrange(k_mask, 'b j -> b () () j') - input_mask = q_mask * k_mask - - if self.num_mem_kv > 0: - mem_k, mem_v = map(lambda t: repeat(t, 'h n d -> b h n d', b=b), (self.mem_k, self.mem_v)) - k = torch.cat((mem_k, k), dim=-2) - v = torch.cat((mem_v, v), dim=-2) - if exists(input_mask): - input_mask = F.pad(input_mask, (self.num_mem_kv, 0), value=True) - - dots = einsum('b h i d, b h j d -> b h i j', q, k) * self.scale - mask_value = max_neg_value(dots) - - if exists(prev_attn): - dots = dots + prev_attn - - pre_softmax_attn = dots - - if talking_heads: - dots = einsum('b h i j, h k -> b k i j', dots, self.pre_softmax_proj).contiguous() - - if exists(rel_pos): - dots = rel_pos(dots) - - if exists(input_mask): - dots.masked_fill_(~input_mask, mask_value) - del input_mask - - if self.causal: - i, j = dots.shape[-2:] - r = torch.arange(i, device=device) - mask = rearrange(r, 'i -> () () i ()') < rearrange(r, 'j -> () () () j') - mask = F.pad(mask, (j - i, 0), value=False) - dots.masked_fill_(mask, mask_value) - del mask - - if exists(self.sparse_topk) and self.sparse_topk < dots.shape[-1]: - top, _ = dots.topk(self.sparse_topk, dim=-1) - vk = top[..., -1].unsqueeze(-1).expand_as(dots) - mask = dots < vk - dots.masked_fill_(mask, mask_value) - del mask - - attn = self.attn_fn(dots, dim=-1) - post_softmax_attn = attn - - attn = self.dropout(attn) - - if talking_heads: - attn = einsum('b h i j, h k -> b k i j', attn, self.post_softmax_proj).contiguous() - - out = einsum('b h i j, b h j d -> b h i d', attn, v) - out = rearrange(out, 'b h n d -> b n (h d)') - - intermediates = Intermediates( - pre_softmax_attn=pre_softmax_attn, - post_softmax_attn=post_softmax_attn - ) - - return self.to_out(out), intermediates - - -class AttentionLayers(nn.Module): - def __init__( - self, - dim, - depth, - heads=8, - causal=False, - cross_attend=False, - only_cross=False, - use_scalenorm=False, - use_rmsnorm=False, - use_rezero=False, - rel_pos_num_buckets=32, - rel_pos_max_distance=128, - position_infused_attn=False, - custom_layers=None, - sandwich_coef=None, - par_ratio=None, - residual_attn=False, - cross_residual_attn=False, - macaron=False, - pre_norm=True, - gate_residual=False, - **kwargs - ): - super().__init__() - ff_kwargs, kwargs = groupby_prefix_and_trim('ff_', kwargs) - attn_kwargs, _ = groupby_prefix_and_trim('attn_', kwargs) - - dim_head = attn_kwargs.get('dim_head', DEFAULT_DIM_HEAD) - - self.dim = dim - self.depth = depth - self.layers = nn.ModuleList([]) - - self.has_pos_emb = position_infused_attn - self.pia_pos_emb = FixedPositionalEmbedding(dim) if position_infused_attn else None - self.rotary_pos_emb = always(None) - - assert rel_pos_num_buckets <= rel_pos_max_distance, 'number of relative position buckets must be less than the relative position max distance' - self.rel_pos = None - - self.pre_norm = pre_norm - - self.residual_attn = residual_attn - self.cross_residual_attn = cross_residual_attn - - norm_class = ScaleNorm if use_scalenorm else nn.LayerNorm - norm_class = RMSNorm if use_rmsnorm else norm_class - norm_fn = partial(norm_class, dim) - - norm_fn = nn.Identity if use_rezero else norm_fn - branch_fn = Rezero if use_rezero else None - - if cross_attend and not only_cross: - default_block = ('a', 'c', 'f') - elif cross_attend and only_cross: - default_block = ('c', 'f') - else: - default_block = ('a', 'f') - - if macaron: - default_block = ('f',) + default_block - - if exists(custom_layers): - layer_types = custom_layers - elif exists(par_ratio): - par_depth = depth * len(default_block) - assert 1 < par_ratio <= par_depth, 'par ratio out of range' - default_block = tuple(filter(not_equals('f'), default_block)) - par_attn = par_depth // par_ratio - depth_cut = par_depth * 2 // 3 # 2 / 3 attention layer cutoff suggested by PAR paper - par_width = (depth_cut + depth_cut // par_attn) // par_attn - assert len(default_block) <= par_width, 'default block is too large for par_ratio' - par_block = default_block + ('f',) * (par_width - len(default_block)) - par_head = par_block * par_attn - layer_types = par_head + ('f',) * (par_depth - len(par_head)) - elif exists(sandwich_coef): - assert sandwich_coef > 0 and sandwich_coef <= depth, 'sandwich coefficient should be less than the depth' - layer_types = ('a',) * sandwich_coef + default_block * (depth - sandwich_coef) + ('f',) * sandwich_coef - else: - layer_types = default_block * depth - - self.layer_types = layer_types - self.num_attn_layers = len(list(filter(equals('a'), layer_types))) - - for layer_type in self.layer_types: - if layer_type == 'a': - layer = Attention(dim, heads=heads, causal=causal, **attn_kwargs) - elif layer_type == 'c': - layer = Attention(dim, heads=heads, **attn_kwargs) - elif layer_type == 'f': - layer = FeedForward(dim, **ff_kwargs) - layer = layer if not macaron else Scale(0.5, layer) - else: - raise Exception(f'invalid layer type {layer_type}') - - if isinstance(layer, Attention) and exists(branch_fn): - layer = branch_fn(layer) - - if gate_residual: - residual_fn = GRUGating(dim) - else: - residual_fn = Residual() - - self.layers.append(nn.ModuleList([ - norm_fn(), - layer, - residual_fn - ])) - - def forward( - self, - x, - context=None, - mask=None, - context_mask=None, - mems=None, - return_hiddens=False - ): - hiddens = [] - intermediates = [] - prev_attn = None - prev_cross_attn = None - - mems = mems.copy() if exists(mems) else [None] * self.num_attn_layers - - for ind, (layer_type, (norm, block, residual_fn)) in enumerate(zip(self.layer_types, self.layers)): - is_last = ind == (len(self.layers) - 1) - - if layer_type == 'a': - hiddens.append(x) - layer_mem = mems.pop(0) - - residual = x - - if self.pre_norm: - x = norm(x) - - if layer_type == 'a': - out, inter = block(x, mask=mask, sinusoidal_emb=self.pia_pos_emb, rel_pos=self.rel_pos, - prev_attn=prev_attn, mem=layer_mem) - elif layer_type == 'c': - out, inter = block(x, context=context, mask=mask, context_mask=context_mask, prev_attn=prev_cross_attn) - elif layer_type == 'f': - out = block(x) - - x = residual_fn(out, residual) - - if layer_type in ('a', 'c'): - intermediates.append(inter) - - if layer_type == 'a' and self.residual_attn: - prev_attn = inter.pre_softmax_attn - elif layer_type == 'c' and self.cross_residual_attn: - prev_cross_attn = inter.pre_softmax_attn - - if not self.pre_norm and not is_last: - x = norm(x) - - if return_hiddens: - intermediates = LayerIntermediates( - hiddens=hiddens, - attn_intermediates=intermediates - ) - - return x, intermediates - - return x - - -class Encoder(AttentionLayers): - def __init__(self, **kwargs): - assert 'causal' not in kwargs, 'cannot set causality on encoder' - super().__init__(causal=False, **kwargs) - - - -class TransformerWrapper(nn.Module): - def __init__( - self, - *, - num_tokens, - max_seq_len, - attn_layers, - emb_dim=None, - max_mem_len=0., - emb_dropout=0., - num_memory_tokens=None, - tie_embedding=False, - use_pos_emb=True - ): - super().__init__() - assert isinstance(attn_layers, AttentionLayers), 'attention layers must be one of Encoder or Decoder' - - dim = attn_layers.dim - emb_dim = default(emb_dim, dim) - - self.max_seq_len = max_seq_len - self.max_mem_len = max_mem_len - self.num_tokens = num_tokens - - self.token_emb = nn.Embedding(num_tokens, emb_dim) - self.pos_emb = AbsolutePositionalEmbedding(emb_dim, max_seq_len) if ( - use_pos_emb and not attn_layers.has_pos_emb) else always(0) - self.emb_dropout = nn.Dropout(emb_dropout) - - self.project_emb = nn.Linear(emb_dim, dim) if emb_dim != dim else nn.Identity() - self.attn_layers = attn_layers - self.norm = nn.LayerNorm(dim) - - self.init_() - - self.to_logits = nn.Linear(dim, num_tokens) if not tie_embedding else lambda t: t @ self.token_emb.weight.t() - - # memory tokens (like [cls]) from Memory Transformers paper - num_memory_tokens = default(num_memory_tokens, 0) - self.num_memory_tokens = num_memory_tokens - if num_memory_tokens > 0: - self.memory_tokens = nn.Parameter(torch.randn(num_memory_tokens, dim)) - - # let funnel encoder know number of memory tokens, if specified - if hasattr(attn_layers, 'num_memory_tokens'): - attn_layers.num_memory_tokens = num_memory_tokens - - def init_(self): - nn.init.normal_(self.token_emb.weight, std=0.02) - - def forward( - self, - x, - return_embeddings=False, - mask=None, - return_mems=False, - return_attn=False, - mems=None, - **kwargs - ): - b, n, device, num_mem = *x.shape, x.device, self.num_memory_tokens - x = self.token_emb(x) - x += self.pos_emb(x) - x = self.emb_dropout(x) - - x = self.project_emb(x) - - if num_mem > 0: - mem = repeat(self.memory_tokens, 'n d -> b n d', b=b) - x = torch.cat((mem, x), dim=1) - - # auto-handle masking after appending memory tokens - if exists(mask): - mask = F.pad(mask, (num_mem, 0), value=True) - - x, intermediates = self.attn_layers(x, mask=mask, mems=mems, return_hiddens=True, **kwargs) - x = self.norm(x) - - mem, x = x[:, :num_mem], x[:, num_mem:] - - out = self.to_logits(x) if not return_embeddings else x - - if return_mems: - hiddens = intermediates.hiddens - new_mems = list(map(lambda pair: torch.cat(pair, dim=-2), zip(mems, hiddens))) if exists(mems) else hiddens - new_mems = list(map(lambda t: t[..., -self.max_mem_len:, :].detach(), new_mems)) - return out, new_mems - - if return_attn: - attn_maps = list(map(lambda t: t.post_softmax_attn, intermediates.attn_intermediates)) - return out, attn_maps - - return out - diff --git a/examples/images/diffusion/ldm/util.py b/examples/images/diffusion/ldm/util.py index 8ba38853e..8c09ca1c7 100644 --- a/examples/images/diffusion/ldm/util.py +++ b/examples/images/diffusion/ldm/util.py @@ -1,14 +1,8 @@ import importlib import torch +from torch import optim import numpy as np -from collections import abc -from einops import rearrange -from functools import partial - -import multiprocessing as mp -from threading import Thread -from queue import Queue from inspect import isfunction from PIL import Image, ImageDraw, ImageFont @@ -45,7 +39,7 @@ def ismap(x): def isimage(x): - if not isinstance(x, torch.Tensor): + if not isinstance(x,torch.Tensor): return False return (len(x.shape) == 4) and (x.shape[1] == 3 or x.shape[1] == 1) @@ -71,7 +65,7 @@ def mean_flat(tensor): def count_params(model, verbose=False): total_params = sum(p.numel() for p in model.parameters()) if verbose: - print(f"{model.__class__.__name__} has {total_params * 1.e-6:.2f} M params.") + print(f"{model.__class__.__name__} has {total_params*1.e-6:.2f} M params.") return total_params @@ -93,111 +87,111 @@ def get_obj_from_str(string, reload=False): return getattr(importlib.import_module(module, package=None), cls) -def _do_parallel_data_prefetch(func, Q, data, idx, idx_to_fn=False): - # create dummy dataset instance +class AdamWwithEMAandWings(optim.Optimizer): + # credit to https://gist.github.com/crowsonkb/65f7265353f403714fce3b2595e0b298 + def __init__(self, params, lr=1.e-3, betas=(0.9, 0.999), eps=1.e-8, # TODO: check hyperparameters before using + weight_decay=1.e-2, amsgrad=False, ema_decay=0.9999, # ema decay to match previous code + ema_power=1., param_names=()): + """AdamW that saves EMA versions of the parameters.""" + if not 0.0 <= lr: + raise ValueError("Invalid learning rate: {}".format(lr)) + if not 0.0 <= eps: + raise ValueError("Invalid epsilon value: {}".format(eps)) + if not 0.0 <= betas[0] < 1.0: + raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0])) + if not 0.0 <= betas[1] < 1.0: + raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1])) + if not 0.0 <= weight_decay: + raise ValueError("Invalid weight_decay value: {}".format(weight_decay)) + if not 0.0 <= ema_decay <= 1.0: + raise ValueError("Invalid ema_decay value: {}".format(ema_decay)) + defaults = dict(lr=lr, betas=betas, eps=eps, + weight_decay=weight_decay, amsgrad=amsgrad, ema_decay=ema_decay, + ema_power=ema_power, param_names=param_names) + super().__init__(params, defaults) - # run prefetching - if idx_to_fn: - res = func(data, worker_id=idx) - else: - res = func(data) - Q.put([idx, res]) - Q.put("Done") + def __setstate__(self, state): + super().__setstate__(state) + for group in self.param_groups: + group.setdefault('amsgrad', False) + @torch.no_grad() + def step(self, closure=None): + """Performs a single optimization step. + Args: + closure (callable, optional): A closure that reevaluates the model + and returns the loss. + """ + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() -def parallel_data_prefetch( - func: callable, data, n_proc, target_data_type="ndarray", cpu_intensive=True, use_worker_id=False -): - # if target_data_type not in ["ndarray", "list"]: - # raise ValueError( - # "Data, which is passed to parallel_data_prefetch has to be either of type list or ndarray." - # ) - if isinstance(data, np.ndarray) and target_data_type == "list": - raise ValueError("list expected but function got ndarray.") - elif isinstance(data, abc.Iterable): - if isinstance(data, dict): - print( - f'WARNING:"data" argument passed to parallel_data_prefetch is a dict: Using only its values and disregarding keys.' - ) - data = list(data.values()) - if target_data_type == "ndarray": - data = np.asarray(data) - else: - data = list(data) - else: - raise TypeError( - f"The data, that shall be processed parallel has to be either an np.ndarray or an Iterable, but is actually {type(data)}." - ) + for group in self.param_groups: + params_with_grad = [] + grads = [] + exp_avgs = [] + exp_avg_sqs = [] + ema_params_with_grad = [] + state_sums = [] + max_exp_avg_sqs = [] + state_steps = [] + amsgrad = group['amsgrad'] + beta1, beta2 = group['betas'] + ema_decay = group['ema_decay'] + ema_power = group['ema_power'] - if cpu_intensive: - Q = mp.Queue(1000) - proc = mp.Process - else: - Q = Queue(1000) - proc = Thread - # spawn processes - if target_data_type == "ndarray": - arguments = [ - [func, Q, part, i, use_worker_id] - for i, part in enumerate(np.array_split(data, n_proc)) - ] - else: - step = ( - int(len(data) / n_proc + 1) - if len(data) % n_proc != 0 - else int(len(data) / n_proc) - ) - arguments = [ - [func, Q, part, i, use_worker_id] - for i, part in enumerate( - [data[i: i + step] for i in range(0, len(data), step)] - ) - ] - processes = [] - for i in range(n_proc): - p = proc(target=_do_parallel_data_prefetch, args=arguments[i]) - processes += [p] + for p in group['params']: + if p.grad is None: + continue + params_with_grad.append(p) + if p.grad.is_sparse: + raise RuntimeError('AdamW does not support sparse gradients') + grads.append(p.grad) - # start processes - print(f"Start prefetching...") - import time + state = self.state[p] - start = time.time() - gather_res = [[] for _ in range(n_proc)] - try: - for p in processes: - p.start() + # State initialization + if len(state) == 0: + state['step'] = 0 + # Exponential moving average of gradient values + state['exp_avg'] = torch.zeros_like(p, memory_format=torch.preserve_format) + # Exponential moving average of squared gradient values + state['exp_avg_sq'] = torch.zeros_like(p, memory_format=torch.preserve_format) + if amsgrad: + # Maintains max of all exp. moving avg. of sq. grad. values + state['max_exp_avg_sq'] = torch.zeros_like(p, memory_format=torch.preserve_format) + # Exponential moving average of parameter values + state['param_exp_avg'] = p.detach().float().clone() - k = 0 - while k < n_proc: - # get result - res = Q.get() - if res == "Done": - k += 1 - else: - gather_res[res[0]] = res[1] + exp_avgs.append(state['exp_avg']) + exp_avg_sqs.append(state['exp_avg_sq']) + ema_params_with_grad.append(state['param_exp_avg']) - except Exception as e: - print("Exception: ", e) - for p in processes: - p.terminate() + if amsgrad: + max_exp_avg_sqs.append(state['max_exp_avg_sq']) - raise e - finally: - for p in processes: - p.join() - print(f"Prefetching complete. [{time.time() - start} sec.]") + # update the steps for each param group update + state['step'] += 1 + # record the step after step update + state_steps.append(state['step']) - if target_data_type == 'ndarray': - if not isinstance(gather_res[0], np.ndarray): - return np.concatenate([np.asarray(r) for r in gather_res], axis=0) + optim._functional.adamw(params_with_grad, + grads, + exp_avgs, + exp_avg_sqs, + max_exp_avg_sqs, + state_steps, + amsgrad=amsgrad, + beta1=beta1, + beta2=beta2, + lr=group['lr'], + weight_decay=group['weight_decay'], + eps=group['eps'], + maximize=False) - # order outputs - return np.concatenate(gather_res, axis=0) - elif target_data_type == 'list': - out = [] - for r in gather_res: - out.extend(r) - return out - else: - return gather_res + cur_ema_decay = min(ema_decay, 1 - state['step'] ** -ema_power) + for param, ema_param in zip(params_with_grad, ema_params_with_grad): + ema_param.mul_(cur_ema_decay).add_(param.float(), alpha=1 - cur_ema_decay) + + return loss \ No newline at end of file diff --git a/examples/images/diffusion/main.py b/examples/images/diffusion/main.py index f968227e5..87d495123 100644 --- a/examples/images/diffusion/main.py +++ b/examples/images/diffusion/main.py @@ -1,54 +1,48 @@ -import argparse, os, sys, datetime, glob, importlib, csv -import numpy as np +import argparse +import csv +import datetime +import glob +import importlib +import os +import sys import time + +import numpy as np import torch import torchvision -import lightning.pytorch as pl -from packaging import version -from omegaconf import OmegaConf -from torch.utils.data import random_split, DataLoader, Dataset, Subset +try: + import lightning.pytorch as pl +except: + import pytorch_lightning as pl + from functools import partial + +from omegaconf import OmegaConf +from packaging import version from PIL import Image -# from lightning.pytorch.strategies.colossalai import ColossalAIStrategy -# from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR -from colossalai.nn.optimizer import HybridAdam from prefetch_generator import BackgroundGenerator +from torch.utils.data import DataLoader, Dataset, Subset, random_split -from lightning.pytorch import seed_everything -from lightning.pytorch.trainer import Trainer -from lightning.pytorch.callbacks import ModelCheckpoint, Callback, LearningRateMonitor -from lightning.pytorch.utilities.rank_zero import rank_zero_only -from lightning.pytorch.utilities import rank_zero_info -from diffusers.models.unet_2d import UNet2DModel - -from clip.model import Bottleneck -from transformers.models.clip.modeling_clip import CLIPTextTransformer +try: + from lightning.pytorch import seed_everything + from lightning.pytorch.callbacks import Callback, LearningRateMonitor, ModelCheckpoint + from lightning.pytorch.trainer import Trainer + from lightning.pytorch.utilities import rank_zero_info, rank_zero_only + LIGHTNING_PACK_NAME = "lightning.pytorch." +except: + from pytorch_lightning import seed_everything + from pytorch_lightning.callbacks import Callback, LearningRateMonitor, ModelCheckpoint + from pytorch_lightning.trainer import Trainer + from pytorch_lightning.utilities import rank_zero_info, rank_zero_only + LIGHTNING_PACK_NAME = "pytorch_lightning." from ldm.data.base import Txt2ImgIterableBaseDataset from ldm.util import instantiate_from_config -import clip -from einops import rearrange, repeat -from transformers import CLIPTokenizer, CLIPTextModel -import kornia -from ldm.modules.x_transformer import * -from ldm.modules.encoders.modules import * -from taming.modules.diffusionmodules.model import ResnetBlock -from taming.modules.transformer.mingpt import * -from taming.modules.transformer.permuter import * +# from ldm.modules.attention import enable_flash_attentions -from ldm.modules.ema import LitEma -from ldm.modules.distributions.distributions import normal_kl, DiagonalGaussianDistribution -from ldm.models.autoencoder import AutoencoderKL -from ldm.models.autoencoder import * -from ldm.models.diffusion.ddim import * -from ldm.modules.diffusionmodules.openaimodel import * -from ldm.modules.diffusionmodules.model import * -from ldm.modules.diffusionmodules.model import Decoder, Encoder, Up_module, Down_module, Mid_module, temb_module -from ldm.modules.attention import enable_flash_attention - class DataLoaderX(DataLoader): def __iter__(self): @@ -56,6 +50,7 @@ class DataLoaderX(DataLoader): def get_parser(**parser_kwargs): + def str2bool(v): if isinstance(v, bool): return v @@ -91,7 +86,7 @@ def get_parser(**parser_kwargs): nargs="*", metavar="base_config.yaml", help="paths to base configs. Loaded from left-to-right. " - "Parameters can be overwritten or added with command-line options of the form `--key value`.", + "Parameters can be overwritten or added with command-line options of the form `--key value`.", default=list(), ) parser.add_argument( @@ -111,11 +106,7 @@ def get_parser(**parser_kwargs): nargs="?", help="disable test", ) - parser.add_argument( - "-p", - "--project", - help="name of new or path to existing project" - ) + parser.add_argument("-p", "--project", help="name of new or path to existing project") parser.add_argument( "-d", "--debug", @@ -210,8 +201,17 @@ def worker_init_fn(_): class DataModuleFromConfig(pl.LightningDataModule): - def __init__(self, batch_size, train=None, validation=None, test=None, predict=None, - wrap=False, num_workers=None, shuffle_test_loader=False, use_worker_init_fn=False, + + def __init__(self, + batch_size, + train=None, + validation=None, + test=None, + predict=None, + wrap=False, + num_workers=None, + shuffle_test_loader=False, + use_worker_init_fn=False, shuffle_val_dataloader=False): super().__init__() self.batch_size = batch_size @@ -237,9 +237,7 @@ class DataModuleFromConfig(pl.LightningDataModule): instantiate_from_config(data_cfg) def setup(self, stage=None): - self.datasets = dict( - (k, instantiate_from_config(self.dataset_configs[k])) - for k in self.dataset_configs) + self.datasets = dict((k, instantiate_from_config(self.dataset_configs[k])) for k in self.dataset_configs) if self.wrap: for k in self.datasets: self.datasets[k] = WrappedDataset(self.datasets[k]) @@ -250,9 +248,11 @@ class DataModuleFromConfig(pl.LightningDataModule): init_fn = worker_init_fn else: init_fn = None - return DataLoaderX(self.datasets["train"], batch_size=self.batch_size, - num_workers=self.num_workers, shuffle=False if is_iterable_dataset else True, - worker_init_fn=init_fn) + return DataLoaderX(self.datasets["train"], + batch_size=self.batch_size, + num_workers=self.num_workers, + shuffle=False if is_iterable_dataset else True, + worker_init_fn=init_fn) def _val_dataloader(self, shuffle=False): if isinstance(self.datasets['validation'], Txt2ImgIterableBaseDataset) or self.use_worker_init_fn: @@ -260,10 +260,10 @@ class DataModuleFromConfig(pl.LightningDataModule): else: init_fn = None return DataLoaderX(self.datasets["validation"], - batch_size=self.batch_size, - num_workers=self.num_workers, - worker_init_fn=init_fn, - shuffle=shuffle) + batch_size=self.batch_size, + num_workers=self.num_workers, + worker_init_fn=init_fn, + shuffle=shuffle) def _test_dataloader(self, shuffle=False): is_iterable_dataset = isinstance(self.datasets['train'], Txt2ImgIterableBaseDataset) @@ -275,19 +275,25 @@ class DataModuleFromConfig(pl.LightningDataModule): # do not shuffle dataloader for iterable dataset shuffle = shuffle and (not is_iterable_dataset) - return DataLoaderX(self.datasets["test"], batch_size=self.batch_size, - num_workers=self.num_workers, worker_init_fn=init_fn, shuffle=shuffle) + return DataLoaderX(self.datasets["test"], + batch_size=self.batch_size, + num_workers=self.num_workers, + worker_init_fn=init_fn, + shuffle=shuffle) def _predict_dataloader(self, shuffle=False): if isinstance(self.datasets['predict'], Txt2ImgIterableBaseDataset) or self.use_worker_init_fn: init_fn = worker_init_fn else: init_fn = None - return DataLoaderX(self.datasets["predict"], batch_size=self.batch_size, - num_workers=self.num_workers, worker_init_fn=init_fn) + return DataLoaderX(self.datasets["predict"], + batch_size=self.batch_size, + num_workers=self.num_workers, + worker_init_fn=init_fn) class SetupCallback(Callback): + def __init__(self, resume, now, logdir, ckptdir, cfgdir, config, lightning_config): super().__init__() self.resume = resume @@ -317,8 +323,7 @@ class SetupCallback(Callback): os.makedirs(os.path.join(self.ckptdir, 'trainstep_checkpoints'), exist_ok=True) print("Project config") print(OmegaConf.to_yaml(self.config)) - OmegaConf.save(self.config, - os.path.join(self.cfgdir, "{}-project.yaml".format(self.now))) + OmegaConf.save(self.config, os.path.join(self.cfgdir, "{}-project.yaml".format(self.now))) print("Lightning config") print(OmegaConf.to_yaml(self.lightning_config)) @@ -338,8 +343,16 @@ class SetupCallback(Callback): class ImageLogger(Callback): - def __init__(self, batch_frequency, max_images, clamp=True, increase_log_steps=True, - rescale=True, disabled=False, log_on_batch_idx=False, log_first_step=False, + + def __init__(self, + batch_frequency, + max_images, + clamp=True, + increase_log_steps=True, + rescale=True, + disabled=False, + log_on_batch_idx=False, + log_first_step=False, log_images_kwargs=None): super().__init__() self.rescale = rescale @@ -348,7 +361,7 @@ class ImageLogger(Callback): self.logger_log_images = { pl.loggers.CSVLogger: self._testtube, } - self.log_steps = [2 ** n for n in range(int(np.log2(self.batch_freq)) + 1)] + self.log_steps = [2**n for n in range(int(np.log2(self.batch_freq)) + 1)] if not increase_log_steps: self.log_steps = [self.batch_freq] self.clamp = clamp @@ -361,39 +374,30 @@ class ImageLogger(Callback): def _testtube(self, pl_module, images, batch_idx, split): for k in images: grid = torchvision.utils.make_grid(images[k]) - grid = (grid + 1.0) / 2.0 # -1,1 -> 0,1; c,h,w + grid = (grid + 1.0) / 2.0 # -1,1 -> 0,1; c,h,w tag = f"{split}/{k}" - pl_module.logger.experiment.add_image( - tag, grid, - global_step=pl_module.global_step) + pl_module.logger.experiment.add_image(tag, grid, global_step=pl_module.global_step) @rank_zero_only - def log_local(self, save_dir, split, images, - global_step, current_epoch, batch_idx): + def log_local(self, save_dir, split, images, global_step, current_epoch, batch_idx): root = os.path.join(save_dir, "images", split) for k in images: grid = torchvision.utils.make_grid(images[k], nrow=4) if self.rescale: - grid = (grid + 1.0) / 2.0 # -1,1 -> 0,1; c,h,w + grid = (grid + 1.0) / 2.0 # -1,1 -> 0,1; c,h,w grid = grid.transpose(0, 1).transpose(1, 2).squeeze(-1) grid = grid.numpy() grid = (grid * 255).astype(np.uint8) - filename = "{}_gs-{:06}_e-{:06}_b-{:06}.png".format( - k, - global_step, - current_epoch, - batch_idx) + filename = "{}_gs-{:06}_e-{:06}_b-{:06}.png".format(k, global_step, current_epoch, batch_idx) path = os.path.join(root, filename) os.makedirs(os.path.split(path)[0], exist_ok=True) Image.fromarray(grid).save(path) def log_img(self, pl_module, batch, batch_idx, split="train"): check_idx = batch_idx if self.log_on_batch_idx else pl_module.global_step - if (self.check_frequency(check_idx) and # batch_idx % self.batch_freq == 0 - hasattr(pl_module, "log_images") and - callable(pl_module.log_images) and - self.max_images > 0): + if (self.check_frequency(check_idx) and # batch_idx % self.batch_freq == 0 + hasattr(pl_module, "log_images") and callable(pl_module.log_images) and self.max_images > 0): logger = type(pl_module.logger) is_train = pl_module.training @@ -411,8 +415,8 @@ class ImageLogger(Callback): if self.clamp: images[k] = torch.clamp(images[k], -1., 1.) - self.log_local(pl_module.logger.save_dir, split, images, - pl_module.global_step, pl_module.current_epoch, batch_idx) + self.log_local(pl_module.logger.save_dir, split, images, pl_module.global_step, pl_module.current_epoch, + batch_idx) logger_log_images = self.logger_log_images.get(logger, lambda *args, **kwargs: None) logger_log_images(pl_module, images, pl_module.global_step, split) @@ -421,8 +425,8 @@ class ImageLogger(Callback): pl_module.train() def check_frequency(self, check_idx): - if ((check_idx % self.batch_freq) == 0 or (check_idx in self.log_steps)) and ( - check_idx > 0 or self.log_first_step): + if ((check_idx % self.batch_freq) == 0 or + (check_idx in self.log_steps)) and (check_idx > 0 or self.log_first_step): try: self.log_steps.pop(0) except IndexError as e: @@ -461,7 +465,7 @@ class CUDACallback(Callback): def on_train_epoch_end(self, trainer, pl_module): torch.cuda.synchronize(trainer.strategy.root_device.index) - max_memory = torch.cuda.max_memory_allocated(trainer.strategy.root_device.index) / 2 ** 20 + max_memory = torch.cuda.max_memory_allocated(trainer.strategy.root_device.index) / 2**20 epoch_time = time.time() - self.start_time try: @@ -528,13 +532,9 @@ if __name__ == "__main__": opt, unknown = parser.parse_known_args() if opt.name and opt.resume: - raise ValueError( - "-n/--name and -r/--resume cannot be specified both." - "If you want to resume training in a new log folder, " - "use -n/--name in combination with --resume_from_checkpoint" - ) - if opt.flash: - enable_flash_attention() + raise ValueError("-n/--name and -r/--resume cannot be specified both." + "If you want to resume training in a new log folder, " + "use -n/--name in combination with --resume_from_checkpoint") if opt.resume: if not os.path.exists(opt.resume): raise ValueError("Cannot find {}".format(opt.resume)) @@ -578,7 +578,7 @@ if __name__ == "__main__": lightning_config = config.pop("lightning", OmegaConf.create()) # merge trainer cli with config trainer_config = lightning_config.get("trainer", OmegaConf.create()) - + for k in nondefault_trainer_args(opt): trainer_config[k] = getattr(opt, k) @@ -601,7 +601,7 @@ if __name__ == "__main__": else: config.model["params"].update({"use_fp16": False}) print("Using FP16 = {}".format(config.model["params"]["use_fp16"])) - + model = instantiate_from_config(config.model) # trainer and callbacks trainer_kwargs = dict() @@ -610,7 +610,7 @@ if __name__ == "__main__": # default logger configs default_logger_cfgs = { "wandb": { - "target": "lightning.pytorch.loggers.WandbLogger", + "target": LIGHTNING_PACK_NAME + "loggers.WandbLogger", "params": { "name": nowname, "save_dir": logdir, @@ -618,9 +618,9 @@ if __name__ == "__main__": "id": nowname, } }, - "tensorboard":{ - "target": "lightning.pytorch.loggers.TensorBoardLogger", - "params":{ + "tensorboard": { + "target": LIGHTNING_PACK_NAME + "loggers.TensorBoardLogger", + "params": { "save_dir": logdir, "name": "diff_tb", "log_graph": True @@ -640,9 +640,10 @@ if __name__ == "__main__": if "strategy" in trainer_config: strategy_cfg = trainer_config["strategy"] print("Using strategy: {}".format(strategy_cfg["target"])) + strategy_cfg["target"] = LIGHTNING_PACK_NAME + strategy_cfg["target"] else: strategy_cfg = { - "target": "lightning.pytorch.strategies.DDPStrategy", + "target": LIGHTNING_PACK_NAME + "strategies.DDPStrategy", "params": { "find_unused_parameters": False } @@ -654,7 +655,7 @@ if __name__ == "__main__": # modelcheckpoint - use TrainResult/EvalResult(checkpoint_on=metric) to # specify which metric is used to determine best models default_modelckpt_cfg = { - "target": "lightning.pytorch.callbacks.ModelCheckpoint", + "target": LIGHTNING_PACK_NAME + "callbacks.ModelCheckpoint", "params": { "dirpath": ckptdir, "filename": "{epoch:06}", @@ -670,7 +671,7 @@ if __name__ == "__main__": if "modelcheckpoint" in lightning_config: modelckpt_cfg = lightning_config.modelcheckpoint else: - modelckpt_cfg = OmegaConf.create() + modelckpt_cfg = OmegaConf.create() modelckpt_cfg = OmegaConf.merge(default_modelckpt_cfg, modelckpt_cfg) print(f"Merged modelckpt-cfg: \n{modelckpt_cfg}") if version.parse(pl.__version__) < version.parse('1.4.0'): @@ -702,7 +703,7 @@ if __name__ == "__main__": "target": "main.LearningRateMonitor", "params": { "logging_interval": "step", - # "log_momentum": True + # "log_momentum": True } }, "cuda_callback": { @@ -721,17 +722,17 @@ if __name__ == "__main__": print( 'Caution: Saving checkpoints every n train steps without deleting. This might require some free space.') default_metrics_over_trainsteps_ckpt_dict = { - 'metrics_over_trainsteps_checkpoint': - {"target": 'lightning.pytorch.callbacks.ModelCheckpoint', - 'params': { - "dirpath": os.path.join(ckptdir, 'trainstep_checkpoints'), - "filename": "{epoch:06}-{step:09}", - "verbose": True, - 'save_top_k': -1, - 'every_n_train_steps': 10000, - 'save_weights_only': True - } - } + 'metrics_over_trainsteps_checkpoint': { + "target": LIGHTNING_PACK_NAME + 'callbacks.ModelCheckpoint', + 'params': { + "dirpath": os.path.join(ckptdir, 'trainstep_checkpoints'), + "filename": "{epoch:06}-{step:09}", + "verbose": True, + 'save_top_k': -1, + 'every_n_train_steps': 10000, + 'save_weights_only': True + } + } } default_callbacks_cfg.update(default_metrics_over_trainsteps_ckpt_dict) @@ -744,7 +745,7 @@ if __name__ == "__main__": trainer_kwargs["callbacks"] = [instantiate_from_config(callbacks_cfg[k]) for k in callbacks_cfg] trainer = Trainer.from_argparse_args(trainer_opt, **trainer_kwargs) - trainer.logdir = logdir ### + trainer.logdir = logdir ### # data data = instantiate_from_config(config.data) @@ -772,14 +773,13 @@ if __name__ == "__main__": if opt.scale_lr: model.learning_rate = accumulate_grad_batches * ngpu * bs * base_lr print( - "Setting learning rate to {:.2e} = {} (accumulate_grad_batches) * {} (num_gpus) * {} (batchsize) * {:.2e} (base_lr)".format( - model.learning_rate, accumulate_grad_batches, ngpu, bs, base_lr)) + "Setting learning rate to {:.2e} = {} (accumulate_grad_batches) * {} (num_gpus) * {} (batchsize) * {:.2e} (base_lr)" + .format(model.learning_rate, accumulate_grad_batches, ngpu, bs, base_lr)) else: model.learning_rate = base_lr print("++++ NOT USING LR SCALING ++++") print(f"Setting learning rate to {model.learning_rate:.2e}") - # allow checkpointing via USR1 def melk(*args, **kwargs): # run all checkpoint hooks @@ -788,13 +788,11 @@ if __name__ == "__main__": ckpt_path = os.path.join(ckptdir, "last.ckpt") trainer.save_checkpoint(ckpt_path) - def divein(*args, **kwargs): if trainer.global_rank == 0: - import pudb; + import pudb pudb.set_trace() - import signal signal.signal(signal.SIGUSR1, melk) @@ -803,8 +801,6 @@ if __name__ == "__main__": # run if opt.train: try: - for name, m in model.named_parameters(): - print(name) trainer.fit(model, data) except Exception: melk() diff --git a/examples/images/diffusion/requirements.txt b/examples/images/diffusion/requirements.txt index 01d6560ca..5a83b2aa3 100644 --- a/examples/images/diffusion/requirements.txt +++ b/examples/images/diffusion/requirements.txt @@ -1,22 +1,17 @@ -albumentations==0.4.3 -diffusers +albumentations==1.3.0 +opencv-python pudb==2019.2 -datasets -invisible-watermark +prefetch_generator imageio==2.9.0 imageio-ffmpeg==0.4.2 +torchmetrics==0.6 omegaconf==2.1.1 -multiprocess -lightning==1.8.1 test-tube>=0.7.5 streamlit>=0.73.1 einops==0.3.0 -torch-fidelity==0.3.0 transformers==4.19.2 -torchmetrics==0.6.0 -kornia==0.6 -opencv-python==4.6.0.66 -prefetch_generator --e git+https://github.com/CompVis/taming-transformers.git@master#egg=taming-transformers --e git+https://github.com/openai/CLIP.git@main#egg=clip +webdataset==0.2.5 +open-clip-torch==2.7.0 +gradio==3.11 +datasets -e . diff --git a/examples/images/diffusion/scripts/img2img.py b/examples/images/diffusion/scripts/img2img.py index 9fc46deb8..e8ccfa259 100644 --- a/examples/images/diffusion/scripts/img2img.py +++ b/examples/images/diffusion/scripts/img2img.py @@ -1,6 +1,6 @@ """make variations of input image""" -import argparse, os, sys, glob +import argparse, os import PIL import torch import numpy as np @@ -12,12 +12,16 @@ from einops import rearrange, repeat from torchvision.utils import make_grid from torch import autocast from contextlib import nullcontext -import time -from lightning.pytorch import seed_everything +try: + from lightning.pytorch import seed_everything +except: + from pytorch_lightning import seed_everything +from imwatermark import WatermarkEncoder + +from scripts.txt2img import put_watermark from ldm.util import instantiate_from_config from ldm.models.diffusion.ddim import DDIMSampler -from ldm.models.diffusion.plms import PLMSSampler def chunk(it, size): @@ -49,12 +53,12 @@ def load_img(path): image = Image.open(path).convert("RGB") w, h = image.size print(f"loaded input image of size ({w}, {h}) from {path}") - w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32 + w, h = map(lambda x: x - x % 64, (w, h)) # resize to integer multiple of 64 image = image.resize((w, h), resample=PIL.Image.LANCZOS) image = np.array(image).astype(np.float32) / 255.0 image = image[None].transpose(0, 3, 1, 2) image = torch.from_numpy(image) - return 2.*image - 1. + return 2. * image - 1. def main(): @@ -83,18 +87,6 @@ def main(): default="outputs/img2img-samples" ) - parser.add_argument( - "--skip_grid", - action='store_true', - help="do not save a grid, only individual samples. Helpful when evaluating lots of samples", - ) - - parser.add_argument( - "--skip_save", - action='store_true', - help="do not save indiviual samples. For speed measurements.", - ) - parser.add_argument( "--ddim_steps", type=int, @@ -102,11 +94,6 @@ def main(): help="number of ddim sampling steps", ) - parser.add_argument( - "--plms", - action='store_true', - help="use plms sampling", - ) parser.add_argument( "--fixed_code", action='store_true', @@ -125,6 +112,7 @@ def main(): default=1, help="sample this often", ) + parser.add_argument( "--C", type=int, @@ -137,31 +125,35 @@ def main(): default=8, help="downsampling factor, most often 8 or 16", ) + parser.add_argument( "--n_samples", type=int, default=2, help="how many samples to produce for each given prompt. A.k.a batch size", ) + parser.add_argument( "--n_rows", type=int, default=0, help="rows in the grid (default: n_samples)", ) + parser.add_argument( "--scale", type=float, - default=5.0, + default=9.0, help="unconditional guidance scale: eps = eps(x, empty) + scale * (eps(x, cond) - eps(x, empty))", ) parser.add_argument( "--strength", type=float, - default=0.75, + default=0.8, help="strength for noising/unnoising. 1.0 corresponds to full destruction of information in init image", ) + parser.add_argument( "--from-file", type=str, @@ -170,13 +162,12 @@ def main(): parser.add_argument( "--config", type=str, - default="configs/stable-diffusion/v1-inference.yaml", + default="configs/stable-diffusion/v2-inference.yaml", help="path to config which constructs model", ) parser.add_argument( "--ckpt", type=str, - default="models/ldm/stable-diffusion-v1/model.ckpt", help="path to checkpoint of model", ) parser.add_argument( @@ -202,15 +193,16 @@ def main(): device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") model = model.to(device) - if opt.plms: - raise NotImplementedError("PLMS sampler not (yet) supported") - sampler = PLMSSampler(model) - else: - sampler = DDIMSampler(model) + sampler = DDIMSampler(model) os.makedirs(opt.outdir, exist_ok=True) outpath = opt.outdir + print("Creating invisible watermark encoder (see https://github.com/ShieldMnt/invisible-watermark)...") + wm = "SDV2" + wm_encoder = WatermarkEncoder() + wm_encoder.set_watermark('bytes', wm.encode('utf-8')) + batch_size = opt.n_samples n_rows = opt.n_rows if opt.n_rows > 0 else batch_size if not opt.from_file: @@ -244,7 +236,6 @@ def main(): with torch.no_grad(): with precision_scope("cuda"): with model.ema_scope(): - tic = time.time() all_samples = list() for n in trange(opt.n_iter, desc="Sampling"): for prompts in tqdm(data, desc="data"): @@ -256,37 +247,35 @@ def main(): c = model.get_learned_conditioning(prompts) # encode (scaled latent) - z_enc = sampler.stochastic_encode(init_latent, torch.tensor([t_enc]*batch_size).to(device)) + z_enc = sampler.stochastic_encode(init_latent, torch.tensor([t_enc] * batch_size).to(device)) # decode it samples = sampler.decode(z_enc, c, t_enc, unconditional_guidance_scale=opt.scale, - unconditional_conditioning=uc,) + unconditional_conditioning=uc, ) x_samples = model.decode_first_stage(samples) x_samples = torch.clamp((x_samples + 1.0) / 2.0, min=0.0, max=1.0) - if not opt.skip_save: - for x_sample in x_samples: - x_sample = 255. * rearrange(x_sample.cpu().numpy(), 'c h w -> h w c') - Image.fromarray(x_sample.astype(np.uint8)).save( - os.path.join(sample_path, f"{base_count:05}.png")) - base_count += 1 + for x_sample in x_samples: + x_sample = 255. * rearrange(x_sample.cpu().numpy(), 'c h w -> h w c') + img = Image.fromarray(x_sample.astype(np.uint8)) + img = put_watermark(img, wm_encoder) + img.save(os.path.join(sample_path, f"{base_count:05}.png")) + base_count += 1 all_samples.append(x_samples) - if not opt.skip_grid: - # additionally, save as grid - grid = torch.stack(all_samples, 0) - grid = rearrange(grid, 'n b c h w -> (n b) c h w') - grid = make_grid(grid, nrow=n_rows) + # additionally, save as grid + grid = torch.stack(all_samples, 0) + grid = rearrange(grid, 'n b c h w -> (n b) c h w') + grid = make_grid(grid, nrow=n_rows) - # to image - grid = 255. * rearrange(grid, 'c h w -> h w c').cpu().numpy() - Image.fromarray(grid.astype(np.uint8)).save(os.path.join(outpath, f'grid-{grid_count:04}.png')) - grid_count += 1 + # to image + grid = 255. * rearrange(grid, 'c h w -> h w c').cpu().numpy() + grid = Image.fromarray(grid.astype(np.uint8)) + grid = put_watermark(grid, wm_encoder) + grid.save(os.path.join(outpath, f'grid-{grid_count:04}.png')) + grid_count += 1 - toc = time.time() - - print(f"Your samples are ready and waiting for you here: \n{outpath} \n" - f" \nEnjoy.") + print(f"Your samples are ready and waiting for you here: \n{outpath} \nEnjoy.") if __name__ == "__main__": diff --git a/examples/images/diffusion/scripts/txt2img.py b/examples/images/diffusion/scripts/txt2img.py index ffebdf7ba..15993008f 100644 --- a/examples/images/diffusion/scripts/txt2img.py +++ b/examples/images/diffusion/scripts/txt2img.py @@ -1,50 +1,33 @@ -import argparse, os, sys, glob +import argparse, os import cv2 import torch import numpy as np from omegaconf import OmegaConf from PIL import Image from tqdm import tqdm, trange -from imwatermark import WatermarkEncoder from itertools import islice from einops import rearrange from torchvision.utils import make_grid -import time -from lightning.pytorch import seed_everything +try: + from lightning.pytorch import seed_everything +except: + from pytorch_lightning import seed_everything from torch import autocast -from contextlib import contextmanager, nullcontext +from contextlib import nullcontext +from imwatermark import WatermarkEncoder from ldm.util import instantiate_from_config from ldm.models.diffusion.ddim import DDIMSampler from ldm.models.diffusion.plms import PLMSSampler +from ldm.models.diffusion.dpm_solver import DPMSolverSampler -from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker -from transformers import AutoFeatureExtractor - - -# load safety model -safety_model_id = "CompVis/stable-diffusion-safety-checker" -safety_feature_extractor = AutoFeatureExtractor.from_pretrained(safety_model_id) -safety_checker = StableDiffusionSafetyChecker.from_pretrained(safety_model_id) - +torch.set_grad_enabled(False) def chunk(it, size): it = iter(it) return iter(lambda: tuple(islice(it, size)), ()) -def numpy_to_pil(images): - """ - Convert a numpy image or a batch of images to a PIL image. - """ - if images.ndim == 3: - images = images[None, ...] - images = (images * 255).round().astype("uint8") - pil_images = [Image.fromarray(image) for image in images] - - return pil_images - - def load_model_from_config(config, ckpt, verbose=False): print(f"Loading model from {ckpt}") pl_sd = torch.load(ckpt, map_location="cpu") @@ -65,43 +48,13 @@ def load_model_from_config(config, ckpt, verbose=False): return model -def put_watermark(img, wm_encoder=None): - if wm_encoder is not None: - img = cv2.cvtColor(np.array(img), cv2.COLOR_RGB2BGR) - img = wm_encoder.encode(img, 'dwtDct') - img = Image.fromarray(img[:, :, ::-1]) - return img - - -def load_replacement(x): - try: - hwc = x.shape - y = Image.open("assets/rick.jpeg").convert("RGB").resize((hwc[1], hwc[0])) - y = (np.array(y)/255.0).astype(x.dtype) - assert y.shape == x.shape - return y - except Exception: - return x - - -def check_safety(x_image): - safety_checker_input = safety_feature_extractor(numpy_to_pil(x_image), return_tensors="pt") - x_checked_image, has_nsfw_concept = safety_checker(images=x_image, clip_input=safety_checker_input.pixel_values) - assert x_checked_image.shape[0] == len(has_nsfw_concept) - for i in range(len(has_nsfw_concept)): - if has_nsfw_concept[i]: - x_checked_image[i] = load_replacement(x_checked_image[i]) - return x_checked_image, has_nsfw_concept - - -def main(): +def parse_args(): parser = argparse.ArgumentParser() - parser.add_argument( "--prompt", type=str, nargs="?", - default="a painting of a virus monster playing guitar", + default="a professional photograph of an astronaut riding a triceratops", help="the prompt to render" ) parser.add_argument( @@ -112,17 +65,7 @@ def main(): default="outputs/txt2img-samples" ) parser.add_argument( - "--skip_grid", - action='store_true', - help="do not save a grid, only individual samples. Helpful when evaluating lots of samples", - ) - parser.add_argument( - "--skip_save", - action='store_true', - help="do not save individual samples. For speed measurements.", - ) - parser.add_argument( - "--ddim_steps", + "--steps", type=int, default=50, help="number of ddim sampling steps", @@ -133,14 +76,14 @@ def main(): help="use plms sampling", ) parser.add_argument( - "--laion400m", + "--dpm", action='store_true', - help="uses the LAION400M model", + help="use DPM (2) sampler", ) parser.add_argument( "--fixed_code", action='store_true', - help="if enabled, uses the same starting code across samples ", + help="if enabled, uses the same starting code across all samples ", ) parser.add_argument( "--ddim_eta", @@ -151,7 +94,7 @@ def main(): parser.add_argument( "--n_iter", type=int, - default=2, + default=3, help="sample this often", ) parser.add_argument( @@ -176,13 +119,13 @@ def main(): "--f", type=int, default=8, - help="downsampling factor", + help="downsampling factor, most often 8 or 16", ) parser.add_argument( "--n_samples", type=int, default=3, - help="how many samples to produce for each given prompt. A.k.a. batch size", + help="how many samples to produce for each given prompt. A.k.a batch size", ) parser.add_argument( "--n_rows", @@ -193,24 +136,23 @@ def main(): parser.add_argument( "--scale", type=float, - default=7.5, + default=9.0, help="unconditional guidance scale: eps = eps(x, empty) + scale * (eps(x, cond) - eps(x, empty))", ) parser.add_argument( "--from-file", type=str, - help="if specified, load prompts from this file", + help="if specified, load prompts from this file, separated by newlines", ) parser.add_argument( "--config", type=str, - default="configs/stable-diffusion/v1-inference.yaml", + default="configs/stable-diffusion/v2-inference.yaml", help="path to config which constructs model", ) parser.add_argument( "--ckpt", type=str, - default="models/ldm/stable-diffusion-v1/model.ckpt", help="path to checkpoint of model", ) parser.add_argument( @@ -226,14 +168,25 @@ def main(): choices=["full", "autocast"], default="autocast" ) + parser.add_argument( + "--repeat", + type=int, + default=1, + help="repeat each prompt in file this often", + ) opt = parser.parse_args() + return opt - if opt.laion400m: - print("Falling back to LAION 400M model...") - opt.config = "configs/latent-diffusion/txt2img-1p4B-eval.yaml" - opt.ckpt = "models/ldm/text2img-large/model.ckpt" - opt.outdir = "outputs/txt2img-samples-laion400m" +def put_watermark(img, wm_encoder=None): + if wm_encoder is not None: + img = cv2.cvtColor(np.array(img), cv2.COLOR_RGB2BGR) + img = wm_encoder.encode(img, 'dwtDct') + img = Image.fromarray(img[:, :, ::-1]) + return img + + +def main(opt): seed_everything(opt.seed) config = OmegaConf.load(f"{opt.config}") @@ -244,6 +197,8 @@ def main(): if opt.plms: sampler = PLMSSampler(model) + elif opt.dpm: + sampler = DPMSolverSampler(model) else: sampler = DDIMSampler(model) @@ -251,7 +206,7 @@ def main(): outpath = opt.outdir print("Creating invisible watermark encoder (see https://github.com/ShieldMnt/invisible-watermark)...") - wm = "StableDiffusionV1" + wm = "SDV2" wm_encoder = WatermarkEncoder() wm_encoder.set_watermark('bytes', wm.encode('utf-8')) @@ -266,10 +221,12 @@ def main(): print(f"reading prompts from {opt.from_file}") with open(opt.from_file, "r") as f: data = f.read().splitlines() + data = [p for p in data for i in range(opt.repeat)] data = list(chunk(data, batch_size)) sample_path = os.path.join(outpath, "samples") os.makedirs(sample_path, exist_ok=True) + sample_count = 0 base_count = len(os.listdir(sample_path)) grid_count = len(os.listdir(outpath)) - 1 @@ -277,68 +234,59 @@ def main(): if opt.fixed_code: start_code = torch.randn([opt.n_samples, opt.C, opt.H // opt.f, opt.W // opt.f], device=device) - precision_scope = autocast if opt.precision=="autocast" else nullcontext - with torch.no_grad(): - with precision_scope("cuda"): - with model.ema_scope(): - tic = time.time() - all_samples = list() - for n in trange(opt.n_iter, desc="Sampling"): - for prompts in tqdm(data, desc="data"): - uc = None - if opt.scale != 1.0: - uc = model.get_learned_conditioning(batch_size * [""]) - if isinstance(prompts, tuple): - prompts = list(prompts) - c = model.get_learned_conditioning(prompts) - shape = [opt.C, opt.H // opt.f, opt.W // opt.f] - samples_ddim, _ = sampler.sample(S=opt.ddim_steps, - conditioning=c, - batch_size=opt.n_samples, - shape=shape, - verbose=False, - unconditional_guidance_scale=opt.scale, - unconditional_conditioning=uc, - eta=opt.ddim_eta, - x_T=start_code) + precision_scope = autocast if opt.precision == "autocast" else nullcontext + with torch.no_grad(), \ + precision_scope("cuda"), \ + model.ema_scope(): + all_samples = list() + for n in trange(opt.n_iter, desc="Sampling"): + for prompts in tqdm(data, desc="data"): + uc = None + if opt.scale != 1.0: + uc = model.get_learned_conditioning(batch_size * [""]) + if isinstance(prompts, tuple): + prompts = list(prompts) + c = model.get_learned_conditioning(prompts) + shape = [opt.C, opt.H // opt.f, opt.W // opt.f] + samples, _ = sampler.sample(S=opt.steps, + conditioning=c, + batch_size=opt.n_samples, + shape=shape, + verbose=False, + unconditional_guidance_scale=opt.scale, + unconditional_conditioning=uc, + eta=opt.ddim_eta, + x_T=start_code) - x_samples_ddim = model.decode_first_stage(samples_ddim) - x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0) - x_samples_ddim = x_samples_ddim.cpu().permute(0, 2, 3, 1).numpy() + x_samples = model.decode_first_stage(samples) + x_samples = torch.clamp((x_samples + 1.0) / 2.0, min=0.0, max=1.0) - x_checked_image, has_nsfw_concept = check_safety(x_samples_ddim) + for x_sample in x_samples: + x_sample = 255. * rearrange(x_sample.cpu().numpy(), 'c h w -> h w c') + img = Image.fromarray(x_sample.astype(np.uint8)) + img = put_watermark(img, wm_encoder) + img.save(os.path.join(sample_path, f"{base_count:05}.png")) + base_count += 1 + sample_count += 1 - x_checked_image_torch = torch.from_numpy(x_checked_image).permute(0, 3, 1, 2) + all_samples.append(x_samples) - if not opt.skip_save: - for x_sample in x_checked_image_torch: - x_sample = 255. * rearrange(x_sample.cpu().numpy(), 'c h w -> h w c') - img = Image.fromarray(x_sample.astype(np.uint8)) - img = put_watermark(img, wm_encoder) - img.save(os.path.join(sample_path, f"{base_count:05}.png")) - base_count += 1 + # additionally, save as grid + grid = torch.stack(all_samples, 0) + grid = rearrange(grid, 'n b c h w -> (n b) c h w') + grid = make_grid(grid, nrow=n_rows) - if not opt.skip_grid: - all_samples.append(x_checked_image_torch) - - if not opt.skip_grid: - # additionally, save as grid - grid = torch.stack(all_samples, 0) - grid = rearrange(grid, 'n b c h w -> (n b) c h w') - grid = make_grid(grid, nrow=n_rows) - - # to image - grid = 255. * rearrange(grid, 'c h w -> h w c').cpu().numpy() - img = Image.fromarray(grid.astype(np.uint8)) - img = put_watermark(img, wm_encoder) - img.save(os.path.join(outpath, f'grid-{grid_count:04}.png')) - grid_count += 1 - - toc = time.time() + # to image + grid = 255. * rearrange(grid, 'c h w -> h w c').cpu().numpy() + grid = Image.fromarray(grid.astype(np.uint8)) + grid = put_watermark(grid, wm_encoder) + grid.save(os.path.join(outpath, f'grid-{grid_count:04}.png')) + grid_count += 1 print(f"Your samples are ready and waiting for you here: \n{outpath} \n" f" \nEnjoy.") if __name__ == "__main__": - main() + opt = parse_args() + main(opt) diff --git a/examples/images/diffusion/train.sh b/examples/images/diffusion/train.sh index 63abcadbf..ed9ae4b75 100755 --- a/examples/images/diffusion/train.sh +++ b/examples/images/diffusion/train.sh @@ -1,4 +1,5 @@ -HF_DATASETS_OFFLINE=1 -TRANSFORMERS_OFFLINE=1 +# HF_DATASETS_OFFLINE=1 +# TRANSFORMERS_OFFLINE=1 +# DIFFUSERS_OFFLINE=1 -python main.py --logdir /tmp -t --postfix test -b configs/train_colossalai.yaml +python main.py --logdir /tmp/ -t -b configs/Teyvat/train_colossalai_teyvat.yaml