mirror of https://github.com/hpcaitech/ColossalAI
Merge pull request #2120 from Fazziekey/example/stablediffusion-v2
[example] support stable diffusion v2pull/2127/head
commit
6c4c6a0409
|
@ -1,4 +1,5 @@
|
||||||
# Stable Diffusion with Colossal-AI
|
# ColoDiffusion: Stable Diffusion with Colossal-AI
|
||||||
|
|
||||||
*[Colosssal-AI](https://github.com/hpcaitech/ColossalAI) provides a faster and lower cost solution for pretraining and
|
*[Colosssal-AI](https://github.com/hpcaitech/ColossalAI) provides a faster and lower cost solution for pretraining and
|
||||||
fine-tuning for AIGC (AI-Generated Content) applications such as the model [stable-diffusion](https://github.com/CompVis/stable-diffusion) from [Stability AI](https://stability.ai/).*
|
fine-tuning for AIGC (AI-Generated Content) applications such as the model [stable-diffusion](https://github.com/CompVis/stable-diffusion) from [Stability AI](https://stability.ai/).*
|
||||||
|
|
||||||
|
@ -6,6 +7,7 @@ We take advantage of [Colosssal-AI](https://github.com/hpcaitech/ColossalAI) to
|
||||||
, e.g. data parallelism, tensor parallelism, mixed precision & ZeRO, to scale the training to multiple GPUs.
|
, e.g. data parallelism, tensor parallelism, mixed precision & ZeRO, to scale the training to multiple GPUs.
|
||||||
|
|
||||||
## Stable Diffusion
|
## Stable Diffusion
|
||||||
|
|
||||||
[Stable Diffusion](https://huggingface.co/CompVis/stable-diffusion) is a latent text-to-image diffusion
|
[Stable Diffusion](https://huggingface.co/CompVis/stable-diffusion) is a latent text-to-image diffusion
|
||||||
model.
|
model.
|
||||||
Thanks to a generous compute donation from [Stability AI](https://stability.ai/) and support from [LAION](https://laion.ai/), we were able to train a Latent Diffusion Model on 512x512 images from a subset of the [LAION-5B](https://laion.ai/blog/laion-5b/) database.
|
Thanks to a generous compute donation from [Stability AI](https://stability.ai/) and support from [LAION](https://laion.ai/), we were able to train a Latent Diffusion Model on 512x512 images from a subset of the [LAION-5B](https://laion.ai/blog/laion-5b/) database.
|
||||||
|
@ -23,6 +25,7 @@ this model uses a frozen CLIP ViT-L/14 text encoder to condition the model on te
|
||||||
</p>
|
</p>
|
||||||
|
|
||||||
## Requirements
|
## Requirements
|
||||||
|
|
||||||
A suitable [conda](https://conda.io/) environment named `ldm` can be created
|
A suitable [conda](https://conda.io/) environment named `ldm` can be created
|
||||||
and activated with:
|
and activated with:
|
||||||
|
|
||||||
|
@ -34,14 +37,24 @@ conda activate ldm
|
||||||
You can also update an existing [latent diffusion](https://github.com/CompVis/latent-diffusion) environment by running
|
You can also update an existing [latent diffusion](https://github.com/CompVis/latent-diffusion) environment by running
|
||||||
|
|
||||||
```
|
```
|
||||||
conda install pytorch torchvision -c pytorch
|
conda install pytorch==1.12.1 torchvision==0.13.1 torchaudio==0.12.1 cudatoolkit=11.3 -c pytorch
|
||||||
pip install transformers==4.19.2 diffusers invisible-watermark
|
pip install transformers==4.19.2 diffusers invisible-watermark
|
||||||
pip install -e .
|
pip install -e .
|
||||||
```
|
```
|
||||||
|
|
||||||
### Install [Colossal-AI v0.1.10](https://colossalai.org/download/) From Our Official Website
|
### install lightning
|
||||||
|
|
||||||
```
|
```
|
||||||
pip install colossalai==0.1.10+torch1.11cu11.3 -f https://release.colossalai.org
|
git clone https://github.com/1SAA/lightning.git
|
||||||
|
git checkout strategy/colossalai
|
||||||
|
export PACKAGE_NAME=pytorch
|
||||||
|
pip install .
|
||||||
|
```
|
||||||
|
|
||||||
|
### Install [Colossal-AI v0.1.10](https://colossalai.org/download/) From Our Official Website
|
||||||
|
|
||||||
|
```
|
||||||
|
pip install colossalai==0.1.12+torch1.12cu11.3 -f https://release.colossalai.org
|
||||||
```
|
```
|
||||||
|
|
||||||
> The specified version is due to the interface incompatibility caused by the latest update of [Lightning](https://github.com/Lightning-AI/lightning), which will be fixed in the near future.
|
> The specified version is due to the interface incompatibility caused by the latest update of [Lightning](https://github.com/Lightning-AI/lightning), which will be fixed in the near future.
|
||||||
|
@ -49,6 +62,7 @@ pip install colossalai==0.1.10+torch1.11cu11.3 -f https://release.colossalai.org
|
||||||
## Download the model checkpoint from pretrained
|
## Download the model checkpoint from pretrained
|
||||||
|
|
||||||
### stable-diffusion-v1-4
|
### stable-diffusion-v1-4
|
||||||
|
|
||||||
Our default model config use the weight from [CompVis/stable-diffusion-v1-4](https://huggingface.co/CompVis/stable-diffusion-v1-4?text=A+mecha+robot+in+a+favela+in+expressionist+style)
|
Our default model config use the weight from [CompVis/stable-diffusion-v1-4](https://huggingface.co/CompVis/stable-diffusion-v1-4?text=A+mecha+robot+in+a+favela+in+expressionist+style)
|
||||||
|
|
||||||
```
|
```
|
||||||
|
@ -57,6 +71,7 @@ git clone https://huggingface.co/CompVis/stable-diffusion-v1-4
|
||||||
```
|
```
|
||||||
|
|
||||||
### stable-diffusion-v1-5 from runway
|
### stable-diffusion-v1-5 from runway
|
||||||
|
|
||||||
If you want to useed the Last [stable-diffusion-v1-5](https://huggingface.co/runwayml/stable-diffusion-v1-5) wiegh from runwayml
|
If you want to useed the Last [stable-diffusion-v1-5](https://huggingface.co/runwayml/stable-diffusion-v1-5) wiegh from runwayml
|
||||||
|
|
||||||
```
|
```
|
||||||
|
@ -64,23 +79,24 @@ git lfs install
|
||||||
git clone https://huggingface.co/runwayml/stable-diffusion-v1-5
|
git clone https://huggingface.co/runwayml/stable-diffusion-v1-5
|
||||||
```
|
```
|
||||||
|
|
||||||
|
|
||||||
## Dataset
|
## Dataset
|
||||||
|
|
||||||
The dataSet is from [LAION-5B](https://laion.ai/blog/laion-5b/), the subset of [LAION](https://laion.ai/),
|
The dataSet is from [LAION-5B](https://laion.ai/blog/laion-5b/), the subset of [LAION](https://laion.ai/),
|
||||||
you should the change the `data.file_path` in the `config/train_colossalai.yaml`
|
you should the change the `data.file_path` in the `config/train_colossalai.yaml`
|
||||||
|
|
||||||
## Training
|
## Training
|
||||||
|
|
||||||
We provide the script `train.sh` to run the training task , and two Stategy in `configs`:`train_colossalai.yaml`
|
We provide the script `train.sh` to run the training task , and two Stategy in `configs`:`train_colossalai.yaml` and `train_ddp.yaml`
|
||||||
|
|
||||||
For example, you can run the training from colossalai by
|
For example, you can run the training from colossalai by
|
||||||
```
|
```
|
||||||
python main.py --logdir /tmp -t --postfix test -b configs/train_colossalai.yaml
|
python main.py --logdir /tmp/ -t -b configs/train_colossalai.yaml
|
||||||
```
|
```
|
||||||
|
|
||||||
- you can change the `--logdir` the save the log information and the last checkpoint
|
- you can change the `--logdir` the save the log information and the last checkpoint
|
||||||
|
|
||||||
### Training config
|
### Training config
|
||||||
|
|
||||||
You can change the trainging config in the yaml file
|
You can change the trainging config in the yaml file
|
||||||
|
|
||||||
- accelerator: acceleratortype, default 'gpu'
|
- accelerator: acceleratortype, default 'gpu'
|
||||||
|
@ -88,27 +104,25 @@ You can change the trainging config in the yaml file
|
||||||
- max_epochs: max training epochs
|
- max_epochs: max training epochs
|
||||||
- precision: usefp16 for training or not, default 16, you must use fp16 if you want to apply colossalai
|
- precision: usefp16 for training or not, default 16, you must use fp16 if you want to apply colossalai
|
||||||
|
|
||||||
## Example
|
## Finetone Example
|
||||||
|
### Training on Teyvat Datasets
|
||||||
|
|
||||||
### Training on cifar10
|
We provide the finetuning example on [Teyvat](https://huggingface.co/datasets/Fazzie/Teyvat) dataset, which is create by BLIP generated captions.
|
||||||
|
|
||||||
We provide the finetuning example on CIFAR10 dataset
|
You can run by config `configs/Teyvat/train_colossalai_teyvat.yaml`
|
||||||
|
|
||||||
You can run by config `train_colossalai_cifar10.yaml`
|
|
||||||
```
|
```
|
||||||
python main.py --logdir /tmp -t --postfix test -b configs/train_colossalai_cifar10.yaml
|
python main.py --logdir /tmp/ -t -b configs/Teyvat/train_colossalai_teyvat.yaml
|
||||||
```
|
```
|
||||||
|
|
||||||
## Inference
|
## Inference
|
||||||
you can get yout training last.ckpt and train config.yaml in your `--logdir`, and run by
|
you can get yout training last.ckpt and train config.yaml in your `--logdir`, and run by
|
||||||
```
|
```
|
||||||
python scripts/txt2img.py --prompt "a photograph of an astronaut riding a horse" --plms
|
python scripts/txt2img.py --prompt "a photograph of an astronaut riding a horse" --plms
|
||||||
--outdir ./output \
|
--outdir ./output \
|
||||||
--config path/to/logdir/checkpoints/last.ckpt \
|
--config path/to/logdir/checkpoints/last.ckpt \
|
||||||
--ckpt /path/to/logdir/configs/project.yaml \
|
--ckpt /path/to/logdir/configs/project.yaml \
|
||||||
```
|
```
|
||||||
|
|
||||||
|
|
||||||
```commandline
|
```commandline
|
||||||
usage: txt2img.py [-h] [--prompt [PROMPT]] [--outdir [OUTDIR]] [--skip_grid] [--skip_save] [--ddim_steps DDIM_STEPS] [--plms] [--laion400m] [--fixed_code] [--ddim_eta DDIM_ETA]
|
usage: txt2img.py [-h] [--prompt [PROMPT]] [--outdir [OUTDIR]] [--skip_grid] [--skip_save] [--ddim_steps DDIM_STEPS] [--plms] [--laion400m] [--fixed_code] [--ddim_eta DDIM_ETA]
|
||||||
[--n_iter N_ITER] [--H H] [--W W] [--C C] [--f F] [--n_samples N_SAMPLES] [--n_rows N_ROWS] [--scale SCALE] [--from-file FROM_FILE] [--config CONFIG] [--ckpt CKPT]
|
[--n_iter N_ITER] [--H H] [--W W] [--C C] [--f F] [--n_samples N_SAMPLES] [--n_rows N_ROWS] [--scale SCALE] [--from-file FROM_FILE] [--config CONFIG] [--ckpt CKPT]
|
||||||
|
@ -144,7 +158,6 @@ optional arguments:
|
||||||
evaluate at this precision
|
evaluate at this precision
|
||||||
```
|
```
|
||||||
|
|
||||||
|
|
||||||
## Comments
|
## Comments
|
||||||
|
|
||||||
- Our codebase for the diffusion models builds heavily on [OpenAI's ADM codebase](https://github.com/openai/guided-diffusion)
|
- Our codebase for the diffusion models builds heavily on [OpenAI's ADM codebase](https://github.com/openai/guided-diffusion)
|
||||||
|
|
|
@ -0,0 +1,68 @@
|
||||||
|
model:
|
||||||
|
base_learning_rate: 1.0e-4
|
||||||
|
target: ldm.models.diffusion.ddpm.LatentDiffusion
|
||||||
|
params:
|
||||||
|
parameterization: "v"
|
||||||
|
linear_start: 0.00085
|
||||||
|
linear_end: 0.0120
|
||||||
|
num_timesteps_cond: 1
|
||||||
|
log_every_t: 200
|
||||||
|
timesteps: 1000
|
||||||
|
first_stage_key: "jpg"
|
||||||
|
cond_stage_key: "txt"
|
||||||
|
image_size: 64
|
||||||
|
channels: 4
|
||||||
|
cond_stage_trainable: false
|
||||||
|
conditioning_key: crossattn
|
||||||
|
monitor: val/loss_simple_ema
|
||||||
|
scale_factor: 0.18215
|
||||||
|
use_ema: False # we set this to false because this is an inference only config
|
||||||
|
|
||||||
|
unet_config:
|
||||||
|
target: ldm.modules.diffusionmodules.openaimodel.UNetModel
|
||||||
|
params:
|
||||||
|
use_checkpoint: True
|
||||||
|
use_fp16: True
|
||||||
|
image_size: 32 # unused
|
||||||
|
in_channels: 4
|
||||||
|
out_channels: 4
|
||||||
|
model_channels: 320
|
||||||
|
attention_resolutions: [ 4, 2, 1 ]
|
||||||
|
num_res_blocks: 2
|
||||||
|
channel_mult: [ 1, 2, 4, 4 ]
|
||||||
|
num_head_channels: 64 # need to fix for flash-attn
|
||||||
|
use_spatial_transformer: True
|
||||||
|
use_linear_in_transformer: True
|
||||||
|
transformer_depth: 1
|
||||||
|
context_dim: 1024
|
||||||
|
legacy: False
|
||||||
|
|
||||||
|
first_stage_config:
|
||||||
|
target: ldm.models.autoencoder.AutoencoderKL
|
||||||
|
params:
|
||||||
|
embed_dim: 4
|
||||||
|
monitor: val/rec_loss
|
||||||
|
ddconfig:
|
||||||
|
#attn_type: "vanilla-xformers"
|
||||||
|
double_z: true
|
||||||
|
z_channels: 4
|
||||||
|
resolution: 256
|
||||||
|
in_channels: 3
|
||||||
|
out_ch: 3
|
||||||
|
ch: 128
|
||||||
|
ch_mult:
|
||||||
|
- 1
|
||||||
|
- 2
|
||||||
|
- 4
|
||||||
|
- 4
|
||||||
|
num_res_blocks: 2
|
||||||
|
attn_resolutions: []
|
||||||
|
dropout: 0.0
|
||||||
|
lossconfig:
|
||||||
|
target: torch.nn.Identity
|
||||||
|
|
||||||
|
cond_stage_config:
|
||||||
|
target: ldm.modules.encoders.modules.FrozenOpenCLIPEmbedder
|
||||||
|
params:
|
||||||
|
freeze: True
|
||||||
|
layer: "penultimate"
|
|
@ -0,0 +1,67 @@
|
||||||
|
model:
|
||||||
|
base_learning_rate: 1.0e-4
|
||||||
|
target: ldm.models.diffusion.ddpm.LatentDiffusion
|
||||||
|
params:
|
||||||
|
linear_start: 0.00085
|
||||||
|
linear_end: 0.0120
|
||||||
|
num_timesteps_cond: 1
|
||||||
|
log_every_t: 200
|
||||||
|
timesteps: 1000
|
||||||
|
first_stage_key: "jpg"
|
||||||
|
cond_stage_key: "txt"
|
||||||
|
image_size: 64
|
||||||
|
channels: 4
|
||||||
|
cond_stage_trainable: false
|
||||||
|
conditioning_key: crossattn
|
||||||
|
monitor: val/loss_simple_ema
|
||||||
|
scale_factor: 0.18215
|
||||||
|
use_ema: False # we set this to false because this is an inference only config
|
||||||
|
|
||||||
|
unet_config:
|
||||||
|
target: ldm.modules.diffusionmodules.openaimodel.UNetModel
|
||||||
|
params:
|
||||||
|
use_checkpoint: True
|
||||||
|
use_fp16: True
|
||||||
|
image_size: 32 # unused
|
||||||
|
in_channels: 4
|
||||||
|
out_channels: 4
|
||||||
|
model_channels: 320
|
||||||
|
attention_resolutions: [ 4, 2, 1 ]
|
||||||
|
num_res_blocks: 2
|
||||||
|
channel_mult: [ 1, 2, 4, 4 ]
|
||||||
|
num_head_channels: 64 # need to fix for flash-attn
|
||||||
|
use_spatial_transformer: True
|
||||||
|
use_linear_in_transformer: True
|
||||||
|
transformer_depth: 1
|
||||||
|
context_dim: 1024
|
||||||
|
legacy: False
|
||||||
|
|
||||||
|
first_stage_config:
|
||||||
|
target: ldm.models.autoencoder.AutoencoderKL
|
||||||
|
params:
|
||||||
|
embed_dim: 4
|
||||||
|
monitor: val/rec_loss
|
||||||
|
ddconfig:
|
||||||
|
#attn_type: "vanilla-xformers"
|
||||||
|
double_z: true
|
||||||
|
z_channels: 4
|
||||||
|
resolution: 256
|
||||||
|
in_channels: 3
|
||||||
|
out_ch: 3
|
||||||
|
ch: 128
|
||||||
|
ch_mult:
|
||||||
|
- 1
|
||||||
|
- 2
|
||||||
|
- 4
|
||||||
|
- 4
|
||||||
|
num_res_blocks: 2
|
||||||
|
attn_resolutions: []
|
||||||
|
dropout: 0.0
|
||||||
|
lossconfig:
|
||||||
|
target: torch.nn.Identity
|
||||||
|
|
||||||
|
cond_stage_config:
|
||||||
|
target: ldm.modules.encoders.modules.FrozenOpenCLIPEmbedder
|
||||||
|
params:
|
||||||
|
freeze: True
|
||||||
|
layer: "penultimate"
|
|
@ -0,0 +1,158 @@
|
||||||
|
model:
|
||||||
|
base_learning_rate: 5.0e-05
|
||||||
|
target: ldm.models.diffusion.ddpm.LatentInpaintDiffusion
|
||||||
|
params:
|
||||||
|
linear_start: 0.00085
|
||||||
|
linear_end: 0.0120
|
||||||
|
num_timesteps_cond: 1
|
||||||
|
log_every_t: 200
|
||||||
|
timesteps: 1000
|
||||||
|
first_stage_key: "jpg"
|
||||||
|
cond_stage_key: "txt"
|
||||||
|
image_size: 64
|
||||||
|
channels: 4
|
||||||
|
cond_stage_trainable: false
|
||||||
|
conditioning_key: hybrid
|
||||||
|
scale_factor: 0.18215
|
||||||
|
monitor: val/loss_simple_ema
|
||||||
|
finetune_keys: null
|
||||||
|
use_ema: False
|
||||||
|
|
||||||
|
unet_config:
|
||||||
|
target: ldm.modules.diffusionmodules.openaimodel.UNetModel
|
||||||
|
params:
|
||||||
|
use_checkpoint: True
|
||||||
|
image_size: 32 # unused
|
||||||
|
in_channels: 9
|
||||||
|
out_channels: 4
|
||||||
|
model_channels: 320
|
||||||
|
attention_resolutions: [ 4, 2, 1 ]
|
||||||
|
num_res_blocks: 2
|
||||||
|
channel_mult: [ 1, 2, 4, 4 ]
|
||||||
|
num_head_channels: 64 # need to fix for flash-attn
|
||||||
|
use_spatial_transformer: True
|
||||||
|
use_linear_in_transformer: True
|
||||||
|
transformer_depth: 1
|
||||||
|
context_dim: 1024
|
||||||
|
legacy: False
|
||||||
|
|
||||||
|
first_stage_config:
|
||||||
|
target: ldm.models.autoencoder.AutoencoderKL
|
||||||
|
params:
|
||||||
|
embed_dim: 4
|
||||||
|
monitor: val/rec_loss
|
||||||
|
ddconfig:
|
||||||
|
#attn_type: "vanilla-xformers"
|
||||||
|
double_z: true
|
||||||
|
z_channels: 4
|
||||||
|
resolution: 256
|
||||||
|
in_channels: 3
|
||||||
|
out_ch: 3
|
||||||
|
ch: 128
|
||||||
|
ch_mult:
|
||||||
|
- 1
|
||||||
|
- 2
|
||||||
|
- 4
|
||||||
|
- 4
|
||||||
|
num_res_blocks: 2
|
||||||
|
attn_resolutions: [ ]
|
||||||
|
dropout: 0.0
|
||||||
|
lossconfig:
|
||||||
|
target: torch.nn.Identity
|
||||||
|
|
||||||
|
cond_stage_config:
|
||||||
|
target: ldm.modules.encoders.modules.FrozenOpenCLIPEmbedder
|
||||||
|
params:
|
||||||
|
freeze: True
|
||||||
|
layer: "penultimate"
|
||||||
|
|
||||||
|
|
||||||
|
data:
|
||||||
|
target: ldm.data.laion.WebDataModuleFromConfig
|
||||||
|
params:
|
||||||
|
tar_base: null # for concat as in LAION-A
|
||||||
|
p_unsafe_threshold: 0.1
|
||||||
|
filter_word_list: "data/filters.yaml"
|
||||||
|
max_pwatermark: 0.45
|
||||||
|
batch_size: 8
|
||||||
|
num_workers: 6
|
||||||
|
multinode: True
|
||||||
|
min_size: 512
|
||||||
|
train:
|
||||||
|
shards:
|
||||||
|
- "pipe:aws s3 cp s3://stability-aws/laion-a-native/part-0/{00000..18699}.tar -"
|
||||||
|
- "pipe:aws s3 cp s3://stability-aws/laion-a-native/part-1/{00000..18699}.tar -"
|
||||||
|
- "pipe:aws s3 cp s3://stability-aws/laion-a-native/part-2/{00000..18699}.tar -"
|
||||||
|
- "pipe:aws s3 cp s3://stability-aws/laion-a-native/part-3/{00000..18699}.tar -"
|
||||||
|
- "pipe:aws s3 cp s3://stability-aws/laion-a-native/part-4/{00000..18699}.tar -" #{00000-94333}.tar"
|
||||||
|
shuffle: 10000
|
||||||
|
image_key: jpg
|
||||||
|
image_transforms:
|
||||||
|
- target: torchvision.transforms.Resize
|
||||||
|
params:
|
||||||
|
size: 512
|
||||||
|
interpolation: 3
|
||||||
|
- target: torchvision.transforms.RandomCrop
|
||||||
|
params:
|
||||||
|
size: 512
|
||||||
|
postprocess:
|
||||||
|
target: ldm.data.laion.AddMask
|
||||||
|
params:
|
||||||
|
mode: "512train-large"
|
||||||
|
p_drop: 0.25
|
||||||
|
# NOTE use enough shards to avoid empty validation loops in workers
|
||||||
|
validation:
|
||||||
|
shards:
|
||||||
|
- "pipe:aws s3 cp s3://deep-floyd-s3/datasets/laion_cleaned-part5/{93001..94333}.tar - "
|
||||||
|
shuffle: 0
|
||||||
|
image_key: jpg
|
||||||
|
image_transforms:
|
||||||
|
- target: torchvision.transforms.Resize
|
||||||
|
params:
|
||||||
|
size: 512
|
||||||
|
interpolation: 3
|
||||||
|
- target: torchvision.transforms.CenterCrop
|
||||||
|
params:
|
||||||
|
size: 512
|
||||||
|
postprocess:
|
||||||
|
target: ldm.data.laion.AddMask
|
||||||
|
params:
|
||||||
|
mode: "512train-large"
|
||||||
|
p_drop: 0.25
|
||||||
|
|
||||||
|
lightning:
|
||||||
|
find_unused_parameters: True
|
||||||
|
modelcheckpoint:
|
||||||
|
params:
|
||||||
|
every_n_train_steps: 5000
|
||||||
|
|
||||||
|
callbacks:
|
||||||
|
metrics_over_trainsteps_checkpoint:
|
||||||
|
params:
|
||||||
|
every_n_train_steps: 10000
|
||||||
|
|
||||||
|
image_logger:
|
||||||
|
target: main.ImageLogger
|
||||||
|
params:
|
||||||
|
enable_autocast: False
|
||||||
|
disabled: False
|
||||||
|
batch_frequency: 1000
|
||||||
|
max_images: 4
|
||||||
|
increase_log_steps: False
|
||||||
|
log_first_step: False
|
||||||
|
log_images_kwargs:
|
||||||
|
use_ema_scope: False
|
||||||
|
inpaint: False
|
||||||
|
plot_progressive_rows: False
|
||||||
|
plot_diffusion_rows: False
|
||||||
|
N: 4
|
||||||
|
unconditional_guidance_scale: 5.0
|
||||||
|
unconditional_guidance_label: [""]
|
||||||
|
ddim_steps: 50 # todo check these out for depth2img,
|
||||||
|
ddim_eta: 0.0 # todo check these out for depth2img,
|
||||||
|
|
||||||
|
trainer:
|
||||||
|
benchmark: True
|
||||||
|
val_check_interval: 5000000
|
||||||
|
num_sanity_val_steps: 0
|
||||||
|
accumulate_grad_batches: 1
|
|
@ -0,0 +1,72 @@
|
||||||
|
model:
|
||||||
|
base_learning_rate: 5.0e-07
|
||||||
|
target: ldm.models.diffusion.ddpm.LatentDepth2ImageDiffusion
|
||||||
|
params:
|
||||||
|
linear_start: 0.00085
|
||||||
|
linear_end: 0.0120
|
||||||
|
num_timesteps_cond: 1
|
||||||
|
log_every_t: 200
|
||||||
|
timesteps: 1000
|
||||||
|
first_stage_key: "jpg"
|
||||||
|
cond_stage_key: "txt"
|
||||||
|
image_size: 64
|
||||||
|
channels: 4
|
||||||
|
cond_stage_trainable: false
|
||||||
|
conditioning_key: hybrid
|
||||||
|
scale_factor: 0.18215
|
||||||
|
monitor: val/loss_simple_ema
|
||||||
|
finetune_keys: null
|
||||||
|
use_ema: False
|
||||||
|
|
||||||
|
depth_stage_config:
|
||||||
|
target: ldm.modules.midas.api.MiDaSInference
|
||||||
|
params:
|
||||||
|
model_type: "dpt_hybrid"
|
||||||
|
|
||||||
|
unet_config:
|
||||||
|
target: ldm.modules.diffusionmodules.openaimodel.UNetModel
|
||||||
|
params:
|
||||||
|
use_checkpoint: True
|
||||||
|
image_size: 32 # unused
|
||||||
|
in_channels: 5
|
||||||
|
out_channels: 4
|
||||||
|
model_channels: 320
|
||||||
|
attention_resolutions: [ 4, 2, 1 ]
|
||||||
|
num_res_blocks: 2
|
||||||
|
channel_mult: [ 1, 2, 4, 4 ]
|
||||||
|
num_head_channels: 64 # need to fix for flash-attn
|
||||||
|
use_spatial_transformer: True
|
||||||
|
use_linear_in_transformer: True
|
||||||
|
transformer_depth: 1
|
||||||
|
context_dim: 1024
|
||||||
|
legacy: False
|
||||||
|
|
||||||
|
first_stage_config:
|
||||||
|
target: ldm.models.autoencoder.AutoencoderKL
|
||||||
|
params:
|
||||||
|
embed_dim: 4
|
||||||
|
monitor: val/rec_loss
|
||||||
|
ddconfig:
|
||||||
|
#attn_type: "vanilla-xformers"
|
||||||
|
double_z: true
|
||||||
|
z_channels: 4
|
||||||
|
resolution: 256
|
||||||
|
in_channels: 3
|
||||||
|
out_ch: 3
|
||||||
|
ch: 128
|
||||||
|
ch_mult:
|
||||||
|
- 1
|
||||||
|
- 2
|
||||||
|
- 4
|
||||||
|
- 4
|
||||||
|
num_res_blocks: 2
|
||||||
|
attn_resolutions: [ ]
|
||||||
|
dropout: 0.0
|
||||||
|
lossconfig:
|
||||||
|
target: torch.nn.Identity
|
||||||
|
|
||||||
|
cond_stage_config:
|
||||||
|
target: ldm.modules.encoders.modules.FrozenOpenCLIPEmbedder
|
||||||
|
params:
|
||||||
|
freeze: True
|
||||||
|
layer: "penultimate"
|
|
@ -0,0 +1,75 @@
|
||||||
|
model:
|
||||||
|
base_learning_rate: 1.0e-04
|
||||||
|
target: ldm.models.diffusion.ddpm.LatentUpscaleDiffusion
|
||||||
|
params:
|
||||||
|
parameterization: "v"
|
||||||
|
low_scale_key: "lr"
|
||||||
|
linear_start: 0.0001
|
||||||
|
linear_end: 0.02
|
||||||
|
num_timesteps_cond: 1
|
||||||
|
log_every_t: 200
|
||||||
|
timesteps: 1000
|
||||||
|
first_stage_key: "jpg"
|
||||||
|
cond_stage_key: "txt"
|
||||||
|
image_size: 128
|
||||||
|
channels: 4
|
||||||
|
cond_stage_trainable: false
|
||||||
|
conditioning_key: "hybrid-adm"
|
||||||
|
monitor: val/loss_simple_ema
|
||||||
|
scale_factor: 0.08333
|
||||||
|
use_ema: False
|
||||||
|
|
||||||
|
low_scale_config:
|
||||||
|
target: ldm.modules.diffusionmodules.upscaling.ImageConcatWithNoiseAugmentation
|
||||||
|
params:
|
||||||
|
noise_schedule_config: # image space
|
||||||
|
linear_start: 0.0001
|
||||||
|
linear_end: 0.02
|
||||||
|
max_noise_level: 350
|
||||||
|
|
||||||
|
unet_config:
|
||||||
|
target: ldm.modules.diffusionmodules.openaimodel.UNetModel
|
||||||
|
params:
|
||||||
|
use_checkpoint: True
|
||||||
|
num_classes: 1000 # timesteps for noise conditioning (here constant, just need one)
|
||||||
|
image_size: 128
|
||||||
|
in_channels: 7
|
||||||
|
out_channels: 4
|
||||||
|
model_channels: 256
|
||||||
|
attention_resolutions: [ 2,4,8]
|
||||||
|
num_res_blocks: 2
|
||||||
|
channel_mult: [ 1, 2, 2, 4]
|
||||||
|
disable_self_attentions: [True, True, True, False]
|
||||||
|
disable_middle_self_attn: False
|
||||||
|
num_heads: 8
|
||||||
|
use_spatial_transformer: True
|
||||||
|
transformer_depth: 1
|
||||||
|
context_dim: 1024
|
||||||
|
legacy: False
|
||||||
|
use_linear_in_transformer: True
|
||||||
|
|
||||||
|
first_stage_config:
|
||||||
|
target: ldm.models.autoencoder.AutoencoderKL
|
||||||
|
params:
|
||||||
|
embed_dim: 4
|
||||||
|
ddconfig:
|
||||||
|
# attn_type: "vanilla-xformers" this model needs efficient attention to be feasible on HR data, also the decoder seems to break in half precision (UNet is fine though)
|
||||||
|
double_z: True
|
||||||
|
z_channels: 4
|
||||||
|
resolution: 256
|
||||||
|
in_channels: 3
|
||||||
|
out_ch: 3
|
||||||
|
ch: 128
|
||||||
|
ch_mult: [ 1,2,4 ] # num_down = len(ch_mult)-1
|
||||||
|
num_res_blocks: 2
|
||||||
|
attn_resolutions: [ ]
|
||||||
|
dropout: 0.0
|
||||||
|
|
||||||
|
lossconfig:
|
||||||
|
target: torch.nn.Identity
|
||||||
|
|
||||||
|
cond_stage_config:
|
||||||
|
target: ldm.modules.encoders.modules.FrozenOpenCLIPEmbedder
|
||||||
|
params:
|
||||||
|
freeze: True
|
||||||
|
layer: "penultimate"
|
|
@ -0,0 +1,25 @@
|
||||||
|
# Dataset Card for Teyvat BLIP captions
|
||||||
|
Dataset used to train [Teyvat characters text to image model](https://github.com/hpcaitech/ColossalAI/tree/main/examples/images/diffusion).
|
||||||
|
|
||||||
|
BLIP generated captions for characters images from [genshin-impact fandom wiki](https://genshin-impact.fandom.com/wiki/Character#Playable_Characters)and [biligame wiki for genshin impact](https://wiki.biligame.com/ys/%E8%A7%92%E8%89%B2).
|
||||||
|
|
||||||
|
For each row the dataset contains `image` and `text` keys. `image` is a varying size PIL png, and `text` is the accompanying text caption. Only a train split is provided.
|
||||||
|
|
||||||
|
The `text` include the tag `Teyvat`, `Name`,`Element`, `Weapon`, `Region`, `Model type`, and `Description`, the `Description` is captioned with the [pre-trained BLIP model](https://github.com/salesforce/BLIP).
|
||||||
|
## Examples
|
||||||
|
|
||||||
|
<img src = "https://huggingface.co/datasets/Fazzie/Teyvat/resolve/main/data/Ganyu_001.png" title = "Ganyu_001.png" style="max-width: 20%;" >
|
||||||
|
|
||||||
|
> Teyvat, Name:Ganyu, Element:Cryo, Weapon:Bow, Region:Liyue, Model type:Medium Female, Description:an anime character with blue hair and blue eyes
|
||||||
|
|
||||||
|
<img src = "https://huggingface.co/datasets/Fazzie/Teyvat/resolve/main/data/Ganyu_002.png" title = "Ganyu_002.png" style="max-width: 20%;" >
|
||||||
|
|
||||||
|
> Teyvat, Name:Ganyu, Element:Cryo, Weapon:Bow, Region:Liyue, Model type:Medium Female, Description:an anime character with blue hair and blue eyes
|
||||||
|
|
||||||
|
<img src = "https://huggingface.co/datasets/Fazzie/Teyvat/resolve/main/data/Keqing_003.png" title = "Keqing_003.png" style="max-width: 20%;" >
|
||||||
|
|
||||||
|
> Teyvat, Name:Keqing, Element:Electro, Weapon:Sword, Region:Liyue, Model type:Medium Female, Description:a anime girl with long white hair and blue eyes
|
||||||
|
|
||||||
|
<img src = "https://huggingface.co/datasets/Fazzie/Teyvat/resolve/main/data/Keqing_004.png" title = "Keqing_004.png" style="max-width: 20%;" >
|
||||||
|
|
||||||
|
> Teyvat, Name:Keqing, Element:Electro, Weapon:Sword, Region:Liyue, Model type:Medium Female, Description:an anime character wearing a purple dress and cat ears
|
|
@ -1,7 +1,8 @@
|
||||||
model:
|
model:
|
||||||
base_learning_rate: 1.0e-04
|
base_learning_rate: 1.0e-4
|
||||||
target: ldm.models.diffusion.ddpm.LatentDiffusion
|
target: ldm.models.diffusion.ddpm.LatentDiffusion
|
||||||
params:
|
params:
|
||||||
|
parameterization: "v"
|
||||||
linear_start: 0.00085
|
linear_start: 0.00085
|
||||||
linear_end: 0.0120
|
linear_end: 0.0120
|
||||||
num_timesteps_cond: 1
|
num_timesteps_cond: 1
|
||||||
|
@ -11,11 +12,11 @@ model:
|
||||||
cond_stage_key: txt
|
cond_stage_key: txt
|
||||||
image_size: 64
|
image_size: 64
|
||||||
channels: 4
|
channels: 4
|
||||||
cond_stage_trainable: false # Note: different from the one we trained before
|
cond_stage_trainable: false
|
||||||
conditioning_key: crossattn
|
conditioning_key: crossattn
|
||||||
monitor: val/loss_simple_ema
|
monitor: val/loss_simple_ema
|
||||||
scale_factor: 0.18215
|
scale_factor: 0.18215
|
||||||
use_ema: False
|
use_ema: False # we set this to false because this is an inference only config
|
||||||
|
|
||||||
scheduler_config: # 10000 warmup steps
|
scheduler_config: # 10000 warmup steps
|
||||||
target: ldm.lr_scheduler.LambdaLinearScheduler
|
target: ldm.lr_scheduler.LambdaLinearScheduler
|
||||||
|
@ -26,31 +27,33 @@ model:
|
||||||
f_max: [ 1.e-4 ]
|
f_max: [ 1.e-4 ]
|
||||||
f_min: [ 1.e-10 ]
|
f_min: [ 1.e-10 ]
|
||||||
|
|
||||||
|
|
||||||
unet_config:
|
unet_config:
|
||||||
target: ldm.modules.diffusionmodules.openaimodel.UNetModel
|
target: ldm.modules.diffusionmodules.openaimodel.UNetModel
|
||||||
params:
|
params:
|
||||||
|
use_checkpoint: True
|
||||||
|
use_fp16: True
|
||||||
image_size: 32 # unused
|
image_size: 32 # unused
|
||||||
from_pretrained: '/data/scratch/diffuser/stable-diffusion-v1-4/unet/diffusion_pytorch_model.bin'
|
|
||||||
in_channels: 4
|
in_channels: 4
|
||||||
out_channels: 4
|
out_channels: 4
|
||||||
model_channels: 320
|
model_channels: 320
|
||||||
attention_resolutions: [ 4, 2, 1 ]
|
attention_resolutions: [ 4, 2, 1 ]
|
||||||
num_res_blocks: 2
|
num_res_blocks: 2
|
||||||
channel_mult: [ 1, 2, 4, 4 ]
|
channel_mult: [ 1, 2, 4, 4 ]
|
||||||
num_heads: 8
|
num_head_channels: 64 # need to fix for flash-attn
|
||||||
use_spatial_transformer: True
|
use_spatial_transformer: True
|
||||||
|
use_linear_in_transformer: True
|
||||||
transformer_depth: 1
|
transformer_depth: 1
|
||||||
context_dim: 768
|
context_dim: 1024
|
||||||
use_checkpoint: False
|
|
||||||
legacy: False
|
legacy: False
|
||||||
|
|
||||||
first_stage_config:
|
first_stage_config:
|
||||||
target: ldm.models.autoencoder.AutoencoderKL
|
target: ldm.models.autoencoder.AutoencoderKL
|
||||||
params:
|
params:
|
||||||
embed_dim: 4
|
embed_dim: 4
|
||||||
from_pretrained: '/data/scratch/diffuser/stable-diffusion-v1-4/vae/diffusion_pytorch_model.bin'
|
|
||||||
monitor: val/rec_loss
|
monitor: val/rec_loss
|
||||||
ddconfig:
|
ddconfig:
|
||||||
|
#attn_type: "vanilla-xformers"
|
||||||
double_z: true
|
double_z: true
|
||||||
z_channels: 4
|
z_channels: 4
|
||||||
resolution: 256
|
resolution: 256
|
||||||
|
@ -69,9 +72,10 @@ model:
|
||||||
target: torch.nn.Identity
|
target: torch.nn.Identity
|
||||||
|
|
||||||
cond_stage_config:
|
cond_stage_config:
|
||||||
target: ldm.modules.encoders.modules.FrozenCLIPEmbedder
|
target: ldm.modules.encoders.modules.FrozenOpenCLIPEmbedder
|
||||||
params:
|
params:
|
||||||
use_fp16: True
|
freeze: True
|
||||||
|
layer: "penultimate"
|
||||||
|
|
||||||
data:
|
data:
|
||||||
target: main.DataModuleFromConfig
|
target: main.DataModuleFromConfig
|
||||||
|
@ -86,37 +90,37 @@ data:
|
||||||
- target: torchvision.transforms.Resize
|
- target: torchvision.transforms.Resize
|
||||||
params:
|
params:
|
||||||
size: 512
|
size: 512
|
||||||
# - target: torchvision.transforms.RandomCrop
|
- target: torchvision.transforms.RandomCrop
|
||||||
# params:
|
params:
|
||||||
# size: 256
|
size: 512
|
||||||
# - target: torchvision.transforms.RandomHorizontalFlip
|
- target: torchvision.transforms.RandomHorizontalFlip
|
||||||
|
|
||||||
lightning:
|
lightning:
|
||||||
trainer:
|
trainer:
|
||||||
accelerator: 'gpu'
|
accelerator: 'gpu'
|
||||||
devices: 2
|
devices: 2
|
||||||
log_gpu_memory: all
|
log_gpu_memory: all
|
||||||
max_epochs: 10
|
max_epochs: 2
|
||||||
precision: 16
|
precision: 16
|
||||||
auto_select_gpus: False
|
auto_select_gpus: False
|
||||||
strategy:
|
strategy:
|
||||||
target: lightning.pytorch.strategies.ColossalAIStrategy
|
target: strategies.ColossalAIStrategy
|
||||||
params:
|
params:
|
||||||
use_chunk: False
|
use_chunk: True
|
||||||
enable_distributed_storage: True,
|
enable_distributed_storage: True
|
||||||
placement_policy: cuda
|
placement_policy: auto
|
||||||
force_outputs_fp32: False
|
force_outputs_fp32: true
|
||||||
|
|
||||||
log_every_n_steps: 2
|
log_every_n_steps: 2
|
||||||
logger: True
|
logger: True
|
||||||
default_root_dir: "/tmp/diff_log/"
|
default_root_dir: "/tmp/diff_log/"
|
||||||
profiler: pytorch
|
# profiler: pytorch
|
||||||
|
|
||||||
logger_config:
|
logger_config:
|
||||||
wandb:
|
wandb:
|
||||||
target: lightning.pytorch.loggers.WandbLogger
|
target: loggers.WandbLogger
|
||||||
params:
|
params:
|
||||||
name: nowname
|
name: nowname
|
||||||
save_dir: "/tmp/diff_log/"
|
save_dir: "/tmp/diff_log/"
|
||||||
offline: opt.debug
|
offline: opt.debug
|
||||||
id: nowname
|
id: nowname
|
|
@ -1,21 +1,22 @@
|
||||||
model:
|
model:
|
||||||
base_learning_rate: 1.0e-04
|
base_learning_rate: 1.0e-4
|
||||||
target: ldm.models.diffusion.ddpm.LatentDiffusion
|
target: ldm.models.diffusion.ddpm.LatentDiffusion
|
||||||
params:
|
params:
|
||||||
|
parameterization: "v"
|
||||||
linear_start: 0.00085
|
linear_start: 0.00085
|
||||||
linear_end: 0.0120
|
linear_end: 0.0120
|
||||||
num_timesteps_cond: 1
|
num_timesteps_cond: 1
|
||||||
log_every_t: 200
|
log_every_t: 200
|
||||||
timesteps: 1000
|
timesteps: 1000
|
||||||
first_stage_key: image
|
first_stage_key: image
|
||||||
cond_stage_key: caption
|
cond_stage_key: txt
|
||||||
image_size: 64
|
image_size: 64
|
||||||
channels: 4
|
channels: 4
|
||||||
cond_stage_trainable: false # Note: different from the one we trained before
|
cond_stage_trainable: false
|
||||||
conditioning_key: crossattn
|
conditioning_key: crossattn
|
||||||
monitor: val/loss_simple_ema
|
monitor: val/loss_simple_ema
|
||||||
scale_factor: 0.18215
|
scale_factor: 0.18215
|
||||||
use_ema: False
|
use_ema: False # we set this to false because this is an inference only config
|
||||||
|
|
||||||
scheduler_config: # 10000 warmup steps
|
scheduler_config: # 10000 warmup steps
|
||||||
target: ldm.lr_scheduler.LambdaLinearScheduler
|
target: ldm.lr_scheduler.LambdaLinearScheduler
|
||||||
|
@ -26,31 +27,33 @@ model:
|
||||||
f_max: [ 1.e-4 ]
|
f_max: [ 1.e-4 ]
|
||||||
f_min: [ 1.e-10 ]
|
f_min: [ 1.e-10 ]
|
||||||
|
|
||||||
|
|
||||||
unet_config:
|
unet_config:
|
||||||
target: ldm.modules.diffusionmodules.openaimodel.UNetModel
|
target: ldm.modules.diffusionmodules.openaimodel.UNetModel
|
||||||
params:
|
params:
|
||||||
|
use_checkpoint: True
|
||||||
|
use_fp16: True
|
||||||
image_size: 32 # unused
|
image_size: 32 # unused
|
||||||
from_pretrained: '/data/scratch/diffuser/stable-diffusion-v1-4/unet/diffusion_pytorch_model.bin'
|
|
||||||
in_channels: 4
|
in_channels: 4
|
||||||
out_channels: 4
|
out_channels: 4
|
||||||
model_channels: 320
|
model_channels: 320
|
||||||
attention_resolutions: [ 4, 2, 1 ]
|
attention_resolutions: [ 4, 2, 1 ]
|
||||||
num_res_blocks: 2
|
num_res_blocks: 2
|
||||||
channel_mult: [ 1, 2, 4, 4 ]
|
channel_mult: [ 1, 2, 4, 4 ]
|
||||||
num_heads: 8
|
num_head_channels: 64 # need to fix for flash-attn
|
||||||
use_spatial_transformer: True
|
use_spatial_transformer: True
|
||||||
|
use_linear_in_transformer: True
|
||||||
transformer_depth: 1
|
transformer_depth: 1
|
||||||
context_dim: 768
|
context_dim: 1024
|
||||||
use_checkpoint: False
|
|
||||||
legacy: False
|
legacy: False
|
||||||
|
|
||||||
first_stage_config:
|
first_stage_config:
|
||||||
target: ldm.models.autoencoder.AutoencoderKL
|
target: ldm.models.autoencoder.AutoencoderKL
|
||||||
params:
|
params:
|
||||||
embed_dim: 4
|
embed_dim: 4
|
||||||
from_pretrained: '/data/scratch/diffuser/stable-diffusion-v1-4/vae/diffusion_pytorch_model.bin'
|
|
||||||
monitor: val/rec_loss
|
monitor: val/rec_loss
|
||||||
ddconfig:
|
ddconfig:
|
||||||
|
#attn_type: "vanilla-xformers"
|
||||||
double_z: true
|
double_z: true
|
||||||
z_channels: 4
|
z_channels: 4
|
||||||
resolution: 256
|
resolution: 256
|
||||||
|
@ -69,9 +72,10 @@ model:
|
||||||
target: torch.nn.Identity
|
target: torch.nn.Identity
|
||||||
|
|
||||||
cond_stage_config:
|
cond_stage_config:
|
||||||
target: ldm.modules.encoders.modules.FrozenCLIPEmbedder
|
target: ldm.modules.encoders.modules.FrozenOpenCLIPEmbedder
|
||||||
params:
|
params:
|
||||||
use_fp16: True
|
freeze: True
|
||||||
|
layer: "penultimate"
|
||||||
|
|
||||||
data:
|
data:
|
||||||
target: main.DataModuleFromConfig
|
target: main.DataModuleFromConfig
|
||||||
|
@ -87,30 +91,30 @@ data:
|
||||||
|
|
||||||
lightning:
|
lightning:
|
||||||
trainer:
|
trainer:
|
||||||
accelerator: 'gpu'
|
accelerator: 'gpu'
|
||||||
devices: 4
|
devices: 1
|
||||||
log_gpu_memory: all
|
log_gpu_memory: all
|
||||||
max_epochs: 2
|
max_epochs: 2
|
||||||
precision: 16
|
precision: 16
|
||||||
auto_select_gpus: False
|
auto_select_gpus: False
|
||||||
strategy:
|
strategy:
|
||||||
target: lightning.pytorch.strategies.ColossalAIStrategy
|
target: strategies.ColossalAIStrategy
|
||||||
params:
|
params:
|
||||||
use_chunk: False
|
use_chunk: True
|
||||||
enable_distributed_storage: True,
|
enable_distributed_storage: True
|
||||||
placement_policy: cuda
|
placement_policy: auto
|
||||||
force_outputs_fp32: False
|
force_outputs_fp32: true
|
||||||
|
|
||||||
log_every_n_steps: 2
|
log_every_n_steps: 2
|
||||||
logger: True
|
logger: True
|
||||||
default_root_dir: "/tmp/diff_log/"
|
default_root_dir: "/tmp/diff_log/"
|
||||||
profiler: pytorch
|
# profiler: pytorch
|
||||||
|
|
||||||
logger_config:
|
logger_config:
|
||||||
wandb:
|
wandb:
|
||||||
target: lightning.pytorch.loggers.WandbLogger
|
target: loggers.WandbLogger
|
||||||
params:
|
params:
|
||||||
name: nowname
|
name: nowname
|
||||||
save_dir: "/tmp/diff_log/"
|
save_dir: "/tmp/diff_log/"
|
||||||
offline: opt.debug
|
offline: opt.debug
|
||||||
id: nowname
|
id: nowname
|
||||||
|
|
|
@ -1,7 +1,8 @@
|
||||||
model:
|
model:
|
||||||
base_learning_rate: 1.0e-04
|
base_learning_rate: 1.0e-4
|
||||||
target: ldm.models.diffusion.ddpm.LatentDiffusion
|
target: ldm.models.diffusion.ddpm.LatentDiffusion
|
||||||
params:
|
params:
|
||||||
|
parameterization: "v"
|
||||||
linear_start: 0.00085
|
linear_start: 0.00085
|
||||||
linear_end: 0.0120
|
linear_end: 0.0120
|
||||||
num_timesteps_cond: 1
|
num_timesteps_cond: 1
|
||||||
|
@ -11,11 +12,11 @@ model:
|
||||||
cond_stage_key: txt
|
cond_stage_key: txt
|
||||||
image_size: 64
|
image_size: 64
|
||||||
channels: 4
|
channels: 4
|
||||||
cond_stage_trainable: false # Note: different from the one we trained before
|
cond_stage_trainable: false
|
||||||
conditioning_key: crossattn
|
conditioning_key: crossattn
|
||||||
monitor: val/loss_simple_ema
|
monitor: val/loss_simple_ema
|
||||||
scale_factor: 0.18215
|
scale_factor: 0.18215
|
||||||
use_ema: False
|
use_ema: False # we set this to false because this is an inference only config
|
||||||
|
|
||||||
scheduler_config: # 10000 warmup steps
|
scheduler_config: # 10000 warmup steps
|
||||||
target: ldm.lr_scheduler.LambdaLinearScheduler
|
target: ldm.lr_scheduler.LambdaLinearScheduler
|
||||||
|
@ -26,31 +27,33 @@ model:
|
||||||
f_max: [ 1.e-4 ]
|
f_max: [ 1.e-4 ]
|
||||||
f_min: [ 1.e-10 ]
|
f_min: [ 1.e-10 ]
|
||||||
|
|
||||||
|
|
||||||
unet_config:
|
unet_config:
|
||||||
target: ldm.modules.diffusionmodules.openaimodel.UNetModel
|
target: ldm.modules.diffusionmodules.openaimodel.UNetModel
|
||||||
params:
|
params:
|
||||||
|
use_checkpoint: True
|
||||||
|
use_fp16: True
|
||||||
image_size: 32 # unused
|
image_size: 32 # unused
|
||||||
from_pretrained: '/data/scratch/diffuser/stable-diffusion-v1-4/unet/diffusion_pytorch_model.bin'
|
|
||||||
in_channels: 4
|
in_channels: 4
|
||||||
out_channels: 4
|
out_channels: 4
|
||||||
model_channels: 320
|
model_channels: 320
|
||||||
attention_resolutions: [ 4, 2, 1 ]
|
attention_resolutions: [ 4, 2, 1 ]
|
||||||
num_res_blocks: 2
|
num_res_blocks: 2
|
||||||
channel_mult: [ 1, 2, 4, 4 ]
|
channel_mult: [ 1, 2, 4, 4 ]
|
||||||
num_heads: 8
|
num_head_channels: 64 # need to fix for flash-attn
|
||||||
use_spatial_transformer: True
|
use_spatial_transformer: True
|
||||||
|
use_linear_in_transformer: True
|
||||||
transformer_depth: 1
|
transformer_depth: 1
|
||||||
context_dim: 768
|
context_dim: 1024
|
||||||
use_checkpoint: False
|
|
||||||
legacy: False
|
legacy: False
|
||||||
|
|
||||||
first_stage_config:
|
first_stage_config:
|
||||||
target: ldm.models.autoencoder.AutoencoderKL
|
target: ldm.models.autoencoder.AutoencoderKL
|
||||||
params:
|
params:
|
||||||
embed_dim: 4
|
embed_dim: 4
|
||||||
from_pretrained: '/data/scratch/diffuser/stable-diffusion-v1-4/vae/diffusion_pytorch_model.bin'
|
|
||||||
monitor: val/rec_loss
|
monitor: val/rec_loss
|
||||||
ddconfig:
|
ddconfig:
|
||||||
|
#attn_type: "vanilla-xformers"
|
||||||
double_z: true
|
double_z: true
|
||||||
z_channels: 4
|
z_channels: 4
|
||||||
resolution: 256
|
resolution: 256
|
||||||
|
@ -69,9 +72,10 @@ model:
|
||||||
target: torch.nn.Identity
|
target: torch.nn.Identity
|
||||||
|
|
||||||
cond_stage_config:
|
cond_stage_config:
|
||||||
target: ldm.modules.encoders.modules.FrozenCLIPEmbedder
|
target: ldm.modules.encoders.modules.FrozenOpenCLIPEmbedder
|
||||||
params:
|
params:
|
||||||
use_fp16: True
|
freeze: True
|
||||||
|
layer: "penultimate"
|
||||||
|
|
||||||
data:
|
data:
|
||||||
target: main.DataModuleFromConfig
|
target: main.DataModuleFromConfig
|
||||||
|
@ -94,30 +98,30 @@ data:
|
||||||
|
|
||||||
lightning:
|
lightning:
|
||||||
trainer:
|
trainer:
|
||||||
accelerator: 'gpu'
|
accelerator: 'gpu'
|
||||||
devices: 2
|
devices: 1
|
||||||
log_gpu_memory: all
|
log_gpu_memory: all
|
||||||
max_epochs: 2
|
max_epochs: 2
|
||||||
precision: 16
|
precision: 16
|
||||||
auto_select_gpus: False
|
auto_select_gpus: False
|
||||||
strategy:
|
strategy:
|
||||||
target: lightning.pytorch.strategies.ColossalAIStrategy
|
target: strategies.ColossalAIStrategy
|
||||||
params:
|
params:
|
||||||
use_chunk: False
|
use_chunk: True
|
||||||
enable_distributed_storage: True,
|
enable_distributed_storage: True
|
||||||
placement_policy: cuda
|
placement_policy: auto
|
||||||
force_outputs_fp32: False
|
force_outputs_fp32: true
|
||||||
|
|
||||||
log_every_n_steps: 2
|
log_every_n_steps: 2
|
||||||
logger: True
|
logger: True
|
||||||
default_root_dir: "/tmp/diff_log/"
|
default_root_dir: "/tmp/diff_log/"
|
||||||
profiler: pytorch
|
# profiler: pytorch
|
||||||
|
|
||||||
logger_config:
|
logger_config:
|
||||||
wandb:
|
wandb:
|
||||||
target: lightning.pytorch.loggers.WandbLogger
|
target: loggers.WandbLogger
|
||||||
params:
|
params:
|
||||||
name: nowname
|
name: nowname
|
||||||
save_dir: "/tmp/diff_log/"
|
save_dir: "/tmp/diff_log/"
|
||||||
offline: opt.debug
|
offline: opt.debug
|
||||||
id: nowname
|
id: nowname
|
||||||
|
|
|
@ -1,56 +1,59 @@
|
||||||
model:
|
model:
|
||||||
base_learning_rate: 1.0e-04
|
base_learning_rate: 1.0e-4
|
||||||
target: ldm.models.diffusion.ddpm.LatentDiffusion
|
target: ldm.models.diffusion.ddpm.LatentDiffusion
|
||||||
params:
|
params:
|
||||||
|
parameterization: "v"
|
||||||
linear_start: 0.00085
|
linear_start: 0.00085
|
||||||
linear_end: 0.0120
|
linear_end: 0.0120
|
||||||
num_timesteps_cond: 1
|
num_timesteps_cond: 1
|
||||||
log_every_t: 200
|
log_every_t: 200
|
||||||
timesteps: 1000
|
timesteps: 1000
|
||||||
first_stage_key: image
|
first_stage_key: image
|
||||||
cond_stage_key: caption
|
cond_stage_key: txt
|
||||||
image_size: 32
|
image_size: 64
|
||||||
channels: 4
|
channels: 4
|
||||||
cond_stage_trainable: false # Note: different from the one we trained before
|
cond_stage_trainable: false
|
||||||
conditioning_key: crossattn
|
conditioning_key: crossattn
|
||||||
monitor: val/loss_simple_ema
|
monitor: val/loss_simple_ema
|
||||||
scale_factor: 0.18215
|
scale_factor: 0.18215
|
||||||
use_ema: False
|
use_ema: False # we set this to false because this is an inference only config
|
||||||
|
|
||||||
scheduler_config: # 10000 warmup steps
|
scheduler_config: # 10000 warmup steps
|
||||||
target: ldm.lr_scheduler.LambdaLinearScheduler
|
target: ldm.lr_scheduler.LambdaLinearScheduler
|
||||||
params:
|
params:
|
||||||
warm_up_steps: [ 100 ]
|
warm_up_steps: [ 1 ] # NOTE for resuming. use 10000 if starting from scratch
|
||||||
cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases
|
cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases
|
||||||
f_start: [ 1.e-6 ]
|
f_start: [ 1.e-6 ]
|
||||||
f_max: [ 1.e-4 ]
|
f_max: [ 1.e-4 ]
|
||||||
f_min: [ 1.e-10 ]
|
f_min: [ 1.e-10 ]
|
||||||
|
|
||||||
|
|
||||||
unet_config:
|
unet_config:
|
||||||
target: ldm.modules.diffusionmodules.openaimodel.UNetModel
|
target: ldm.modules.diffusionmodules.openaimodel.UNetModel
|
||||||
params:
|
params:
|
||||||
|
use_checkpoint: True
|
||||||
|
use_fp16: True
|
||||||
image_size: 32 # unused
|
image_size: 32 # unused
|
||||||
from_pretrained: '/data/scratch/diffuser/stable-diffusion-v1-4/unet/diffusion_pytorch_model.bin'
|
|
||||||
in_channels: 4
|
in_channels: 4
|
||||||
out_channels: 4
|
out_channels: 4
|
||||||
model_channels: 320
|
model_channels: 320
|
||||||
attention_resolutions: [ 4, 2, 1 ]
|
attention_resolutions: [ 4, 2, 1 ]
|
||||||
num_res_blocks: 2
|
num_res_blocks: 2
|
||||||
channel_mult: [ 1, 2, 4, 4 ]
|
channel_mult: [ 1, 2, 4, 4 ]
|
||||||
num_heads: 8
|
num_head_channels: 64 # need to fix for flash-attn
|
||||||
use_spatial_transformer: True
|
use_spatial_transformer: True
|
||||||
|
use_linear_in_transformer: True
|
||||||
transformer_depth: 1
|
transformer_depth: 1
|
||||||
context_dim: 768
|
context_dim: 1024
|
||||||
use_checkpoint: False
|
|
||||||
legacy: False
|
legacy: False
|
||||||
|
|
||||||
first_stage_config:
|
first_stage_config:
|
||||||
target: ldm.models.autoencoder.AutoencoderKL
|
target: ldm.models.autoencoder.AutoencoderKL
|
||||||
params:
|
params:
|
||||||
embed_dim: 4
|
embed_dim: 4
|
||||||
from_pretrained: '/data/scratch/diffuser/stable-diffusion-v1-4/vae/diffusion_pytorch_model.bin'
|
|
||||||
monitor: val/rec_loss
|
monitor: val/rec_loss
|
||||||
ddconfig:
|
ddconfig:
|
||||||
|
#attn_type: "vanilla-xformers"
|
||||||
double_z: true
|
double_z: true
|
||||||
z_channels: 4
|
z_channels: 4
|
||||||
resolution: 256
|
resolution: 256
|
||||||
|
@ -69,32 +72,39 @@ model:
|
||||||
target: torch.nn.Identity
|
target: torch.nn.Identity
|
||||||
|
|
||||||
cond_stage_config:
|
cond_stage_config:
|
||||||
target: ldm.modules.encoders.modules.FrozenCLIPEmbedder
|
target: ldm.modules.encoders.modules.FrozenOpenCLIPEmbedder
|
||||||
params:
|
params:
|
||||||
use_fp16: True
|
freeze: True
|
||||||
|
layer: "penultimate"
|
||||||
|
|
||||||
data:
|
data:
|
||||||
target: main.DataModuleFromConfig
|
target: main.DataModuleFromConfig
|
||||||
params:
|
params:
|
||||||
batch_size: 64
|
batch_size: 16
|
||||||
wrap: False
|
num_workers: 4
|
||||||
train:
|
train:
|
||||||
target: ldm.data.base.Txt2ImgIterableBaseDataset
|
target: ldm.data.teyvat.hf_dataset
|
||||||
params:
|
params:
|
||||||
file_path: "/data/scratch/diffuser/laion_part0/"
|
path: Fazzie/Teyvat
|
||||||
world_size: 1
|
image_transforms:
|
||||||
rank: 0
|
- target: torchvision.transforms.Resize
|
||||||
|
params:
|
||||||
|
size: 512
|
||||||
|
- target: torchvision.transforms.RandomCrop
|
||||||
|
params:
|
||||||
|
size: 512
|
||||||
|
- target: torchvision.transforms.RandomHorizontalFlip
|
||||||
|
|
||||||
lightning:
|
lightning:
|
||||||
trainer:
|
trainer:
|
||||||
accelerator: 'gpu'
|
accelerator: 'gpu'
|
||||||
devices: 4
|
devices: 2
|
||||||
log_gpu_memory: all
|
log_gpu_memory: all
|
||||||
max_epochs: 2
|
max_epochs: 2
|
||||||
precision: 16
|
precision: 16
|
||||||
auto_select_gpus: False
|
auto_select_gpus: False
|
||||||
strategy:
|
strategy:
|
||||||
target: lightning.pytorch.strategies.DDPStrategy
|
target: strategies.DDPStrategy
|
||||||
params:
|
params:
|
||||||
find_unused_parameters: False
|
find_unused_parameters: False
|
||||||
log_every_n_steps: 2
|
log_every_n_steps: 2
|
||||||
|
@ -105,9 +115,9 @@ lightning:
|
||||||
|
|
||||||
logger_config:
|
logger_config:
|
||||||
wandb:
|
wandb:
|
||||||
target: lightning.pytorch.loggers.WandbLogger
|
target: loggers.WandbLogger
|
||||||
params:
|
params:
|
||||||
name: nowname
|
name: nowname
|
||||||
save_dir: "/tmp/diff_log/"
|
save_dir: "/data2/tmp/diff_log/"
|
||||||
offline: opt.debug
|
offline: opt.debug
|
||||||
id: nowname
|
id: nowname
|
||||||
|
|
|
@ -1,57 +1,59 @@
|
||||||
model:
|
model:
|
||||||
base_learning_rate: 1.0e-04
|
base_learning_rate: 1.0e-4
|
||||||
target: ldm.models.diffusion.ddpm.LatentDiffusion
|
target: ldm.models.diffusion.ddpm.LatentDiffusion
|
||||||
params:
|
params:
|
||||||
|
parameterization: "v"
|
||||||
linear_start: 0.00085
|
linear_start: 0.00085
|
||||||
linear_end: 0.0120
|
linear_end: 0.0120
|
||||||
num_timesteps_cond: 1
|
num_timesteps_cond: 1
|
||||||
log_every_t: 200
|
log_every_t: 200
|
||||||
timesteps: 1000
|
timesteps: 1000
|
||||||
first_stage_key: image
|
first_stage_key: image
|
||||||
cond_stage_key: caption
|
cond_stage_key: txt
|
||||||
image_size: 32
|
image_size: 64
|
||||||
channels: 4
|
channels: 4
|
||||||
cond_stage_trainable: false # Note: different from the one we trained before
|
cond_stage_trainable: false
|
||||||
conditioning_key: crossattn
|
conditioning_key: crossattn
|
||||||
monitor: val/loss_simple_ema
|
monitor: val/loss_simple_ema
|
||||||
scale_factor: 0.18215
|
scale_factor: 0.18215
|
||||||
use_ema: False
|
use_ema: False # we set this to false because this is an inference only config
|
||||||
check_nan_inf: False
|
|
||||||
|
|
||||||
scheduler_config: # 10000 warmup steps
|
scheduler_config: # 10000 warmup steps
|
||||||
target: ldm.lr_scheduler.LambdaLinearScheduler
|
target: ldm.lr_scheduler.LambdaLinearScheduler
|
||||||
params:
|
params:
|
||||||
warm_up_steps: [ 10000 ]
|
warm_up_steps: [ 1 ] # NOTE for resuming. use 10000 if starting from scratch
|
||||||
cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases
|
cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases
|
||||||
f_start: [ 1.e-6 ]
|
f_start: [ 1.e-6 ]
|
||||||
f_max: [ 1.e-4 ]
|
f_max: [ 1.e-4 ]
|
||||||
f_min: [ 1.e-10 ]
|
f_min: [ 1.e-10 ]
|
||||||
|
|
||||||
|
|
||||||
unet_config:
|
unet_config:
|
||||||
target: ldm.modules.diffusionmodules.openaimodel.UNetModel
|
target: ldm.modules.diffusionmodules.openaimodel.UNetModel
|
||||||
params:
|
params:
|
||||||
|
use_checkpoint: True
|
||||||
|
use_fp16: True
|
||||||
image_size: 32 # unused
|
image_size: 32 # unused
|
||||||
from_pretrained: '/data/scratch/diffuser/stable-diffusion-v1-4/unet/diffusion_pytorch_model.bin'
|
|
||||||
in_channels: 4
|
in_channels: 4
|
||||||
out_channels: 4
|
out_channels: 4
|
||||||
model_channels: 320
|
model_channels: 320
|
||||||
attention_resolutions: [ 4, 2, 1 ]
|
attention_resolutions: [ 4, 2, 1 ]
|
||||||
num_res_blocks: 2
|
num_res_blocks: 2
|
||||||
channel_mult: [ 1, 2, 4, 4 ]
|
channel_mult: [ 1, 2, 4, 4 ]
|
||||||
num_heads: 8
|
num_head_channels: 64 # need to fix for flash-attn
|
||||||
use_spatial_transformer: True
|
use_spatial_transformer: True
|
||||||
|
use_linear_in_transformer: True
|
||||||
transformer_depth: 1
|
transformer_depth: 1
|
||||||
context_dim: 768
|
context_dim: 1024
|
||||||
use_checkpoint: False
|
|
||||||
legacy: False
|
legacy: False
|
||||||
|
|
||||||
first_stage_config:
|
first_stage_config:
|
||||||
target: ldm.models.autoencoder.AutoencoderKL
|
target: ldm.models.autoencoder.AutoencoderKL
|
||||||
params:
|
params:
|
||||||
embed_dim: 4
|
embed_dim: 4
|
||||||
from_pretrained: '/data/scratch/diffuser/stable-diffusion-v1-4/vae/diffusion_pytorch_model.bin'
|
|
||||||
monitor: val/rec_loss
|
monitor: val/rec_loss
|
||||||
ddconfig:
|
ddconfig:
|
||||||
|
#attn_type: "vanilla-xformers"
|
||||||
double_z: true
|
double_z: true
|
||||||
z_channels: 4
|
z_channels: 4
|
||||||
resolution: 256
|
resolution: 256
|
||||||
|
@ -70,9 +72,10 @@ model:
|
||||||
target: torch.nn.Identity
|
target: torch.nn.Identity
|
||||||
|
|
||||||
cond_stage_config:
|
cond_stage_config:
|
||||||
target: ldm.modules.encoders.modules.FrozenCLIPEmbedder
|
target: ldm.modules.encoders.modules.FrozenOpenCLIPEmbedder
|
||||||
params:
|
params:
|
||||||
use_fp16: True
|
freeze: True
|
||||||
|
layer: "penultimate"
|
||||||
|
|
||||||
data:
|
data:
|
||||||
target: main.DataModuleFromConfig
|
target: main.DataModuleFromConfig
|
||||||
|
@ -88,34 +91,30 @@ data:
|
||||||
|
|
||||||
lightning:
|
lightning:
|
||||||
trainer:
|
trainer:
|
||||||
accelerator: 'gpu'
|
accelerator: 'gpu'
|
||||||
devices: 4
|
devices: 1
|
||||||
log_gpu_memory: all
|
log_gpu_memory: all
|
||||||
max_epochs: 2
|
max_epochs: 2
|
||||||
precision: 16
|
precision: 16
|
||||||
auto_select_gpus: False
|
auto_select_gpus: False
|
||||||
strategy:
|
strategy:
|
||||||
target: lightning.pytorch.strategies.ColossalAIStrategy
|
target: strategies.ColossalAIStrategy
|
||||||
params:
|
params:
|
||||||
use_chunk: False
|
use_chunk: True
|
||||||
enable_distributed_storage: True,
|
enable_distributed_storage: True
|
||||||
placement_policy: cuda
|
placement_policy: auto
|
||||||
force_outputs_fp32: False
|
force_outputs_fp32: true
|
||||||
initial_scale: 65536
|
|
||||||
min_scale: 1
|
|
||||||
max_scale: 65536
|
|
||||||
# max_scale: 4294967296
|
|
||||||
|
|
||||||
log_every_n_steps: 2
|
log_every_n_steps: 2
|
||||||
logger: True
|
logger: True
|
||||||
default_root_dir: "/tmp/diff_log/"
|
default_root_dir: "/tmp/diff_log/"
|
||||||
profiler: pytorch
|
# profiler: pytorch
|
||||||
|
|
||||||
logger_config:
|
logger_config:
|
||||||
wandb:
|
wandb:
|
||||||
target: lightning.pytorch.loggers.WandbLogger
|
target: loggers.WandbLogger
|
||||||
params:
|
params:
|
||||||
name: nowname
|
name: nowname
|
||||||
save_dir: "/tmp/diff_log/"
|
save_dir: "/tmp/diff_log/"
|
||||||
offline: opt.debug
|
offline: opt.debug
|
||||||
id: nowname
|
id: nowname
|
||||||
|
|
|
@ -6,28 +6,25 @@ dependencies:
|
||||||
- python=3.9.12
|
- python=3.9.12
|
||||||
- pip=20.3
|
- pip=20.3
|
||||||
- cudatoolkit=11.3
|
- cudatoolkit=11.3
|
||||||
- pytorch=1.11.0
|
- pytorch=1.12.1
|
||||||
- torchvision=0.12.0
|
- torchvision=0.13.1
|
||||||
- numpy=1.19.2
|
- numpy=1.23.1
|
||||||
- pip:
|
- pip:
|
||||||
- albumentations==0.4.3
|
- albumentations==1.3.0
|
||||||
- datasets
|
|
||||||
- diffusers
|
|
||||||
- opencv-python==4.6.0.66
|
- opencv-python==4.6.0.66
|
||||||
- pudb==2019.2
|
|
||||||
- invisible-watermark
|
|
||||||
- imageio==2.9.0
|
- imageio==2.9.0
|
||||||
- imageio-ffmpeg==0.4.2
|
- imageio-ffmpeg==0.4.2
|
||||||
- lightning==1.8.1
|
|
||||||
- omegaconf==2.1.1
|
- omegaconf==2.1.1
|
||||||
- test-tube>=0.7.5
|
- test-tube>=0.7.5
|
||||||
- streamlit>=0.73.1
|
- streamlit==1.12.1
|
||||||
- einops==0.3.0
|
- einops==0.3.0
|
||||||
- torch-fidelity==0.3.0
|
|
||||||
- transformers==4.19.2
|
- transformers==4.19.2
|
||||||
- torchmetrics==0.7.0
|
- webdataset==0.2.5
|
||||||
- kornia==0.6
|
- kornia==0.6
|
||||||
|
- open_clip_torch==2.0.2
|
||||||
|
- invisible-watermark>=0.1.5
|
||||||
|
- streamlit-drawable-canvas==0.8.0
|
||||||
|
- torchmetrics==0.7.0
|
||||||
- prefetch_generator
|
- prefetch_generator
|
||||||
- -e git+https://github.com/CompVis/taming-transformers.git@master#egg=taming-transformers
|
- datasets
|
||||||
- -e git+https://github.com/openai/CLIP.git@main#egg=clip
|
|
||||||
- -e .
|
- -e .
|
||||||
|
|
|
@ -1,64 +1,68 @@
|
||||||
import torch
|
import torch
|
||||||
import lightning.pytorch as pl
|
try:
|
||||||
|
import lightning.pytorch as pl
|
||||||
|
except:
|
||||||
|
import pytorch_lightning as pl
|
||||||
|
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
|
|
||||||
from taming.modules.vqvae.quantize import VectorQuantizer2 as VectorQuantizer
|
|
||||||
|
|
||||||
from ldm.modules.diffusionmodules.model import Encoder, Decoder
|
from ldm.modules.diffusionmodules.model import Encoder, Decoder
|
||||||
from ldm.modules.distributions.distributions import DiagonalGaussianDistribution
|
from ldm.modules.distributions.distributions import DiagonalGaussianDistribution
|
||||||
|
|
||||||
from ldm.util import instantiate_from_config
|
from ldm.util import instantiate_from_config
|
||||||
|
from ldm.modules.ema import LitEma
|
||||||
|
|
||||||
|
|
||||||
class VQModel(pl.LightningModule):
|
class AutoencoderKL(pl.LightningModule):
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
ddconfig,
|
ddconfig,
|
||||||
lossconfig,
|
lossconfig,
|
||||||
n_embed,
|
|
||||||
embed_dim,
|
embed_dim,
|
||||||
ckpt_path=None,
|
ckpt_path=None,
|
||||||
ignore_keys=[],
|
ignore_keys=[],
|
||||||
image_key="image",
|
image_key="image",
|
||||||
colorize_nlabels=None,
|
colorize_nlabels=None,
|
||||||
monitor=None,
|
monitor=None,
|
||||||
batch_resize_range=None,
|
ema_decay=None,
|
||||||
scheduler_config=None,
|
learn_logvar=False
|
||||||
lr_g_factor=1.0,
|
|
||||||
remap=None,
|
|
||||||
sane_index_shape=False, # tell vector quantizer to return indices as bhw
|
|
||||||
use_ema=False
|
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.embed_dim = embed_dim
|
self.learn_logvar = learn_logvar
|
||||||
self.n_embed = n_embed
|
|
||||||
self.image_key = image_key
|
self.image_key = image_key
|
||||||
self.encoder = Encoder(**ddconfig)
|
self.encoder = Encoder(**ddconfig)
|
||||||
self.decoder = Decoder(**ddconfig)
|
self.decoder = Decoder(**ddconfig)
|
||||||
self.loss = instantiate_from_config(lossconfig)
|
self.loss = instantiate_from_config(lossconfig)
|
||||||
self.quantize = VectorQuantizer(n_embed, embed_dim, beta=0.25,
|
assert ddconfig["double_z"]
|
||||||
remap=remap,
|
self.quant_conv = torch.nn.Conv2d(2*ddconfig["z_channels"], 2*embed_dim, 1)
|
||||||
sane_index_shape=sane_index_shape)
|
|
||||||
self.quant_conv = torch.nn.Conv2d(ddconfig["z_channels"], embed_dim, 1)
|
|
||||||
self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1)
|
self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1)
|
||||||
|
self.embed_dim = embed_dim
|
||||||
if colorize_nlabels is not None:
|
if colorize_nlabels is not None:
|
||||||
assert type(colorize_nlabels)==int
|
assert type(colorize_nlabels)==int
|
||||||
self.register_buffer("colorize", torch.randn(3, colorize_nlabels, 1, 1))
|
self.register_buffer("colorize", torch.randn(3, colorize_nlabels, 1, 1))
|
||||||
if monitor is not None:
|
if monitor is not None:
|
||||||
self.monitor = monitor
|
self.monitor = monitor
|
||||||
self.batch_resize_range = batch_resize_range
|
|
||||||
if self.batch_resize_range is not None:
|
|
||||||
print(f"{self.__class__.__name__}: Using per-batch resizing in range {batch_resize_range}.")
|
|
||||||
|
|
||||||
self.use_ema = use_ema
|
self.use_ema = ema_decay is not None
|
||||||
if self.use_ema:
|
if self.use_ema:
|
||||||
self.model_ema = LitEma(self)
|
self.ema_decay = ema_decay
|
||||||
|
assert 0. < ema_decay < 1.
|
||||||
|
self.model_ema = LitEma(self, decay=ema_decay)
|
||||||
print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.")
|
print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.")
|
||||||
|
|
||||||
if ckpt_path is not None:
|
if ckpt_path is not None:
|
||||||
self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
|
self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
|
||||||
self.scheduler_config = scheduler_config
|
|
||||||
self.lr_g_factor = lr_g_factor
|
def init_from_ckpt(self, path, ignore_keys=list()):
|
||||||
|
sd = torch.load(path, map_location="cpu")["state_dict"]
|
||||||
|
keys = list(sd.keys())
|
||||||
|
for k in keys:
|
||||||
|
for ik in ignore_keys:
|
||||||
|
if k.startswith(ik):
|
||||||
|
print("Deleting key {} from state_dict.".format(k))
|
||||||
|
del sd[k]
|
||||||
|
self.load_state_dict(sd, strict=False)
|
||||||
|
print(f"Restored from {path}")
|
||||||
|
|
||||||
@contextmanager
|
@contextmanager
|
||||||
def ema_scope(self, context=None):
|
def ema_scope(self, context=None):
|
||||||
|
@ -75,353 +79,10 @@ class VQModel(pl.LightningModule):
|
||||||
if context is not None:
|
if context is not None:
|
||||||
print(f"{context}: Restored training weights")
|
print(f"{context}: Restored training weights")
|
||||||
|
|
||||||
def init_from_ckpt(self, path, ignore_keys=list()):
|
|
||||||
sd = torch.load(path, map_location="cpu")["state_dict"]
|
|
||||||
keys = list(sd.keys())
|
|
||||||
for k in keys:
|
|
||||||
for ik in ignore_keys:
|
|
||||||
if k.startswith(ik):
|
|
||||||
print("Deleting key {} from state_dict.".format(k))
|
|
||||||
del sd[k]
|
|
||||||
missing, unexpected = self.load_state_dict(sd, strict=False)
|
|
||||||
print(f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys")
|
|
||||||
if len(missing) > 0:
|
|
||||||
print(f"Missing Keys: {missing}")
|
|
||||||
print(f"Unexpected Keys: {unexpected}")
|
|
||||||
|
|
||||||
def on_train_batch_end(self, *args, **kwargs):
|
def on_train_batch_end(self, *args, **kwargs):
|
||||||
if self.use_ema:
|
if self.use_ema:
|
||||||
self.model_ema(self)
|
self.model_ema(self)
|
||||||
|
|
||||||
def encode(self, x):
|
|
||||||
h = self.encoder(x)
|
|
||||||
h = self.quant_conv(h)
|
|
||||||
quant, emb_loss, info = self.quantize(h)
|
|
||||||
return quant, emb_loss, info
|
|
||||||
|
|
||||||
def encode_to_prequant(self, x):
|
|
||||||
h = self.encoder(x)
|
|
||||||
h = self.quant_conv(h)
|
|
||||||
return h
|
|
||||||
|
|
||||||
def decode(self, quant):
|
|
||||||
quant = self.post_quant_conv(quant)
|
|
||||||
dec = self.decoder(quant)
|
|
||||||
return dec
|
|
||||||
|
|
||||||
def decode_code(self, code_b):
|
|
||||||
quant_b = self.quantize.embed_code(code_b)
|
|
||||||
dec = self.decode(quant_b)
|
|
||||||
return dec
|
|
||||||
|
|
||||||
def forward(self, input, return_pred_indices=False):
|
|
||||||
quant, diff, (_,_,ind) = self.encode(input)
|
|
||||||
dec = self.decode(quant)
|
|
||||||
if return_pred_indices:
|
|
||||||
return dec, diff, ind
|
|
||||||
return dec, diff
|
|
||||||
|
|
||||||
def get_input(self, batch, k):
|
|
||||||
x = batch[k]
|
|
||||||
if len(x.shape) == 3:
|
|
||||||
x = x[..., None]
|
|
||||||
x = x.permute(0, 3, 1, 2).to(memory_format=torch.contiguous_format).float()
|
|
||||||
if self.batch_resize_range is not None:
|
|
||||||
lower_size = self.batch_resize_range[0]
|
|
||||||
upper_size = self.batch_resize_range[1]
|
|
||||||
if self.global_step <= 4:
|
|
||||||
# do the first few batches with max size to avoid later oom
|
|
||||||
new_resize = upper_size
|
|
||||||
else:
|
|
||||||
new_resize = np.random.choice(np.arange(lower_size, upper_size+16, 16))
|
|
||||||
if new_resize != x.shape[2]:
|
|
||||||
x = F.interpolate(x, size=new_resize, mode="bicubic")
|
|
||||||
x = x.detach()
|
|
||||||
return x
|
|
||||||
|
|
||||||
def training_step(self, batch, batch_idx, optimizer_idx):
|
|
||||||
# https://github.com/pytorch/pytorch/issues/37142
|
|
||||||
# try not to fool the heuristics
|
|
||||||
x = self.get_input(batch, self.image_key)
|
|
||||||
xrec, qloss, ind = self(x, return_pred_indices=True)
|
|
||||||
|
|
||||||
if optimizer_idx == 0:
|
|
||||||
# autoencode
|
|
||||||
aeloss, log_dict_ae = self.loss(qloss, x, xrec, optimizer_idx, self.global_step,
|
|
||||||
last_layer=self.get_last_layer(), split="train",
|
|
||||||
predicted_indices=ind)
|
|
||||||
|
|
||||||
self.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=True)
|
|
||||||
return aeloss
|
|
||||||
|
|
||||||
if optimizer_idx == 1:
|
|
||||||
# discriminator
|
|
||||||
discloss, log_dict_disc = self.loss(qloss, x, xrec, optimizer_idx, self.global_step,
|
|
||||||
last_layer=self.get_last_layer(), split="train")
|
|
||||||
self.log_dict(log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=True)
|
|
||||||
return discloss
|
|
||||||
|
|
||||||
def validation_step(self, batch, batch_idx):
|
|
||||||
log_dict = self._validation_step(batch, batch_idx)
|
|
||||||
with self.ema_scope():
|
|
||||||
log_dict_ema = self._validation_step(batch, batch_idx, suffix="_ema")
|
|
||||||
return log_dict
|
|
||||||
|
|
||||||
def _validation_step(self, batch, batch_idx, suffix=""):
|
|
||||||
x = self.get_input(batch, self.image_key)
|
|
||||||
xrec, qloss, ind = self(x, return_pred_indices=True)
|
|
||||||
aeloss, log_dict_ae = self.loss(qloss, x, xrec, 0,
|
|
||||||
self.global_step,
|
|
||||||
last_layer=self.get_last_layer(),
|
|
||||||
split="val"+suffix,
|
|
||||||
predicted_indices=ind
|
|
||||||
)
|
|
||||||
|
|
||||||
discloss, log_dict_disc = self.loss(qloss, x, xrec, 1,
|
|
||||||
self.global_step,
|
|
||||||
last_layer=self.get_last_layer(),
|
|
||||||
split="val"+suffix,
|
|
||||||
predicted_indices=ind
|
|
||||||
)
|
|
||||||
rec_loss = log_dict_ae[f"val{suffix}/rec_loss"]
|
|
||||||
self.log(f"val{suffix}/rec_loss", rec_loss,
|
|
||||||
prog_bar=True, logger=True, on_step=False, on_epoch=True, sync_dist=True)
|
|
||||||
self.log(f"val{suffix}/aeloss", aeloss,
|
|
||||||
prog_bar=True, logger=True, on_step=False, on_epoch=True, sync_dist=True)
|
|
||||||
if version.parse(pl.__version__) >= version.parse('1.4.0'):
|
|
||||||
del log_dict_ae[f"val{suffix}/rec_loss"]
|
|
||||||
self.log_dict(log_dict_ae)
|
|
||||||
self.log_dict(log_dict_disc)
|
|
||||||
return self.log_dict
|
|
||||||
|
|
||||||
def configure_optimizers(self):
|
|
||||||
lr_d = self.learning_rate
|
|
||||||
lr_g = self.lr_g_factor*self.learning_rate
|
|
||||||
print("lr_d", lr_d)
|
|
||||||
print("lr_g", lr_g)
|
|
||||||
opt_ae = torch.optim.Adam(list(self.encoder.parameters())+
|
|
||||||
list(self.decoder.parameters())+
|
|
||||||
list(self.quantize.parameters())+
|
|
||||||
list(self.quant_conv.parameters())+
|
|
||||||
list(self.post_quant_conv.parameters()),
|
|
||||||
lr=lr_g, betas=(0.5, 0.9))
|
|
||||||
opt_disc = torch.optim.Adam(self.loss.discriminator.parameters(),
|
|
||||||
lr=lr_d, betas=(0.5, 0.9))
|
|
||||||
|
|
||||||
if self.scheduler_config is not None:
|
|
||||||
scheduler = instantiate_from_config(self.scheduler_config)
|
|
||||||
|
|
||||||
print("Setting up LambdaLR scheduler...")
|
|
||||||
scheduler = [
|
|
||||||
{
|
|
||||||
'scheduler': LambdaLR(opt_ae, lr_lambda=scheduler.schedule),
|
|
||||||
'interval': 'step',
|
|
||||||
'frequency': 1
|
|
||||||
},
|
|
||||||
{
|
|
||||||
'scheduler': LambdaLR(opt_disc, lr_lambda=scheduler.schedule),
|
|
||||||
'interval': 'step',
|
|
||||||
'frequency': 1
|
|
||||||
},
|
|
||||||
]
|
|
||||||
return [opt_ae, opt_disc], scheduler
|
|
||||||
return [opt_ae, opt_disc], []
|
|
||||||
|
|
||||||
def get_last_layer(self):
|
|
||||||
return self.decoder.conv_out.weight
|
|
||||||
|
|
||||||
def log_images(self, batch, only_inputs=False, plot_ema=False, **kwargs):
|
|
||||||
log = dict()
|
|
||||||
x = self.get_input(batch, self.image_key)
|
|
||||||
x = x.to(self.device)
|
|
||||||
if only_inputs:
|
|
||||||
log["inputs"] = x
|
|
||||||
return log
|
|
||||||
xrec, _ = self(x)
|
|
||||||
if x.shape[1] > 3:
|
|
||||||
# colorize with random projection
|
|
||||||
assert xrec.shape[1] > 3
|
|
||||||
x = self.to_rgb(x)
|
|
||||||
xrec = self.to_rgb(xrec)
|
|
||||||
log["inputs"] = x
|
|
||||||
log["reconstructions"] = xrec
|
|
||||||
if plot_ema:
|
|
||||||
with self.ema_scope():
|
|
||||||
xrec_ema, _ = self(x)
|
|
||||||
if x.shape[1] > 3: xrec_ema = self.to_rgb(xrec_ema)
|
|
||||||
log["reconstructions_ema"] = xrec_ema
|
|
||||||
return log
|
|
||||||
|
|
||||||
def to_rgb(self, x):
|
|
||||||
assert self.image_key == "segmentation"
|
|
||||||
if not hasattr(self, "colorize"):
|
|
||||||
self.register_buffer("colorize", torch.randn(3, x.shape[1], 1, 1).to(x))
|
|
||||||
x = F.conv2d(x, weight=self.colorize)
|
|
||||||
x = 2.*(x-x.min())/(x.max()-x.min()) - 1.
|
|
||||||
return x
|
|
||||||
|
|
||||||
|
|
||||||
class VQModelInterface(VQModel):
|
|
||||||
def __init__(self, embed_dim, *args, **kwargs):
|
|
||||||
super().__init__(embed_dim=embed_dim, *args, **kwargs)
|
|
||||||
self.embed_dim = embed_dim
|
|
||||||
|
|
||||||
def encode(self, x):
|
|
||||||
h = self.encoder(x)
|
|
||||||
h = self.quant_conv(h)
|
|
||||||
return h
|
|
||||||
|
|
||||||
def decode(self, h, force_not_quantize=False):
|
|
||||||
# also go through quantization layer
|
|
||||||
if not force_not_quantize:
|
|
||||||
quant, emb_loss, info = self.quantize(h)
|
|
||||||
else:
|
|
||||||
quant = h
|
|
||||||
quant = self.post_quant_conv(quant)
|
|
||||||
dec = self.decoder(quant)
|
|
||||||
return dec
|
|
||||||
|
|
||||||
|
|
||||||
class AutoencoderKL(pl.LightningModule):
|
|
||||||
def __init__(self,
|
|
||||||
ddconfig,
|
|
||||||
lossconfig,
|
|
||||||
embed_dim,
|
|
||||||
ckpt_path=None,
|
|
||||||
ignore_keys=[],
|
|
||||||
image_key="image",
|
|
||||||
colorize_nlabels=None,
|
|
||||||
monitor=None,
|
|
||||||
from_pretrained: str=None
|
|
||||||
):
|
|
||||||
super().__init__()
|
|
||||||
self.image_key = image_key
|
|
||||||
self.encoder = Encoder(**ddconfig)
|
|
||||||
self.decoder = Decoder(**ddconfig)
|
|
||||||
self.loss = instantiate_from_config(lossconfig)
|
|
||||||
assert ddconfig["double_z"]
|
|
||||||
self.quant_conv = torch.nn.Conv2d(2*ddconfig["z_channels"], 2*embed_dim, 1)
|
|
||||||
self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1)
|
|
||||||
self.embed_dim = embed_dim
|
|
||||||
if colorize_nlabels is not None:
|
|
||||||
assert type(colorize_nlabels)==int
|
|
||||||
self.register_buffer("colorize", torch.randn(3, colorize_nlabels, 1, 1))
|
|
||||||
if monitor is not None:
|
|
||||||
self.monitor = monitor
|
|
||||||
if ckpt_path is not None:
|
|
||||||
self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
|
|
||||||
from diffusers.modeling_utils import load_state_dict
|
|
||||||
if from_pretrained is not None:
|
|
||||||
state_dict = load_state_dict(from_pretrained)
|
|
||||||
self._load_pretrained_model(state_dict)
|
|
||||||
|
|
||||||
def _state_key_mapping(self, state_dict: dict):
|
|
||||||
import re
|
|
||||||
res_dict = {}
|
|
||||||
key_list = state_dict.keys()
|
|
||||||
key_str = " ".join(key_list)
|
|
||||||
up_block_pattern = re.compile('upsamplers')
|
|
||||||
p1 = re.compile('mid.block_[0-9]')
|
|
||||||
p2 = re.compile('decoder.up.[0-9]')
|
|
||||||
up_blocks_count = int(len(re.findall(up_block_pattern, key_str)) / 2 + 1)
|
|
||||||
for key_, val_ in state_dict.items():
|
|
||||||
key_ = key_.replace("up_blocks", "up").replace("down_blocks", "down").replace('resnets', 'block')\
|
|
||||||
.replace('mid_block', 'mid').replace("mid.block.", "mid.block_")\
|
|
||||||
.replace('mid.attentions.0.key', 'mid.attn_1.k')\
|
|
||||||
.replace('mid.attentions.0.query', 'mid.attn_1.q') \
|
|
||||||
.replace('mid.attentions.0.value', 'mid.attn_1.v') \
|
|
||||||
.replace('mid.attentions.0.group_norm', 'mid.attn_1.norm') \
|
|
||||||
.replace('mid.attentions.0.proj_attn', 'mid.attn_1.proj_out')\
|
|
||||||
.replace('upsamplers.0', 'upsample')\
|
|
||||||
.replace('downsamplers.0', 'downsample')\
|
|
||||||
.replace('conv_shortcut', 'nin_shortcut')\
|
|
||||||
.replace('conv_norm_out', 'norm_out')
|
|
||||||
|
|
||||||
mid_list = re.findall(p1, key_)
|
|
||||||
if len(mid_list) != 0:
|
|
||||||
mid_str = mid_list[0]
|
|
||||||
mid_id = int(mid_str[-1]) + 1
|
|
||||||
key_ = key_.replace(mid_str, mid_str[:-1] + str(mid_id))
|
|
||||||
|
|
||||||
up_list = re.findall(p2, key_)
|
|
||||||
if len(up_list) != 0:
|
|
||||||
up_str = up_list[0]
|
|
||||||
up_id = up_blocks_count - 1 -int(up_str[-1])
|
|
||||||
key_ = key_.replace(up_str, up_str[:-1] + str(up_id))
|
|
||||||
res_dict[key_] = val_
|
|
||||||
return res_dict
|
|
||||||
|
|
||||||
def _load_pretrained_model(self, state_dict, ignore_mismatched_sizes=False):
|
|
||||||
state_dict = self._state_key_mapping(state_dict)
|
|
||||||
model_state_dict = self.state_dict()
|
|
||||||
loaded_keys = [k for k in state_dict.keys()]
|
|
||||||
expected_keys = list(model_state_dict.keys())
|
|
||||||
original_loaded_keys = loaded_keys
|
|
||||||
missing_keys = list(set(expected_keys) - set(loaded_keys))
|
|
||||||
unexpected_keys = list(set(loaded_keys) - set(expected_keys))
|
|
||||||
|
|
||||||
def _find_mismatched_keys(
|
|
||||||
state_dict,
|
|
||||||
model_state_dict,
|
|
||||||
loaded_keys,
|
|
||||||
ignore_mismatched_sizes,
|
|
||||||
):
|
|
||||||
mismatched_keys = []
|
|
||||||
if ignore_mismatched_sizes:
|
|
||||||
for checkpoint_key in loaded_keys:
|
|
||||||
model_key = checkpoint_key
|
|
||||||
|
|
||||||
if (
|
|
||||||
model_key in model_state_dict
|
|
||||||
and state_dict[checkpoint_key].shape != model_state_dict[model_key].shape
|
|
||||||
):
|
|
||||||
mismatched_keys.append(
|
|
||||||
(checkpoint_key, state_dict[checkpoint_key].shape, model_state_dict[model_key].shape)
|
|
||||||
)
|
|
||||||
del state_dict[checkpoint_key]
|
|
||||||
return mismatched_keys
|
|
||||||
if state_dict is not None:
|
|
||||||
# Whole checkpoint
|
|
||||||
mismatched_keys = _find_mismatched_keys(
|
|
||||||
state_dict,
|
|
||||||
model_state_dict,
|
|
||||||
original_loaded_keys,
|
|
||||||
ignore_mismatched_sizes,
|
|
||||||
)
|
|
||||||
error_msgs = self._load_state_dict_into_model(state_dict)
|
|
||||||
return missing_keys, unexpected_keys, mismatched_keys, error_msgs
|
|
||||||
|
|
||||||
def _load_state_dict_into_model(self, state_dict):
|
|
||||||
# Convert old format to new format if needed from a PyTorch state_dict
|
|
||||||
# copy state_dict so _load_from_state_dict can modify it
|
|
||||||
state_dict = state_dict.copy()
|
|
||||||
error_msgs = []
|
|
||||||
|
|
||||||
# PyTorch's `_load_from_state_dict` does not copy parameters in a module's descendants
|
|
||||||
# so we need to apply the function recursively.
|
|
||||||
def load(module: torch.nn.Module, prefix=""):
|
|
||||||
args = (state_dict, prefix, {}, True, [], [], error_msgs)
|
|
||||||
module._load_from_state_dict(*args)
|
|
||||||
|
|
||||||
for name, child in module._modules.items():
|
|
||||||
if child is not None:
|
|
||||||
load(child, prefix + name + ".")
|
|
||||||
|
|
||||||
load(self)
|
|
||||||
|
|
||||||
return error_msgs
|
|
||||||
|
|
||||||
def init_from_ckpt(self, path, ignore_keys=list()):
|
|
||||||
sd = torch.load(path, map_location="cpu")["state_dict"]
|
|
||||||
keys = list(sd.keys())
|
|
||||||
for k in keys:
|
|
||||||
for ik in ignore_keys:
|
|
||||||
if k.startswith(ik):
|
|
||||||
print("Deleting key {} from state_dict.".format(k))
|
|
||||||
del sd[k]
|
|
||||||
self.load_state_dict(sd, strict=False)
|
|
||||||
print(f"Restored from {path}")
|
|
||||||
|
|
||||||
def encode(self, x):
|
def encode(self, x):
|
||||||
h = self.encoder(x)
|
h = self.encoder(x)
|
||||||
moments = self.quant_conv(h)
|
moments = self.quant_conv(h)
|
||||||
|
@ -471,25 +132,33 @@ class AutoencoderKL(pl.LightningModule):
|
||||||
return discloss
|
return discloss
|
||||||
|
|
||||||
def validation_step(self, batch, batch_idx):
|
def validation_step(self, batch, batch_idx):
|
||||||
|
log_dict = self._validation_step(batch, batch_idx)
|
||||||
|
with self.ema_scope():
|
||||||
|
log_dict_ema = self._validation_step(batch, batch_idx, postfix="_ema")
|
||||||
|
return log_dict
|
||||||
|
|
||||||
|
def _validation_step(self, batch, batch_idx, postfix=""):
|
||||||
inputs = self.get_input(batch, self.image_key)
|
inputs = self.get_input(batch, self.image_key)
|
||||||
reconstructions, posterior = self(inputs)
|
reconstructions, posterior = self(inputs)
|
||||||
aeloss, log_dict_ae = self.loss(inputs, reconstructions, posterior, 0, self.global_step,
|
aeloss, log_dict_ae = self.loss(inputs, reconstructions, posterior, 0, self.global_step,
|
||||||
last_layer=self.get_last_layer(), split="val")
|
last_layer=self.get_last_layer(), split="val"+postfix)
|
||||||
|
|
||||||
discloss, log_dict_disc = self.loss(inputs, reconstructions, posterior, 1, self.global_step,
|
discloss, log_dict_disc = self.loss(inputs, reconstructions, posterior, 1, self.global_step,
|
||||||
last_layer=self.get_last_layer(), split="val")
|
last_layer=self.get_last_layer(), split="val"+postfix)
|
||||||
|
|
||||||
self.log("val/rec_loss", log_dict_ae["val/rec_loss"])
|
self.log(f"val{postfix}/rec_loss", log_dict_ae[f"val{postfix}/rec_loss"])
|
||||||
self.log_dict(log_dict_ae)
|
self.log_dict(log_dict_ae)
|
||||||
self.log_dict(log_dict_disc)
|
self.log_dict(log_dict_disc)
|
||||||
return self.log_dict
|
return self.log_dict
|
||||||
|
|
||||||
def configure_optimizers(self):
|
def configure_optimizers(self):
|
||||||
lr = self.learning_rate
|
lr = self.learning_rate
|
||||||
opt_ae = torch.optim.Adam(list(self.encoder.parameters())+
|
ae_params_list = list(self.encoder.parameters()) + list(self.decoder.parameters()) + list(
|
||||||
list(self.decoder.parameters())+
|
self.quant_conv.parameters()) + list(self.post_quant_conv.parameters())
|
||||||
list(self.quant_conv.parameters())+
|
if self.learn_logvar:
|
||||||
list(self.post_quant_conv.parameters()),
|
print(f"{self.__class__.__name__}: Learning logvar")
|
||||||
|
ae_params_list.append(self.loss.logvar)
|
||||||
|
opt_ae = torch.optim.Adam(ae_params_list,
|
||||||
lr=lr, betas=(0.5, 0.9))
|
lr=lr, betas=(0.5, 0.9))
|
||||||
opt_disc = torch.optim.Adam(self.loss.discriminator.parameters(),
|
opt_disc = torch.optim.Adam(self.loss.discriminator.parameters(),
|
||||||
lr=lr, betas=(0.5, 0.9))
|
lr=lr, betas=(0.5, 0.9))
|
||||||
|
@ -499,7 +168,7 @@ class AutoencoderKL(pl.LightningModule):
|
||||||
return self.decoder.conv_out.weight
|
return self.decoder.conv_out.weight
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def log_images(self, batch, only_inputs=False, **kwargs):
|
def log_images(self, batch, only_inputs=False, log_ema=False, **kwargs):
|
||||||
log = dict()
|
log = dict()
|
||||||
x = self.get_input(batch, self.image_key)
|
x = self.get_input(batch, self.image_key)
|
||||||
x = x.to(self.device)
|
x = x.to(self.device)
|
||||||
|
@ -512,6 +181,15 @@ class AutoencoderKL(pl.LightningModule):
|
||||||
xrec = self.to_rgb(xrec)
|
xrec = self.to_rgb(xrec)
|
||||||
log["samples"] = self.decode(torch.randn_like(posterior.sample()))
|
log["samples"] = self.decode(torch.randn_like(posterior.sample()))
|
||||||
log["reconstructions"] = xrec
|
log["reconstructions"] = xrec
|
||||||
|
if log_ema or self.use_ema:
|
||||||
|
with self.ema_scope():
|
||||||
|
xrec_ema, posterior_ema = self(x)
|
||||||
|
if x.shape[1] > 3:
|
||||||
|
# colorize with random projection
|
||||||
|
assert xrec_ema.shape[1] > 3
|
||||||
|
xrec_ema = self.to_rgb(xrec_ema)
|
||||||
|
log["samples_ema"] = self.decode(torch.randn_like(posterior_ema.sample()))
|
||||||
|
log["reconstructions_ema"] = xrec_ema
|
||||||
log["inputs"] = x
|
log["inputs"] = x
|
||||||
return log
|
return log
|
||||||
|
|
||||||
|
@ -526,7 +204,7 @@ class AutoencoderKL(pl.LightningModule):
|
||||||
|
|
||||||
class IdentityFirstStage(torch.nn.Module):
|
class IdentityFirstStage(torch.nn.Module):
|
||||||
def __init__(self, *args, vq_interface=False, **kwargs):
|
def __init__(self, *args, vq_interface=False, **kwargs):
|
||||||
self.vq_interface = vq_interface # TODO: Should be true by default but check to not break older stuff
|
self.vq_interface = vq_interface
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
def encode(self, x, *args, **kwargs):
|
def encode(self, x, *args, **kwargs):
|
||||||
|
@ -542,3 +220,4 @@ class IdentityFirstStage(torch.nn.Module):
|
||||||
|
|
||||||
def forward(self, x, *args, **kwargs):
|
def forward(self, x, *args, **kwargs):
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
|
@ -3,10 +3,8 @@
|
||||||
import torch
|
import torch
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
from functools import partial
|
|
||||||
|
|
||||||
from ldm.modules.diffusionmodules.util import make_ddim_sampling_parameters, make_ddim_timesteps, noise_like, \
|
from ldm.modules.diffusionmodules.util import make_ddim_sampling_parameters, make_ddim_timesteps, noise_like, extract_into_tensor
|
||||||
extract_into_tensor
|
|
||||||
|
|
||||||
|
|
||||||
class DDIMSampler(object):
|
class DDIMSampler(object):
|
||||||
|
@ -74,15 +72,24 @@ class DDIMSampler(object):
|
||||||
x_T=None,
|
x_T=None,
|
||||||
log_every_t=100,
|
log_every_t=100,
|
||||||
unconditional_guidance_scale=1.,
|
unconditional_guidance_scale=1.,
|
||||||
unconditional_conditioning=None,
|
unconditional_conditioning=None, # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
|
||||||
# this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
|
dynamic_threshold=None,
|
||||||
|
ucg_schedule=None,
|
||||||
**kwargs
|
**kwargs
|
||||||
):
|
):
|
||||||
if conditioning is not None:
|
if conditioning is not None:
|
||||||
if isinstance(conditioning, dict):
|
if isinstance(conditioning, dict):
|
||||||
cbs = conditioning[list(conditioning.keys())[0]].shape[0]
|
ctmp = conditioning[list(conditioning.keys())[0]]
|
||||||
|
while isinstance(ctmp, list): ctmp = ctmp[0]
|
||||||
|
cbs = ctmp.shape[0]
|
||||||
if cbs != batch_size:
|
if cbs != batch_size:
|
||||||
print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}")
|
print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}")
|
||||||
|
|
||||||
|
elif isinstance(conditioning, list):
|
||||||
|
for ctmp in conditioning:
|
||||||
|
if ctmp.shape[0] != batch_size:
|
||||||
|
print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}")
|
||||||
|
|
||||||
else:
|
else:
|
||||||
if conditioning.shape[0] != batch_size:
|
if conditioning.shape[0] != batch_size:
|
||||||
print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}")
|
print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}")
|
||||||
|
@ -107,6 +114,8 @@ class DDIMSampler(object):
|
||||||
log_every_t=log_every_t,
|
log_every_t=log_every_t,
|
||||||
unconditional_guidance_scale=unconditional_guidance_scale,
|
unconditional_guidance_scale=unconditional_guidance_scale,
|
||||||
unconditional_conditioning=unconditional_conditioning,
|
unconditional_conditioning=unconditional_conditioning,
|
||||||
|
dynamic_threshold=dynamic_threshold,
|
||||||
|
ucg_schedule=ucg_schedule
|
||||||
)
|
)
|
||||||
return samples, intermediates
|
return samples, intermediates
|
||||||
|
|
||||||
|
@ -116,7 +125,8 @@ class DDIMSampler(object):
|
||||||
callback=None, timesteps=None, quantize_denoised=False,
|
callback=None, timesteps=None, quantize_denoised=False,
|
||||||
mask=None, x0=None, img_callback=None, log_every_t=100,
|
mask=None, x0=None, img_callback=None, log_every_t=100,
|
||||||
temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
|
temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
|
||||||
unconditional_guidance_scale=1., unconditional_conditioning=None,):
|
unconditional_guidance_scale=1., unconditional_conditioning=None, dynamic_threshold=None,
|
||||||
|
ucg_schedule=None):
|
||||||
device = self.model.betas.device
|
device = self.model.betas.device
|
||||||
b = shape[0]
|
b = shape[0]
|
||||||
if x_T is None:
|
if x_T is None:
|
||||||
|
@ -145,12 +155,18 @@ class DDIMSampler(object):
|
||||||
assert x0 is not None
|
assert x0 is not None
|
||||||
img_orig = self.model.q_sample(x0, ts) # TODO: deterministic forward pass?
|
img_orig = self.model.q_sample(x0, ts) # TODO: deterministic forward pass?
|
||||||
img = img_orig * mask + (1. - mask) * img
|
img = img_orig * mask + (1. - mask) * img
|
||||||
|
|
||||||
|
if ucg_schedule is not None:
|
||||||
|
assert len(ucg_schedule) == len(time_range)
|
||||||
|
unconditional_guidance_scale = ucg_schedule[i]
|
||||||
|
|
||||||
outs = self.p_sample_ddim(img, cond, ts, index=index, use_original_steps=ddim_use_original_steps,
|
outs = self.p_sample_ddim(img, cond, ts, index=index, use_original_steps=ddim_use_original_steps,
|
||||||
quantize_denoised=quantize_denoised, temperature=temperature,
|
quantize_denoised=quantize_denoised, temperature=temperature,
|
||||||
noise_dropout=noise_dropout, score_corrector=score_corrector,
|
noise_dropout=noise_dropout, score_corrector=score_corrector,
|
||||||
corrector_kwargs=corrector_kwargs,
|
corrector_kwargs=corrector_kwargs,
|
||||||
unconditional_guidance_scale=unconditional_guidance_scale,
|
unconditional_guidance_scale=unconditional_guidance_scale,
|
||||||
unconditional_conditioning=unconditional_conditioning)
|
unconditional_conditioning=unconditional_conditioning,
|
||||||
|
dynamic_threshold=dynamic_threshold)
|
||||||
img, pred_x0 = outs
|
img, pred_x0 = outs
|
||||||
if callback: callback(i)
|
if callback: callback(i)
|
||||||
if img_callback: img_callback(pred_x0, i)
|
if img_callback: img_callback(pred_x0, i)
|
||||||
|
@ -164,20 +180,44 @@ class DDIMSampler(object):
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def p_sample_ddim(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False,
|
def p_sample_ddim(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False,
|
||||||
temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
|
temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
|
||||||
unconditional_guidance_scale=1., unconditional_conditioning=None):
|
unconditional_guidance_scale=1., unconditional_conditioning=None,
|
||||||
|
dynamic_threshold=None):
|
||||||
b, *_, device = *x.shape, x.device
|
b, *_, device = *x.shape, x.device
|
||||||
|
|
||||||
if unconditional_conditioning is None or unconditional_guidance_scale == 1.:
|
if unconditional_conditioning is None or unconditional_guidance_scale == 1.:
|
||||||
e_t = self.model.apply_model(x, t, c)
|
model_output = self.model.apply_model(x, t, c)
|
||||||
else:
|
else:
|
||||||
x_in = torch.cat([x] * 2)
|
x_in = torch.cat([x] * 2)
|
||||||
t_in = torch.cat([t] * 2)
|
t_in = torch.cat([t] * 2)
|
||||||
c_in = torch.cat([unconditional_conditioning, c])
|
if isinstance(c, dict):
|
||||||
e_t_uncond, e_t = self.model.apply_model(x_in, t_in, c_in).chunk(2)
|
assert isinstance(unconditional_conditioning, dict)
|
||||||
e_t = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond)
|
c_in = dict()
|
||||||
|
for k in c:
|
||||||
|
if isinstance(c[k], list):
|
||||||
|
c_in[k] = [torch.cat([
|
||||||
|
unconditional_conditioning[k][i],
|
||||||
|
c[k][i]]) for i in range(len(c[k]))]
|
||||||
|
else:
|
||||||
|
c_in[k] = torch.cat([
|
||||||
|
unconditional_conditioning[k],
|
||||||
|
c[k]])
|
||||||
|
elif isinstance(c, list):
|
||||||
|
c_in = list()
|
||||||
|
assert isinstance(unconditional_conditioning, list)
|
||||||
|
for i in range(len(c)):
|
||||||
|
c_in.append(torch.cat([unconditional_conditioning[i], c[i]]))
|
||||||
|
else:
|
||||||
|
c_in = torch.cat([unconditional_conditioning, c])
|
||||||
|
model_uncond, model_t = self.model.apply_model(x_in, t_in, c_in).chunk(2)
|
||||||
|
model_output = model_uncond + unconditional_guidance_scale * (model_t - model_uncond)
|
||||||
|
|
||||||
|
if self.model.parameterization == "v":
|
||||||
|
e_t = self.model.predict_eps_from_z_and_v(x, t, model_output)
|
||||||
|
else:
|
||||||
|
e_t = model_output
|
||||||
|
|
||||||
if score_corrector is not None:
|
if score_corrector is not None:
|
||||||
assert self.model.parameterization == "eps"
|
assert self.model.parameterization == "eps", 'not implemented'
|
||||||
e_t = score_corrector.modify_score(self.model, e_t, x, t, c, **corrector_kwargs)
|
e_t = score_corrector.modify_score(self.model, e_t, x, t, c, **corrector_kwargs)
|
||||||
|
|
||||||
alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas
|
alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas
|
||||||
|
@ -191,9 +231,17 @@ class DDIMSampler(object):
|
||||||
sqrt_one_minus_at = torch.full((b, 1, 1, 1), sqrt_one_minus_alphas[index],device=device)
|
sqrt_one_minus_at = torch.full((b, 1, 1, 1), sqrt_one_minus_alphas[index],device=device)
|
||||||
|
|
||||||
# current prediction for x_0
|
# current prediction for x_0
|
||||||
pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
|
if self.model.parameterization != "v":
|
||||||
|
pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
|
||||||
|
else:
|
||||||
|
pred_x0 = self.model.predict_start_from_z_and_v(x, t, model_output)
|
||||||
|
|
||||||
if quantize_denoised:
|
if quantize_denoised:
|
||||||
pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0)
|
pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0)
|
||||||
|
|
||||||
|
if dynamic_threshold is not None:
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
# direction pointing to x_t
|
# direction pointing to x_t
|
||||||
dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t
|
dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t
|
||||||
noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature
|
noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature
|
||||||
|
@ -202,6 +250,53 @@ class DDIMSampler(object):
|
||||||
x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise
|
x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise
|
||||||
return x_prev, pred_x0
|
return x_prev, pred_x0
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def encode(self, x0, c, t_enc, use_original_steps=False, return_intermediates=None,
|
||||||
|
unconditional_guidance_scale=1.0, unconditional_conditioning=None, callback=None):
|
||||||
|
num_reference_steps = self.ddpm_num_timesteps if use_original_steps else self.ddim_timesteps.shape[0]
|
||||||
|
|
||||||
|
assert t_enc <= num_reference_steps
|
||||||
|
num_steps = t_enc
|
||||||
|
|
||||||
|
if use_original_steps:
|
||||||
|
alphas_next = self.alphas_cumprod[:num_steps]
|
||||||
|
alphas = self.alphas_cumprod_prev[:num_steps]
|
||||||
|
else:
|
||||||
|
alphas_next = self.ddim_alphas[:num_steps]
|
||||||
|
alphas = torch.tensor(self.ddim_alphas_prev[:num_steps])
|
||||||
|
|
||||||
|
x_next = x0
|
||||||
|
intermediates = []
|
||||||
|
inter_steps = []
|
||||||
|
for i in tqdm(range(num_steps), desc='Encoding Image'):
|
||||||
|
t = torch.full((x0.shape[0],), i, device=self.model.device, dtype=torch.long)
|
||||||
|
if unconditional_guidance_scale == 1.:
|
||||||
|
noise_pred = self.model.apply_model(x_next, t, c)
|
||||||
|
else:
|
||||||
|
assert unconditional_conditioning is not None
|
||||||
|
e_t_uncond, noise_pred = torch.chunk(
|
||||||
|
self.model.apply_model(torch.cat((x_next, x_next)), torch.cat((t, t)),
|
||||||
|
torch.cat((unconditional_conditioning, c))), 2)
|
||||||
|
noise_pred = e_t_uncond + unconditional_guidance_scale * (noise_pred - e_t_uncond)
|
||||||
|
|
||||||
|
xt_weighted = (alphas_next[i] / alphas[i]).sqrt() * x_next
|
||||||
|
weighted_noise_pred = alphas_next[i].sqrt() * (
|
||||||
|
(1 / alphas_next[i] - 1).sqrt() - (1 / alphas[i] - 1).sqrt()) * noise_pred
|
||||||
|
x_next = xt_weighted + weighted_noise_pred
|
||||||
|
if return_intermediates and i % (
|
||||||
|
num_steps // return_intermediates) == 0 and i < num_steps - 1:
|
||||||
|
intermediates.append(x_next)
|
||||||
|
inter_steps.append(i)
|
||||||
|
elif return_intermediates and i >= num_steps - 2:
|
||||||
|
intermediates.append(x_next)
|
||||||
|
inter_steps.append(i)
|
||||||
|
if callback: callback(i)
|
||||||
|
|
||||||
|
out = {'x_encoded': x_next, 'intermediate_steps': inter_steps}
|
||||||
|
if return_intermediates:
|
||||||
|
out.update({'intermediates': intermediates})
|
||||||
|
return x_next, out
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def stochastic_encode(self, x0, t, use_original_steps=False, noise=None):
|
def stochastic_encode(self, x0, t, use_original_steps=False, noise=None):
|
||||||
# fast, but does not allow for exact reconstruction
|
# fast, but does not allow for exact reconstruction
|
||||||
|
@ -220,7 +315,7 @@ class DDIMSampler(object):
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def decode(self, x_latent, cond, t_start, unconditional_guidance_scale=1.0, unconditional_conditioning=None,
|
def decode(self, x_latent, cond, t_start, unconditional_guidance_scale=1.0, unconditional_conditioning=None,
|
||||||
use_original_steps=False):
|
use_original_steps=False, callback=None):
|
||||||
|
|
||||||
timesteps = np.arange(self.ddpm_num_timesteps) if use_original_steps else self.ddim_timesteps
|
timesteps = np.arange(self.ddpm_num_timesteps) if use_original_steps else self.ddim_timesteps
|
||||||
timesteps = timesteps[:t_start]
|
timesteps = timesteps[:t_start]
|
||||||
|
@ -237,4 +332,5 @@ class DDIMSampler(object):
|
||||||
x_dec, _ = self.p_sample_ddim(x_dec, cond, ts, index=index, use_original_steps=use_original_steps,
|
x_dec, _ = self.p_sample_ddim(x_dec, cond, ts, index=index, use_original_steps=use_original_steps,
|
||||||
unconditional_guidance_scale=unconditional_guidance_scale,
|
unconditional_guidance_scale=unconditional_guidance_scale,
|
||||||
unconditional_conditioning=unconditional_conditioning)
|
unconditional_conditioning=unconditional_conditioning)
|
||||||
|
if callback: callback(i)
|
||||||
return x_dec
|
return x_dec
|
File diff suppressed because it is too large
Load Diff
|
@ -0,0 +1 @@
|
||||||
|
from .sampler import DPMSolverSampler
|
File diff suppressed because it is too large
Load Diff
|
@ -0,0 +1,87 @@
|
||||||
|
"""SAMPLING ONLY."""
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from .dpm_solver import NoiseScheduleVP, model_wrapper, DPM_Solver
|
||||||
|
|
||||||
|
|
||||||
|
MODEL_TYPES = {
|
||||||
|
"eps": "noise",
|
||||||
|
"v": "v"
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class DPMSolverSampler(object):
|
||||||
|
def __init__(self, model, **kwargs):
|
||||||
|
super().__init__()
|
||||||
|
self.model = model
|
||||||
|
to_torch = lambda x: x.clone().detach().to(torch.float32).to(model.device)
|
||||||
|
self.register_buffer('alphas_cumprod', to_torch(model.alphas_cumprod))
|
||||||
|
|
||||||
|
def register_buffer(self, name, attr):
|
||||||
|
if type(attr) == torch.Tensor:
|
||||||
|
if attr.device != torch.device("cuda"):
|
||||||
|
attr = attr.to(torch.device("cuda"))
|
||||||
|
setattr(self, name, attr)
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def sample(self,
|
||||||
|
S,
|
||||||
|
batch_size,
|
||||||
|
shape,
|
||||||
|
conditioning=None,
|
||||||
|
callback=None,
|
||||||
|
normals_sequence=None,
|
||||||
|
img_callback=None,
|
||||||
|
quantize_x0=False,
|
||||||
|
eta=0.,
|
||||||
|
mask=None,
|
||||||
|
x0=None,
|
||||||
|
temperature=1.,
|
||||||
|
noise_dropout=0.,
|
||||||
|
score_corrector=None,
|
||||||
|
corrector_kwargs=None,
|
||||||
|
verbose=True,
|
||||||
|
x_T=None,
|
||||||
|
log_every_t=100,
|
||||||
|
unconditional_guidance_scale=1.,
|
||||||
|
unconditional_conditioning=None,
|
||||||
|
# this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
|
||||||
|
**kwargs
|
||||||
|
):
|
||||||
|
if conditioning is not None:
|
||||||
|
if isinstance(conditioning, dict):
|
||||||
|
cbs = conditioning[list(conditioning.keys())[0]].shape[0]
|
||||||
|
if cbs != batch_size:
|
||||||
|
print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}")
|
||||||
|
else:
|
||||||
|
if conditioning.shape[0] != batch_size:
|
||||||
|
print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}")
|
||||||
|
|
||||||
|
# sampling
|
||||||
|
C, H, W = shape
|
||||||
|
size = (batch_size, C, H, W)
|
||||||
|
|
||||||
|
print(f'Data shape for DPM-Solver sampling is {size}, sampling steps {S}')
|
||||||
|
|
||||||
|
device = self.model.betas.device
|
||||||
|
if x_T is None:
|
||||||
|
img = torch.randn(size, device=device)
|
||||||
|
else:
|
||||||
|
img = x_T
|
||||||
|
|
||||||
|
ns = NoiseScheduleVP('discrete', alphas_cumprod=self.alphas_cumprod)
|
||||||
|
|
||||||
|
model_fn = model_wrapper(
|
||||||
|
lambda x, t, c: self.model.apply_model(x, t, c),
|
||||||
|
ns,
|
||||||
|
model_type=MODEL_TYPES[self.model.parameterization],
|
||||||
|
guidance_type="classifier-free",
|
||||||
|
condition=conditioning,
|
||||||
|
unconditional_condition=unconditional_conditioning,
|
||||||
|
guidance_scale=unconditional_guidance_scale,
|
||||||
|
)
|
||||||
|
|
||||||
|
dpm_solver = DPM_Solver(model_fn, ns, predict_x0=True, thresholding=False)
|
||||||
|
x = dpm_solver.sample(img, steps=S, skip_type="time_uniform", method="multistep", order=2, lower_order_final=True)
|
||||||
|
|
||||||
|
return x.to(device), None
|
|
@ -6,6 +6,7 @@ from tqdm import tqdm
|
||||||
from functools import partial
|
from functools import partial
|
||||||
|
|
||||||
from ldm.modules.diffusionmodules.util import make_ddim_sampling_parameters, make_ddim_timesteps, noise_like
|
from ldm.modules.diffusionmodules.util import make_ddim_sampling_parameters, make_ddim_timesteps, noise_like
|
||||||
|
from ldm.models.diffusion.sampling_util import norm_thresholding
|
||||||
|
|
||||||
|
|
||||||
class PLMSSampler(object):
|
class PLMSSampler(object):
|
||||||
|
@ -77,6 +78,7 @@ class PLMSSampler(object):
|
||||||
unconditional_guidance_scale=1.,
|
unconditional_guidance_scale=1.,
|
||||||
unconditional_conditioning=None,
|
unconditional_conditioning=None,
|
||||||
# this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
|
# this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
|
||||||
|
dynamic_threshold=None,
|
||||||
**kwargs
|
**kwargs
|
||||||
):
|
):
|
||||||
if conditioning is not None:
|
if conditioning is not None:
|
||||||
|
@ -108,6 +110,7 @@ class PLMSSampler(object):
|
||||||
log_every_t=log_every_t,
|
log_every_t=log_every_t,
|
||||||
unconditional_guidance_scale=unconditional_guidance_scale,
|
unconditional_guidance_scale=unconditional_guidance_scale,
|
||||||
unconditional_conditioning=unconditional_conditioning,
|
unconditional_conditioning=unconditional_conditioning,
|
||||||
|
dynamic_threshold=dynamic_threshold,
|
||||||
)
|
)
|
||||||
return samples, intermediates
|
return samples, intermediates
|
||||||
|
|
||||||
|
@ -117,7 +120,8 @@ class PLMSSampler(object):
|
||||||
callback=None, timesteps=None, quantize_denoised=False,
|
callback=None, timesteps=None, quantize_denoised=False,
|
||||||
mask=None, x0=None, img_callback=None, log_every_t=100,
|
mask=None, x0=None, img_callback=None, log_every_t=100,
|
||||||
temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
|
temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
|
||||||
unconditional_guidance_scale=1., unconditional_conditioning=None,):
|
unconditional_guidance_scale=1., unconditional_conditioning=None,
|
||||||
|
dynamic_threshold=None):
|
||||||
device = self.model.betas.device
|
device = self.model.betas.device
|
||||||
b = shape[0]
|
b = shape[0]
|
||||||
if x_T is None:
|
if x_T is None:
|
||||||
|
@ -155,7 +159,8 @@ class PLMSSampler(object):
|
||||||
corrector_kwargs=corrector_kwargs,
|
corrector_kwargs=corrector_kwargs,
|
||||||
unconditional_guidance_scale=unconditional_guidance_scale,
|
unconditional_guidance_scale=unconditional_guidance_scale,
|
||||||
unconditional_conditioning=unconditional_conditioning,
|
unconditional_conditioning=unconditional_conditioning,
|
||||||
old_eps=old_eps, t_next=ts_next)
|
old_eps=old_eps, t_next=ts_next,
|
||||||
|
dynamic_threshold=dynamic_threshold)
|
||||||
img, pred_x0, e_t = outs
|
img, pred_x0, e_t = outs
|
||||||
old_eps.append(e_t)
|
old_eps.append(e_t)
|
||||||
if len(old_eps) >= 4:
|
if len(old_eps) >= 4:
|
||||||
|
@ -172,7 +177,8 @@ class PLMSSampler(object):
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def p_sample_plms(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False,
|
def p_sample_plms(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False,
|
||||||
temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
|
temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
|
||||||
unconditional_guidance_scale=1., unconditional_conditioning=None, old_eps=None, t_next=None):
|
unconditional_guidance_scale=1., unconditional_conditioning=None, old_eps=None, t_next=None,
|
||||||
|
dynamic_threshold=None):
|
||||||
b, *_, device = *x.shape, x.device
|
b, *_, device = *x.shape, x.device
|
||||||
|
|
||||||
def get_model_output(x, t):
|
def get_model_output(x, t):
|
||||||
|
@ -207,6 +213,8 @@ class PLMSSampler(object):
|
||||||
pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
|
pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
|
||||||
if quantize_denoised:
|
if quantize_denoised:
|
||||||
pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0)
|
pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0)
|
||||||
|
if dynamic_threshold is not None:
|
||||||
|
pred_x0 = norm_thresholding(pred_x0, dynamic_threshold)
|
||||||
# direction pointing to x_t
|
# direction pointing to x_t
|
||||||
dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t
|
dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t
|
||||||
noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature
|
noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature
|
||||||
|
|
|
@ -0,0 +1,22 @@
|
||||||
|
import torch
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
|
||||||
|
def append_dims(x, target_dims):
|
||||||
|
"""Appends dimensions to the end of a tensor until it has target_dims dimensions.
|
||||||
|
From https://github.com/crowsonkb/k-diffusion/blob/master/k_diffusion/utils.py"""
|
||||||
|
dims_to_append = target_dims - x.ndim
|
||||||
|
if dims_to_append < 0:
|
||||||
|
raise ValueError(f'input has {x.ndim} dims but target_dims is {target_dims}, which is less')
|
||||||
|
return x[(...,) + (None,) * dims_to_append]
|
||||||
|
|
||||||
|
|
||||||
|
def norm_thresholding(x0, value):
|
||||||
|
s = append_dims(x0.pow(2).flatten(1).mean(1).sqrt().clamp(min=value), x0.ndim)
|
||||||
|
return x0 * (value / s)
|
||||||
|
|
||||||
|
|
||||||
|
def spatial_norm_thresholding(x0, value):
|
||||||
|
# b c h w
|
||||||
|
s = x0.pow(2).mean(1, keepdim=True).sqrt().clamp(min=value)
|
||||||
|
return x0 * (value / s)
|
|
@ -4,24 +4,17 @@ import torch
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from torch import nn, einsum
|
from torch import nn, einsum
|
||||||
from einops import rearrange, repeat
|
from einops import rearrange, repeat
|
||||||
|
from typing import Optional, Any
|
||||||
|
|
||||||
|
from ldm.modules.diffusionmodules.util import checkpoint
|
||||||
|
|
||||||
from torch.utils import checkpoint
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from ldm.modules.flash_attention import flash_attention_qkv, flash_attention_q_kv
|
import xformers
|
||||||
FlASH_AVAILABLE = True
|
import xformers.ops
|
||||||
|
XFORMERS_IS_AVAILBLE = True
|
||||||
except:
|
except:
|
||||||
FlASH_AVAILABLE = False
|
XFORMERS_IS_AVAILBLE = False
|
||||||
|
|
||||||
USE_FLASH = False
|
|
||||||
|
|
||||||
|
|
||||||
def enable_flash_attention():
|
|
||||||
global USE_FLASH
|
|
||||||
USE_FLASH = True
|
|
||||||
if FlASH_AVAILABLE is False:
|
|
||||||
print("Please install flash attention to activate new attention kernel.\n" +
|
|
||||||
"Use \'pip install git+https://github.com/HazyResearch/flash-attention.git@c422fee3776eb3ea24e011ef641fd5fbeb212623#egg=flash_attn\'")
|
|
||||||
|
|
||||||
|
|
||||||
def exists(val):
|
def exists(val):
|
||||||
|
@ -93,25 +86,6 @@ def Normalize(in_channels):
|
||||||
return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
|
return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
|
||||||
|
|
||||||
|
|
||||||
class LinearAttention(nn.Module):
|
|
||||||
def __init__(self, dim, heads=4, dim_head=32):
|
|
||||||
super().__init__()
|
|
||||||
self.heads = heads
|
|
||||||
hidden_dim = dim_head * heads
|
|
||||||
self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias = False)
|
|
||||||
self.to_out = nn.Conv2d(hidden_dim, dim, 1)
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
b, c, h, w = x.shape
|
|
||||||
qkv = self.to_qkv(x)
|
|
||||||
q, k, v = rearrange(qkv, 'b (qkv heads c) h w -> qkv b heads c (h w)', heads = self.heads, qkv=3)
|
|
||||||
k = k.softmax(dim=-1)
|
|
||||||
context = torch.einsum('bhdn,bhen->bhde', k, v)
|
|
||||||
out = torch.einsum('bhde,bhdn->bhen', context, q)
|
|
||||||
out = rearrange(out, 'b heads c (h w) -> b (heads c) h w', heads=self.heads, h=h, w=w)
|
|
||||||
return self.to_out(out)
|
|
||||||
|
|
||||||
|
|
||||||
class SpatialSelfAttention(nn.Module):
|
class SpatialSelfAttention(nn.Module):
|
||||||
def __init__(self, in_channels):
|
def __init__(self, in_channels):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
@ -184,85 +158,111 @@ class CrossAttention(nn.Module):
|
||||||
)
|
)
|
||||||
|
|
||||||
def forward(self, x, context=None, mask=None):
|
def forward(self, x, context=None, mask=None):
|
||||||
|
h = self.heads
|
||||||
|
|
||||||
q = self.to_q(x)
|
q = self.to_q(x)
|
||||||
context = default(context, x)
|
context = default(context, x)
|
||||||
k = self.to_k(context)
|
k = self.to_k(context)
|
||||||
v = self.to_v(context)
|
v = self.to_v(context)
|
||||||
dim_head = q.shape[-1] / self.heads
|
|
||||||
|
|
||||||
if USE_FLASH and FlASH_AVAILABLE and q.dtype in (torch.float16, torch.bfloat16) and \
|
|
||||||
dim_head <= 128 and (dim_head % 8) == 0:
|
|
||||||
# print("in flash")
|
|
||||||
if q.shape[1] == k.shape[1]:
|
|
||||||
out = self._flash_attention_qkv(q, k, v)
|
|
||||||
else:
|
|
||||||
out = self._flash_attention_q_kv(q, k, v)
|
|
||||||
else:
|
|
||||||
out = self._native_attention(q, k, v, self.heads, mask)
|
|
||||||
|
|
||||||
return self.to_out(out)
|
|
||||||
|
|
||||||
def _native_attention(self, q, k, v, h, mask):
|
|
||||||
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))
|
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))
|
||||||
|
|
||||||
sim = einsum('b i d, b j d -> b i j', q, k) * self.scale
|
sim = einsum('b i d, b j d -> b i j', q, k) * self.scale
|
||||||
|
del q, k
|
||||||
|
|
||||||
if exists(mask):
|
if exists(mask):
|
||||||
mask = rearrange(mask, 'b ... -> b (...)')
|
mask = rearrange(mask, 'b ... -> b (...)')
|
||||||
max_neg_value = -torch.finfo(sim.dtype).max
|
max_neg_value = -torch.finfo(sim.dtype).max
|
||||||
mask = repeat(mask, 'b j -> (b h) () j', h=h)
|
mask = repeat(mask, 'b j -> (b h) () j', h=h)
|
||||||
sim.masked_fill_(~mask, max_neg_value)
|
sim.masked_fill_(~mask, max_neg_value)
|
||||||
# attention, what we cannot get enough of
|
|
||||||
out = sim.softmax(dim=-1)
|
|
||||||
out = einsum('b i j, b j d -> b i d', out, v)
|
|
||||||
out = rearrange(out, '(b h) n d -> b n (h d)', h=h)
|
|
||||||
return out
|
|
||||||
|
|
||||||
def _flash_attention_qkv(self, q, k, v):
|
# attention, what we cannot get enough of
|
||||||
qkv = torch.stack([q, k, v], dim=2)
|
sim = sim.softmax(dim=-1)
|
||||||
b = qkv.shape[0]
|
|
||||||
n = qkv.shape[1]
|
out = einsum('b i j, b j d -> b i d', sim, v)
|
||||||
qkv = rearrange(qkv, 'b n t (h d) -> (b n) t h d', h=self.heads)
|
out = rearrange(out, '(b h) n d -> b n (h d)', h=h)
|
||||||
out = flash_attention_qkv(qkv, self.scale, b, n)
|
return self.to_out(out)
|
||||||
out = rearrange(out, '(b n) h d -> b n (h d)', b=b, h=self.heads)
|
|
||||||
return out
|
|
||||||
|
class MemoryEfficientCrossAttention(nn.Module):
|
||||||
def _flash_attention_q_kv(self, q, k, v):
|
# https://github.com/MatthieuTPHR/diffusers/blob/d80b531ff8060ec1ea982b65a1b8df70f73aa67c/src/diffusers/models/attention.py#L223
|
||||||
kv = torch.stack([k, v], dim=2)
|
def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.0):
|
||||||
b = q.shape[0]
|
super().__init__()
|
||||||
q_seqlen = q.shape[1]
|
print(f"Setting up {self.__class__.__name__}. Query dim is {query_dim}, context_dim is {context_dim} and using "
|
||||||
kv_seqlen = kv.shape[1]
|
f"{heads} heads.")
|
||||||
q = rearrange(q, 'b n (h d) -> (b n) h d', h=self.heads)
|
inner_dim = dim_head * heads
|
||||||
kv = rearrange(kv, 'b n t (h d) -> (b n) t h d', h=self.heads)
|
context_dim = default(context_dim, query_dim)
|
||||||
out = flash_attention_q_kv(q, kv, self.scale, b, q_seqlen, kv_seqlen)
|
|
||||||
out = rearrange(out, '(b n) h d -> b n (h d)', b=b, h=self.heads)
|
self.heads = heads
|
||||||
return out
|
self.dim_head = dim_head
|
||||||
|
|
||||||
|
self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
|
||||||
|
self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
|
||||||
|
self.to_v = nn.Linear(context_dim, inner_dim, bias=False)
|
||||||
|
|
||||||
|
self.to_out = nn.Sequential(nn.Linear(inner_dim, query_dim), nn.Dropout(dropout))
|
||||||
|
self.attention_op: Optional[Any] = None
|
||||||
|
|
||||||
|
def forward(self, x, context=None, mask=None):
|
||||||
|
q = self.to_q(x)
|
||||||
|
context = default(context, x)
|
||||||
|
k = self.to_k(context)
|
||||||
|
v = self.to_v(context)
|
||||||
|
|
||||||
|
b, _, _ = q.shape
|
||||||
|
q, k, v = map(
|
||||||
|
lambda t: t.unsqueeze(3)
|
||||||
|
.reshape(b, t.shape[1], self.heads, self.dim_head)
|
||||||
|
.permute(0, 2, 1, 3)
|
||||||
|
.reshape(b * self.heads, t.shape[1], self.dim_head)
|
||||||
|
.contiguous(),
|
||||||
|
(q, k, v),
|
||||||
|
)
|
||||||
|
|
||||||
|
# actually compute the attention, what we cannot get enough of
|
||||||
|
out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None, op=self.attention_op)
|
||||||
|
|
||||||
|
if exists(mask):
|
||||||
|
raise NotImplementedError
|
||||||
|
out = (
|
||||||
|
out.unsqueeze(0)
|
||||||
|
.reshape(b, self.heads, out.shape[1], self.dim_head)
|
||||||
|
.permute(0, 2, 1, 3)
|
||||||
|
.reshape(b, out.shape[1], self.heads * self.dim_head)
|
||||||
|
)
|
||||||
|
return self.to_out(out)
|
||||||
|
|
||||||
|
|
||||||
class BasicTransformerBlock(nn.Module):
|
class BasicTransformerBlock(nn.Module):
|
||||||
def __init__(self, dim, n_heads, d_head, dropout=0., context_dim=None, gated_ff=True, use_checkpoint=False):
|
ATTENTION_MODES = {
|
||||||
|
"softmax": CrossAttention, # vanilla attention
|
||||||
|
"softmax-xformers": MemoryEfficientCrossAttention
|
||||||
|
}
|
||||||
|
def __init__(self, dim, n_heads, d_head, dropout=0., context_dim=None, gated_ff=True, checkpoint=True,
|
||||||
|
disable_self_attn=False):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.attn1 = CrossAttention(query_dim=dim, heads=n_heads, dim_head=d_head, dropout=dropout) # is a self-attention
|
attn_mode = "softmax-xformers" if XFORMERS_IS_AVAILBLE else "softmax"
|
||||||
|
assert attn_mode in self.ATTENTION_MODES
|
||||||
|
attn_cls = self.ATTENTION_MODES[attn_mode]
|
||||||
|
self.disable_self_attn = disable_self_attn
|
||||||
|
self.attn1 = attn_cls(query_dim=dim, heads=n_heads, dim_head=d_head, dropout=dropout,
|
||||||
|
context_dim=context_dim if self.disable_self_attn else None) # is a self-attention if not self.disable_self_attn
|
||||||
self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff)
|
self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff)
|
||||||
self.attn2 = CrossAttention(query_dim=dim, context_dim=context_dim,
|
self.attn2 = attn_cls(query_dim=dim, context_dim=context_dim,
|
||||||
heads=n_heads, dim_head=d_head, dropout=dropout) # is self-attn if context is none
|
heads=n_heads, dim_head=d_head, dropout=dropout) # is self-attn if context is none
|
||||||
self.norm1 = nn.LayerNorm(dim)
|
self.norm1 = nn.LayerNorm(dim)
|
||||||
self.norm2 = nn.LayerNorm(dim)
|
self.norm2 = nn.LayerNorm(dim)
|
||||||
self.norm3 = nn.LayerNorm(dim)
|
self.norm3 = nn.LayerNorm(dim)
|
||||||
self.use_checkpoint = use_checkpoint
|
self.checkpoint = checkpoint
|
||||||
|
|
||||||
def forward(self, x, context=None):
|
def forward(self, x, context=None):
|
||||||
|
return checkpoint(self._forward, (x, context), self.parameters(), self.checkpoint)
|
||||||
|
|
||||||
if self.use_checkpoint:
|
|
||||||
return checkpoint(self._forward, x, context)
|
|
||||||
else:
|
|
||||||
return self._forward(x, context)
|
|
||||||
|
|
||||||
def _forward(self, x, context=None):
|
def _forward(self, x, context=None):
|
||||||
x = self.attn1(self.norm1(x)) + x
|
x = self.attn1(self.norm1(x), context=context if self.disable_self_attn else None) + x
|
||||||
x = self.attn2(self.norm2(x), context=context) + x
|
x = self.attn2(self.norm2(x), context=context) + x
|
||||||
x = self.ff(self.norm3(x)) + x
|
x = self.ff(self.norm3(x)) + x
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class SpatialTransformer(nn.Module):
|
class SpatialTransformer(nn.Module):
|
||||||
|
@ -272,43 +272,60 @@ class SpatialTransformer(nn.Module):
|
||||||
and reshape to b, t, d.
|
and reshape to b, t, d.
|
||||||
Then apply standard transformer action.
|
Then apply standard transformer action.
|
||||||
Finally, reshape to image
|
Finally, reshape to image
|
||||||
|
NEW: use_linear for more efficiency instead of the 1x1 convs
|
||||||
"""
|
"""
|
||||||
def __init__(self, in_channels, n_heads, d_head,
|
def __init__(self, in_channels, n_heads, d_head,
|
||||||
depth=1, dropout=0., context_dim=None, use_checkpoint=False):
|
depth=1, dropout=0., context_dim=None,
|
||||||
|
disable_self_attn=False, use_linear=False,
|
||||||
|
use_checkpoint=True):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
if exists(context_dim) and not isinstance(context_dim, list):
|
||||||
|
context_dim = [context_dim]
|
||||||
self.in_channels = in_channels
|
self.in_channels = in_channels
|
||||||
inner_dim = n_heads * d_head
|
inner_dim = n_heads * d_head
|
||||||
self.norm = Normalize(in_channels)
|
self.norm = Normalize(in_channels)
|
||||||
|
if not use_linear:
|
||||||
self.proj_in = nn.Conv2d(in_channels,
|
self.proj_in = nn.Conv2d(in_channels,
|
||||||
inner_dim,
|
inner_dim,
|
||||||
kernel_size=1,
|
kernel_size=1,
|
||||||
stride=1,
|
stride=1,
|
||||||
padding=0)
|
padding=0)
|
||||||
|
else:
|
||||||
|
self.proj_in = nn.Linear(in_channels, inner_dim)
|
||||||
|
|
||||||
self.transformer_blocks = nn.ModuleList(
|
self.transformer_blocks = nn.ModuleList(
|
||||||
[BasicTransformerBlock(inner_dim, n_heads, d_head, dropout=dropout, context_dim=context_dim, use_checkpoint=use_checkpoint)
|
[BasicTransformerBlock(inner_dim, n_heads, d_head, dropout=dropout, context_dim=context_dim[d],
|
||||||
|
disable_self_attn=disable_self_attn, checkpoint=use_checkpoint)
|
||||||
for d in range(depth)]
|
for d in range(depth)]
|
||||||
)
|
)
|
||||||
|
if not use_linear:
|
||||||
self.proj_out = zero_module(nn.Conv2d(inner_dim,
|
self.proj_out = zero_module(nn.Conv2d(inner_dim,
|
||||||
in_channels,
|
in_channels,
|
||||||
kernel_size=1,
|
kernel_size=1,
|
||||||
stride=1,
|
stride=1,
|
||||||
padding=0))
|
padding=0))
|
||||||
|
else:
|
||||||
|
self.proj_out = zero_module(nn.Linear(in_channels, inner_dim))
|
||||||
|
self.use_linear = use_linear
|
||||||
|
|
||||||
def forward(self, x, context=None):
|
def forward(self, x, context=None):
|
||||||
# note: if no context is given, cross-attention defaults to self-attention
|
# note: if no context is given, cross-attention defaults to self-attention
|
||||||
|
if not isinstance(context, list):
|
||||||
|
context = [context]
|
||||||
b, c, h, w = x.shape
|
b, c, h, w = x.shape
|
||||||
x_in = x
|
x_in = x
|
||||||
x = self.norm(x)
|
x = self.norm(x)
|
||||||
x = self.proj_in(x)
|
if not self.use_linear:
|
||||||
x = rearrange(x, 'b c h w -> b (h w) c')
|
x = self.proj_in(x)
|
||||||
x = x.contiguous()
|
x = rearrange(x, 'b c h w -> b (h w) c').contiguous()
|
||||||
for block in self.transformer_blocks:
|
if self.use_linear:
|
||||||
x = block(x, context=context)
|
x = self.proj_in(x)
|
||||||
x = rearrange(x, 'b (h w) c -> b c h w', h=h, w=w)
|
for i, block in enumerate(self.transformer_blocks):
|
||||||
x = x.contiguous()
|
x = block(x, context=context[i])
|
||||||
x = self.proj_out(x)
|
if self.use_linear:
|
||||||
return x + x_in
|
x = self.proj_out(x)
|
||||||
|
x = rearrange(x, 'b (h w) c -> b c h w', h=h, w=w).contiguous()
|
||||||
|
if not self.use_linear:
|
||||||
|
x = self.proj_out(x)
|
||||||
|
return x + x_in
|
||||||
|
|
||||||
|
|
|
@ -4,9 +4,22 @@ import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from einops import rearrange
|
from einops import rearrange
|
||||||
|
from typing import Optional, Any
|
||||||
|
|
||||||
from ldm.util import instantiate_from_config
|
try:
|
||||||
from ldm.modules.attention import LinearAttention
|
from lightning.pytorch.utilities import rank_zero_info
|
||||||
|
except:
|
||||||
|
from pytorch_lightning.utilities import rank_zero_info
|
||||||
|
|
||||||
|
from ldm.modules.attention import MemoryEfficientCrossAttention
|
||||||
|
|
||||||
|
try:
|
||||||
|
import xformers
|
||||||
|
import xformers.ops
|
||||||
|
XFORMERS_IS_AVAILBLE = True
|
||||||
|
except:
|
||||||
|
XFORMERS_IS_AVAILBLE = False
|
||||||
|
print("No module 'xformers'. Proceeding without it.")
|
||||||
|
|
||||||
|
|
||||||
def get_timestep_embedding(timesteps, embedding_dim):
|
def get_timestep_embedding(timesteps, embedding_dim):
|
||||||
|
@ -141,12 +154,6 @@ class ResnetBlock(nn.Module):
|
||||||
return x+h
|
return x+h
|
||||||
|
|
||||||
|
|
||||||
class LinAttnBlock(LinearAttention):
|
|
||||||
"""to match AttnBlock usage"""
|
|
||||||
def __init__(self, in_channels):
|
|
||||||
super().__init__(dim=in_channels, heads=1, dim_head=in_channels)
|
|
||||||
|
|
||||||
|
|
||||||
class AttnBlock(nn.Module):
|
class AttnBlock(nn.Module):
|
||||||
def __init__(self, in_channels):
|
def __init__(self, in_channels):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
@ -174,7 +181,6 @@ class AttnBlock(nn.Module):
|
||||||
stride=1,
|
stride=1,
|
||||||
padding=0)
|
padding=0)
|
||||||
|
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
h_ = x
|
h_ = x
|
||||||
h_ = self.norm(h_)
|
h_ = self.norm(h_)
|
||||||
|
@ -201,21 +207,100 @@ class AttnBlock(nn.Module):
|
||||||
|
|
||||||
return x+h_
|
return x+h_
|
||||||
|
|
||||||
|
class MemoryEfficientAttnBlock(nn.Module):
|
||||||
|
"""
|
||||||
|
Uses xformers efficient implementation,
|
||||||
|
see https://github.com/MatthieuTPHR/diffusers/blob/d80b531ff8060ec1ea982b65a1b8df70f73aa67c/src/diffusers/models/attention.py#L223
|
||||||
|
Note: this is a single-head self-attention operation
|
||||||
|
"""
|
||||||
|
#
|
||||||
|
def __init__(self, in_channels):
|
||||||
|
super().__init__()
|
||||||
|
self.in_channels = in_channels
|
||||||
|
|
||||||
def make_attn(in_channels, attn_type="vanilla"):
|
self.norm = Normalize(in_channels)
|
||||||
assert attn_type in ["vanilla", "linear", "none"], f'attn_type {attn_type} unknown'
|
self.q = torch.nn.Conv2d(in_channels,
|
||||||
print(f"making attention of type '{attn_type}' with {in_channels} in_channels")
|
in_channels,
|
||||||
|
kernel_size=1,
|
||||||
|
stride=1,
|
||||||
|
padding=0)
|
||||||
|
self.k = torch.nn.Conv2d(in_channels,
|
||||||
|
in_channels,
|
||||||
|
kernel_size=1,
|
||||||
|
stride=1,
|
||||||
|
padding=0)
|
||||||
|
self.v = torch.nn.Conv2d(in_channels,
|
||||||
|
in_channels,
|
||||||
|
kernel_size=1,
|
||||||
|
stride=1,
|
||||||
|
padding=0)
|
||||||
|
self.proj_out = torch.nn.Conv2d(in_channels,
|
||||||
|
in_channels,
|
||||||
|
kernel_size=1,
|
||||||
|
stride=1,
|
||||||
|
padding=0)
|
||||||
|
self.attention_op: Optional[Any] = None
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
h_ = x
|
||||||
|
h_ = self.norm(h_)
|
||||||
|
q = self.q(h_)
|
||||||
|
k = self.k(h_)
|
||||||
|
v = self.v(h_)
|
||||||
|
|
||||||
|
# compute attention
|
||||||
|
B, C, H, W = q.shape
|
||||||
|
q, k, v = map(lambda x: rearrange(x, 'b c h w -> b (h w) c'), (q, k, v))
|
||||||
|
|
||||||
|
q, k, v = map(
|
||||||
|
lambda t: t.unsqueeze(3)
|
||||||
|
.reshape(B, t.shape[1], 1, C)
|
||||||
|
.permute(0, 2, 1, 3)
|
||||||
|
.reshape(B * 1, t.shape[1], C)
|
||||||
|
.contiguous(),
|
||||||
|
(q, k, v),
|
||||||
|
)
|
||||||
|
out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None, op=self.attention_op)
|
||||||
|
|
||||||
|
out = (
|
||||||
|
out.unsqueeze(0)
|
||||||
|
.reshape(B, 1, out.shape[1], C)
|
||||||
|
.permute(0, 2, 1, 3)
|
||||||
|
.reshape(B, out.shape[1], C)
|
||||||
|
)
|
||||||
|
out = rearrange(out, 'b (h w) c -> b c h w', b=B, h=H, w=W, c=C)
|
||||||
|
out = self.proj_out(out)
|
||||||
|
return x+out
|
||||||
|
|
||||||
|
|
||||||
|
class MemoryEfficientCrossAttentionWrapper(MemoryEfficientCrossAttention):
|
||||||
|
def forward(self, x, context=None, mask=None):
|
||||||
|
b, c, h, w = x.shape
|
||||||
|
x = rearrange(x, 'b c h w -> b (h w) c')
|
||||||
|
out = super().forward(x, context=context, mask=mask)
|
||||||
|
out = rearrange(out, 'b (h w) c -> b c h w', h=h, w=w, c=c)
|
||||||
|
return x + out
|
||||||
|
|
||||||
|
|
||||||
|
def make_attn(in_channels, attn_type="vanilla", attn_kwargs=None):
|
||||||
|
assert attn_type in ["vanilla", "vanilla-xformers", "memory-efficient-cross-attn", "linear", "none"], f'attn_type {attn_type} unknown'
|
||||||
|
if XFORMERS_IS_AVAILBLE and attn_type == "vanilla":
|
||||||
|
attn_type = "vanilla-xformers"
|
||||||
|
rank_zero_info(f"making attention of type '{attn_type}' with {in_channels} in_channels")
|
||||||
if attn_type == "vanilla":
|
if attn_type == "vanilla":
|
||||||
|
assert attn_kwargs is None
|
||||||
return AttnBlock(in_channels)
|
return AttnBlock(in_channels)
|
||||||
|
elif attn_type == "vanilla-xformers":
|
||||||
|
rank_zero_info(f"building MemoryEfficientAttnBlock with {in_channels} in_channels...")
|
||||||
|
return MemoryEfficientAttnBlock(in_channels)
|
||||||
|
elif type == "memory-efficient-cross-attn":
|
||||||
|
attn_kwargs["query_dim"] = in_channels
|
||||||
|
return MemoryEfficientCrossAttentionWrapper(**attn_kwargs)
|
||||||
elif attn_type == "none":
|
elif attn_type == "none":
|
||||||
return nn.Identity(in_channels)
|
return nn.Identity(in_channels)
|
||||||
else:
|
else:
|
||||||
return LinAttnBlock(in_channels)
|
raise NotImplementedError()
|
||||||
|
|
||||||
class temb_module(nn.Module):
|
|
||||||
def __init__(self):
|
|
||||||
super().__init__()
|
|
||||||
pass
|
|
||||||
|
|
||||||
class Model(nn.Module):
|
class Model(nn.Module):
|
||||||
def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks,
|
def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks,
|
||||||
|
@ -233,8 +318,7 @@ class Model(nn.Module):
|
||||||
self.use_timestep = use_timestep
|
self.use_timestep = use_timestep
|
||||||
if self.use_timestep:
|
if self.use_timestep:
|
||||||
# timestep embedding
|
# timestep embedding
|
||||||
# self.temb = nn.Module()
|
self.temb = nn.Module()
|
||||||
self.temb = temb_module()
|
|
||||||
self.temb.dense = nn.ModuleList([
|
self.temb.dense = nn.ModuleList([
|
||||||
torch.nn.Linear(self.ch,
|
torch.nn.Linear(self.ch,
|
||||||
self.temb_ch),
|
self.temb_ch),
|
||||||
|
@ -265,8 +349,7 @@ class Model(nn.Module):
|
||||||
block_in = block_out
|
block_in = block_out
|
||||||
if curr_res in attn_resolutions:
|
if curr_res in attn_resolutions:
|
||||||
attn.append(make_attn(block_in, attn_type=attn_type))
|
attn.append(make_attn(block_in, attn_type=attn_type))
|
||||||
# down = nn.Module()
|
down = nn.Module()
|
||||||
down = Down_module()
|
|
||||||
down.block = block
|
down.block = block
|
||||||
down.attn = attn
|
down.attn = attn
|
||||||
if i_level != self.num_resolutions-1:
|
if i_level != self.num_resolutions-1:
|
||||||
|
@ -275,8 +358,7 @@ class Model(nn.Module):
|
||||||
self.down.append(down)
|
self.down.append(down)
|
||||||
|
|
||||||
# middle
|
# middle
|
||||||
# self.mid = nn.Module()
|
self.mid = nn.Module()
|
||||||
self.mid = Mid_module()
|
|
||||||
self.mid.block_1 = ResnetBlock(in_channels=block_in,
|
self.mid.block_1 = ResnetBlock(in_channels=block_in,
|
||||||
out_channels=block_in,
|
out_channels=block_in,
|
||||||
temb_channels=self.temb_ch,
|
temb_channels=self.temb_ch,
|
||||||
|
@ -304,8 +386,7 @@ class Model(nn.Module):
|
||||||
block_in = block_out
|
block_in = block_out
|
||||||
if curr_res in attn_resolutions:
|
if curr_res in attn_resolutions:
|
||||||
attn.append(make_attn(block_in, attn_type=attn_type))
|
attn.append(make_attn(block_in, attn_type=attn_type))
|
||||||
# up = nn.Module()
|
up = nn.Module()
|
||||||
up = Up_module()
|
|
||||||
up.block = block
|
up.block = block
|
||||||
up.attn = attn
|
up.attn = attn
|
||||||
if i_level != 0:
|
if i_level != 0:
|
||||||
|
@ -372,21 +453,6 @@ class Model(nn.Module):
|
||||||
def get_last_layer(self):
|
def get_last_layer(self):
|
||||||
return self.conv_out.weight
|
return self.conv_out.weight
|
||||||
|
|
||||||
class Down_module(nn.Module):
|
|
||||||
def __init__(self):
|
|
||||||
super().__init__()
|
|
||||||
pass
|
|
||||||
|
|
||||||
class Up_module(nn.Module):
|
|
||||||
def __init__(self):
|
|
||||||
super().__init__()
|
|
||||||
pass
|
|
||||||
|
|
||||||
class Mid_module(nn.Module):
|
|
||||||
def __init__(self):
|
|
||||||
super().__init__()
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
class Encoder(nn.Module):
|
class Encoder(nn.Module):
|
||||||
def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks,
|
def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks,
|
||||||
|
@ -426,8 +492,7 @@ class Encoder(nn.Module):
|
||||||
block_in = block_out
|
block_in = block_out
|
||||||
if curr_res in attn_resolutions:
|
if curr_res in attn_resolutions:
|
||||||
attn.append(make_attn(block_in, attn_type=attn_type))
|
attn.append(make_attn(block_in, attn_type=attn_type))
|
||||||
# down = nn.Module()
|
down = nn.Module()
|
||||||
down = Down_module()
|
|
||||||
down.block = block
|
down.block = block
|
||||||
down.attn = attn
|
down.attn = attn
|
||||||
if i_level != self.num_resolutions-1:
|
if i_level != self.num_resolutions-1:
|
||||||
|
@ -436,8 +501,7 @@ class Encoder(nn.Module):
|
||||||
self.down.append(down)
|
self.down.append(down)
|
||||||
|
|
||||||
# middle
|
# middle
|
||||||
# self.mid = nn.Module()
|
self.mid = nn.Module()
|
||||||
self.mid = Mid_module()
|
|
||||||
self.mid.block_1 = ResnetBlock(in_channels=block_in,
|
self.mid.block_1 = ResnetBlock(in_channels=block_in,
|
||||||
out_channels=block_in,
|
out_channels=block_in,
|
||||||
temb_channels=self.temb_ch,
|
temb_channels=self.temb_ch,
|
||||||
|
@ -505,7 +569,7 @@ class Decoder(nn.Module):
|
||||||
block_in = ch*ch_mult[self.num_resolutions-1]
|
block_in = ch*ch_mult[self.num_resolutions-1]
|
||||||
curr_res = resolution // 2**(self.num_resolutions-1)
|
curr_res = resolution // 2**(self.num_resolutions-1)
|
||||||
self.z_shape = (1,z_channels,curr_res,curr_res)
|
self.z_shape = (1,z_channels,curr_res,curr_res)
|
||||||
print("Working with z of shape {} = {} dimensions.".format(
|
rank_zero_info("Working with z of shape {} = {} dimensions.".format(
|
||||||
self.z_shape, np.prod(self.z_shape)))
|
self.z_shape, np.prod(self.z_shape)))
|
||||||
|
|
||||||
# z to block_in
|
# z to block_in
|
||||||
|
@ -516,8 +580,7 @@ class Decoder(nn.Module):
|
||||||
padding=1)
|
padding=1)
|
||||||
|
|
||||||
# middle
|
# middle
|
||||||
# self.mid = nn.Module()
|
self.mid = nn.Module()
|
||||||
self.mid = Mid_module()
|
|
||||||
self.mid.block_1 = ResnetBlock(in_channels=block_in,
|
self.mid.block_1 = ResnetBlock(in_channels=block_in,
|
||||||
out_channels=block_in,
|
out_channels=block_in,
|
||||||
temb_channels=self.temb_ch,
|
temb_channels=self.temb_ch,
|
||||||
|
@ -542,8 +605,7 @@ class Decoder(nn.Module):
|
||||||
block_in = block_out
|
block_in = block_out
|
||||||
if curr_res in attn_resolutions:
|
if curr_res in attn_resolutions:
|
||||||
attn.append(make_attn(block_in, attn_type=attn_type))
|
attn.append(make_attn(block_in, attn_type=attn_type))
|
||||||
# up = nn.Module()
|
up = nn.Module()
|
||||||
up = Up_module()
|
|
||||||
up.block = block
|
up.block = block
|
||||||
up.attn = attn
|
up.attn = attn
|
||||||
if i_level != 0:
|
if i_level != 0:
|
||||||
|
@ -758,7 +820,7 @@ class Upsampler(nn.Module):
|
||||||
assert out_size >= in_size
|
assert out_size >= in_size
|
||||||
num_blocks = int(np.log2(out_size//in_size))+1
|
num_blocks = int(np.log2(out_size//in_size))+1
|
||||||
factor_up = 1.+ (out_size % in_size)
|
factor_up = 1.+ (out_size % in_size)
|
||||||
print(f"Building {self.__class__.__name__} with in_size: {in_size} --> out_size {out_size} and factor {factor_up}")
|
rank_zero_info(f"Building {self.__class__.__name__} with in_size: {in_size} --> out_size {out_size} and factor {factor_up}")
|
||||||
self.rescaler = LatentRescaler(factor=factor_up, in_channels=in_channels, mid_channels=2*in_channels,
|
self.rescaler = LatentRescaler(factor=factor_up, in_channels=in_channels, mid_channels=2*in_channels,
|
||||||
out_channels=in_channels)
|
out_channels=in_channels)
|
||||||
self.decoder = Decoder(out_ch=out_channels, resolution=out_size, z_channels=in_channels, num_res_blocks=2,
|
self.decoder = Decoder(out_ch=out_channels, resolution=out_size, z_channels=in_channels, num_res_blocks=2,
|
||||||
|
@ -777,7 +839,7 @@ class Resize(nn.Module):
|
||||||
self.with_conv = learned
|
self.with_conv = learned
|
||||||
self.mode = mode
|
self.mode = mode
|
||||||
if self.with_conv:
|
if self.with_conv:
|
||||||
print(f"Note: {self.__class__.__name} uses learned downsampling and will ignore the fixed {mode} mode")
|
rank_zero_info(f"Note: {self.__class__.__name} uses learned downsampling and will ignore the fixed {mode} mode")
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
assert in_channels is not None
|
assert in_channels is not None
|
||||||
# no asymmetric padding in torch conv, must do it ourselves
|
# no asymmetric padding in torch conv, must do it ourselves
|
||||||
|
@ -793,70 +855,3 @@ class Resize(nn.Module):
|
||||||
else:
|
else:
|
||||||
x = torch.nn.functional.interpolate(x, mode=self.mode, align_corners=False, scale_factor=scale_factor)
|
x = torch.nn.functional.interpolate(x, mode=self.mode, align_corners=False, scale_factor=scale_factor)
|
||||||
return x
|
return x
|
||||||
|
|
||||||
class FirstStagePostProcessor(nn.Module):
|
|
||||||
|
|
||||||
def __init__(self, ch_mult:list, in_channels,
|
|
||||||
pretrained_model:nn.Module=None,
|
|
||||||
reshape=False,
|
|
||||||
n_channels=None,
|
|
||||||
dropout=0.,
|
|
||||||
pretrained_config=None):
|
|
||||||
super().__init__()
|
|
||||||
if pretrained_config is None:
|
|
||||||
assert pretrained_model is not None, 'Either "pretrained_model" or "pretrained_config" must not be None'
|
|
||||||
self.pretrained_model = pretrained_model
|
|
||||||
else:
|
|
||||||
assert pretrained_config is not None, 'Either "pretrained_model" or "pretrained_config" must not be None'
|
|
||||||
self.instantiate_pretrained(pretrained_config)
|
|
||||||
|
|
||||||
self.do_reshape = reshape
|
|
||||||
|
|
||||||
if n_channels is None:
|
|
||||||
n_channels = self.pretrained_model.encoder.ch
|
|
||||||
|
|
||||||
self.proj_norm = Normalize(in_channels,num_groups=in_channels//2)
|
|
||||||
self.proj = nn.Conv2d(in_channels,n_channels,kernel_size=3,
|
|
||||||
stride=1,padding=1)
|
|
||||||
|
|
||||||
blocks = []
|
|
||||||
downs = []
|
|
||||||
ch_in = n_channels
|
|
||||||
for m in ch_mult:
|
|
||||||
blocks.append(ResnetBlock(in_channels=ch_in,out_channels=m*n_channels,dropout=dropout))
|
|
||||||
ch_in = m * n_channels
|
|
||||||
downs.append(Downsample(ch_in, with_conv=False))
|
|
||||||
|
|
||||||
self.model = nn.ModuleList(blocks)
|
|
||||||
self.downsampler = nn.ModuleList(downs)
|
|
||||||
|
|
||||||
|
|
||||||
def instantiate_pretrained(self, config):
|
|
||||||
model = instantiate_from_config(config)
|
|
||||||
self.pretrained_model = model.eval()
|
|
||||||
# self.pretrained_model.train = False
|
|
||||||
for param in self.pretrained_model.parameters():
|
|
||||||
param.requires_grad = False
|
|
||||||
|
|
||||||
|
|
||||||
@torch.no_grad()
|
|
||||||
def encode_with_pretrained(self,x):
|
|
||||||
c = self.pretrained_model.encode(x)
|
|
||||||
if isinstance(c, DiagonalGaussianDistribution):
|
|
||||||
c = c.mode()
|
|
||||||
return c
|
|
||||||
|
|
||||||
def forward(self,x):
|
|
||||||
z_fs = self.encode_with_pretrained(x)
|
|
||||||
z = self.proj_norm(z_fs)
|
|
||||||
z = self.proj(z)
|
|
||||||
z = nonlinearity(z)
|
|
||||||
|
|
||||||
for submodel, downmodel in zip(self.model,self.downsampler):
|
|
||||||
z = submodel(z,temb=None)
|
|
||||||
z = downmodel(z)
|
|
||||||
|
|
||||||
if self.do_reshape:
|
|
||||||
z = rearrange(z,'b c h w -> b (h w) c')
|
|
||||||
return z
|
|
||||||
|
|
||||||
|
|
|
@ -1,16 +1,13 @@
|
||||||
from abc import abstractmethod
|
from abc import abstractmethod
|
||||||
from functools import partial
|
|
||||||
import math
|
import math
|
||||||
from typing import Iterable
|
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
|
||||||
import torch as th
|
import torch as th
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from torch.utils import checkpoint
|
|
||||||
|
|
||||||
from ldm.modules.diffusionmodules.util import (
|
from ldm.modules.diffusionmodules.util import (
|
||||||
|
checkpoint,
|
||||||
conv_nd,
|
conv_nd,
|
||||||
linear,
|
linear,
|
||||||
avg_pool_nd,
|
avg_pool_nd,
|
||||||
|
@ -19,13 +16,11 @@ from ldm.modules.diffusionmodules.util import (
|
||||||
timestep_embedding,
|
timestep_embedding,
|
||||||
)
|
)
|
||||||
from ldm.modules.attention import SpatialTransformer
|
from ldm.modules.attention import SpatialTransformer
|
||||||
|
from ldm.util import exists
|
||||||
|
|
||||||
|
|
||||||
# dummy replace
|
# dummy replace
|
||||||
def convert_module_to_f16(x):
|
def convert_module_to_f16(x):
|
||||||
# for n,p in x.named_parameter():
|
|
||||||
# print(f"convert module {n} to_f16")
|
|
||||||
# p.data = p.data.half()
|
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def convert_module_to_f32(x):
|
def convert_module_to_f32(x):
|
||||||
|
@ -251,10 +246,9 @@ class ResBlock(TimestepBlock):
|
||||||
:param emb: an [N x emb_channels] Tensor of timestep embeddings.
|
:param emb: an [N x emb_channels] Tensor of timestep embeddings.
|
||||||
:return: an [N x C x ...] Tensor of outputs.
|
:return: an [N x C x ...] Tensor of outputs.
|
||||||
"""
|
"""
|
||||||
if self.use_checkpoint:
|
return checkpoint(
|
||||||
return checkpoint(self._forward, x, emb)
|
self._forward, (x, emb), self.parameters(), self.use_checkpoint
|
||||||
else:
|
)
|
||||||
return self._forward(x, emb)
|
|
||||||
|
|
||||||
|
|
||||||
def _forward(self, x, emb):
|
def _forward(self, x, emb):
|
||||||
|
@ -317,11 +311,8 @@ class AttentionBlock(nn.Module):
|
||||||
self.proj_out = zero_module(conv_nd(1, channels, channels, 1))
|
self.proj_out = zero_module(conv_nd(1, channels, channels, 1))
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
if self.use_checkpoint:
|
return checkpoint(self._forward, (x,), self.parameters(), True) # TODO: check checkpoint usage, is True # TODO: fix the .half call!!!
|
||||||
return checkpoint(self._forward, x) # TODO: check checkpoint usage, is True # TODO: fix the .half call!!!
|
|
||||||
#return pt_checkpoint(self._forward, x) # pytorch
|
#return pt_checkpoint(self._forward, x) # pytorch
|
||||||
else:
|
|
||||||
return self._forward(x)
|
|
||||||
|
|
||||||
def _forward(self, x):
|
def _forward(self, x):
|
||||||
b, c, *spatial = x.shape
|
b, c, *spatial = x.shape
|
||||||
|
@ -474,7 +465,10 @@ class UNetModel(nn.Module):
|
||||||
context_dim=None, # custom transformer support
|
context_dim=None, # custom transformer support
|
||||||
n_embed=None, # custom support for prediction of discrete ids into codebook of first stage vq model
|
n_embed=None, # custom support for prediction of discrete ids into codebook of first stage vq model
|
||||||
legacy=True,
|
legacy=True,
|
||||||
from_pretrained: str=None
|
disable_self_attentions=None,
|
||||||
|
num_attention_blocks=None,
|
||||||
|
disable_middle_self_attn=False,
|
||||||
|
use_linear_in_transformer=False,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
if use_spatial_transformer:
|
if use_spatial_transformer:
|
||||||
|
@ -499,7 +493,24 @@ class UNetModel(nn.Module):
|
||||||
self.in_channels = in_channels
|
self.in_channels = in_channels
|
||||||
self.model_channels = model_channels
|
self.model_channels = model_channels
|
||||||
self.out_channels = out_channels
|
self.out_channels = out_channels
|
||||||
self.num_res_blocks = num_res_blocks
|
if isinstance(num_res_blocks, int):
|
||||||
|
self.num_res_blocks = len(channel_mult) * [num_res_blocks]
|
||||||
|
else:
|
||||||
|
if len(num_res_blocks) != len(channel_mult):
|
||||||
|
raise ValueError("provide num_res_blocks either as an int (globally constant) or "
|
||||||
|
"as a list/tuple (per-level) with the same length as channel_mult")
|
||||||
|
self.num_res_blocks = num_res_blocks
|
||||||
|
if disable_self_attentions is not None:
|
||||||
|
# should be a list of booleans, indicating whether to disable self-attention in TransformerBlocks or not
|
||||||
|
assert len(disable_self_attentions) == len(channel_mult)
|
||||||
|
if num_attention_blocks is not None:
|
||||||
|
assert len(num_attention_blocks) == len(self.num_res_blocks)
|
||||||
|
assert all(map(lambda i: self.num_res_blocks[i] >= num_attention_blocks[i], range(len(num_attention_blocks))))
|
||||||
|
print(f"Constructor of UNetModel received num_attention_blocks={num_attention_blocks}. "
|
||||||
|
f"This option has LESS priority than attention_resolutions {attention_resolutions}, "
|
||||||
|
f"i.e., in cases where num_attention_blocks[i] > 0 but 2**i not in attention_resolutions, "
|
||||||
|
f"attention will still not be set.")
|
||||||
|
|
||||||
self.attention_resolutions = attention_resolutions
|
self.attention_resolutions = attention_resolutions
|
||||||
self.dropout = dropout
|
self.dropout = dropout
|
||||||
self.channel_mult = channel_mult
|
self.channel_mult = channel_mult
|
||||||
|
@ -520,7 +531,13 @@ class UNetModel(nn.Module):
|
||||||
)
|
)
|
||||||
|
|
||||||
if self.num_classes is not None:
|
if self.num_classes is not None:
|
||||||
self.label_emb = nn.Embedding(num_classes, time_embed_dim)
|
if isinstance(self.num_classes, int):
|
||||||
|
self.label_emb = nn.Embedding(num_classes, time_embed_dim)
|
||||||
|
elif self.num_classes == "continuous":
|
||||||
|
print("setting up linear c_adm embedding layer")
|
||||||
|
self.label_emb = nn.Linear(1, time_embed_dim)
|
||||||
|
else:
|
||||||
|
raise ValueError()
|
||||||
|
|
||||||
self.input_blocks = nn.ModuleList(
|
self.input_blocks = nn.ModuleList(
|
||||||
[
|
[
|
||||||
|
@ -534,7 +551,7 @@ class UNetModel(nn.Module):
|
||||||
ch = model_channels
|
ch = model_channels
|
||||||
ds = 1
|
ds = 1
|
||||||
for level, mult in enumerate(channel_mult):
|
for level, mult in enumerate(channel_mult):
|
||||||
for _ in range(num_res_blocks):
|
for nr in range(self.num_res_blocks[level]):
|
||||||
layers = [
|
layers = [
|
||||||
ResBlock(
|
ResBlock(
|
||||||
ch,
|
ch,
|
||||||
|
@ -556,17 +573,25 @@ class UNetModel(nn.Module):
|
||||||
if legacy:
|
if legacy:
|
||||||
#num_heads = 1
|
#num_heads = 1
|
||||||
dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
|
dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
|
||||||
layers.append(
|
if exists(disable_self_attentions):
|
||||||
AttentionBlock(
|
disabled_sa = disable_self_attentions[level]
|
||||||
ch,
|
else:
|
||||||
use_checkpoint=use_checkpoint,
|
disabled_sa = False
|
||||||
num_heads=num_heads,
|
|
||||||
num_head_channels=dim_head,
|
if not exists(num_attention_blocks) or nr < num_attention_blocks[level]:
|
||||||
use_new_attention_order=use_new_attention_order,
|
layers.append(
|
||||||
) if not use_spatial_transformer else SpatialTransformer(
|
AttentionBlock(
|
||||||
ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim, use_checkpoint=use_checkpoint,
|
ch,
|
||||||
|
use_checkpoint=use_checkpoint,
|
||||||
|
num_heads=num_heads,
|
||||||
|
num_head_channels=dim_head,
|
||||||
|
use_new_attention_order=use_new_attention_order,
|
||||||
|
) if not use_spatial_transformer else SpatialTransformer(
|
||||||
|
ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim,
|
||||||
|
disable_self_attn=disabled_sa, use_linear=use_linear_in_transformer,
|
||||||
|
use_checkpoint=use_checkpoint
|
||||||
|
)
|
||||||
)
|
)
|
||||||
)
|
|
||||||
self.input_blocks.append(TimestepEmbedSequential(*layers))
|
self.input_blocks.append(TimestepEmbedSequential(*layers))
|
||||||
self._feature_size += ch
|
self._feature_size += ch
|
||||||
input_block_chans.append(ch)
|
input_block_chans.append(ch)
|
||||||
|
@ -618,8 +643,10 @@ class UNetModel(nn.Module):
|
||||||
num_heads=num_heads,
|
num_heads=num_heads,
|
||||||
num_head_channels=dim_head,
|
num_head_channels=dim_head,
|
||||||
use_new_attention_order=use_new_attention_order,
|
use_new_attention_order=use_new_attention_order,
|
||||||
) if not use_spatial_transformer else SpatialTransformer(
|
) if not use_spatial_transformer else SpatialTransformer( # always uses a self-attn
|
||||||
ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim
|
ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim,
|
||||||
|
disable_self_attn=disable_middle_self_attn, use_linear=use_linear_in_transformer,
|
||||||
|
use_checkpoint=use_checkpoint
|
||||||
),
|
),
|
||||||
ResBlock(
|
ResBlock(
|
||||||
ch,
|
ch,
|
||||||
|
@ -634,7 +661,7 @@ class UNetModel(nn.Module):
|
||||||
|
|
||||||
self.output_blocks = nn.ModuleList([])
|
self.output_blocks = nn.ModuleList([])
|
||||||
for level, mult in list(enumerate(channel_mult))[::-1]:
|
for level, mult in list(enumerate(channel_mult))[::-1]:
|
||||||
for i in range(num_res_blocks + 1):
|
for i in range(self.num_res_blocks[level] + 1):
|
||||||
ich = input_block_chans.pop()
|
ich = input_block_chans.pop()
|
||||||
layers = [
|
layers = [
|
||||||
ResBlock(
|
ResBlock(
|
||||||
|
@ -657,18 +684,26 @@ class UNetModel(nn.Module):
|
||||||
if legacy:
|
if legacy:
|
||||||
#num_heads = 1
|
#num_heads = 1
|
||||||
dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
|
dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
|
||||||
layers.append(
|
if exists(disable_self_attentions):
|
||||||
AttentionBlock(
|
disabled_sa = disable_self_attentions[level]
|
||||||
ch,
|
else:
|
||||||
use_checkpoint=use_checkpoint,
|
disabled_sa = False
|
||||||
num_heads=num_heads_upsample,
|
|
||||||
num_head_channels=dim_head,
|
if not exists(num_attention_blocks) or i < num_attention_blocks[level]:
|
||||||
use_new_attention_order=use_new_attention_order,
|
layers.append(
|
||||||
) if not use_spatial_transformer else SpatialTransformer(
|
AttentionBlock(
|
||||||
ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim
|
ch,
|
||||||
|
use_checkpoint=use_checkpoint,
|
||||||
|
num_heads=num_heads_upsample,
|
||||||
|
num_head_channels=dim_head,
|
||||||
|
use_new_attention_order=use_new_attention_order,
|
||||||
|
) if not use_spatial_transformer else SpatialTransformer(
|
||||||
|
ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim,
|
||||||
|
disable_self_attn=disabled_sa, use_linear=use_linear_in_transformer,
|
||||||
|
use_checkpoint=use_checkpoint
|
||||||
|
)
|
||||||
)
|
)
|
||||||
)
|
if level and i == self.num_res_blocks[level]:
|
||||||
if level and i == num_res_blocks:
|
|
||||||
out_ch = ch
|
out_ch = ch
|
||||||
layers.append(
|
layers.append(
|
||||||
ResBlock(
|
ResBlock(
|
||||||
|
@ -699,188 +734,6 @@ class UNetModel(nn.Module):
|
||||||
conv_nd(dims, model_channels, n_embed, 1),
|
conv_nd(dims, model_channels, n_embed, 1),
|
||||||
#nn.LogSoftmax(dim=1) # change to cross_entropy and produce non-normalized logits
|
#nn.LogSoftmax(dim=1) # change to cross_entropy and produce non-normalized logits
|
||||||
)
|
)
|
||||||
# if use_fp16:
|
|
||||||
# self.convert_to_fp16()
|
|
||||||
from diffusers.modeling_utils import load_state_dict
|
|
||||||
if from_pretrained is not None:
|
|
||||||
state_dict = load_state_dict(from_pretrained)
|
|
||||||
self._load_pretrained_model(state_dict)
|
|
||||||
|
|
||||||
def _input_blocks_mapping(self, input_dict):
|
|
||||||
res_dict = {}
|
|
||||||
for key_, value_ in input_dict.items():
|
|
||||||
id_0 = int(key_[13])
|
|
||||||
if "resnets" in key_:
|
|
||||||
id_1 = int(key_[23])
|
|
||||||
target_id = 3 * id_0 + 1 + id_1
|
|
||||||
post_fix = key_[25:].replace('time_emb_proj', 'emb_layers.1')\
|
|
||||||
.replace('norm1', 'in_layers.0')\
|
|
||||||
.replace('norm2', 'out_layers.0')\
|
|
||||||
.replace('conv1', 'in_layers.2')\
|
|
||||||
.replace('conv2', 'out_layers.3')\
|
|
||||||
.replace('conv_shortcut', 'skip_connection')
|
|
||||||
res_dict["input_blocks." + str(target_id) + '.0.' + post_fix] = value_
|
|
||||||
elif "attentions" in key_:
|
|
||||||
id_1 = int(key_[26])
|
|
||||||
target_id = 3 * id_0 + 1 + id_1
|
|
||||||
post_fix = key_[28:]
|
|
||||||
res_dict["input_blocks." + str(target_id) + '.1.' + post_fix] = value_
|
|
||||||
elif "downsamplers" in key_:
|
|
||||||
post_fix = key_[35:]
|
|
||||||
target_id = 3 * (id_0 + 1)
|
|
||||||
res_dict["input_blocks." + str(target_id) + '.0.op.' + post_fix] = value_
|
|
||||||
return res_dict
|
|
||||||
|
|
||||||
|
|
||||||
def _mid_blocks_mapping(self, mid_dict):
|
|
||||||
res_dict = {}
|
|
||||||
for key_, value_ in mid_dict.items():
|
|
||||||
if "resnets" in key_:
|
|
||||||
temp_key_ =key_.replace('time_emb_proj', 'emb_layers.1') \
|
|
||||||
.replace('norm1', 'in_layers.0') \
|
|
||||||
.replace('norm2', 'out_layers.0') \
|
|
||||||
.replace('conv1', 'in_layers.2') \
|
|
||||||
.replace('conv2', 'out_layers.3') \
|
|
||||||
.replace('conv_shortcut', 'skip_connection')\
|
|
||||||
.replace('middle_block.resnets.0', 'middle_block.0')\
|
|
||||||
.replace('middle_block.resnets.1', 'middle_block.2')
|
|
||||||
res_dict[temp_key_] = value_
|
|
||||||
elif "attentions" in key_:
|
|
||||||
res_dict[key_.replace('attentions.0', '1')] = value_
|
|
||||||
return res_dict
|
|
||||||
|
|
||||||
def _other_blocks_mapping(self, other_dict):
|
|
||||||
res_dict = {}
|
|
||||||
for key_, value_ in other_dict.items():
|
|
||||||
tmp_key = key_.replace('conv_in', 'input_blocks.0.0')\
|
|
||||||
.replace('time_embedding.linear_1', 'time_embed.0')\
|
|
||||||
.replace('time_embedding.linear_2', 'time_embed.2')\
|
|
||||||
.replace('conv_norm_out', 'out.0')\
|
|
||||||
.replace('conv_out', 'out.2')
|
|
||||||
res_dict[tmp_key] = value_
|
|
||||||
return res_dict
|
|
||||||
|
|
||||||
|
|
||||||
def _output_blocks_mapping(self, output_dict):
|
|
||||||
res_dict = {}
|
|
||||||
for key_, value_ in output_dict.items():
|
|
||||||
id_0 = int(key_[14])
|
|
||||||
if "resnets" in key_:
|
|
||||||
id_1 = int(key_[24])
|
|
||||||
target_id = 3 * id_0 + id_1
|
|
||||||
post_fix = key_[26:].replace('time_emb_proj', 'emb_layers.1') \
|
|
||||||
.replace('norm1', 'in_layers.0') \
|
|
||||||
.replace('norm2', 'out_layers.0') \
|
|
||||||
.replace('conv1', 'in_layers.2') \
|
|
||||||
.replace('conv2', 'out_layers.3') \
|
|
||||||
.replace('conv_shortcut', 'skip_connection')
|
|
||||||
res_dict["output_blocks." + str(target_id) + '.0.' + post_fix] = value_
|
|
||||||
elif "attentions" in key_:
|
|
||||||
id_1 = int(key_[27])
|
|
||||||
target_id = 3 * id_0 + id_1
|
|
||||||
post_fix = key_[29:]
|
|
||||||
res_dict["output_blocks." + str(target_id) + '.1.' + post_fix] = value_
|
|
||||||
elif "upsamplers" in key_:
|
|
||||||
post_fix = key_[34:]
|
|
||||||
target_id = 3 * (id_0 + 1) - 1
|
|
||||||
mid_str = '.2.conv.' if target_id != 2 else '.1.conv.'
|
|
||||||
res_dict["output_blocks." + str(target_id) + mid_str + post_fix] = value_
|
|
||||||
return res_dict
|
|
||||||
|
|
||||||
def _state_key_mapping(self, state_dict: dict):
|
|
||||||
import re
|
|
||||||
res_dict = {}
|
|
||||||
input_dict = {}
|
|
||||||
mid_dict = {}
|
|
||||||
output_dict = {}
|
|
||||||
other_dict = {}
|
|
||||||
for key_, value_ in state_dict.items():
|
|
||||||
if "down_blocks" in key_:
|
|
||||||
input_dict[key_.replace('down_blocks', 'input_blocks')] = value_
|
|
||||||
elif "up_blocks" in key_:
|
|
||||||
output_dict[key_.replace('up_blocks', 'output_blocks')] = value_
|
|
||||||
elif "mid_block" in key_:
|
|
||||||
mid_dict[key_.replace('mid_block', 'middle_block')] = value_
|
|
||||||
else:
|
|
||||||
other_dict[key_] = value_
|
|
||||||
|
|
||||||
input_dict = self._input_blocks_mapping(input_dict)
|
|
||||||
output_dict = self._output_blocks_mapping(output_dict)
|
|
||||||
mid_dict = self._mid_blocks_mapping(mid_dict)
|
|
||||||
other_dict = self._other_blocks_mapping(other_dict)
|
|
||||||
# key_list = state_dict.keys()
|
|
||||||
# key_str = " ".join(key_list)
|
|
||||||
|
|
||||||
# for key_, val_ in state_dict.items():
|
|
||||||
# key_ = key_.replace("down_blocks", "input_blocks")\
|
|
||||||
# .replace("up_blocks", 'output_blocks')
|
|
||||||
# res_dict[key_] = val_
|
|
||||||
res_dict.update(input_dict)
|
|
||||||
res_dict.update(output_dict)
|
|
||||||
res_dict.update(mid_dict)
|
|
||||||
res_dict.update(other_dict)
|
|
||||||
|
|
||||||
return res_dict
|
|
||||||
|
|
||||||
def _load_pretrained_model(self, state_dict, ignore_mismatched_sizes=False):
|
|
||||||
state_dict = self._state_key_mapping(state_dict)
|
|
||||||
model_state_dict = self.state_dict()
|
|
||||||
loaded_keys = [k for k in state_dict.keys()]
|
|
||||||
expected_keys = list(model_state_dict.keys())
|
|
||||||
original_loaded_keys = loaded_keys
|
|
||||||
missing_keys = list(set(expected_keys) - set(loaded_keys))
|
|
||||||
unexpected_keys = list(set(loaded_keys) - set(expected_keys))
|
|
||||||
|
|
||||||
def _find_mismatched_keys(
|
|
||||||
state_dict,
|
|
||||||
model_state_dict,
|
|
||||||
loaded_keys,
|
|
||||||
ignore_mismatched_sizes,
|
|
||||||
):
|
|
||||||
mismatched_keys = []
|
|
||||||
if ignore_mismatched_sizes:
|
|
||||||
for checkpoint_key in loaded_keys:
|
|
||||||
model_key = checkpoint_key
|
|
||||||
|
|
||||||
if (
|
|
||||||
model_key in model_state_dict
|
|
||||||
and state_dict[checkpoint_key].shape != model_state_dict[model_key].shape
|
|
||||||
):
|
|
||||||
mismatched_keys.append(
|
|
||||||
(checkpoint_key, state_dict[checkpoint_key].shape, model_state_dict[model_key].shape)
|
|
||||||
)
|
|
||||||
del state_dict[checkpoint_key]
|
|
||||||
return mismatched_keys
|
|
||||||
if state_dict is not None:
|
|
||||||
# Whole checkpoint
|
|
||||||
mismatched_keys = _find_mismatched_keys(
|
|
||||||
state_dict,
|
|
||||||
model_state_dict,
|
|
||||||
original_loaded_keys,
|
|
||||||
ignore_mismatched_sizes,
|
|
||||||
)
|
|
||||||
error_msgs = self._load_state_dict_into_model(state_dict)
|
|
||||||
return missing_keys, unexpected_keys, mismatched_keys, error_msgs
|
|
||||||
|
|
||||||
def _load_state_dict_into_model(self, state_dict):
|
|
||||||
# Convert old format to new format if needed from a PyTorch state_dict
|
|
||||||
# copy state_dict so _load_from_state_dict can modify it
|
|
||||||
state_dict = state_dict.copy()
|
|
||||||
error_msgs = []
|
|
||||||
|
|
||||||
# PyTorch's `_load_from_state_dict` does not copy parameters in a module's descendants
|
|
||||||
# so we need to apply the function recursively.
|
|
||||||
def load(module: torch.nn.Module, prefix=""):
|
|
||||||
args = (state_dict, prefix, {}, True, [], [], error_msgs)
|
|
||||||
module._load_from_state_dict(*args)
|
|
||||||
|
|
||||||
for name, child in module._modules.items():
|
|
||||||
if child is not None:
|
|
||||||
load(child, prefix + name + ".")
|
|
||||||
|
|
||||||
load(self)
|
|
||||||
|
|
||||||
return error_msgs
|
|
||||||
|
|
||||||
def convert_to_fp16(self):
|
def convert_to_fp16(self):
|
||||||
"""
|
"""
|
||||||
|
@ -912,10 +765,11 @@ class UNetModel(nn.Module):
|
||||||
), "must specify y if and only if the model is class-conditional"
|
), "must specify y if and only if the model is class-conditional"
|
||||||
hs = []
|
hs = []
|
||||||
t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False)
|
t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False)
|
||||||
|
t_emb = t_emb.type(self.dtype)
|
||||||
emb = self.time_embed(t_emb)
|
emb = self.time_embed(t_emb)
|
||||||
|
|
||||||
if self.num_classes is not None:
|
if self.num_classes is not None:
|
||||||
assert y.shape == (x.shape[0],)
|
assert y.shape[0] == x.shape[0]
|
||||||
emb = emb + self.label_emb(y)
|
emb = emb + self.label_emb(y)
|
||||||
|
|
||||||
h = x.type(self.dtype)
|
h = x.type(self.dtype)
|
||||||
|
@ -926,227 +780,8 @@ class UNetModel(nn.Module):
|
||||||
for module in self.output_blocks:
|
for module in self.output_blocks:
|
||||||
h = th.cat([h, hs.pop()], dim=1)
|
h = th.cat([h, hs.pop()], dim=1)
|
||||||
h = module(h, emb, context)
|
h = module(h, emb, context)
|
||||||
h = h.type(self.dtype)
|
h = h.type(x.dtype)
|
||||||
if self.predict_codebook_ids:
|
if self.predict_codebook_ids:
|
||||||
return self.id_predictor(h)
|
return self.id_predictor(h)
|
||||||
else:
|
else:
|
||||||
return self.out(h)
|
return self.out(h)
|
||||||
|
|
||||||
|
|
||||||
class EncoderUNetModel(nn.Module):
|
|
||||||
"""
|
|
||||||
The half UNet model with attention and timestep embedding.
|
|
||||||
For usage, see UNet.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
image_size,
|
|
||||||
in_channels,
|
|
||||||
model_channels,
|
|
||||||
out_channels,
|
|
||||||
num_res_blocks,
|
|
||||||
attention_resolutions,
|
|
||||||
dropout=0,
|
|
||||||
channel_mult=(1, 2, 4, 8),
|
|
||||||
conv_resample=True,
|
|
||||||
dims=2,
|
|
||||||
use_checkpoint=False,
|
|
||||||
use_fp16=False,
|
|
||||||
num_heads=1,
|
|
||||||
num_head_channels=-1,
|
|
||||||
num_heads_upsample=-1,
|
|
||||||
use_scale_shift_norm=False,
|
|
||||||
resblock_updown=False,
|
|
||||||
use_new_attention_order=False,
|
|
||||||
pool="adaptive",
|
|
||||||
*args,
|
|
||||||
**kwargs
|
|
||||||
):
|
|
||||||
super().__init__()
|
|
||||||
|
|
||||||
if num_heads_upsample == -1:
|
|
||||||
num_heads_upsample = num_heads
|
|
||||||
|
|
||||||
self.in_channels = in_channels
|
|
||||||
self.model_channels = model_channels
|
|
||||||
self.out_channels = out_channels
|
|
||||||
self.num_res_blocks = num_res_blocks
|
|
||||||
self.attention_resolutions = attention_resolutions
|
|
||||||
self.dropout = dropout
|
|
||||||
self.channel_mult = channel_mult
|
|
||||||
self.conv_resample = conv_resample
|
|
||||||
self.use_checkpoint = use_checkpoint
|
|
||||||
self.dtype = th.float16 if use_fp16 else th.float32
|
|
||||||
self.num_heads = num_heads
|
|
||||||
self.num_head_channels = num_head_channels
|
|
||||||
self.num_heads_upsample = num_heads_upsample
|
|
||||||
|
|
||||||
time_embed_dim = model_channels * 4
|
|
||||||
self.time_embed = nn.Sequential(
|
|
||||||
linear(model_channels, time_embed_dim),
|
|
||||||
nn.SiLU(),
|
|
||||||
linear(time_embed_dim, time_embed_dim),
|
|
||||||
)
|
|
||||||
|
|
||||||
self.input_blocks = nn.ModuleList(
|
|
||||||
[
|
|
||||||
TimestepEmbedSequential(
|
|
||||||
conv_nd(dims, in_channels, model_channels, 3, padding=1)
|
|
||||||
)
|
|
||||||
]
|
|
||||||
)
|
|
||||||
self._feature_size = model_channels
|
|
||||||
input_block_chans = [model_channels]
|
|
||||||
ch = model_channels
|
|
||||||
ds = 1
|
|
||||||
for level, mult in enumerate(channel_mult):
|
|
||||||
for _ in range(num_res_blocks):
|
|
||||||
layers = [
|
|
||||||
ResBlock(
|
|
||||||
ch,
|
|
||||||
time_embed_dim,
|
|
||||||
dropout,
|
|
||||||
out_channels=mult * model_channels,
|
|
||||||
dims=dims,
|
|
||||||
use_checkpoint=use_checkpoint,
|
|
||||||
use_scale_shift_norm=use_scale_shift_norm,
|
|
||||||
)
|
|
||||||
]
|
|
||||||
ch = mult * model_channels
|
|
||||||
if ds in attention_resolutions:
|
|
||||||
layers.append(
|
|
||||||
AttentionBlock(
|
|
||||||
ch,
|
|
||||||
use_checkpoint=use_checkpoint,
|
|
||||||
num_heads=num_heads,
|
|
||||||
num_head_channels=num_head_channels,
|
|
||||||
use_new_attention_order=use_new_attention_order,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
self.input_blocks.append(TimestepEmbedSequential(*layers))
|
|
||||||
self._feature_size += ch
|
|
||||||
input_block_chans.append(ch)
|
|
||||||
if level != len(channel_mult) - 1:
|
|
||||||
out_ch = ch
|
|
||||||
self.input_blocks.append(
|
|
||||||
TimestepEmbedSequential(
|
|
||||||
ResBlock(
|
|
||||||
ch,
|
|
||||||
time_embed_dim,
|
|
||||||
dropout,
|
|
||||||
out_channels=out_ch,
|
|
||||||
dims=dims,
|
|
||||||
use_checkpoint=use_checkpoint,
|
|
||||||
use_scale_shift_norm=use_scale_shift_norm,
|
|
||||||
down=True,
|
|
||||||
)
|
|
||||||
if resblock_updown
|
|
||||||
else Downsample(
|
|
||||||
ch, conv_resample, dims=dims, out_channels=out_ch
|
|
||||||
)
|
|
||||||
)
|
|
||||||
)
|
|
||||||
ch = out_ch
|
|
||||||
input_block_chans.append(ch)
|
|
||||||
ds *= 2
|
|
||||||
self._feature_size += ch
|
|
||||||
|
|
||||||
self.middle_block = TimestepEmbedSequential(
|
|
||||||
ResBlock(
|
|
||||||
ch,
|
|
||||||
time_embed_dim,
|
|
||||||
dropout,
|
|
||||||
dims=dims,
|
|
||||||
use_checkpoint=use_checkpoint,
|
|
||||||
use_scale_shift_norm=use_scale_shift_norm,
|
|
||||||
),
|
|
||||||
AttentionBlock(
|
|
||||||
ch,
|
|
||||||
use_checkpoint=use_checkpoint,
|
|
||||||
num_heads=num_heads,
|
|
||||||
num_head_channels=num_head_channels,
|
|
||||||
use_new_attention_order=use_new_attention_order,
|
|
||||||
),
|
|
||||||
ResBlock(
|
|
||||||
ch,
|
|
||||||
time_embed_dim,
|
|
||||||
dropout,
|
|
||||||
dims=dims,
|
|
||||||
use_checkpoint=use_checkpoint,
|
|
||||||
use_scale_shift_norm=use_scale_shift_norm,
|
|
||||||
),
|
|
||||||
)
|
|
||||||
self._feature_size += ch
|
|
||||||
self.pool = pool
|
|
||||||
if pool == "adaptive":
|
|
||||||
self.out = nn.Sequential(
|
|
||||||
normalization(ch),
|
|
||||||
nn.SiLU(),
|
|
||||||
nn.AdaptiveAvgPool2d((1, 1)),
|
|
||||||
zero_module(conv_nd(dims, ch, out_channels, 1)),
|
|
||||||
nn.Flatten(),
|
|
||||||
)
|
|
||||||
elif pool == "attention":
|
|
||||||
assert num_head_channels != -1
|
|
||||||
self.out = nn.Sequential(
|
|
||||||
normalization(ch),
|
|
||||||
nn.SiLU(),
|
|
||||||
AttentionPool2d(
|
|
||||||
(image_size // ds), ch, num_head_channels, out_channels
|
|
||||||
),
|
|
||||||
)
|
|
||||||
elif pool == "spatial":
|
|
||||||
self.out = nn.Sequential(
|
|
||||||
nn.Linear(self._feature_size, 2048),
|
|
||||||
nn.ReLU(),
|
|
||||||
nn.Linear(2048, self.out_channels),
|
|
||||||
)
|
|
||||||
elif pool == "spatial_v2":
|
|
||||||
self.out = nn.Sequential(
|
|
||||||
nn.Linear(self._feature_size, 2048),
|
|
||||||
normalization(2048),
|
|
||||||
nn.SiLU(),
|
|
||||||
nn.Linear(2048, self.out_channels),
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
raise NotImplementedError(f"Unexpected {pool} pooling")
|
|
||||||
|
|
||||||
def convert_to_fp16(self):
|
|
||||||
"""
|
|
||||||
Convert the torso of the model to float16.
|
|
||||||
"""
|
|
||||||
self.input_blocks.apply(convert_module_to_f16)
|
|
||||||
self.middle_block.apply(convert_module_to_f16)
|
|
||||||
|
|
||||||
def convert_to_fp32(self):
|
|
||||||
"""
|
|
||||||
Convert the torso of the model to float32.
|
|
||||||
"""
|
|
||||||
self.input_blocks.apply(convert_module_to_f32)
|
|
||||||
self.middle_block.apply(convert_module_to_f32)
|
|
||||||
|
|
||||||
def forward(self, x, timesteps):
|
|
||||||
"""
|
|
||||||
Apply the model to an input batch.
|
|
||||||
:param x: an [N x C x ...] Tensor of inputs.
|
|
||||||
:param timesteps: a 1-D batch of timesteps.
|
|
||||||
:return: an [N x K] Tensor of outputs.
|
|
||||||
"""
|
|
||||||
emb = self.time_embed(timestep_embedding(timesteps, self.model_channels))
|
|
||||||
|
|
||||||
results = []
|
|
||||||
h = x.type(self.dtype)
|
|
||||||
for module in self.input_blocks:
|
|
||||||
h = module(h, emb)
|
|
||||||
if self.pool.startswith("spatial"):
|
|
||||||
results.append(h.type(x.dtype).mean(dim=(2, 3)))
|
|
||||||
h = self.middle_block(h, emb)
|
|
||||||
if self.pool.startswith("spatial"):
|
|
||||||
results.append(h.type(x.dtype).mean(dim=(2, 3)))
|
|
||||||
h = th.cat(results, axis=-1)
|
|
||||||
return self.out(h)
|
|
||||||
else:
|
|
||||||
h = h.type(self.dtype)
|
|
||||||
return self.out(h)
|
|
||||||
|
|
||||||
|
|
|
@ -0,0 +1,81 @@
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import numpy as np
|
||||||
|
from functools import partial
|
||||||
|
|
||||||
|
from ldm.modules.diffusionmodules.util import extract_into_tensor, make_beta_schedule
|
||||||
|
from ldm.util import default
|
||||||
|
|
||||||
|
|
||||||
|
class AbstractLowScaleModel(nn.Module):
|
||||||
|
# for concatenating a downsampled image to the latent representation
|
||||||
|
def __init__(self, noise_schedule_config=None):
|
||||||
|
super(AbstractLowScaleModel, self).__init__()
|
||||||
|
if noise_schedule_config is not None:
|
||||||
|
self.register_schedule(**noise_schedule_config)
|
||||||
|
|
||||||
|
def register_schedule(self, beta_schedule="linear", timesteps=1000,
|
||||||
|
linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3):
|
||||||
|
betas = make_beta_schedule(beta_schedule, timesteps, linear_start=linear_start, linear_end=linear_end,
|
||||||
|
cosine_s=cosine_s)
|
||||||
|
alphas = 1. - betas
|
||||||
|
alphas_cumprod = np.cumprod(alphas, axis=0)
|
||||||
|
alphas_cumprod_prev = np.append(1., alphas_cumprod[:-1])
|
||||||
|
|
||||||
|
timesteps, = betas.shape
|
||||||
|
self.num_timesteps = int(timesteps)
|
||||||
|
self.linear_start = linear_start
|
||||||
|
self.linear_end = linear_end
|
||||||
|
assert alphas_cumprod.shape[0] == self.num_timesteps, 'alphas have to be defined for each timestep'
|
||||||
|
|
||||||
|
to_torch = partial(torch.tensor, dtype=torch.float32)
|
||||||
|
|
||||||
|
self.register_buffer('betas', to_torch(betas))
|
||||||
|
self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod))
|
||||||
|
self.register_buffer('alphas_cumprod_prev', to_torch(alphas_cumprod_prev))
|
||||||
|
|
||||||
|
# calculations for diffusion q(x_t | x_{t-1}) and others
|
||||||
|
self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod)))
|
||||||
|
self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod)))
|
||||||
|
self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod)))
|
||||||
|
self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod)))
|
||||||
|
self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod - 1)))
|
||||||
|
|
||||||
|
def q_sample(self, x_start, t, noise=None):
|
||||||
|
noise = default(noise, lambda: torch.randn_like(x_start))
|
||||||
|
return (extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start +
|
||||||
|
extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
return x, None
|
||||||
|
|
||||||
|
def decode(self, x):
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class SimpleImageConcat(AbstractLowScaleModel):
|
||||||
|
# no noise level conditioning
|
||||||
|
def __init__(self):
|
||||||
|
super(SimpleImageConcat, self).__init__(noise_schedule_config=None)
|
||||||
|
self.max_noise_level = 0
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
# fix to constant noise level
|
||||||
|
return x, torch.zeros(x.shape[0], device=x.device).long()
|
||||||
|
|
||||||
|
|
||||||
|
class ImageConcatWithNoiseAugmentation(AbstractLowScaleModel):
|
||||||
|
def __init__(self, noise_schedule_config, max_noise_level=1000, to_cuda=False):
|
||||||
|
super().__init__(noise_schedule_config=noise_schedule_config)
|
||||||
|
self.max_noise_level = max_noise_level
|
||||||
|
|
||||||
|
def forward(self, x, noise_level=None):
|
||||||
|
if noise_level is None:
|
||||||
|
noise_level = torch.randint(0, self.max_noise_level, (x.shape[0],), device=x.device).long()
|
||||||
|
else:
|
||||||
|
assert isinstance(noise_level, torch.Tensor)
|
||||||
|
z = self.q_sample(x, noise_level)
|
||||||
|
return z, noise_level
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -122,7 +122,9 @@ class CheckpointFunction(torch.autograd.Function):
|
||||||
ctx.run_function = run_function
|
ctx.run_function = run_function
|
||||||
ctx.input_tensors = list(args[:length])
|
ctx.input_tensors = list(args[:length])
|
||||||
ctx.input_params = list(args[length:])
|
ctx.input_params = list(args[length:])
|
||||||
|
ctx.gpu_autocast_kwargs = {"enabled": torch.is_autocast_enabled(),
|
||||||
|
"dtype": torch.get_autocast_gpu_dtype(),
|
||||||
|
"cache_enabled": torch.is_autocast_cache_enabled()}
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
output_tensors = ctx.run_function(*ctx.input_tensors)
|
output_tensors = ctx.run_function(*ctx.input_tensors)
|
||||||
return output_tensors
|
return output_tensors
|
||||||
|
@ -130,7 +132,8 @@ class CheckpointFunction(torch.autograd.Function):
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def backward(ctx, *output_grads):
|
def backward(ctx, *output_grads):
|
||||||
ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors]
|
ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors]
|
||||||
with torch.enable_grad():
|
with torch.enable_grad(), \
|
||||||
|
torch.cuda.amp.autocast(**ctx.gpu_autocast_kwargs):
|
||||||
# Fixes a bug where the first op in run_function modifies the
|
# Fixes a bug where the first op in run_function modifies the
|
||||||
# Tensor storage in place, which is not allowed for detach()'d
|
# Tensor storage in place, which is not allowed for detach()'d
|
||||||
# Tensors.
|
# Tensors.
|
||||||
|
@ -148,7 +151,7 @@ class CheckpointFunction(torch.autograd.Function):
|
||||||
return (None, None) + input_grads
|
return (None, None) + input_grads
|
||||||
|
|
||||||
|
|
||||||
def timestep_embedding(timesteps, dim, max_period=10000, repeat_only=False, use_fp16=True):
|
def timestep_embedding(timesteps, dim, max_period=10000, repeat_only=False):
|
||||||
"""
|
"""
|
||||||
Create sinusoidal timestep embeddings.
|
Create sinusoidal timestep embeddings.
|
||||||
:param timesteps: a 1-D Tensor of N indices, one per batch element.
|
:param timesteps: a 1-D Tensor of N indices, one per batch element.
|
||||||
|
@ -168,10 +171,7 @@ def timestep_embedding(timesteps, dim, max_period=10000, repeat_only=False, use_
|
||||||
embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
|
embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
|
||||||
else:
|
else:
|
||||||
embedding = repeat(timesteps, 'b -> b d', d=dim)
|
embedding = repeat(timesteps, 'b -> b d', d=dim)
|
||||||
if use_fp16:
|
return embedding
|
||||||
return embedding.half()
|
|
||||||
else:
|
|
||||||
return embedding
|
|
||||||
|
|
||||||
|
|
||||||
def zero_module(module):
|
def zero_module(module):
|
||||||
|
@ -199,16 +199,14 @@ def mean_flat(tensor):
|
||||||
return tensor.mean(dim=list(range(1, len(tensor.shape))))
|
return tensor.mean(dim=list(range(1, len(tensor.shape))))
|
||||||
|
|
||||||
|
|
||||||
def normalization(channels, precision=16):
|
def normalization(channels):
|
||||||
"""
|
"""
|
||||||
Make a standard normalization layer.
|
Make a standard normalization layer.
|
||||||
:param channels: number of input channels.
|
:param channels: number of input channels.
|
||||||
:return: an nn.Module for normalization.
|
:return: an nn.Module for normalization.
|
||||||
"""
|
"""
|
||||||
if precision == 16:
|
return nn.GroupNorm(16, channels)
|
||||||
return GroupNorm16(16, channels)
|
# return GroupNorm32(32, channels)
|
||||||
else:
|
|
||||||
return GroupNorm32(32, channels)
|
|
||||||
|
|
||||||
|
|
||||||
# PyTorch 1.7 has SiLU, but we support PyTorch 1.5.
|
# PyTorch 1.7 has SiLU, but we support PyTorch 1.5.
|
||||||
|
@ -216,9 +214,6 @@ class SiLU(nn.Module):
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
return x * torch.sigmoid(x)
|
return x * torch.sigmoid(x)
|
||||||
|
|
||||||
class GroupNorm16(nn.GroupNorm):
|
|
||||||
def forward(self, x):
|
|
||||||
return super().forward(x.half()).type(x.dtype)
|
|
||||||
|
|
||||||
class GroupNorm32(nn.GroupNorm):
|
class GroupNorm32(nn.GroupNorm):
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
|
|
|
@ -10,24 +10,28 @@ class LitEma(nn.Module):
|
||||||
|
|
||||||
self.m_name2s_name = {}
|
self.m_name2s_name = {}
|
||||||
self.register_buffer('decay', torch.tensor(decay, dtype=torch.float32))
|
self.register_buffer('decay', torch.tensor(decay, dtype=torch.float32))
|
||||||
self.register_buffer('num_updates', torch.tensor(0,dtype=torch.int) if use_num_upates
|
self.register_buffer('num_updates', torch.tensor(0, dtype=torch.int) if use_num_upates
|
||||||
else torch.tensor(-1,dtype=torch.int))
|
else torch.tensor(-1, dtype=torch.int))
|
||||||
|
|
||||||
for name, p in model.named_parameters():
|
for name, p in model.named_parameters():
|
||||||
if p.requires_grad:
|
if p.requires_grad:
|
||||||
#remove as '.'-character is not allowed in buffers
|
# remove as '.'-character is not allowed in buffers
|
||||||
s_name = name.replace('.','')
|
s_name = name.replace('.', '')
|
||||||
self.m_name2s_name.update({name:s_name})
|
self.m_name2s_name.update({name: s_name})
|
||||||
self.register_buffer(s_name,p.clone().detach().data)
|
self.register_buffer(s_name, p.clone().detach().data)
|
||||||
|
|
||||||
self.collected_params = []
|
self.collected_params = []
|
||||||
|
|
||||||
def forward(self,model):
|
def reset_num_updates(self):
|
||||||
|
del self.num_updates
|
||||||
|
self.register_buffer('num_updates', torch.tensor(0, dtype=torch.int))
|
||||||
|
|
||||||
|
def forward(self, model):
|
||||||
decay = self.decay
|
decay = self.decay
|
||||||
|
|
||||||
if self.num_updates >= 0:
|
if self.num_updates >= 0:
|
||||||
self.num_updates += 1
|
self.num_updates += 1
|
||||||
decay = min(self.decay,(1 + self.num_updates) / (10 + self.num_updates))
|
decay = min(self.decay, (1 + self.num_updates) / (10 + self.num_updates))
|
||||||
|
|
||||||
one_minus_decay = 1.0 - decay
|
one_minus_decay = 1.0 - decay
|
||||||
|
|
||||||
|
|
|
@ -1,15 +1,11 @@
|
||||||
import types
|
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from functools import partial
|
from torch.utils.checkpoint import checkpoint
|
||||||
import clip
|
|
||||||
from einops import rearrange, repeat
|
|
||||||
from transformers import CLIPTokenizer, CLIPTextModel, CLIPTextConfig
|
|
||||||
import kornia
|
|
||||||
from transformers.models.clip.modeling_clip import CLIPTextTransformer
|
|
||||||
|
|
||||||
from ldm.modules.x_transformer import Encoder, TransformerWrapper # TODO: can we directly rely on lucidrains code and simply add this as a reuirement? --> test
|
from transformers import T5Tokenizer, T5EncoderModel, CLIPTokenizer, CLIPTextModel
|
||||||
|
|
||||||
|
import open_clip
|
||||||
|
from ldm.util import default, count_params
|
||||||
|
|
||||||
|
|
||||||
class AbstractEncoder(nn.Module):
|
class AbstractEncoder(nn.Module):
|
||||||
|
@ -20,169 +16,66 @@ class AbstractEncoder(nn.Module):
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
||||||
|
class IdentityEncoder(AbstractEncoder):
|
||||||
|
|
||||||
|
def encode(self, x):
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
class ClassEmbedder(nn.Module):
|
class ClassEmbedder(nn.Module):
|
||||||
def __init__(self, embed_dim, n_classes=1000, key='class'):
|
def __init__(self, embed_dim, n_classes=1000, key='class', ucg_rate=0.1):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.key = key
|
self.key = key
|
||||||
self.embedding = nn.Embedding(n_classes, embed_dim)
|
self.embedding = nn.Embedding(n_classes, embed_dim)
|
||||||
|
self.n_classes = n_classes
|
||||||
|
self.ucg_rate = ucg_rate
|
||||||
|
|
||||||
def forward(self, batch, key=None):
|
def forward(self, batch, key=None, disable_dropout=False):
|
||||||
if key is None:
|
if key is None:
|
||||||
key = self.key
|
key = self.key
|
||||||
# this is for use in crossattn
|
# this is for use in crossattn
|
||||||
c = batch[key][:, None]
|
c = batch[key][:, None]
|
||||||
|
if self.ucg_rate > 0. and not disable_dropout:
|
||||||
|
mask = 1. - torch.bernoulli(torch.ones_like(c) * self.ucg_rate)
|
||||||
|
c = mask * c + (1-mask) * torch.ones_like(c)*(self.n_classes-1)
|
||||||
|
c = c.long()
|
||||||
c = self.embedding(c)
|
c = self.embedding(c)
|
||||||
return c
|
return c
|
||||||
|
|
||||||
|
def get_unconditional_conditioning(self, bs, device="cuda"):
|
||||||
|
uc_class = self.n_classes - 1 # 1000 classes --> 0 ... 999, one extra class for ucg (class 1000)
|
||||||
|
uc = torch.ones((bs,), device=device) * uc_class
|
||||||
|
uc = {self.key: uc}
|
||||||
|
return uc
|
||||||
|
|
||||||
class TransformerEmbedder(AbstractEncoder):
|
|
||||||
"""Some transformer encoder layers"""
|
def disabled_train(self, mode=True):
|
||||||
def __init__(self, n_embed, n_layer, vocab_size, max_seq_len=77, device="cuda"):
|
"""Overwrite model.train with this function to make sure train/eval mode
|
||||||
|
does not change anymore."""
|
||||||
|
return self
|
||||||
|
|
||||||
|
|
||||||
|
class FrozenT5Embedder(AbstractEncoder):
|
||||||
|
"""Uses the T5 transformer encoder for text"""
|
||||||
|
def __init__(self, version="google/t5-v1_1-large", device="cuda", max_length=77, freeze=True): # others are google/t5-v1_1-xl and google/t5-v1_1-xxl
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
self.tokenizer = T5Tokenizer.from_pretrained(version)
|
||||||
|
self.transformer = T5EncoderModel.from_pretrained(version)
|
||||||
self.device = device
|
self.device = device
|
||||||
self.transformer = TransformerWrapper(num_tokens=vocab_size, max_seq_len=max_seq_len,
|
self.max_length = max_length # TODO: typical value?
|
||||||
attn_layers=Encoder(dim=n_embed, depth=n_layer))
|
if freeze:
|
||||||
|
self.freeze()
|
||||||
def forward(self, tokens):
|
|
||||||
tokens = tokens.to(self.device) # meh
|
|
||||||
z = self.transformer(tokens, return_embeddings=True)
|
|
||||||
return z
|
|
||||||
|
|
||||||
def encode(self, x):
|
|
||||||
return self(x)
|
|
||||||
|
|
||||||
|
|
||||||
class BERTTokenizer(AbstractEncoder):
|
|
||||||
""" Uses a pretrained BERT tokenizer by huggingface. Vocab size: 30522 (?)"""
|
|
||||||
def __init__(self, device="cuda", vq_interface=True, max_length=77):
|
|
||||||
super().__init__()
|
|
||||||
from transformers import BertTokenizerFast # TODO: add to reuquirements
|
|
||||||
self.tokenizer = BertTokenizerFast.from_pretrained("bert-base-uncased")
|
|
||||||
self.device = device
|
|
||||||
self.vq_interface = vq_interface
|
|
||||||
self.max_length = max_length
|
|
||||||
|
|
||||||
def forward(self, text):
|
|
||||||
batch_encoding = self.tokenizer(text, truncation=True, max_length=self.max_length, return_length=True,
|
|
||||||
return_overflowing_tokens=False, padding="max_length", return_tensors="pt")
|
|
||||||
tokens = batch_encoding["input_ids"].to(self.device)
|
|
||||||
return tokens
|
|
||||||
|
|
||||||
@torch.no_grad()
|
|
||||||
def encode(self, text):
|
|
||||||
tokens = self(text)
|
|
||||||
if not self.vq_interface:
|
|
||||||
return tokens
|
|
||||||
return None, None, [None, None, tokens]
|
|
||||||
|
|
||||||
def decode(self, text):
|
|
||||||
return text
|
|
||||||
|
|
||||||
|
|
||||||
class BERTEmbedder(AbstractEncoder):
|
|
||||||
"""Uses the BERT tokenizr model and add some transformer encoder layers"""
|
|
||||||
def __init__(self, n_embed, n_layer, vocab_size=30522, max_seq_len=77,
|
|
||||||
device="cuda",use_tokenizer=True, embedding_dropout=0.0):
|
|
||||||
super().__init__()
|
|
||||||
self.use_tknz_fn = use_tokenizer
|
|
||||||
if self.use_tknz_fn:
|
|
||||||
self.tknz_fn = BERTTokenizer(vq_interface=False, max_length=max_seq_len)
|
|
||||||
self.device = device
|
|
||||||
self.transformer = TransformerWrapper(num_tokens=vocab_size, max_seq_len=max_seq_len,
|
|
||||||
attn_layers=Encoder(dim=n_embed, depth=n_layer),
|
|
||||||
emb_dropout=embedding_dropout)
|
|
||||||
|
|
||||||
def forward(self, text):
|
|
||||||
if self.use_tknz_fn:
|
|
||||||
tokens = self.tknz_fn(text)#.to(self.device)
|
|
||||||
else:
|
|
||||||
tokens = text
|
|
||||||
z = self.transformer(tokens, return_embeddings=True)
|
|
||||||
return z
|
|
||||||
|
|
||||||
def encode(self, text):
|
|
||||||
# output of length 77
|
|
||||||
return self(text)
|
|
||||||
|
|
||||||
|
|
||||||
class SpatialRescaler(nn.Module):
|
|
||||||
def __init__(self,
|
|
||||||
n_stages=1,
|
|
||||||
method='bilinear',
|
|
||||||
multiplier=0.5,
|
|
||||||
in_channels=3,
|
|
||||||
out_channels=None,
|
|
||||||
bias=False):
|
|
||||||
super().__init__()
|
|
||||||
self.n_stages = n_stages
|
|
||||||
assert self.n_stages >= 0
|
|
||||||
assert method in ['nearest','linear','bilinear','trilinear','bicubic','area']
|
|
||||||
self.multiplier = multiplier
|
|
||||||
self.interpolator = partial(torch.nn.functional.interpolate, mode=method)
|
|
||||||
self.remap_output = out_channels is not None
|
|
||||||
if self.remap_output:
|
|
||||||
print(f'Spatial Rescaler mapping from {in_channels} to {out_channels} channels after resizing.')
|
|
||||||
self.channel_mapper = nn.Conv2d(in_channels,out_channels,1,bias=bias)
|
|
||||||
|
|
||||||
def forward(self,x):
|
|
||||||
for stage in range(self.n_stages):
|
|
||||||
x = self.interpolator(x, scale_factor=self.multiplier)
|
|
||||||
|
|
||||||
|
|
||||||
if self.remap_output:
|
|
||||||
x = self.channel_mapper(x)
|
|
||||||
return x
|
|
||||||
|
|
||||||
def encode(self, x):
|
|
||||||
return self(x)
|
|
||||||
|
|
||||||
|
|
||||||
class CLIPTextModelZero(CLIPTextModel):
|
|
||||||
config_class = CLIPTextConfig
|
|
||||||
|
|
||||||
def __init__(self, config: CLIPTextConfig):
|
|
||||||
super().__init__(config)
|
|
||||||
self.text_model = CLIPTextTransformerZero(config)
|
|
||||||
|
|
||||||
class CLIPTextTransformerZero(CLIPTextTransformer):
|
|
||||||
def _build_causal_attention_mask(self, bsz, seq_len):
|
|
||||||
# lazily create causal attention mask, with full attention between the vision tokens
|
|
||||||
# pytorch uses additive attention mask; fill with -inf
|
|
||||||
mask = torch.empty(bsz, seq_len, seq_len)
|
|
||||||
mask.fill_(float("-inf"))
|
|
||||||
mask.triu_(1) # zero out the lower diagonal
|
|
||||||
mask = mask.unsqueeze(1) # expand mask
|
|
||||||
return mask.half()
|
|
||||||
|
|
||||||
class FrozenCLIPEmbedder(AbstractEncoder):
|
|
||||||
"""Uses the CLIP transformer encoder for text (from Hugging Face)"""
|
|
||||||
def __init__(self, version="openai/clip-vit-large-patch14", device="cuda", max_length=77, use_fp16=True):
|
|
||||||
super().__init__()
|
|
||||||
self.tokenizer = CLIPTokenizer.from_pretrained(version)
|
|
||||||
|
|
||||||
if use_fp16:
|
|
||||||
self.transformer = CLIPTextModelZero.from_pretrained(version)
|
|
||||||
else:
|
|
||||||
self.transformer = CLIPTextModel.from_pretrained(version)
|
|
||||||
|
|
||||||
# print(self.transformer.modules())
|
|
||||||
# print("check model dtyoe: {}, {}".format(self.tokenizer.dtype, self.transformer.dtype))
|
|
||||||
self.device = device
|
|
||||||
self.max_length = max_length
|
|
||||||
self.freeze()
|
|
||||||
|
|
||||||
def freeze(self):
|
def freeze(self):
|
||||||
self.transformer = self.transformer.eval()
|
self.transformer = self.transformer.eval()
|
||||||
|
#self.train = disabled_train
|
||||||
for param in self.parameters():
|
for param in self.parameters():
|
||||||
param.requires_grad = False
|
param.requires_grad = False
|
||||||
|
|
||||||
def forward(self, text):
|
def forward(self, text):
|
||||||
batch_encoding = self.tokenizer(text, truncation=True, max_length=self.max_length, return_length=True,
|
batch_encoding = self.tokenizer(text, truncation=True, max_length=self.max_length, return_length=True,
|
||||||
return_overflowing_tokens=False, padding="max_length", return_tensors="pt")
|
return_overflowing_tokens=False, padding="max_length", return_tensors="pt")
|
||||||
# tokens = batch_encoding["input_ids"].to(self.device)
|
|
||||||
tokens = batch_encoding["input_ids"].to(self.device)
|
tokens = batch_encoding["input_ids"].to(self.device)
|
||||||
# print("token type: {}".format(tokens.dtype))
|
|
||||||
outputs = self.transformer(input_ids=tokens)
|
outputs = self.transformer(input_ids=tokens)
|
||||||
|
|
||||||
z = outputs.last_hidden_state
|
z = outputs.last_hidden_state
|
||||||
|
@ -192,17 +85,80 @@ class FrozenCLIPEmbedder(AbstractEncoder):
|
||||||
return self(text)
|
return self(text)
|
||||||
|
|
||||||
|
|
||||||
class FrozenCLIPTextEmbedder(nn.Module):
|
class FrozenCLIPEmbedder(AbstractEncoder):
|
||||||
"""
|
"""Uses the CLIP transformer encoder for text (from huggingface)"""
|
||||||
Uses the CLIP transformer encoder for text.
|
LAYERS = [
|
||||||
"""
|
"last",
|
||||||
def __init__(self, version='ViT-L/14', device="cuda", max_length=77, n_repeat=1, normalize=True):
|
"pooled",
|
||||||
|
"hidden"
|
||||||
|
]
|
||||||
|
def __init__(self, version="openai/clip-vit-large-patch14", device="cuda", max_length=77,
|
||||||
|
freeze=True, layer="last", layer_idx=None): # clip-vit-base-patch32
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.model, _ = clip.load(version, jit=False, device="cpu")
|
assert layer in self.LAYERS
|
||||||
|
self.tokenizer = CLIPTokenizer.from_pretrained(version)
|
||||||
|
self.transformer = CLIPTextModel.from_pretrained(version)
|
||||||
self.device = device
|
self.device = device
|
||||||
self.max_length = max_length
|
self.max_length = max_length
|
||||||
self.n_repeat = n_repeat
|
if freeze:
|
||||||
self.normalize = normalize
|
self.freeze()
|
||||||
|
self.layer = layer
|
||||||
|
self.layer_idx = layer_idx
|
||||||
|
if layer == "hidden":
|
||||||
|
assert layer_idx is not None
|
||||||
|
assert 0 <= abs(layer_idx) <= 12
|
||||||
|
|
||||||
|
def freeze(self):
|
||||||
|
self.transformer = self.transformer.eval()
|
||||||
|
#self.train = disabled_train
|
||||||
|
for param in self.parameters():
|
||||||
|
param.requires_grad = False
|
||||||
|
|
||||||
|
def forward(self, text):
|
||||||
|
batch_encoding = self.tokenizer(text, truncation=True, max_length=self.max_length, return_length=True,
|
||||||
|
return_overflowing_tokens=False, padding="max_length", return_tensors="pt")
|
||||||
|
tokens = batch_encoding["input_ids"].to(self.device)
|
||||||
|
outputs = self.transformer(input_ids=tokens, output_hidden_states=self.layer=="hidden")
|
||||||
|
if self.layer == "last":
|
||||||
|
z = outputs.last_hidden_state
|
||||||
|
elif self.layer == "pooled":
|
||||||
|
z = outputs.pooler_output[:, None, :]
|
||||||
|
else:
|
||||||
|
z = outputs.hidden_states[self.layer_idx]
|
||||||
|
return z
|
||||||
|
|
||||||
|
def encode(self, text):
|
||||||
|
return self(text)
|
||||||
|
|
||||||
|
|
||||||
|
class FrozenOpenCLIPEmbedder(AbstractEncoder):
|
||||||
|
"""
|
||||||
|
Uses the OpenCLIP transformer encoder for text
|
||||||
|
"""
|
||||||
|
LAYERS = [
|
||||||
|
#"pooled",
|
||||||
|
"last",
|
||||||
|
"penultimate"
|
||||||
|
]
|
||||||
|
def __init__(self, arch="ViT-H-14", version="laion2b_s32b_b79k", device="cuda", max_length=77,
|
||||||
|
freeze=True, layer="last"):
|
||||||
|
super().__init__()
|
||||||
|
assert layer in self.LAYERS
|
||||||
|
model, _, _ = open_clip.create_model_and_transforms(arch, device=torch.device('cpu'), pretrained=version)
|
||||||
|
del model.visual
|
||||||
|
self.model = model
|
||||||
|
|
||||||
|
self.device = device
|
||||||
|
self.max_length = max_length
|
||||||
|
if freeze:
|
||||||
|
self.freeze()
|
||||||
|
self.layer = layer
|
||||||
|
if self.layer == "last":
|
||||||
|
self.layer_idx = 0
|
||||||
|
elif self.layer == "penultimate":
|
||||||
|
self.layer_idx = 1
|
||||||
|
else:
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
def freeze(self):
|
def freeze(self):
|
||||||
self.model = self.model.eval()
|
self.model = self.model.eval()
|
||||||
|
@ -210,55 +166,48 @@ class FrozenCLIPTextEmbedder(nn.Module):
|
||||||
param.requires_grad = False
|
param.requires_grad = False
|
||||||
|
|
||||||
def forward(self, text):
|
def forward(self, text):
|
||||||
tokens = clip.tokenize(text).to(self.device)
|
tokens = open_clip.tokenize(text)
|
||||||
z = self.model.encode_text(tokens)
|
z = self.encode_with_transformer(tokens.to(self.device))
|
||||||
if self.normalize:
|
|
||||||
z = z / torch.linalg.norm(z, dim=1, keepdim=True)
|
|
||||||
return z
|
return z
|
||||||
|
|
||||||
def encode(self, text):
|
def encode_with_transformer(self, text):
|
||||||
z = self(text)
|
x = self.model.token_embedding(text) # [batch_size, n_ctx, d_model]
|
||||||
if z.ndim==2:
|
x = x + self.model.positional_embedding
|
||||||
z = z[:, None, :]
|
x = x.permute(1, 0, 2) # NLD -> LND
|
||||||
z = repeat(z, 'b 1 d -> b k d', k=self.n_repeat)
|
x = self.text_transformer_forward(x, attn_mask=self.model.attn_mask)
|
||||||
return z
|
x = x.permute(1, 0, 2) # LND -> NLD
|
||||||
|
x = self.model.ln_final(x)
|
||||||
|
|
||||||
class FrozenClipImageEmbedder(nn.Module):
|
|
||||||
"""
|
|
||||||
Uses the CLIP image encoder.
|
|
||||||
"""
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
model,
|
|
||||||
jit=False,
|
|
||||||
device='cuda' if torch.cuda.is_available() else 'cpu',
|
|
||||||
antialias=False,
|
|
||||||
):
|
|
||||||
super().__init__()
|
|
||||||
self.model, _ = clip.load(name=model, device=device, jit=jit)
|
|
||||||
|
|
||||||
self.antialias = antialias
|
|
||||||
|
|
||||||
self.register_buffer('mean', torch.Tensor([0.48145466, 0.4578275, 0.40821073]), persistent=False)
|
|
||||||
self.register_buffer('std', torch.Tensor([0.26862954, 0.26130258, 0.27577711]), persistent=False)
|
|
||||||
|
|
||||||
def preprocess(self, x):
|
|
||||||
# normalize to [0,1]
|
|
||||||
x = kornia.geometry.resize(x, (224, 224),
|
|
||||||
interpolation='bicubic',align_corners=True,
|
|
||||||
antialias=self.antialias)
|
|
||||||
x = (x + 1.) / 2.
|
|
||||||
# renormalize according to clip
|
|
||||||
x = kornia.enhance.normalize(x, self.mean, self.std)
|
|
||||||
return x
|
return x
|
||||||
|
|
||||||
def forward(self, x):
|
def text_transformer_forward(self, x: torch.Tensor, attn_mask = None):
|
||||||
# x is assumed to be in range [-1,1]
|
for i, r in enumerate(self.model.transformer.resblocks):
|
||||||
return self.model.encode_image(self.preprocess(x))
|
if i == len(self.model.transformer.resblocks) - self.layer_idx:
|
||||||
|
break
|
||||||
|
if self.model.transformer.grad_checkpointing and not torch.jit.is_scripting():
|
||||||
|
x = checkpoint(r, x, attn_mask)
|
||||||
|
else:
|
||||||
|
x = r(x, attn_mask=attn_mask)
|
||||||
|
return x
|
||||||
|
|
||||||
|
def encode(self, text):
|
||||||
|
return self(text)
|
||||||
|
|
||||||
|
|
||||||
|
class FrozenCLIPT5Encoder(AbstractEncoder):
|
||||||
|
def __init__(self, clip_version="openai/clip-vit-large-patch14", t5_version="google/t5-v1_1-xl", device="cuda",
|
||||||
|
clip_max_length=77, t5_max_length=77):
|
||||||
|
super().__init__()
|
||||||
|
self.clip_encoder = FrozenCLIPEmbedder(clip_version, device, max_length=clip_max_length)
|
||||||
|
self.t5_encoder = FrozenT5Embedder(t5_version, device, max_length=t5_max_length)
|
||||||
|
print(f"{self.clip_encoder.__class__.__name__} has {count_params(self.clip_encoder)*1.e-6:.2f} M parameters, "
|
||||||
|
f"{self.t5_encoder.__class__.__name__} comes with {count_params(self.t5_encoder)*1.e-6:.2f} M params.")
|
||||||
|
|
||||||
|
def encode(self, text):
|
||||||
|
return self(text)
|
||||||
|
|
||||||
|
def forward(self, text):
|
||||||
|
clip_z = self.clip_encoder.encode(text)
|
||||||
|
t5_z = self.t5_encoder.encode(text)
|
||||||
|
return [clip_z, t5_z]
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
from ldm.util import count_params
|
|
||||||
model = FrozenCLIPEmbedder()
|
|
||||||
count_params(model, verbose=True)
|
|
|
@ -1,50 +0,0 @@
|
||||||
"""
|
|
||||||
Fused Attention
|
|
||||||
===============
|
|
||||||
This is a Triton implementation of the Flash Attention algorithm
|
|
||||||
(see: Dao et al., https://arxiv.org/pdf/2205.14135v2.pdf; Rabe and Staats https://arxiv.org/pdf/2112.05682v2.pdf; Triton https://github.com/openai/triton)
|
|
||||||
"""
|
|
||||||
|
|
||||||
import torch
|
|
||||||
try:
|
|
||||||
from flash_attn.flash_attn_interface import flash_attn_unpadded_qkvpacked_func, flash_attn_unpadded_kvpacked_func
|
|
||||||
except ImportError:
|
|
||||||
raise ImportError('please install flash_attn from https://github.com/HazyResearch/flash-attention')
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def flash_attention_qkv(qkv, sm_scale, batch_size, seq_len):
|
|
||||||
"""
|
|
||||||
Arguments:
|
|
||||||
qkv: (batch*seq, 3, nheads, headdim)
|
|
||||||
batch_size: int.
|
|
||||||
seq_len: int.
|
|
||||||
sm_scale: float. The scaling of QK^T before applying softmax.
|
|
||||||
Return:
|
|
||||||
out: (total, nheads, headdim).
|
|
||||||
"""
|
|
||||||
max_s = seq_len
|
|
||||||
cu_seqlens = torch.arange(0, (batch_size + 1) * seq_len, step=seq_len, dtype=torch.int32,
|
|
||||||
device=qkv.device)
|
|
||||||
out = flash_attn_unpadded_qkvpacked_func(
|
|
||||||
qkv, cu_seqlens, max_s, 0.0,
|
|
||||||
softmax_scale=sm_scale, causal=False
|
|
||||||
)
|
|
||||||
return out
|
|
||||||
|
|
||||||
|
|
||||||
def flash_attention_q_kv(q, kv, sm_scale, batch_size, q_seqlen, kv_seqlen):
|
|
||||||
"""
|
|
||||||
Arguments:
|
|
||||||
q: (batch*seq, nheads, headdim)
|
|
||||||
kv: (batch*seq, 2, nheads, headdim)
|
|
||||||
batch_size: int.
|
|
||||||
seq_len: int.
|
|
||||||
sm_scale: float. The scaling of QK^T before applying softmax.
|
|
||||||
Return:
|
|
||||||
out: (total, nheads, headdim).
|
|
||||||
"""
|
|
||||||
cu_seqlens_q = torch.arange(0, (batch_size + 1) * q_seqlen, step=q_seqlen, dtype=torch.int32, device=q.device)
|
|
||||||
cu_seqlens_k = torch.arange(0, (batch_size + 1) * kv_seqlen, step=kv_seqlen, dtype=torch.int32, device=kv.device)
|
|
||||||
out = flash_attn_unpadded_kvpacked_func(q, kv, cu_seqlens_q, cu_seqlens_k, q_seqlen, kv_seqlen, 0.0, sm_scale)
|
|
||||||
return out
|
|
|
@ -25,7 +25,6 @@ import ldm.modules.image_degradation.utils_image as util
|
||||||
# --------------------------------------------
|
# --------------------------------------------
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
def modcrop_np(img, sf):
|
def modcrop_np(img, sf):
|
||||||
'''
|
'''
|
||||||
Args:
|
Args:
|
||||||
|
@ -254,7 +253,7 @@ def srmd_degradation(x, k, sf=3):
|
||||||
year={2018}
|
year={2018}
|
||||||
}
|
}
|
||||||
'''
|
'''
|
||||||
x = ndimage.filters.convolve(x, np.expand_dims(k, axis=2), mode='wrap') # 'nearest' | 'mirror'
|
x = ndimage.convolve(x, np.expand_dims(k, axis=2), mode='wrap') # 'nearest' | 'mirror'
|
||||||
x = bicubic_degradation(x, sf=sf)
|
x = bicubic_degradation(x, sf=sf)
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
@ -277,7 +276,7 @@ def dpsr_degradation(x, k, sf=3):
|
||||||
}
|
}
|
||||||
'''
|
'''
|
||||||
x = bicubic_degradation(x, sf=sf)
|
x = bicubic_degradation(x, sf=sf)
|
||||||
x = ndimage.filters.convolve(x, np.expand_dims(k, axis=2), mode='wrap')
|
x = ndimage.convolve(x, np.expand_dims(k, axis=2), mode='wrap')
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
@ -290,7 +289,7 @@ def classical_degradation(x, k, sf=3):
|
||||||
Return:
|
Return:
|
||||||
downsampled LR image
|
downsampled LR image
|
||||||
'''
|
'''
|
||||||
x = ndimage.filters.convolve(x, np.expand_dims(k, axis=2), mode='wrap')
|
x = ndimage.convolve(x, np.expand_dims(k, axis=2), mode='wrap')
|
||||||
# x = filters.correlate(x, np.expand_dims(np.flip(k), axis=2))
|
# x = filters.correlate(x, np.expand_dims(np.flip(k), axis=2))
|
||||||
st = 0
|
st = 0
|
||||||
return x[st::sf, st::sf, ...]
|
return x[st::sf, st::sf, ...]
|
||||||
|
@ -335,7 +334,7 @@ def add_blur(img, sf=4):
|
||||||
k = anisotropic_Gaussian(ksize=random.randint(2, 11) + 3, theta=random.random() * np.pi, l1=l1, l2=l2)
|
k = anisotropic_Gaussian(ksize=random.randint(2, 11) + 3, theta=random.random() * np.pi, l1=l1, l2=l2)
|
||||||
else:
|
else:
|
||||||
k = fspecial('gaussian', random.randint(2, 4) + 3, wd * random.random())
|
k = fspecial('gaussian', random.randint(2, 4) + 3, wd * random.random())
|
||||||
img = ndimage.filters.convolve(img, np.expand_dims(k, axis=2), mode='mirror')
|
img = ndimage.convolve(img, np.expand_dims(k, axis=2), mode='mirror')
|
||||||
|
|
||||||
return img
|
return img
|
||||||
|
|
||||||
|
@ -497,7 +496,7 @@ def degradation_bsrgan(img, sf=4, lq_patchsize=72, isp_model=None):
|
||||||
k = fspecial('gaussian', 25, random.uniform(0.1, 0.6 * sf))
|
k = fspecial('gaussian', 25, random.uniform(0.1, 0.6 * sf))
|
||||||
k_shifted = shift_pixel(k, sf)
|
k_shifted = shift_pixel(k, sf)
|
||||||
k_shifted = k_shifted / k_shifted.sum() # blur with shifted kernel
|
k_shifted = k_shifted / k_shifted.sum() # blur with shifted kernel
|
||||||
img = ndimage.filters.convolve(img, np.expand_dims(k_shifted, axis=2), mode='mirror')
|
img = ndimage.convolve(img, np.expand_dims(k_shifted, axis=2), mode='mirror')
|
||||||
img = img[0::sf, 0::sf, ...] # nearest downsampling
|
img = img[0::sf, 0::sf, ...] # nearest downsampling
|
||||||
img = np.clip(img, 0.0, 1.0)
|
img = np.clip(img, 0.0, 1.0)
|
||||||
|
|
||||||
|
@ -531,7 +530,7 @@ def degradation_bsrgan(img, sf=4, lq_patchsize=72, isp_model=None):
|
||||||
|
|
||||||
|
|
||||||
# todo no isp_model?
|
# todo no isp_model?
|
||||||
def degradation_bsrgan_variant(image, sf=4, isp_model=None):
|
def degradation_bsrgan_variant(image, sf=4, isp_model=None, up=False):
|
||||||
"""
|
"""
|
||||||
This is the degradation model of BSRGAN from the paper
|
This is the degradation model of BSRGAN from the paper
|
||||||
"Designing a Practical Degradation Model for Deep Blind Image Super-Resolution"
|
"Designing a Practical Degradation Model for Deep Blind Image Super-Resolution"
|
||||||
|
@ -589,7 +588,7 @@ def degradation_bsrgan_variant(image, sf=4, isp_model=None):
|
||||||
k = fspecial('gaussian', 25, random.uniform(0.1, 0.6 * sf))
|
k = fspecial('gaussian', 25, random.uniform(0.1, 0.6 * sf))
|
||||||
k_shifted = shift_pixel(k, sf)
|
k_shifted = shift_pixel(k, sf)
|
||||||
k_shifted = k_shifted / k_shifted.sum() # blur with shifted kernel
|
k_shifted = k_shifted / k_shifted.sum() # blur with shifted kernel
|
||||||
image = ndimage.filters.convolve(image, np.expand_dims(k_shifted, axis=2), mode='mirror')
|
image = ndimage.convolve(image, np.expand_dims(k_shifted, axis=2), mode='mirror')
|
||||||
image = image[0::sf, 0::sf, ...] # nearest downsampling
|
image = image[0::sf, 0::sf, ...] # nearest downsampling
|
||||||
|
|
||||||
image = np.clip(image, 0.0, 1.0)
|
image = np.clip(image, 0.0, 1.0)
|
||||||
|
@ -617,6 +616,8 @@ def degradation_bsrgan_variant(image, sf=4, isp_model=None):
|
||||||
# add final JPEG compression noise
|
# add final JPEG compression noise
|
||||||
image = add_JPEG_noise(image)
|
image = add_JPEG_noise(image)
|
||||||
image = util.single2uint(image)
|
image = util.single2uint(image)
|
||||||
|
if up:
|
||||||
|
image = cv2.resize(image, (w1, h1), interpolation=cv2.INTER_CUBIC) # todo: random, as above? want to condition on it then
|
||||||
example = {"image": image}
|
example = {"image": image}
|
||||||
return example
|
return example
|
||||||
|
|
||||||
|
|
|
@ -1 +0,0 @@
|
||||||
from ldm.modules.losses.contperceptual import LPIPSWithDiscriminator
|
|
|
@ -1,111 +0,0 @@
|
||||||
import torch
|
|
||||||
import torch.nn as nn
|
|
||||||
|
|
||||||
from taming.modules.losses.vqperceptual import * # TODO: taming dependency yes/no?
|
|
||||||
|
|
||||||
|
|
||||||
class LPIPSWithDiscriminator(nn.Module):
|
|
||||||
def __init__(self, disc_start, logvar_init=0.0, kl_weight=1.0, pixelloss_weight=1.0,
|
|
||||||
disc_num_layers=3, disc_in_channels=3, disc_factor=1.0, disc_weight=1.0,
|
|
||||||
perceptual_weight=1.0, use_actnorm=False, disc_conditional=False,
|
|
||||||
disc_loss="hinge"):
|
|
||||||
|
|
||||||
super().__init__()
|
|
||||||
assert disc_loss in ["hinge", "vanilla"]
|
|
||||||
self.kl_weight = kl_weight
|
|
||||||
self.pixel_weight = pixelloss_weight
|
|
||||||
self.perceptual_loss = LPIPS().eval()
|
|
||||||
self.perceptual_weight = perceptual_weight
|
|
||||||
# output log variance
|
|
||||||
self.logvar = nn.Parameter(torch.ones(size=()) * logvar_init)
|
|
||||||
|
|
||||||
self.discriminator = NLayerDiscriminator(input_nc=disc_in_channels,
|
|
||||||
n_layers=disc_num_layers,
|
|
||||||
use_actnorm=use_actnorm
|
|
||||||
).apply(weights_init)
|
|
||||||
self.discriminator_iter_start = disc_start
|
|
||||||
self.disc_loss = hinge_d_loss if disc_loss == "hinge" else vanilla_d_loss
|
|
||||||
self.disc_factor = disc_factor
|
|
||||||
self.discriminator_weight = disc_weight
|
|
||||||
self.disc_conditional = disc_conditional
|
|
||||||
|
|
||||||
def calculate_adaptive_weight(self, nll_loss, g_loss, last_layer=None):
|
|
||||||
if last_layer is not None:
|
|
||||||
nll_grads = torch.autograd.grad(nll_loss, last_layer, retain_graph=True)[0]
|
|
||||||
g_grads = torch.autograd.grad(g_loss, last_layer, retain_graph=True)[0]
|
|
||||||
else:
|
|
||||||
nll_grads = torch.autograd.grad(nll_loss, self.last_layer[0], retain_graph=True)[0]
|
|
||||||
g_grads = torch.autograd.grad(g_loss, self.last_layer[0], retain_graph=True)[0]
|
|
||||||
|
|
||||||
d_weight = torch.norm(nll_grads) / (torch.norm(g_grads) + 1e-4)
|
|
||||||
d_weight = torch.clamp(d_weight, 0.0, 1e4).detach()
|
|
||||||
d_weight = d_weight * self.discriminator_weight
|
|
||||||
return d_weight
|
|
||||||
|
|
||||||
def forward(self, inputs, reconstructions, posteriors, optimizer_idx,
|
|
||||||
global_step, last_layer=None, cond=None, split="train",
|
|
||||||
weights=None):
|
|
||||||
rec_loss = torch.abs(inputs.contiguous() - reconstructions.contiguous())
|
|
||||||
if self.perceptual_weight > 0:
|
|
||||||
p_loss = self.perceptual_loss(inputs.contiguous(), reconstructions.contiguous())
|
|
||||||
rec_loss = rec_loss + self.perceptual_weight * p_loss
|
|
||||||
|
|
||||||
nll_loss = rec_loss / torch.exp(self.logvar) + self.logvar
|
|
||||||
weighted_nll_loss = nll_loss
|
|
||||||
if weights is not None:
|
|
||||||
weighted_nll_loss = weights*nll_loss
|
|
||||||
weighted_nll_loss = torch.sum(weighted_nll_loss) / weighted_nll_loss.shape[0]
|
|
||||||
nll_loss = torch.sum(nll_loss) / nll_loss.shape[0]
|
|
||||||
kl_loss = posteriors.kl()
|
|
||||||
kl_loss = torch.sum(kl_loss) / kl_loss.shape[0]
|
|
||||||
|
|
||||||
# now the GAN part
|
|
||||||
if optimizer_idx == 0:
|
|
||||||
# generator update
|
|
||||||
if cond is None:
|
|
||||||
assert not self.disc_conditional
|
|
||||||
logits_fake = self.discriminator(reconstructions.contiguous())
|
|
||||||
else:
|
|
||||||
assert self.disc_conditional
|
|
||||||
logits_fake = self.discriminator(torch.cat((reconstructions.contiguous(), cond), dim=1))
|
|
||||||
g_loss = -torch.mean(logits_fake)
|
|
||||||
|
|
||||||
if self.disc_factor > 0.0:
|
|
||||||
try:
|
|
||||||
d_weight = self.calculate_adaptive_weight(nll_loss, g_loss, last_layer=last_layer)
|
|
||||||
except RuntimeError:
|
|
||||||
assert not self.training
|
|
||||||
d_weight = torch.tensor(0.0)
|
|
||||||
else:
|
|
||||||
d_weight = torch.tensor(0.0)
|
|
||||||
|
|
||||||
disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start)
|
|
||||||
loss = weighted_nll_loss + self.kl_weight * kl_loss + d_weight * disc_factor * g_loss
|
|
||||||
|
|
||||||
log = {"{}/total_loss".format(split): loss.clone().detach().mean(), "{}/logvar".format(split): self.logvar.detach(),
|
|
||||||
"{}/kl_loss".format(split): kl_loss.detach().mean(), "{}/nll_loss".format(split): nll_loss.detach().mean(),
|
|
||||||
"{}/rec_loss".format(split): rec_loss.detach().mean(),
|
|
||||||
"{}/d_weight".format(split): d_weight.detach(),
|
|
||||||
"{}/disc_factor".format(split): torch.tensor(disc_factor),
|
|
||||||
"{}/g_loss".format(split): g_loss.detach().mean(),
|
|
||||||
}
|
|
||||||
return loss, log
|
|
||||||
|
|
||||||
if optimizer_idx == 1:
|
|
||||||
# second pass for discriminator update
|
|
||||||
if cond is None:
|
|
||||||
logits_real = self.discriminator(inputs.contiguous().detach())
|
|
||||||
logits_fake = self.discriminator(reconstructions.contiguous().detach())
|
|
||||||
else:
|
|
||||||
logits_real = self.discriminator(torch.cat((inputs.contiguous().detach(), cond), dim=1))
|
|
||||||
logits_fake = self.discriminator(torch.cat((reconstructions.contiguous().detach(), cond), dim=1))
|
|
||||||
|
|
||||||
disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start)
|
|
||||||
d_loss = disc_factor * self.disc_loss(logits_real, logits_fake)
|
|
||||||
|
|
||||||
log = {"{}/disc_loss".format(split): d_loss.clone().detach().mean(),
|
|
||||||
"{}/logits_real".format(split): logits_real.detach().mean(),
|
|
||||||
"{}/logits_fake".format(split): logits_fake.detach().mean()
|
|
||||||
}
|
|
||||||
return d_loss, log
|
|
||||||
|
|
|
@ -1,167 +0,0 @@
|
||||||
import torch
|
|
||||||
from torch import nn
|
|
||||||
import torch.nn.functional as F
|
|
||||||
from einops import repeat
|
|
||||||
|
|
||||||
from taming.modules.discriminator.model import NLayerDiscriminator, weights_init
|
|
||||||
from taming.modules.losses.lpips import LPIPS
|
|
||||||
from taming.modules.losses.vqperceptual import hinge_d_loss, vanilla_d_loss
|
|
||||||
|
|
||||||
|
|
||||||
def hinge_d_loss_with_exemplar_weights(logits_real, logits_fake, weights):
|
|
||||||
assert weights.shape[0] == logits_real.shape[0] == logits_fake.shape[0]
|
|
||||||
loss_real = torch.mean(F.relu(1. - logits_real), dim=[1,2,3])
|
|
||||||
loss_fake = torch.mean(F.relu(1. + logits_fake), dim=[1,2,3])
|
|
||||||
loss_real = (weights * loss_real).sum() / weights.sum()
|
|
||||||
loss_fake = (weights * loss_fake).sum() / weights.sum()
|
|
||||||
d_loss = 0.5 * (loss_real + loss_fake)
|
|
||||||
return d_loss
|
|
||||||
|
|
||||||
def adopt_weight(weight, global_step, threshold=0, value=0.):
|
|
||||||
if global_step < threshold:
|
|
||||||
weight = value
|
|
||||||
return weight
|
|
||||||
|
|
||||||
|
|
||||||
def measure_perplexity(predicted_indices, n_embed):
|
|
||||||
# src: https://github.com/karpathy/deep-vector-quantization/blob/main/model.py
|
|
||||||
# eval cluster perplexity. when perplexity == num_embeddings then all clusters are used exactly equally
|
|
||||||
encodings = F.one_hot(predicted_indices, n_embed).float().reshape(-1, n_embed)
|
|
||||||
avg_probs = encodings.mean(0)
|
|
||||||
perplexity = (-(avg_probs * torch.log(avg_probs + 1e-10)).sum()).exp()
|
|
||||||
cluster_use = torch.sum(avg_probs > 0)
|
|
||||||
return perplexity, cluster_use
|
|
||||||
|
|
||||||
def l1(x, y):
|
|
||||||
return torch.abs(x-y)
|
|
||||||
|
|
||||||
|
|
||||||
def l2(x, y):
|
|
||||||
return torch.pow((x-y), 2)
|
|
||||||
|
|
||||||
|
|
||||||
class VQLPIPSWithDiscriminator(nn.Module):
|
|
||||||
def __init__(self, disc_start, codebook_weight=1.0, pixelloss_weight=1.0,
|
|
||||||
disc_num_layers=3, disc_in_channels=3, disc_factor=1.0, disc_weight=1.0,
|
|
||||||
perceptual_weight=1.0, use_actnorm=False, disc_conditional=False,
|
|
||||||
disc_ndf=64, disc_loss="hinge", n_classes=None, perceptual_loss="lpips",
|
|
||||||
pixel_loss="l1"):
|
|
||||||
super().__init__()
|
|
||||||
assert disc_loss in ["hinge", "vanilla"]
|
|
||||||
assert perceptual_loss in ["lpips", "clips", "dists"]
|
|
||||||
assert pixel_loss in ["l1", "l2"]
|
|
||||||
self.codebook_weight = codebook_weight
|
|
||||||
self.pixel_weight = pixelloss_weight
|
|
||||||
if perceptual_loss == "lpips":
|
|
||||||
print(f"{self.__class__.__name__}: Running with LPIPS.")
|
|
||||||
self.perceptual_loss = LPIPS().eval()
|
|
||||||
else:
|
|
||||||
raise ValueError(f"Unknown perceptual loss: >> {perceptual_loss} <<")
|
|
||||||
self.perceptual_weight = perceptual_weight
|
|
||||||
|
|
||||||
if pixel_loss == "l1":
|
|
||||||
self.pixel_loss = l1
|
|
||||||
else:
|
|
||||||
self.pixel_loss = l2
|
|
||||||
|
|
||||||
self.discriminator = NLayerDiscriminator(input_nc=disc_in_channels,
|
|
||||||
n_layers=disc_num_layers,
|
|
||||||
use_actnorm=use_actnorm,
|
|
||||||
ndf=disc_ndf
|
|
||||||
).apply(weights_init)
|
|
||||||
self.discriminator_iter_start = disc_start
|
|
||||||
if disc_loss == "hinge":
|
|
||||||
self.disc_loss = hinge_d_loss
|
|
||||||
elif disc_loss == "vanilla":
|
|
||||||
self.disc_loss = vanilla_d_loss
|
|
||||||
else:
|
|
||||||
raise ValueError(f"Unknown GAN loss '{disc_loss}'.")
|
|
||||||
print(f"VQLPIPSWithDiscriminator running with {disc_loss} loss.")
|
|
||||||
self.disc_factor = disc_factor
|
|
||||||
self.discriminator_weight = disc_weight
|
|
||||||
self.disc_conditional = disc_conditional
|
|
||||||
self.n_classes = n_classes
|
|
||||||
|
|
||||||
def calculate_adaptive_weight(self, nll_loss, g_loss, last_layer=None):
|
|
||||||
if last_layer is not None:
|
|
||||||
nll_grads = torch.autograd.grad(nll_loss, last_layer, retain_graph=True)[0]
|
|
||||||
g_grads = torch.autograd.grad(g_loss, last_layer, retain_graph=True)[0]
|
|
||||||
else:
|
|
||||||
nll_grads = torch.autograd.grad(nll_loss, self.last_layer[0], retain_graph=True)[0]
|
|
||||||
g_grads = torch.autograd.grad(g_loss, self.last_layer[0], retain_graph=True)[0]
|
|
||||||
|
|
||||||
d_weight = torch.norm(nll_grads) / (torch.norm(g_grads) + 1e-4)
|
|
||||||
d_weight = torch.clamp(d_weight, 0.0, 1e4).detach()
|
|
||||||
d_weight = d_weight * self.discriminator_weight
|
|
||||||
return d_weight
|
|
||||||
|
|
||||||
def forward(self, codebook_loss, inputs, reconstructions, optimizer_idx,
|
|
||||||
global_step, last_layer=None, cond=None, split="train", predicted_indices=None):
|
|
||||||
if not exists(codebook_loss):
|
|
||||||
codebook_loss = torch.tensor([0.]).to(inputs.device)
|
|
||||||
#rec_loss = torch.abs(inputs.contiguous() - reconstructions.contiguous())
|
|
||||||
rec_loss = self.pixel_loss(inputs.contiguous(), reconstructions.contiguous())
|
|
||||||
if self.perceptual_weight > 0:
|
|
||||||
p_loss = self.perceptual_loss(inputs.contiguous(), reconstructions.contiguous())
|
|
||||||
rec_loss = rec_loss + self.perceptual_weight * p_loss
|
|
||||||
else:
|
|
||||||
p_loss = torch.tensor([0.0])
|
|
||||||
|
|
||||||
nll_loss = rec_loss
|
|
||||||
#nll_loss = torch.sum(nll_loss) / nll_loss.shape[0]
|
|
||||||
nll_loss = torch.mean(nll_loss)
|
|
||||||
|
|
||||||
# now the GAN part
|
|
||||||
if optimizer_idx == 0:
|
|
||||||
# generator update
|
|
||||||
if cond is None:
|
|
||||||
assert not self.disc_conditional
|
|
||||||
logits_fake = self.discriminator(reconstructions.contiguous())
|
|
||||||
else:
|
|
||||||
assert self.disc_conditional
|
|
||||||
logits_fake = self.discriminator(torch.cat((reconstructions.contiguous(), cond), dim=1))
|
|
||||||
g_loss = -torch.mean(logits_fake)
|
|
||||||
|
|
||||||
try:
|
|
||||||
d_weight = self.calculate_adaptive_weight(nll_loss, g_loss, last_layer=last_layer)
|
|
||||||
except RuntimeError:
|
|
||||||
assert not self.training
|
|
||||||
d_weight = torch.tensor(0.0)
|
|
||||||
|
|
||||||
disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start)
|
|
||||||
loss = nll_loss + d_weight * disc_factor * g_loss + self.codebook_weight * codebook_loss.mean()
|
|
||||||
|
|
||||||
log = {"{}/total_loss".format(split): loss.clone().detach().mean(),
|
|
||||||
"{}/quant_loss".format(split): codebook_loss.detach().mean(),
|
|
||||||
"{}/nll_loss".format(split): nll_loss.detach().mean(),
|
|
||||||
"{}/rec_loss".format(split): rec_loss.detach().mean(),
|
|
||||||
"{}/p_loss".format(split): p_loss.detach().mean(),
|
|
||||||
"{}/d_weight".format(split): d_weight.detach(),
|
|
||||||
"{}/disc_factor".format(split): torch.tensor(disc_factor),
|
|
||||||
"{}/g_loss".format(split): g_loss.detach().mean(),
|
|
||||||
}
|
|
||||||
if predicted_indices is not None:
|
|
||||||
assert self.n_classes is not None
|
|
||||||
with torch.no_grad():
|
|
||||||
perplexity, cluster_usage = measure_perplexity(predicted_indices, self.n_classes)
|
|
||||||
log[f"{split}/perplexity"] = perplexity
|
|
||||||
log[f"{split}/cluster_usage"] = cluster_usage
|
|
||||||
return loss, log
|
|
||||||
|
|
||||||
if optimizer_idx == 1:
|
|
||||||
# second pass for discriminator update
|
|
||||||
if cond is None:
|
|
||||||
logits_real = self.discriminator(inputs.contiguous().detach())
|
|
||||||
logits_fake = self.discriminator(reconstructions.contiguous().detach())
|
|
||||||
else:
|
|
||||||
logits_real = self.discriminator(torch.cat((inputs.contiguous().detach(), cond), dim=1))
|
|
||||||
logits_fake = self.discriminator(torch.cat((reconstructions.contiguous().detach(), cond), dim=1))
|
|
||||||
|
|
||||||
disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start)
|
|
||||||
d_loss = disc_factor * self.disc_loss(logits_real, logits_fake)
|
|
||||||
|
|
||||||
log = {"{}/disc_loss".format(split): d_loss.clone().detach().mean(),
|
|
||||||
"{}/logits_real".format(split): logits_real.detach().mean(),
|
|
||||||
"{}/logits_fake".format(split): logits_fake.detach().mean()
|
|
||||||
}
|
|
||||||
return d_loss, log
|
|
|
@ -0,0 +1,170 @@
|
||||||
|
# based on https://github.com/isl-org/MiDaS
|
||||||
|
|
||||||
|
import cv2
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
from torchvision.transforms import Compose
|
||||||
|
|
||||||
|
from ldm.modules.midas.midas.dpt_depth import DPTDepthModel
|
||||||
|
from ldm.modules.midas.midas.midas_net import MidasNet
|
||||||
|
from ldm.modules.midas.midas.midas_net_custom import MidasNet_small
|
||||||
|
from ldm.modules.midas.midas.transforms import Resize, NormalizeImage, PrepareForNet
|
||||||
|
|
||||||
|
|
||||||
|
ISL_PATHS = {
|
||||||
|
"dpt_large": "midas_models/dpt_large-midas-2f21e586.pt",
|
||||||
|
"dpt_hybrid": "midas_models/dpt_hybrid-midas-501f0c75.pt",
|
||||||
|
"midas_v21": "",
|
||||||
|
"midas_v21_small": "",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def disabled_train(self, mode=True):
|
||||||
|
"""Overwrite model.train with this function to make sure train/eval mode
|
||||||
|
does not change anymore."""
|
||||||
|
return self
|
||||||
|
|
||||||
|
|
||||||
|
def load_midas_transform(model_type):
|
||||||
|
# https://github.com/isl-org/MiDaS/blob/master/run.py
|
||||||
|
# load transform only
|
||||||
|
if model_type == "dpt_large": # DPT-Large
|
||||||
|
net_w, net_h = 384, 384
|
||||||
|
resize_mode = "minimal"
|
||||||
|
normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
|
||||||
|
|
||||||
|
elif model_type == "dpt_hybrid": # DPT-Hybrid
|
||||||
|
net_w, net_h = 384, 384
|
||||||
|
resize_mode = "minimal"
|
||||||
|
normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
|
||||||
|
|
||||||
|
elif model_type == "midas_v21":
|
||||||
|
net_w, net_h = 384, 384
|
||||||
|
resize_mode = "upper_bound"
|
||||||
|
normalization = NormalizeImage(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
|
||||||
|
|
||||||
|
elif model_type == "midas_v21_small":
|
||||||
|
net_w, net_h = 256, 256
|
||||||
|
resize_mode = "upper_bound"
|
||||||
|
normalization = NormalizeImage(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
|
||||||
|
|
||||||
|
else:
|
||||||
|
assert False, f"model_type '{model_type}' not implemented, use: --model_type large"
|
||||||
|
|
||||||
|
transform = Compose(
|
||||||
|
[
|
||||||
|
Resize(
|
||||||
|
net_w,
|
||||||
|
net_h,
|
||||||
|
resize_target=None,
|
||||||
|
keep_aspect_ratio=True,
|
||||||
|
ensure_multiple_of=32,
|
||||||
|
resize_method=resize_mode,
|
||||||
|
image_interpolation_method=cv2.INTER_CUBIC,
|
||||||
|
),
|
||||||
|
normalization,
|
||||||
|
PrepareForNet(),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
return transform
|
||||||
|
|
||||||
|
|
||||||
|
def load_model(model_type):
|
||||||
|
# https://github.com/isl-org/MiDaS/blob/master/run.py
|
||||||
|
# load network
|
||||||
|
model_path = ISL_PATHS[model_type]
|
||||||
|
if model_type == "dpt_large": # DPT-Large
|
||||||
|
model = DPTDepthModel(
|
||||||
|
path=model_path,
|
||||||
|
backbone="vitl16_384",
|
||||||
|
non_negative=True,
|
||||||
|
)
|
||||||
|
net_w, net_h = 384, 384
|
||||||
|
resize_mode = "minimal"
|
||||||
|
normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
|
||||||
|
|
||||||
|
elif model_type == "dpt_hybrid": # DPT-Hybrid
|
||||||
|
model = DPTDepthModel(
|
||||||
|
path=model_path,
|
||||||
|
backbone="vitb_rn50_384",
|
||||||
|
non_negative=True,
|
||||||
|
)
|
||||||
|
net_w, net_h = 384, 384
|
||||||
|
resize_mode = "minimal"
|
||||||
|
normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
|
||||||
|
|
||||||
|
elif model_type == "midas_v21":
|
||||||
|
model = MidasNet(model_path, non_negative=True)
|
||||||
|
net_w, net_h = 384, 384
|
||||||
|
resize_mode = "upper_bound"
|
||||||
|
normalization = NormalizeImage(
|
||||||
|
mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
|
||||||
|
)
|
||||||
|
|
||||||
|
elif model_type == "midas_v21_small":
|
||||||
|
model = MidasNet_small(model_path, features=64, backbone="efficientnet_lite3", exportable=True,
|
||||||
|
non_negative=True, blocks={'expand': True})
|
||||||
|
net_w, net_h = 256, 256
|
||||||
|
resize_mode = "upper_bound"
|
||||||
|
normalization = NormalizeImage(
|
||||||
|
mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
|
||||||
|
)
|
||||||
|
|
||||||
|
else:
|
||||||
|
print(f"model_type '{model_type}' not implemented, use: --model_type large")
|
||||||
|
assert False
|
||||||
|
|
||||||
|
transform = Compose(
|
||||||
|
[
|
||||||
|
Resize(
|
||||||
|
net_w,
|
||||||
|
net_h,
|
||||||
|
resize_target=None,
|
||||||
|
keep_aspect_ratio=True,
|
||||||
|
ensure_multiple_of=32,
|
||||||
|
resize_method=resize_mode,
|
||||||
|
image_interpolation_method=cv2.INTER_CUBIC,
|
||||||
|
),
|
||||||
|
normalization,
|
||||||
|
PrepareForNet(),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
return model.eval(), transform
|
||||||
|
|
||||||
|
|
||||||
|
class MiDaSInference(nn.Module):
|
||||||
|
MODEL_TYPES_TORCH_HUB = [
|
||||||
|
"DPT_Large",
|
||||||
|
"DPT_Hybrid",
|
||||||
|
"MiDaS_small"
|
||||||
|
]
|
||||||
|
MODEL_TYPES_ISL = [
|
||||||
|
"dpt_large",
|
||||||
|
"dpt_hybrid",
|
||||||
|
"midas_v21",
|
||||||
|
"midas_v21_small",
|
||||||
|
]
|
||||||
|
|
||||||
|
def __init__(self, model_type):
|
||||||
|
super().__init__()
|
||||||
|
assert (model_type in self.MODEL_TYPES_ISL)
|
||||||
|
model, _ = load_model(model_type)
|
||||||
|
self.model = model
|
||||||
|
self.model.train = disabled_train
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
# x in 0..1 as produced by calling self.transform on a 0..1 float64 numpy array
|
||||||
|
# NOTE: we expect that the correct transform has been called during dataloading.
|
||||||
|
with torch.no_grad():
|
||||||
|
prediction = self.model(x)
|
||||||
|
prediction = torch.nn.functional.interpolate(
|
||||||
|
prediction.unsqueeze(1),
|
||||||
|
size=x.shape[2:],
|
||||||
|
mode="bicubic",
|
||||||
|
align_corners=False,
|
||||||
|
)
|
||||||
|
assert prediction.shape == (x.shape[0], 1, x.shape[2], x.shape[3])
|
||||||
|
return prediction
|
||||||
|
|
|
@ -0,0 +1,16 @@
|
||||||
|
import torch
|
||||||
|
|
||||||
|
|
||||||
|
class BaseModel(torch.nn.Module):
|
||||||
|
def load(self, path):
|
||||||
|
"""Load model from file.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
path (str): file path
|
||||||
|
"""
|
||||||
|
parameters = torch.load(path, map_location=torch.device('cpu'))
|
||||||
|
|
||||||
|
if "optimizer" in parameters:
|
||||||
|
parameters = parameters["model"]
|
||||||
|
|
||||||
|
self.load_state_dict(parameters)
|
|
@ -0,0 +1,342 @@
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
|
||||||
|
from .vit import (
|
||||||
|
_make_pretrained_vitb_rn50_384,
|
||||||
|
_make_pretrained_vitl16_384,
|
||||||
|
_make_pretrained_vitb16_384,
|
||||||
|
forward_vit,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _make_encoder(backbone, features, use_pretrained, groups=1, expand=False, exportable=True, hooks=None, use_vit_only=False, use_readout="ignore",):
|
||||||
|
if backbone == "vitl16_384":
|
||||||
|
pretrained = _make_pretrained_vitl16_384(
|
||||||
|
use_pretrained, hooks=hooks, use_readout=use_readout
|
||||||
|
)
|
||||||
|
scratch = _make_scratch(
|
||||||
|
[256, 512, 1024, 1024], features, groups=groups, expand=expand
|
||||||
|
) # ViT-L/16 - 85.0% Top1 (backbone)
|
||||||
|
elif backbone == "vitb_rn50_384":
|
||||||
|
pretrained = _make_pretrained_vitb_rn50_384(
|
||||||
|
use_pretrained,
|
||||||
|
hooks=hooks,
|
||||||
|
use_vit_only=use_vit_only,
|
||||||
|
use_readout=use_readout,
|
||||||
|
)
|
||||||
|
scratch = _make_scratch(
|
||||||
|
[256, 512, 768, 768], features, groups=groups, expand=expand
|
||||||
|
) # ViT-H/16 - 85.0% Top1 (backbone)
|
||||||
|
elif backbone == "vitb16_384":
|
||||||
|
pretrained = _make_pretrained_vitb16_384(
|
||||||
|
use_pretrained, hooks=hooks, use_readout=use_readout
|
||||||
|
)
|
||||||
|
scratch = _make_scratch(
|
||||||
|
[96, 192, 384, 768], features, groups=groups, expand=expand
|
||||||
|
) # ViT-B/16 - 84.6% Top1 (backbone)
|
||||||
|
elif backbone == "resnext101_wsl":
|
||||||
|
pretrained = _make_pretrained_resnext101_wsl(use_pretrained)
|
||||||
|
scratch = _make_scratch([256, 512, 1024, 2048], features, groups=groups, expand=expand) # efficientnet_lite3
|
||||||
|
elif backbone == "efficientnet_lite3":
|
||||||
|
pretrained = _make_pretrained_efficientnet_lite3(use_pretrained, exportable=exportable)
|
||||||
|
scratch = _make_scratch([32, 48, 136, 384], features, groups=groups, expand=expand) # efficientnet_lite3
|
||||||
|
else:
|
||||||
|
print(f"Backbone '{backbone}' not implemented")
|
||||||
|
assert False
|
||||||
|
|
||||||
|
return pretrained, scratch
|
||||||
|
|
||||||
|
|
||||||
|
def _make_scratch(in_shape, out_shape, groups=1, expand=False):
|
||||||
|
scratch = nn.Module()
|
||||||
|
|
||||||
|
out_shape1 = out_shape
|
||||||
|
out_shape2 = out_shape
|
||||||
|
out_shape3 = out_shape
|
||||||
|
out_shape4 = out_shape
|
||||||
|
if expand==True:
|
||||||
|
out_shape1 = out_shape
|
||||||
|
out_shape2 = out_shape*2
|
||||||
|
out_shape3 = out_shape*4
|
||||||
|
out_shape4 = out_shape*8
|
||||||
|
|
||||||
|
scratch.layer1_rn = nn.Conv2d(
|
||||||
|
in_shape[0], out_shape1, kernel_size=3, stride=1, padding=1, bias=False, groups=groups
|
||||||
|
)
|
||||||
|
scratch.layer2_rn = nn.Conv2d(
|
||||||
|
in_shape[1], out_shape2, kernel_size=3, stride=1, padding=1, bias=False, groups=groups
|
||||||
|
)
|
||||||
|
scratch.layer3_rn = nn.Conv2d(
|
||||||
|
in_shape[2], out_shape3, kernel_size=3, stride=1, padding=1, bias=False, groups=groups
|
||||||
|
)
|
||||||
|
scratch.layer4_rn = nn.Conv2d(
|
||||||
|
in_shape[3], out_shape4, kernel_size=3, stride=1, padding=1, bias=False, groups=groups
|
||||||
|
)
|
||||||
|
|
||||||
|
return scratch
|
||||||
|
|
||||||
|
|
||||||
|
def _make_pretrained_efficientnet_lite3(use_pretrained, exportable=False):
|
||||||
|
efficientnet = torch.hub.load(
|
||||||
|
"rwightman/gen-efficientnet-pytorch",
|
||||||
|
"tf_efficientnet_lite3",
|
||||||
|
pretrained=use_pretrained,
|
||||||
|
exportable=exportable
|
||||||
|
)
|
||||||
|
return _make_efficientnet_backbone(efficientnet)
|
||||||
|
|
||||||
|
|
||||||
|
def _make_efficientnet_backbone(effnet):
|
||||||
|
pretrained = nn.Module()
|
||||||
|
|
||||||
|
pretrained.layer1 = nn.Sequential(
|
||||||
|
effnet.conv_stem, effnet.bn1, effnet.act1, *effnet.blocks[0:2]
|
||||||
|
)
|
||||||
|
pretrained.layer2 = nn.Sequential(*effnet.blocks[2:3])
|
||||||
|
pretrained.layer3 = nn.Sequential(*effnet.blocks[3:5])
|
||||||
|
pretrained.layer4 = nn.Sequential(*effnet.blocks[5:9])
|
||||||
|
|
||||||
|
return pretrained
|
||||||
|
|
||||||
|
|
||||||
|
def _make_resnet_backbone(resnet):
|
||||||
|
pretrained = nn.Module()
|
||||||
|
pretrained.layer1 = nn.Sequential(
|
||||||
|
resnet.conv1, resnet.bn1, resnet.relu, resnet.maxpool, resnet.layer1
|
||||||
|
)
|
||||||
|
|
||||||
|
pretrained.layer2 = resnet.layer2
|
||||||
|
pretrained.layer3 = resnet.layer3
|
||||||
|
pretrained.layer4 = resnet.layer4
|
||||||
|
|
||||||
|
return pretrained
|
||||||
|
|
||||||
|
|
||||||
|
def _make_pretrained_resnext101_wsl(use_pretrained):
|
||||||
|
resnet = torch.hub.load("facebookresearch/WSL-Images", "resnext101_32x8d_wsl")
|
||||||
|
return _make_resnet_backbone(resnet)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
class Interpolate(nn.Module):
|
||||||
|
"""Interpolation module.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, scale_factor, mode, align_corners=False):
|
||||||
|
"""Init.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
scale_factor (float): scaling
|
||||||
|
mode (str): interpolation mode
|
||||||
|
"""
|
||||||
|
super(Interpolate, self).__init__()
|
||||||
|
|
||||||
|
self.interp = nn.functional.interpolate
|
||||||
|
self.scale_factor = scale_factor
|
||||||
|
self.mode = mode
|
||||||
|
self.align_corners = align_corners
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
"""Forward pass.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x (tensor): input
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
tensor: interpolated data
|
||||||
|
"""
|
||||||
|
|
||||||
|
x = self.interp(
|
||||||
|
x, scale_factor=self.scale_factor, mode=self.mode, align_corners=self.align_corners
|
||||||
|
)
|
||||||
|
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class ResidualConvUnit(nn.Module):
|
||||||
|
"""Residual convolution module.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, features):
|
||||||
|
"""Init.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
features (int): number of features
|
||||||
|
"""
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.conv1 = nn.Conv2d(
|
||||||
|
features, features, kernel_size=3, stride=1, padding=1, bias=True
|
||||||
|
)
|
||||||
|
|
||||||
|
self.conv2 = nn.Conv2d(
|
||||||
|
features, features, kernel_size=3, stride=1, padding=1, bias=True
|
||||||
|
)
|
||||||
|
|
||||||
|
self.relu = nn.ReLU(inplace=True)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
"""Forward pass.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x (tensor): input
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
tensor: output
|
||||||
|
"""
|
||||||
|
out = self.relu(x)
|
||||||
|
out = self.conv1(out)
|
||||||
|
out = self.relu(out)
|
||||||
|
out = self.conv2(out)
|
||||||
|
|
||||||
|
return out + x
|
||||||
|
|
||||||
|
|
||||||
|
class FeatureFusionBlock(nn.Module):
|
||||||
|
"""Feature fusion block.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, features):
|
||||||
|
"""Init.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
features (int): number of features
|
||||||
|
"""
|
||||||
|
super(FeatureFusionBlock, self).__init__()
|
||||||
|
|
||||||
|
self.resConfUnit1 = ResidualConvUnit(features)
|
||||||
|
self.resConfUnit2 = ResidualConvUnit(features)
|
||||||
|
|
||||||
|
def forward(self, *xs):
|
||||||
|
"""Forward pass.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
tensor: output
|
||||||
|
"""
|
||||||
|
output = xs[0]
|
||||||
|
|
||||||
|
if len(xs) == 2:
|
||||||
|
output += self.resConfUnit1(xs[1])
|
||||||
|
|
||||||
|
output = self.resConfUnit2(output)
|
||||||
|
|
||||||
|
output = nn.functional.interpolate(
|
||||||
|
output, scale_factor=2, mode="bilinear", align_corners=True
|
||||||
|
)
|
||||||
|
|
||||||
|
return output
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
class ResidualConvUnit_custom(nn.Module):
|
||||||
|
"""Residual convolution module.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, features, activation, bn):
|
||||||
|
"""Init.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
features (int): number of features
|
||||||
|
"""
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.bn = bn
|
||||||
|
|
||||||
|
self.groups=1
|
||||||
|
|
||||||
|
self.conv1 = nn.Conv2d(
|
||||||
|
features, features, kernel_size=3, stride=1, padding=1, bias=True, groups=self.groups
|
||||||
|
)
|
||||||
|
|
||||||
|
self.conv2 = nn.Conv2d(
|
||||||
|
features, features, kernel_size=3, stride=1, padding=1, bias=True, groups=self.groups
|
||||||
|
)
|
||||||
|
|
||||||
|
if self.bn==True:
|
||||||
|
self.bn1 = nn.BatchNorm2d(features)
|
||||||
|
self.bn2 = nn.BatchNorm2d(features)
|
||||||
|
|
||||||
|
self.activation = activation
|
||||||
|
|
||||||
|
self.skip_add = nn.quantized.FloatFunctional()
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
"""Forward pass.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x (tensor): input
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
tensor: output
|
||||||
|
"""
|
||||||
|
|
||||||
|
out = self.activation(x)
|
||||||
|
out = self.conv1(out)
|
||||||
|
if self.bn==True:
|
||||||
|
out = self.bn1(out)
|
||||||
|
|
||||||
|
out = self.activation(out)
|
||||||
|
out = self.conv2(out)
|
||||||
|
if self.bn==True:
|
||||||
|
out = self.bn2(out)
|
||||||
|
|
||||||
|
if self.groups > 1:
|
||||||
|
out = self.conv_merge(out)
|
||||||
|
|
||||||
|
return self.skip_add.add(out, x)
|
||||||
|
|
||||||
|
# return out + x
|
||||||
|
|
||||||
|
|
||||||
|
class FeatureFusionBlock_custom(nn.Module):
|
||||||
|
"""Feature fusion block.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, features, activation, deconv=False, bn=False, expand=False, align_corners=True):
|
||||||
|
"""Init.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
features (int): number of features
|
||||||
|
"""
|
||||||
|
super(FeatureFusionBlock_custom, self).__init__()
|
||||||
|
|
||||||
|
self.deconv = deconv
|
||||||
|
self.align_corners = align_corners
|
||||||
|
|
||||||
|
self.groups=1
|
||||||
|
|
||||||
|
self.expand = expand
|
||||||
|
out_features = features
|
||||||
|
if self.expand==True:
|
||||||
|
out_features = features//2
|
||||||
|
|
||||||
|
self.out_conv = nn.Conv2d(features, out_features, kernel_size=1, stride=1, padding=0, bias=True, groups=1)
|
||||||
|
|
||||||
|
self.resConfUnit1 = ResidualConvUnit_custom(features, activation, bn)
|
||||||
|
self.resConfUnit2 = ResidualConvUnit_custom(features, activation, bn)
|
||||||
|
|
||||||
|
self.skip_add = nn.quantized.FloatFunctional()
|
||||||
|
|
||||||
|
def forward(self, *xs):
|
||||||
|
"""Forward pass.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
tensor: output
|
||||||
|
"""
|
||||||
|
output = xs[0]
|
||||||
|
|
||||||
|
if len(xs) == 2:
|
||||||
|
res = self.resConfUnit1(xs[1])
|
||||||
|
output = self.skip_add.add(output, res)
|
||||||
|
# output += res
|
||||||
|
|
||||||
|
output = self.resConfUnit2(output)
|
||||||
|
|
||||||
|
output = nn.functional.interpolate(
|
||||||
|
output, scale_factor=2, mode="bilinear", align_corners=self.align_corners
|
||||||
|
)
|
||||||
|
|
||||||
|
output = self.out_conv(output)
|
||||||
|
|
||||||
|
return output
|
||||||
|
|
|
@ -0,0 +1,109 @@
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
from .base_model import BaseModel
|
||||||
|
from .blocks import (
|
||||||
|
FeatureFusionBlock,
|
||||||
|
FeatureFusionBlock_custom,
|
||||||
|
Interpolate,
|
||||||
|
_make_encoder,
|
||||||
|
forward_vit,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _make_fusion_block(features, use_bn):
|
||||||
|
return FeatureFusionBlock_custom(
|
||||||
|
features,
|
||||||
|
nn.ReLU(False),
|
||||||
|
deconv=False,
|
||||||
|
bn=use_bn,
|
||||||
|
expand=False,
|
||||||
|
align_corners=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class DPT(BaseModel):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
head,
|
||||||
|
features=256,
|
||||||
|
backbone="vitb_rn50_384",
|
||||||
|
readout="project",
|
||||||
|
channels_last=False,
|
||||||
|
use_bn=False,
|
||||||
|
):
|
||||||
|
|
||||||
|
super(DPT, self).__init__()
|
||||||
|
|
||||||
|
self.channels_last = channels_last
|
||||||
|
|
||||||
|
hooks = {
|
||||||
|
"vitb_rn50_384": [0, 1, 8, 11],
|
||||||
|
"vitb16_384": [2, 5, 8, 11],
|
||||||
|
"vitl16_384": [5, 11, 17, 23],
|
||||||
|
}
|
||||||
|
|
||||||
|
# Instantiate backbone and reassemble blocks
|
||||||
|
self.pretrained, self.scratch = _make_encoder(
|
||||||
|
backbone,
|
||||||
|
features,
|
||||||
|
False, # Set to true of you want to train from scratch, uses ImageNet weights
|
||||||
|
groups=1,
|
||||||
|
expand=False,
|
||||||
|
exportable=False,
|
||||||
|
hooks=hooks[backbone],
|
||||||
|
use_readout=readout,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.scratch.refinenet1 = _make_fusion_block(features, use_bn)
|
||||||
|
self.scratch.refinenet2 = _make_fusion_block(features, use_bn)
|
||||||
|
self.scratch.refinenet3 = _make_fusion_block(features, use_bn)
|
||||||
|
self.scratch.refinenet4 = _make_fusion_block(features, use_bn)
|
||||||
|
|
||||||
|
self.scratch.output_conv = head
|
||||||
|
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
if self.channels_last == True:
|
||||||
|
x.contiguous(memory_format=torch.channels_last)
|
||||||
|
|
||||||
|
layer_1, layer_2, layer_3, layer_4 = forward_vit(self.pretrained, x)
|
||||||
|
|
||||||
|
layer_1_rn = self.scratch.layer1_rn(layer_1)
|
||||||
|
layer_2_rn = self.scratch.layer2_rn(layer_2)
|
||||||
|
layer_3_rn = self.scratch.layer3_rn(layer_3)
|
||||||
|
layer_4_rn = self.scratch.layer4_rn(layer_4)
|
||||||
|
|
||||||
|
path_4 = self.scratch.refinenet4(layer_4_rn)
|
||||||
|
path_3 = self.scratch.refinenet3(path_4, layer_3_rn)
|
||||||
|
path_2 = self.scratch.refinenet2(path_3, layer_2_rn)
|
||||||
|
path_1 = self.scratch.refinenet1(path_2, layer_1_rn)
|
||||||
|
|
||||||
|
out = self.scratch.output_conv(path_1)
|
||||||
|
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
class DPTDepthModel(DPT):
|
||||||
|
def __init__(self, path=None, non_negative=True, **kwargs):
|
||||||
|
features = kwargs["features"] if "features" in kwargs else 256
|
||||||
|
|
||||||
|
head = nn.Sequential(
|
||||||
|
nn.Conv2d(features, features // 2, kernel_size=3, stride=1, padding=1),
|
||||||
|
Interpolate(scale_factor=2, mode="bilinear", align_corners=True),
|
||||||
|
nn.Conv2d(features // 2, 32, kernel_size=3, stride=1, padding=1),
|
||||||
|
nn.ReLU(True),
|
||||||
|
nn.Conv2d(32, 1, kernel_size=1, stride=1, padding=0),
|
||||||
|
nn.ReLU(True) if non_negative else nn.Identity(),
|
||||||
|
nn.Identity(),
|
||||||
|
)
|
||||||
|
|
||||||
|
super().__init__(head, **kwargs)
|
||||||
|
|
||||||
|
if path is not None:
|
||||||
|
self.load(path)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
return super().forward(x).squeeze(dim=1)
|
||||||
|
|
|
@ -0,0 +1,76 @@
|
||||||
|
"""MidashNet: Network for monocular depth estimation trained by mixing several datasets.
|
||||||
|
This file contains code that is adapted from
|
||||||
|
https://github.com/thomasjpfan/pytorch_refinenet/blob/master/pytorch_refinenet/refinenet/refinenet_4cascade.py
|
||||||
|
"""
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
|
||||||
|
from .base_model import BaseModel
|
||||||
|
from .blocks import FeatureFusionBlock, Interpolate, _make_encoder
|
||||||
|
|
||||||
|
|
||||||
|
class MidasNet(BaseModel):
|
||||||
|
"""Network for monocular depth estimation.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, path=None, features=256, non_negative=True):
|
||||||
|
"""Init.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
path (str, optional): Path to saved model. Defaults to None.
|
||||||
|
features (int, optional): Number of features. Defaults to 256.
|
||||||
|
backbone (str, optional): Backbone network for encoder. Defaults to resnet50
|
||||||
|
"""
|
||||||
|
print("Loading weights: ", path)
|
||||||
|
|
||||||
|
super(MidasNet, self).__init__()
|
||||||
|
|
||||||
|
use_pretrained = False if path is None else True
|
||||||
|
|
||||||
|
self.pretrained, self.scratch = _make_encoder(backbone="resnext101_wsl", features=features, use_pretrained=use_pretrained)
|
||||||
|
|
||||||
|
self.scratch.refinenet4 = FeatureFusionBlock(features)
|
||||||
|
self.scratch.refinenet3 = FeatureFusionBlock(features)
|
||||||
|
self.scratch.refinenet2 = FeatureFusionBlock(features)
|
||||||
|
self.scratch.refinenet1 = FeatureFusionBlock(features)
|
||||||
|
|
||||||
|
self.scratch.output_conv = nn.Sequential(
|
||||||
|
nn.Conv2d(features, 128, kernel_size=3, stride=1, padding=1),
|
||||||
|
Interpolate(scale_factor=2, mode="bilinear"),
|
||||||
|
nn.Conv2d(128, 32, kernel_size=3, stride=1, padding=1),
|
||||||
|
nn.ReLU(True),
|
||||||
|
nn.Conv2d(32, 1, kernel_size=1, stride=1, padding=0),
|
||||||
|
nn.ReLU(True) if non_negative else nn.Identity(),
|
||||||
|
)
|
||||||
|
|
||||||
|
if path:
|
||||||
|
self.load(path)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
"""Forward pass.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x (tensor): input data (image)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
tensor: depth
|
||||||
|
"""
|
||||||
|
|
||||||
|
layer_1 = self.pretrained.layer1(x)
|
||||||
|
layer_2 = self.pretrained.layer2(layer_1)
|
||||||
|
layer_3 = self.pretrained.layer3(layer_2)
|
||||||
|
layer_4 = self.pretrained.layer4(layer_3)
|
||||||
|
|
||||||
|
layer_1_rn = self.scratch.layer1_rn(layer_1)
|
||||||
|
layer_2_rn = self.scratch.layer2_rn(layer_2)
|
||||||
|
layer_3_rn = self.scratch.layer3_rn(layer_3)
|
||||||
|
layer_4_rn = self.scratch.layer4_rn(layer_4)
|
||||||
|
|
||||||
|
path_4 = self.scratch.refinenet4(layer_4_rn)
|
||||||
|
path_3 = self.scratch.refinenet3(path_4, layer_3_rn)
|
||||||
|
path_2 = self.scratch.refinenet2(path_3, layer_2_rn)
|
||||||
|
path_1 = self.scratch.refinenet1(path_2, layer_1_rn)
|
||||||
|
|
||||||
|
out = self.scratch.output_conv(path_1)
|
||||||
|
|
||||||
|
return torch.squeeze(out, dim=1)
|
|
@ -0,0 +1,128 @@
|
||||||
|
"""MidashNet: Network for monocular depth estimation trained by mixing several datasets.
|
||||||
|
This file contains code that is adapted from
|
||||||
|
https://github.com/thomasjpfan/pytorch_refinenet/blob/master/pytorch_refinenet/refinenet/refinenet_4cascade.py
|
||||||
|
"""
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
|
||||||
|
from .base_model import BaseModel
|
||||||
|
from .blocks import FeatureFusionBlock, FeatureFusionBlock_custom, Interpolate, _make_encoder
|
||||||
|
|
||||||
|
|
||||||
|
class MidasNet_small(BaseModel):
|
||||||
|
"""Network for monocular depth estimation.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, path=None, features=64, backbone="efficientnet_lite3", non_negative=True, exportable=True, channels_last=False, align_corners=True,
|
||||||
|
blocks={'expand': True}):
|
||||||
|
"""Init.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
path (str, optional): Path to saved model. Defaults to None.
|
||||||
|
features (int, optional): Number of features. Defaults to 256.
|
||||||
|
backbone (str, optional): Backbone network for encoder. Defaults to resnet50
|
||||||
|
"""
|
||||||
|
print("Loading weights: ", path)
|
||||||
|
|
||||||
|
super(MidasNet_small, self).__init__()
|
||||||
|
|
||||||
|
use_pretrained = False if path else True
|
||||||
|
|
||||||
|
self.channels_last = channels_last
|
||||||
|
self.blocks = blocks
|
||||||
|
self.backbone = backbone
|
||||||
|
|
||||||
|
self.groups = 1
|
||||||
|
|
||||||
|
features1=features
|
||||||
|
features2=features
|
||||||
|
features3=features
|
||||||
|
features4=features
|
||||||
|
self.expand = False
|
||||||
|
if "expand" in self.blocks and self.blocks['expand'] == True:
|
||||||
|
self.expand = True
|
||||||
|
features1=features
|
||||||
|
features2=features*2
|
||||||
|
features3=features*4
|
||||||
|
features4=features*8
|
||||||
|
|
||||||
|
self.pretrained, self.scratch = _make_encoder(self.backbone, features, use_pretrained, groups=self.groups, expand=self.expand, exportable=exportable)
|
||||||
|
|
||||||
|
self.scratch.activation = nn.ReLU(False)
|
||||||
|
|
||||||
|
self.scratch.refinenet4 = FeatureFusionBlock_custom(features4, self.scratch.activation, deconv=False, bn=False, expand=self.expand, align_corners=align_corners)
|
||||||
|
self.scratch.refinenet3 = FeatureFusionBlock_custom(features3, self.scratch.activation, deconv=False, bn=False, expand=self.expand, align_corners=align_corners)
|
||||||
|
self.scratch.refinenet2 = FeatureFusionBlock_custom(features2, self.scratch.activation, deconv=False, bn=False, expand=self.expand, align_corners=align_corners)
|
||||||
|
self.scratch.refinenet1 = FeatureFusionBlock_custom(features1, self.scratch.activation, deconv=False, bn=False, align_corners=align_corners)
|
||||||
|
|
||||||
|
|
||||||
|
self.scratch.output_conv = nn.Sequential(
|
||||||
|
nn.Conv2d(features, features//2, kernel_size=3, stride=1, padding=1, groups=self.groups),
|
||||||
|
Interpolate(scale_factor=2, mode="bilinear"),
|
||||||
|
nn.Conv2d(features//2, 32, kernel_size=3, stride=1, padding=1),
|
||||||
|
self.scratch.activation,
|
||||||
|
nn.Conv2d(32, 1, kernel_size=1, stride=1, padding=0),
|
||||||
|
nn.ReLU(True) if non_negative else nn.Identity(),
|
||||||
|
nn.Identity(),
|
||||||
|
)
|
||||||
|
|
||||||
|
if path:
|
||||||
|
self.load(path)
|
||||||
|
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
"""Forward pass.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x (tensor): input data (image)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
tensor: depth
|
||||||
|
"""
|
||||||
|
if self.channels_last==True:
|
||||||
|
print("self.channels_last = ", self.channels_last)
|
||||||
|
x.contiguous(memory_format=torch.channels_last)
|
||||||
|
|
||||||
|
|
||||||
|
layer_1 = self.pretrained.layer1(x)
|
||||||
|
layer_2 = self.pretrained.layer2(layer_1)
|
||||||
|
layer_3 = self.pretrained.layer3(layer_2)
|
||||||
|
layer_4 = self.pretrained.layer4(layer_3)
|
||||||
|
|
||||||
|
layer_1_rn = self.scratch.layer1_rn(layer_1)
|
||||||
|
layer_2_rn = self.scratch.layer2_rn(layer_2)
|
||||||
|
layer_3_rn = self.scratch.layer3_rn(layer_3)
|
||||||
|
layer_4_rn = self.scratch.layer4_rn(layer_4)
|
||||||
|
|
||||||
|
|
||||||
|
path_4 = self.scratch.refinenet4(layer_4_rn)
|
||||||
|
path_3 = self.scratch.refinenet3(path_4, layer_3_rn)
|
||||||
|
path_2 = self.scratch.refinenet2(path_3, layer_2_rn)
|
||||||
|
path_1 = self.scratch.refinenet1(path_2, layer_1_rn)
|
||||||
|
|
||||||
|
out = self.scratch.output_conv(path_1)
|
||||||
|
|
||||||
|
return torch.squeeze(out, dim=1)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
def fuse_model(m):
|
||||||
|
prev_previous_type = nn.Identity()
|
||||||
|
prev_previous_name = ''
|
||||||
|
previous_type = nn.Identity()
|
||||||
|
previous_name = ''
|
||||||
|
for name, module in m.named_modules():
|
||||||
|
if prev_previous_type == nn.Conv2d and previous_type == nn.BatchNorm2d and type(module) == nn.ReLU:
|
||||||
|
# print("FUSED ", prev_previous_name, previous_name, name)
|
||||||
|
torch.quantization.fuse_modules(m, [prev_previous_name, previous_name, name], inplace=True)
|
||||||
|
elif prev_previous_type == nn.Conv2d and previous_type == nn.BatchNorm2d:
|
||||||
|
# print("FUSED ", prev_previous_name, previous_name)
|
||||||
|
torch.quantization.fuse_modules(m, [prev_previous_name, previous_name], inplace=True)
|
||||||
|
# elif previous_type == nn.Conv2d and type(module) == nn.ReLU:
|
||||||
|
# print("FUSED ", previous_name, name)
|
||||||
|
# torch.quantization.fuse_modules(m, [previous_name, name], inplace=True)
|
||||||
|
|
||||||
|
prev_previous_type = previous_type
|
||||||
|
prev_previous_name = previous_name
|
||||||
|
previous_type = type(module)
|
||||||
|
previous_name = name
|
|
@ -0,0 +1,234 @@
|
||||||
|
import numpy as np
|
||||||
|
import cv2
|
||||||
|
import math
|
||||||
|
|
||||||
|
|
||||||
|
def apply_min_size(sample, size, image_interpolation_method=cv2.INTER_AREA):
|
||||||
|
"""Rezise the sample to ensure the given size. Keeps aspect ratio.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
sample (dict): sample
|
||||||
|
size (tuple): image size
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
tuple: new size
|
||||||
|
"""
|
||||||
|
shape = list(sample["disparity"].shape)
|
||||||
|
|
||||||
|
if shape[0] >= size[0] and shape[1] >= size[1]:
|
||||||
|
return sample
|
||||||
|
|
||||||
|
scale = [0, 0]
|
||||||
|
scale[0] = size[0] / shape[0]
|
||||||
|
scale[1] = size[1] / shape[1]
|
||||||
|
|
||||||
|
scale = max(scale)
|
||||||
|
|
||||||
|
shape[0] = math.ceil(scale * shape[0])
|
||||||
|
shape[1] = math.ceil(scale * shape[1])
|
||||||
|
|
||||||
|
# resize
|
||||||
|
sample["image"] = cv2.resize(
|
||||||
|
sample["image"], tuple(shape[::-1]), interpolation=image_interpolation_method
|
||||||
|
)
|
||||||
|
|
||||||
|
sample["disparity"] = cv2.resize(
|
||||||
|
sample["disparity"], tuple(shape[::-1]), interpolation=cv2.INTER_NEAREST
|
||||||
|
)
|
||||||
|
sample["mask"] = cv2.resize(
|
||||||
|
sample["mask"].astype(np.float32),
|
||||||
|
tuple(shape[::-1]),
|
||||||
|
interpolation=cv2.INTER_NEAREST,
|
||||||
|
)
|
||||||
|
sample["mask"] = sample["mask"].astype(bool)
|
||||||
|
|
||||||
|
return tuple(shape)
|
||||||
|
|
||||||
|
|
||||||
|
class Resize(object):
|
||||||
|
"""Resize sample to given size (width, height).
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
width,
|
||||||
|
height,
|
||||||
|
resize_target=True,
|
||||||
|
keep_aspect_ratio=False,
|
||||||
|
ensure_multiple_of=1,
|
||||||
|
resize_method="lower_bound",
|
||||||
|
image_interpolation_method=cv2.INTER_AREA,
|
||||||
|
):
|
||||||
|
"""Init.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
width (int): desired output width
|
||||||
|
height (int): desired output height
|
||||||
|
resize_target (bool, optional):
|
||||||
|
True: Resize the full sample (image, mask, target).
|
||||||
|
False: Resize image only.
|
||||||
|
Defaults to True.
|
||||||
|
keep_aspect_ratio (bool, optional):
|
||||||
|
True: Keep the aspect ratio of the input sample.
|
||||||
|
Output sample might not have the given width and height, and
|
||||||
|
resize behaviour depends on the parameter 'resize_method'.
|
||||||
|
Defaults to False.
|
||||||
|
ensure_multiple_of (int, optional):
|
||||||
|
Output width and height is constrained to be multiple of this parameter.
|
||||||
|
Defaults to 1.
|
||||||
|
resize_method (str, optional):
|
||||||
|
"lower_bound": Output will be at least as large as the given size.
|
||||||
|
"upper_bound": Output will be at max as large as the given size. (Output size might be smaller than given size.)
|
||||||
|
"minimal": Scale as least as possible. (Output size might be smaller than given size.)
|
||||||
|
Defaults to "lower_bound".
|
||||||
|
"""
|
||||||
|
self.__width = width
|
||||||
|
self.__height = height
|
||||||
|
|
||||||
|
self.__resize_target = resize_target
|
||||||
|
self.__keep_aspect_ratio = keep_aspect_ratio
|
||||||
|
self.__multiple_of = ensure_multiple_of
|
||||||
|
self.__resize_method = resize_method
|
||||||
|
self.__image_interpolation_method = image_interpolation_method
|
||||||
|
|
||||||
|
def constrain_to_multiple_of(self, x, min_val=0, max_val=None):
|
||||||
|
y = (np.round(x / self.__multiple_of) * self.__multiple_of).astype(int)
|
||||||
|
|
||||||
|
if max_val is not None and y > max_val:
|
||||||
|
y = (np.floor(x / self.__multiple_of) * self.__multiple_of).astype(int)
|
||||||
|
|
||||||
|
if y < min_val:
|
||||||
|
y = (np.ceil(x / self.__multiple_of) * self.__multiple_of).astype(int)
|
||||||
|
|
||||||
|
return y
|
||||||
|
|
||||||
|
def get_size(self, width, height):
|
||||||
|
# determine new height and width
|
||||||
|
scale_height = self.__height / height
|
||||||
|
scale_width = self.__width / width
|
||||||
|
|
||||||
|
if self.__keep_aspect_ratio:
|
||||||
|
if self.__resize_method == "lower_bound":
|
||||||
|
# scale such that output size is lower bound
|
||||||
|
if scale_width > scale_height:
|
||||||
|
# fit width
|
||||||
|
scale_height = scale_width
|
||||||
|
else:
|
||||||
|
# fit height
|
||||||
|
scale_width = scale_height
|
||||||
|
elif self.__resize_method == "upper_bound":
|
||||||
|
# scale such that output size is upper bound
|
||||||
|
if scale_width < scale_height:
|
||||||
|
# fit width
|
||||||
|
scale_height = scale_width
|
||||||
|
else:
|
||||||
|
# fit height
|
||||||
|
scale_width = scale_height
|
||||||
|
elif self.__resize_method == "minimal":
|
||||||
|
# scale as least as possbile
|
||||||
|
if abs(1 - scale_width) < abs(1 - scale_height):
|
||||||
|
# fit width
|
||||||
|
scale_height = scale_width
|
||||||
|
else:
|
||||||
|
# fit height
|
||||||
|
scale_width = scale_height
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
f"resize_method {self.__resize_method} not implemented"
|
||||||
|
)
|
||||||
|
|
||||||
|
if self.__resize_method == "lower_bound":
|
||||||
|
new_height = self.constrain_to_multiple_of(
|
||||||
|
scale_height * height, min_val=self.__height
|
||||||
|
)
|
||||||
|
new_width = self.constrain_to_multiple_of(
|
||||||
|
scale_width * width, min_val=self.__width
|
||||||
|
)
|
||||||
|
elif self.__resize_method == "upper_bound":
|
||||||
|
new_height = self.constrain_to_multiple_of(
|
||||||
|
scale_height * height, max_val=self.__height
|
||||||
|
)
|
||||||
|
new_width = self.constrain_to_multiple_of(
|
||||||
|
scale_width * width, max_val=self.__width
|
||||||
|
)
|
||||||
|
elif self.__resize_method == "minimal":
|
||||||
|
new_height = self.constrain_to_multiple_of(scale_height * height)
|
||||||
|
new_width = self.constrain_to_multiple_of(scale_width * width)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"resize_method {self.__resize_method} not implemented")
|
||||||
|
|
||||||
|
return (new_width, new_height)
|
||||||
|
|
||||||
|
def __call__(self, sample):
|
||||||
|
width, height = self.get_size(
|
||||||
|
sample["image"].shape[1], sample["image"].shape[0]
|
||||||
|
)
|
||||||
|
|
||||||
|
# resize sample
|
||||||
|
sample["image"] = cv2.resize(
|
||||||
|
sample["image"],
|
||||||
|
(width, height),
|
||||||
|
interpolation=self.__image_interpolation_method,
|
||||||
|
)
|
||||||
|
|
||||||
|
if self.__resize_target:
|
||||||
|
if "disparity" in sample:
|
||||||
|
sample["disparity"] = cv2.resize(
|
||||||
|
sample["disparity"],
|
||||||
|
(width, height),
|
||||||
|
interpolation=cv2.INTER_NEAREST,
|
||||||
|
)
|
||||||
|
|
||||||
|
if "depth" in sample:
|
||||||
|
sample["depth"] = cv2.resize(
|
||||||
|
sample["depth"], (width, height), interpolation=cv2.INTER_NEAREST
|
||||||
|
)
|
||||||
|
|
||||||
|
sample["mask"] = cv2.resize(
|
||||||
|
sample["mask"].astype(np.float32),
|
||||||
|
(width, height),
|
||||||
|
interpolation=cv2.INTER_NEAREST,
|
||||||
|
)
|
||||||
|
sample["mask"] = sample["mask"].astype(bool)
|
||||||
|
|
||||||
|
return sample
|
||||||
|
|
||||||
|
|
||||||
|
class NormalizeImage(object):
|
||||||
|
"""Normlize image by given mean and std.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, mean, std):
|
||||||
|
self.__mean = mean
|
||||||
|
self.__std = std
|
||||||
|
|
||||||
|
def __call__(self, sample):
|
||||||
|
sample["image"] = (sample["image"] - self.__mean) / self.__std
|
||||||
|
|
||||||
|
return sample
|
||||||
|
|
||||||
|
|
||||||
|
class PrepareForNet(object):
|
||||||
|
"""Prepare sample for usage as network input.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def __call__(self, sample):
|
||||||
|
image = np.transpose(sample["image"], (2, 0, 1))
|
||||||
|
sample["image"] = np.ascontiguousarray(image).astype(np.float32)
|
||||||
|
|
||||||
|
if "mask" in sample:
|
||||||
|
sample["mask"] = sample["mask"].astype(np.float32)
|
||||||
|
sample["mask"] = np.ascontiguousarray(sample["mask"])
|
||||||
|
|
||||||
|
if "disparity" in sample:
|
||||||
|
disparity = sample["disparity"].astype(np.float32)
|
||||||
|
sample["disparity"] = np.ascontiguousarray(disparity)
|
||||||
|
|
||||||
|
if "depth" in sample:
|
||||||
|
depth = sample["depth"].astype(np.float32)
|
||||||
|
sample["depth"] = np.ascontiguousarray(depth)
|
||||||
|
|
||||||
|
return sample
|
|
@ -0,0 +1,491 @@
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import timm
|
||||||
|
import types
|
||||||
|
import math
|
||||||
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
|
||||||
|
class Slice(nn.Module):
|
||||||
|
def __init__(self, start_index=1):
|
||||||
|
super(Slice, self).__init__()
|
||||||
|
self.start_index = start_index
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
return x[:, self.start_index :]
|
||||||
|
|
||||||
|
|
||||||
|
class AddReadout(nn.Module):
|
||||||
|
def __init__(self, start_index=1):
|
||||||
|
super(AddReadout, self).__init__()
|
||||||
|
self.start_index = start_index
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
if self.start_index == 2:
|
||||||
|
readout = (x[:, 0] + x[:, 1]) / 2
|
||||||
|
else:
|
||||||
|
readout = x[:, 0]
|
||||||
|
return x[:, self.start_index :] + readout.unsqueeze(1)
|
||||||
|
|
||||||
|
|
||||||
|
class ProjectReadout(nn.Module):
|
||||||
|
def __init__(self, in_features, start_index=1):
|
||||||
|
super(ProjectReadout, self).__init__()
|
||||||
|
self.start_index = start_index
|
||||||
|
|
||||||
|
self.project = nn.Sequential(nn.Linear(2 * in_features, in_features), nn.GELU())
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
readout = x[:, 0].unsqueeze(1).expand_as(x[:, self.start_index :])
|
||||||
|
features = torch.cat((x[:, self.start_index :], readout), -1)
|
||||||
|
|
||||||
|
return self.project(features)
|
||||||
|
|
||||||
|
|
||||||
|
class Transpose(nn.Module):
|
||||||
|
def __init__(self, dim0, dim1):
|
||||||
|
super(Transpose, self).__init__()
|
||||||
|
self.dim0 = dim0
|
||||||
|
self.dim1 = dim1
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
x = x.transpose(self.dim0, self.dim1)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
def forward_vit(pretrained, x):
|
||||||
|
b, c, h, w = x.shape
|
||||||
|
|
||||||
|
glob = pretrained.model.forward_flex(x)
|
||||||
|
|
||||||
|
layer_1 = pretrained.activations["1"]
|
||||||
|
layer_2 = pretrained.activations["2"]
|
||||||
|
layer_3 = pretrained.activations["3"]
|
||||||
|
layer_4 = pretrained.activations["4"]
|
||||||
|
|
||||||
|
layer_1 = pretrained.act_postprocess1[0:2](layer_1)
|
||||||
|
layer_2 = pretrained.act_postprocess2[0:2](layer_2)
|
||||||
|
layer_3 = pretrained.act_postprocess3[0:2](layer_3)
|
||||||
|
layer_4 = pretrained.act_postprocess4[0:2](layer_4)
|
||||||
|
|
||||||
|
unflatten = nn.Sequential(
|
||||||
|
nn.Unflatten(
|
||||||
|
2,
|
||||||
|
torch.Size(
|
||||||
|
[
|
||||||
|
h // pretrained.model.patch_size[1],
|
||||||
|
w // pretrained.model.patch_size[0],
|
||||||
|
]
|
||||||
|
),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
if layer_1.ndim == 3:
|
||||||
|
layer_1 = unflatten(layer_1)
|
||||||
|
if layer_2.ndim == 3:
|
||||||
|
layer_2 = unflatten(layer_2)
|
||||||
|
if layer_3.ndim == 3:
|
||||||
|
layer_3 = unflatten(layer_3)
|
||||||
|
if layer_4.ndim == 3:
|
||||||
|
layer_4 = unflatten(layer_4)
|
||||||
|
|
||||||
|
layer_1 = pretrained.act_postprocess1[3 : len(pretrained.act_postprocess1)](layer_1)
|
||||||
|
layer_2 = pretrained.act_postprocess2[3 : len(pretrained.act_postprocess2)](layer_2)
|
||||||
|
layer_3 = pretrained.act_postprocess3[3 : len(pretrained.act_postprocess3)](layer_3)
|
||||||
|
layer_4 = pretrained.act_postprocess4[3 : len(pretrained.act_postprocess4)](layer_4)
|
||||||
|
|
||||||
|
return layer_1, layer_2, layer_3, layer_4
|
||||||
|
|
||||||
|
|
||||||
|
def _resize_pos_embed(self, posemb, gs_h, gs_w):
|
||||||
|
posemb_tok, posemb_grid = (
|
||||||
|
posemb[:, : self.start_index],
|
||||||
|
posemb[0, self.start_index :],
|
||||||
|
)
|
||||||
|
|
||||||
|
gs_old = int(math.sqrt(len(posemb_grid)))
|
||||||
|
|
||||||
|
posemb_grid = posemb_grid.reshape(1, gs_old, gs_old, -1).permute(0, 3, 1, 2)
|
||||||
|
posemb_grid = F.interpolate(posemb_grid, size=(gs_h, gs_w), mode="bilinear")
|
||||||
|
posemb_grid = posemb_grid.permute(0, 2, 3, 1).reshape(1, gs_h * gs_w, -1)
|
||||||
|
|
||||||
|
posemb = torch.cat([posemb_tok, posemb_grid], dim=1)
|
||||||
|
|
||||||
|
return posemb
|
||||||
|
|
||||||
|
|
||||||
|
def forward_flex(self, x):
|
||||||
|
b, c, h, w = x.shape
|
||||||
|
|
||||||
|
pos_embed = self._resize_pos_embed(
|
||||||
|
self.pos_embed, h // self.patch_size[1], w // self.patch_size[0]
|
||||||
|
)
|
||||||
|
|
||||||
|
B = x.shape[0]
|
||||||
|
|
||||||
|
if hasattr(self.patch_embed, "backbone"):
|
||||||
|
x = self.patch_embed.backbone(x)
|
||||||
|
if isinstance(x, (list, tuple)):
|
||||||
|
x = x[-1] # last feature if backbone outputs list/tuple of features
|
||||||
|
|
||||||
|
x = self.patch_embed.proj(x).flatten(2).transpose(1, 2)
|
||||||
|
|
||||||
|
if getattr(self, "dist_token", None) is not None:
|
||||||
|
cls_tokens = self.cls_token.expand(
|
||||||
|
B, -1, -1
|
||||||
|
) # stole cls_tokens impl from Phil Wang, thanks
|
||||||
|
dist_token = self.dist_token.expand(B, -1, -1)
|
||||||
|
x = torch.cat((cls_tokens, dist_token, x), dim=1)
|
||||||
|
else:
|
||||||
|
cls_tokens = self.cls_token.expand(
|
||||||
|
B, -1, -1
|
||||||
|
) # stole cls_tokens impl from Phil Wang, thanks
|
||||||
|
x = torch.cat((cls_tokens, x), dim=1)
|
||||||
|
|
||||||
|
x = x + pos_embed
|
||||||
|
x = self.pos_drop(x)
|
||||||
|
|
||||||
|
for blk in self.blocks:
|
||||||
|
x = blk(x)
|
||||||
|
|
||||||
|
x = self.norm(x)
|
||||||
|
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
activations = {}
|
||||||
|
|
||||||
|
|
||||||
|
def get_activation(name):
|
||||||
|
def hook(model, input, output):
|
||||||
|
activations[name] = output
|
||||||
|
|
||||||
|
return hook
|
||||||
|
|
||||||
|
|
||||||
|
def get_readout_oper(vit_features, features, use_readout, start_index=1):
|
||||||
|
if use_readout == "ignore":
|
||||||
|
readout_oper = [Slice(start_index)] * len(features)
|
||||||
|
elif use_readout == "add":
|
||||||
|
readout_oper = [AddReadout(start_index)] * len(features)
|
||||||
|
elif use_readout == "project":
|
||||||
|
readout_oper = [
|
||||||
|
ProjectReadout(vit_features, start_index) for out_feat in features
|
||||||
|
]
|
||||||
|
else:
|
||||||
|
assert (
|
||||||
|
False
|
||||||
|
), "wrong operation for readout token, use_readout can be 'ignore', 'add', or 'project'"
|
||||||
|
|
||||||
|
return readout_oper
|
||||||
|
|
||||||
|
|
||||||
|
def _make_vit_b16_backbone(
|
||||||
|
model,
|
||||||
|
features=[96, 192, 384, 768],
|
||||||
|
size=[384, 384],
|
||||||
|
hooks=[2, 5, 8, 11],
|
||||||
|
vit_features=768,
|
||||||
|
use_readout="ignore",
|
||||||
|
start_index=1,
|
||||||
|
):
|
||||||
|
pretrained = nn.Module()
|
||||||
|
|
||||||
|
pretrained.model = model
|
||||||
|
pretrained.model.blocks[hooks[0]].register_forward_hook(get_activation("1"))
|
||||||
|
pretrained.model.blocks[hooks[1]].register_forward_hook(get_activation("2"))
|
||||||
|
pretrained.model.blocks[hooks[2]].register_forward_hook(get_activation("3"))
|
||||||
|
pretrained.model.blocks[hooks[3]].register_forward_hook(get_activation("4"))
|
||||||
|
|
||||||
|
pretrained.activations = activations
|
||||||
|
|
||||||
|
readout_oper = get_readout_oper(vit_features, features, use_readout, start_index)
|
||||||
|
|
||||||
|
# 32, 48, 136, 384
|
||||||
|
pretrained.act_postprocess1 = nn.Sequential(
|
||||||
|
readout_oper[0],
|
||||||
|
Transpose(1, 2),
|
||||||
|
nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
|
||||||
|
nn.Conv2d(
|
||||||
|
in_channels=vit_features,
|
||||||
|
out_channels=features[0],
|
||||||
|
kernel_size=1,
|
||||||
|
stride=1,
|
||||||
|
padding=0,
|
||||||
|
),
|
||||||
|
nn.ConvTranspose2d(
|
||||||
|
in_channels=features[0],
|
||||||
|
out_channels=features[0],
|
||||||
|
kernel_size=4,
|
||||||
|
stride=4,
|
||||||
|
padding=0,
|
||||||
|
bias=True,
|
||||||
|
dilation=1,
|
||||||
|
groups=1,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
pretrained.act_postprocess2 = nn.Sequential(
|
||||||
|
readout_oper[1],
|
||||||
|
Transpose(1, 2),
|
||||||
|
nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
|
||||||
|
nn.Conv2d(
|
||||||
|
in_channels=vit_features,
|
||||||
|
out_channels=features[1],
|
||||||
|
kernel_size=1,
|
||||||
|
stride=1,
|
||||||
|
padding=0,
|
||||||
|
),
|
||||||
|
nn.ConvTranspose2d(
|
||||||
|
in_channels=features[1],
|
||||||
|
out_channels=features[1],
|
||||||
|
kernel_size=2,
|
||||||
|
stride=2,
|
||||||
|
padding=0,
|
||||||
|
bias=True,
|
||||||
|
dilation=1,
|
||||||
|
groups=1,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
pretrained.act_postprocess3 = nn.Sequential(
|
||||||
|
readout_oper[2],
|
||||||
|
Transpose(1, 2),
|
||||||
|
nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
|
||||||
|
nn.Conv2d(
|
||||||
|
in_channels=vit_features,
|
||||||
|
out_channels=features[2],
|
||||||
|
kernel_size=1,
|
||||||
|
stride=1,
|
||||||
|
padding=0,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
pretrained.act_postprocess4 = nn.Sequential(
|
||||||
|
readout_oper[3],
|
||||||
|
Transpose(1, 2),
|
||||||
|
nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
|
||||||
|
nn.Conv2d(
|
||||||
|
in_channels=vit_features,
|
||||||
|
out_channels=features[3],
|
||||||
|
kernel_size=1,
|
||||||
|
stride=1,
|
||||||
|
padding=0,
|
||||||
|
),
|
||||||
|
nn.Conv2d(
|
||||||
|
in_channels=features[3],
|
||||||
|
out_channels=features[3],
|
||||||
|
kernel_size=3,
|
||||||
|
stride=2,
|
||||||
|
padding=1,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
pretrained.model.start_index = start_index
|
||||||
|
pretrained.model.patch_size = [16, 16]
|
||||||
|
|
||||||
|
# We inject this function into the VisionTransformer instances so that
|
||||||
|
# we can use it with interpolated position embeddings without modifying the library source.
|
||||||
|
pretrained.model.forward_flex = types.MethodType(forward_flex, pretrained.model)
|
||||||
|
pretrained.model._resize_pos_embed = types.MethodType(
|
||||||
|
_resize_pos_embed, pretrained.model
|
||||||
|
)
|
||||||
|
|
||||||
|
return pretrained
|
||||||
|
|
||||||
|
|
||||||
|
def _make_pretrained_vitl16_384(pretrained, use_readout="ignore", hooks=None):
|
||||||
|
model = timm.create_model("vit_large_patch16_384", pretrained=pretrained)
|
||||||
|
|
||||||
|
hooks = [5, 11, 17, 23] if hooks == None else hooks
|
||||||
|
return _make_vit_b16_backbone(
|
||||||
|
model,
|
||||||
|
features=[256, 512, 1024, 1024],
|
||||||
|
hooks=hooks,
|
||||||
|
vit_features=1024,
|
||||||
|
use_readout=use_readout,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _make_pretrained_vitb16_384(pretrained, use_readout="ignore", hooks=None):
|
||||||
|
model = timm.create_model("vit_base_patch16_384", pretrained=pretrained)
|
||||||
|
|
||||||
|
hooks = [2, 5, 8, 11] if hooks == None else hooks
|
||||||
|
return _make_vit_b16_backbone(
|
||||||
|
model, features=[96, 192, 384, 768], hooks=hooks, use_readout=use_readout
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _make_pretrained_deitb16_384(pretrained, use_readout="ignore", hooks=None):
|
||||||
|
model = timm.create_model("vit_deit_base_patch16_384", pretrained=pretrained)
|
||||||
|
|
||||||
|
hooks = [2, 5, 8, 11] if hooks == None else hooks
|
||||||
|
return _make_vit_b16_backbone(
|
||||||
|
model, features=[96, 192, 384, 768], hooks=hooks, use_readout=use_readout
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _make_pretrained_deitb16_distil_384(pretrained, use_readout="ignore", hooks=None):
|
||||||
|
model = timm.create_model(
|
||||||
|
"vit_deit_base_distilled_patch16_384", pretrained=pretrained
|
||||||
|
)
|
||||||
|
|
||||||
|
hooks = [2, 5, 8, 11] if hooks == None else hooks
|
||||||
|
return _make_vit_b16_backbone(
|
||||||
|
model,
|
||||||
|
features=[96, 192, 384, 768],
|
||||||
|
hooks=hooks,
|
||||||
|
use_readout=use_readout,
|
||||||
|
start_index=2,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _make_vit_b_rn50_backbone(
|
||||||
|
model,
|
||||||
|
features=[256, 512, 768, 768],
|
||||||
|
size=[384, 384],
|
||||||
|
hooks=[0, 1, 8, 11],
|
||||||
|
vit_features=768,
|
||||||
|
use_vit_only=False,
|
||||||
|
use_readout="ignore",
|
||||||
|
start_index=1,
|
||||||
|
):
|
||||||
|
pretrained = nn.Module()
|
||||||
|
|
||||||
|
pretrained.model = model
|
||||||
|
|
||||||
|
if use_vit_only == True:
|
||||||
|
pretrained.model.blocks[hooks[0]].register_forward_hook(get_activation("1"))
|
||||||
|
pretrained.model.blocks[hooks[1]].register_forward_hook(get_activation("2"))
|
||||||
|
else:
|
||||||
|
pretrained.model.patch_embed.backbone.stages[0].register_forward_hook(
|
||||||
|
get_activation("1")
|
||||||
|
)
|
||||||
|
pretrained.model.patch_embed.backbone.stages[1].register_forward_hook(
|
||||||
|
get_activation("2")
|
||||||
|
)
|
||||||
|
|
||||||
|
pretrained.model.blocks[hooks[2]].register_forward_hook(get_activation("3"))
|
||||||
|
pretrained.model.blocks[hooks[3]].register_forward_hook(get_activation("4"))
|
||||||
|
|
||||||
|
pretrained.activations = activations
|
||||||
|
|
||||||
|
readout_oper = get_readout_oper(vit_features, features, use_readout, start_index)
|
||||||
|
|
||||||
|
if use_vit_only == True:
|
||||||
|
pretrained.act_postprocess1 = nn.Sequential(
|
||||||
|
readout_oper[0],
|
||||||
|
Transpose(1, 2),
|
||||||
|
nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
|
||||||
|
nn.Conv2d(
|
||||||
|
in_channels=vit_features,
|
||||||
|
out_channels=features[0],
|
||||||
|
kernel_size=1,
|
||||||
|
stride=1,
|
||||||
|
padding=0,
|
||||||
|
),
|
||||||
|
nn.ConvTranspose2d(
|
||||||
|
in_channels=features[0],
|
||||||
|
out_channels=features[0],
|
||||||
|
kernel_size=4,
|
||||||
|
stride=4,
|
||||||
|
padding=0,
|
||||||
|
bias=True,
|
||||||
|
dilation=1,
|
||||||
|
groups=1,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
pretrained.act_postprocess2 = nn.Sequential(
|
||||||
|
readout_oper[1],
|
||||||
|
Transpose(1, 2),
|
||||||
|
nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
|
||||||
|
nn.Conv2d(
|
||||||
|
in_channels=vit_features,
|
||||||
|
out_channels=features[1],
|
||||||
|
kernel_size=1,
|
||||||
|
stride=1,
|
||||||
|
padding=0,
|
||||||
|
),
|
||||||
|
nn.ConvTranspose2d(
|
||||||
|
in_channels=features[1],
|
||||||
|
out_channels=features[1],
|
||||||
|
kernel_size=2,
|
||||||
|
stride=2,
|
||||||
|
padding=0,
|
||||||
|
bias=True,
|
||||||
|
dilation=1,
|
||||||
|
groups=1,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
pretrained.act_postprocess1 = nn.Sequential(
|
||||||
|
nn.Identity(), nn.Identity(), nn.Identity()
|
||||||
|
)
|
||||||
|
pretrained.act_postprocess2 = nn.Sequential(
|
||||||
|
nn.Identity(), nn.Identity(), nn.Identity()
|
||||||
|
)
|
||||||
|
|
||||||
|
pretrained.act_postprocess3 = nn.Sequential(
|
||||||
|
readout_oper[2],
|
||||||
|
Transpose(1, 2),
|
||||||
|
nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
|
||||||
|
nn.Conv2d(
|
||||||
|
in_channels=vit_features,
|
||||||
|
out_channels=features[2],
|
||||||
|
kernel_size=1,
|
||||||
|
stride=1,
|
||||||
|
padding=0,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
pretrained.act_postprocess4 = nn.Sequential(
|
||||||
|
readout_oper[3],
|
||||||
|
Transpose(1, 2),
|
||||||
|
nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
|
||||||
|
nn.Conv2d(
|
||||||
|
in_channels=vit_features,
|
||||||
|
out_channels=features[3],
|
||||||
|
kernel_size=1,
|
||||||
|
stride=1,
|
||||||
|
padding=0,
|
||||||
|
),
|
||||||
|
nn.Conv2d(
|
||||||
|
in_channels=features[3],
|
||||||
|
out_channels=features[3],
|
||||||
|
kernel_size=3,
|
||||||
|
stride=2,
|
||||||
|
padding=1,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
pretrained.model.start_index = start_index
|
||||||
|
pretrained.model.patch_size = [16, 16]
|
||||||
|
|
||||||
|
# We inject this function into the VisionTransformer instances so that
|
||||||
|
# we can use it with interpolated position embeddings without modifying the library source.
|
||||||
|
pretrained.model.forward_flex = types.MethodType(forward_flex, pretrained.model)
|
||||||
|
|
||||||
|
# We inject this function into the VisionTransformer instances so that
|
||||||
|
# we can use it with interpolated position embeddings without modifying the library source.
|
||||||
|
pretrained.model._resize_pos_embed = types.MethodType(
|
||||||
|
_resize_pos_embed, pretrained.model
|
||||||
|
)
|
||||||
|
|
||||||
|
return pretrained
|
||||||
|
|
||||||
|
|
||||||
|
def _make_pretrained_vitb_rn50_384(
|
||||||
|
pretrained, use_readout="ignore", hooks=None, use_vit_only=False
|
||||||
|
):
|
||||||
|
model = timm.create_model("vit_base_resnet50_384", pretrained=pretrained)
|
||||||
|
|
||||||
|
hooks = [0, 1, 8, 11] if hooks == None else hooks
|
||||||
|
return _make_vit_b_rn50_backbone(
|
||||||
|
model,
|
||||||
|
features=[256, 512, 768, 768],
|
||||||
|
size=[384, 384],
|
||||||
|
hooks=hooks,
|
||||||
|
use_vit_only=use_vit_only,
|
||||||
|
use_readout=use_readout,
|
||||||
|
)
|
|
@ -0,0 +1,189 @@
|
||||||
|
"""Utils for monoDepth."""
|
||||||
|
import sys
|
||||||
|
import re
|
||||||
|
import numpy as np
|
||||||
|
import cv2
|
||||||
|
import torch
|
||||||
|
|
||||||
|
|
||||||
|
def read_pfm(path):
|
||||||
|
"""Read pfm file.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
path (str): path to file
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
tuple: (data, scale)
|
||||||
|
"""
|
||||||
|
with open(path, "rb") as file:
|
||||||
|
|
||||||
|
color = None
|
||||||
|
width = None
|
||||||
|
height = None
|
||||||
|
scale = None
|
||||||
|
endian = None
|
||||||
|
|
||||||
|
header = file.readline().rstrip()
|
||||||
|
if header.decode("ascii") == "PF":
|
||||||
|
color = True
|
||||||
|
elif header.decode("ascii") == "Pf":
|
||||||
|
color = False
|
||||||
|
else:
|
||||||
|
raise Exception("Not a PFM file: " + path)
|
||||||
|
|
||||||
|
dim_match = re.match(r"^(\d+)\s(\d+)\s$", file.readline().decode("ascii"))
|
||||||
|
if dim_match:
|
||||||
|
width, height = list(map(int, dim_match.groups()))
|
||||||
|
else:
|
||||||
|
raise Exception("Malformed PFM header.")
|
||||||
|
|
||||||
|
scale = float(file.readline().decode("ascii").rstrip())
|
||||||
|
if scale < 0:
|
||||||
|
# little-endian
|
||||||
|
endian = "<"
|
||||||
|
scale = -scale
|
||||||
|
else:
|
||||||
|
# big-endian
|
||||||
|
endian = ">"
|
||||||
|
|
||||||
|
data = np.fromfile(file, endian + "f")
|
||||||
|
shape = (height, width, 3) if color else (height, width)
|
||||||
|
|
||||||
|
data = np.reshape(data, shape)
|
||||||
|
data = np.flipud(data)
|
||||||
|
|
||||||
|
return data, scale
|
||||||
|
|
||||||
|
|
||||||
|
def write_pfm(path, image, scale=1):
|
||||||
|
"""Write pfm file.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
path (str): pathto file
|
||||||
|
image (array): data
|
||||||
|
scale (int, optional): Scale. Defaults to 1.
|
||||||
|
"""
|
||||||
|
|
||||||
|
with open(path, "wb") as file:
|
||||||
|
color = None
|
||||||
|
|
||||||
|
if image.dtype.name != "float32":
|
||||||
|
raise Exception("Image dtype must be float32.")
|
||||||
|
|
||||||
|
image = np.flipud(image)
|
||||||
|
|
||||||
|
if len(image.shape) == 3 and image.shape[2] == 3: # color image
|
||||||
|
color = True
|
||||||
|
elif (
|
||||||
|
len(image.shape) == 2 or len(image.shape) == 3 and image.shape[2] == 1
|
||||||
|
): # greyscale
|
||||||
|
color = False
|
||||||
|
else:
|
||||||
|
raise Exception("Image must have H x W x 3, H x W x 1 or H x W dimensions.")
|
||||||
|
|
||||||
|
file.write("PF\n" if color else "Pf\n".encode())
|
||||||
|
file.write("%d %d\n".encode() % (image.shape[1], image.shape[0]))
|
||||||
|
|
||||||
|
endian = image.dtype.byteorder
|
||||||
|
|
||||||
|
if endian == "<" or endian == "=" and sys.byteorder == "little":
|
||||||
|
scale = -scale
|
||||||
|
|
||||||
|
file.write("%f\n".encode() % scale)
|
||||||
|
|
||||||
|
image.tofile(file)
|
||||||
|
|
||||||
|
|
||||||
|
def read_image(path):
|
||||||
|
"""Read image and output RGB image (0-1).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
path (str): path to file
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
array: RGB image (0-1)
|
||||||
|
"""
|
||||||
|
img = cv2.imread(path)
|
||||||
|
|
||||||
|
if img.ndim == 2:
|
||||||
|
img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
|
||||||
|
|
||||||
|
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) / 255.0
|
||||||
|
|
||||||
|
return img
|
||||||
|
|
||||||
|
|
||||||
|
def resize_image(img):
|
||||||
|
"""Resize image and make it fit for network.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
img (array): image
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
tensor: data ready for network
|
||||||
|
"""
|
||||||
|
height_orig = img.shape[0]
|
||||||
|
width_orig = img.shape[1]
|
||||||
|
|
||||||
|
if width_orig > height_orig:
|
||||||
|
scale = width_orig / 384
|
||||||
|
else:
|
||||||
|
scale = height_orig / 384
|
||||||
|
|
||||||
|
height = (np.ceil(height_orig / scale / 32) * 32).astype(int)
|
||||||
|
width = (np.ceil(width_orig / scale / 32) * 32).astype(int)
|
||||||
|
|
||||||
|
img_resized = cv2.resize(img, (width, height), interpolation=cv2.INTER_AREA)
|
||||||
|
|
||||||
|
img_resized = (
|
||||||
|
torch.from_numpy(np.transpose(img_resized, (2, 0, 1))).contiguous().float()
|
||||||
|
)
|
||||||
|
img_resized = img_resized.unsqueeze(0)
|
||||||
|
|
||||||
|
return img_resized
|
||||||
|
|
||||||
|
|
||||||
|
def resize_depth(depth, width, height):
|
||||||
|
"""Resize depth map and bring to CPU (numpy).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
depth (tensor): depth
|
||||||
|
width (int): image width
|
||||||
|
height (int): image height
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
array: processed depth
|
||||||
|
"""
|
||||||
|
depth = torch.squeeze(depth[0, :, :, :]).to("cpu")
|
||||||
|
|
||||||
|
depth_resized = cv2.resize(
|
||||||
|
depth.numpy(), (width, height), interpolation=cv2.INTER_CUBIC
|
||||||
|
)
|
||||||
|
|
||||||
|
return depth_resized
|
||||||
|
|
||||||
|
def write_depth(path, depth, bits=1):
|
||||||
|
"""Write depth map to pfm and png file.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
path (str): filepath without extension
|
||||||
|
depth (array): depth
|
||||||
|
"""
|
||||||
|
write_pfm(path + ".pfm", depth.astype(np.float32))
|
||||||
|
|
||||||
|
depth_min = depth.min()
|
||||||
|
depth_max = depth.max()
|
||||||
|
|
||||||
|
max_val = (2**(8*bits))-1
|
||||||
|
|
||||||
|
if depth_max - depth_min > np.finfo("float").eps:
|
||||||
|
out = max_val * (depth - depth_min) / (depth_max - depth_min)
|
||||||
|
else:
|
||||||
|
out = np.zeros(depth.shape, dtype=depth.type)
|
||||||
|
|
||||||
|
if bits == 1:
|
||||||
|
cv2.imwrite(path + ".png", out.astype("uint8"))
|
||||||
|
elif bits == 2:
|
||||||
|
cv2.imwrite(path + ".png", out.astype("uint16"))
|
||||||
|
|
||||||
|
return
|
|
@ -1,641 +0,0 @@
|
||||||
"""shout-out to https://github.com/lucidrains/x-transformers/tree/main/x_transformers"""
|
|
||||||
import torch
|
|
||||||
from torch import nn, einsum
|
|
||||||
import torch.nn.functional as F
|
|
||||||
from functools import partial
|
|
||||||
from inspect import isfunction
|
|
||||||
from collections import namedtuple
|
|
||||||
from einops import rearrange, repeat, reduce
|
|
||||||
|
|
||||||
# constants
|
|
||||||
|
|
||||||
DEFAULT_DIM_HEAD = 64
|
|
||||||
|
|
||||||
Intermediates = namedtuple('Intermediates', [
|
|
||||||
'pre_softmax_attn',
|
|
||||||
'post_softmax_attn'
|
|
||||||
])
|
|
||||||
|
|
||||||
LayerIntermediates = namedtuple('Intermediates', [
|
|
||||||
'hiddens',
|
|
||||||
'attn_intermediates'
|
|
||||||
])
|
|
||||||
|
|
||||||
|
|
||||||
class AbsolutePositionalEmbedding(nn.Module):
|
|
||||||
def __init__(self, dim, max_seq_len):
|
|
||||||
super().__init__()
|
|
||||||
self.emb = nn.Embedding(max_seq_len, dim)
|
|
||||||
self.init_()
|
|
||||||
|
|
||||||
def init_(self):
|
|
||||||
nn.init.normal_(self.emb.weight, std=0.02)
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
n = torch.arange(x.shape[1], device=x.device)
|
|
||||||
return self.emb(n)[None, :, :]
|
|
||||||
|
|
||||||
|
|
||||||
class FixedPositionalEmbedding(nn.Module):
|
|
||||||
def __init__(self, dim):
|
|
||||||
super().__init__()
|
|
||||||
inv_freq = 1. / (10000 ** (torch.arange(0, dim, 2).float() / dim))
|
|
||||||
self.register_buffer('inv_freq', inv_freq)
|
|
||||||
|
|
||||||
def forward(self, x, seq_dim=1, offset=0):
|
|
||||||
t = torch.arange(x.shape[seq_dim], device=x.device).type_as(self.inv_freq) + offset
|
|
||||||
sinusoid_inp = torch.einsum('i , j -> i j', t, self.inv_freq)
|
|
||||||
emb = torch.cat((sinusoid_inp.sin(), sinusoid_inp.cos()), dim=-1)
|
|
||||||
return emb[None, :, :]
|
|
||||||
|
|
||||||
|
|
||||||
# helpers
|
|
||||||
|
|
||||||
def exists(val):
|
|
||||||
return val is not None
|
|
||||||
|
|
||||||
|
|
||||||
def default(val, d):
|
|
||||||
if exists(val):
|
|
||||||
return val
|
|
||||||
return d() if isfunction(d) else d
|
|
||||||
|
|
||||||
|
|
||||||
def always(val):
|
|
||||||
def inner(*args, **kwargs):
|
|
||||||
return val
|
|
||||||
return inner
|
|
||||||
|
|
||||||
|
|
||||||
def not_equals(val):
|
|
||||||
def inner(x):
|
|
||||||
return x != val
|
|
||||||
return inner
|
|
||||||
|
|
||||||
|
|
||||||
def equals(val):
|
|
||||||
def inner(x):
|
|
||||||
return x == val
|
|
||||||
return inner
|
|
||||||
|
|
||||||
|
|
||||||
def max_neg_value(tensor):
|
|
||||||
return -torch.finfo(tensor.dtype).max
|
|
||||||
|
|
||||||
|
|
||||||
# keyword argument helpers
|
|
||||||
|
|
||||||
def pick_and_pop(keys, d):
|
|
||||||
values = list(map(lambda key: d.pop(key), keys))
|
|
||||||
return dict(zip(keys, values))
|
|
||||||
|
|
||||||
|
|
||||||
def group_dict_by_key(cond, d):
|
|
||||||
return_val = [dict(), dict()]
|
|
||||||
for key in d.keys():
|
|
||||||
match = bool(cond(key))
|
|
||||||
ind = int(not match)
|
|
||||||
return_val[ind][key] = d[key]
|
|
||||||
return (*return_val,)
|
|
||||||
|
|
||||||
|
|
||||||
def string_begins_with(prefix, str):
|
|
||||||
return str.startswith(prefix)
|
|
||||||
|
|
||||||
|
|
||||||
def group_by_key_prefix(prefix, d):
|
|
||||||
return group_dict_by_key(partial(string_begins_with, prefix), d)
|
|
||||||
|
|
||||||
|
|
||||||
def groupby_prefix_and_trim(prefix, d):
|
|
||||||
kwargs_with_prefix, kwargs = group_dict_by_key(partial(string_begins_with, prefix), d)
|
|
||||||
kwargs_without_prefix = dict(map(lambda x: (x[0][len(prefix):], x[1]), tuple(kwargs_with_prefix.items())))
|
|
||||||
return kwargs_without_prefix, kwargs
|
|
||||||
|
|
||||||
|
|
||||||
# classes
|
|
||||||
class Scale(nn.Module):
|
|
||||||
def __init__(self, value, fn):
|
|
||||||
super().__init__()
|
|
||||||
self.value = value
|
|
||||||
self.fn = fn
|
|
||||||
|
|
||||||
def forward(self, x, **kwargs):
|
|
||||||
x, *rest = self.fn(x, **kwargs)
|
|
||||||
return (x * self.value, *rest)
|
|
||||||
|
|
||||||
|
|
||||||
class Rezero(nn.Module):
|
|
||||||
def __init__(self, fn):
|
|
||||||
super().__init__()
|
|
||||||
self.fn = fn
|
|
||||||
self.g = nn.Parameter(torch.zeros(1))
|
|
||||||
|
|
||||||
def forward(self, x, **kwargs):
|
|
||||||
x, *rest = self.fn(x, **kwargs)
|
|
||||||
return (x * self.g, *rest)
|
|
||||||
|
|
||||||
|
|
||||||
class ScaleNorm(nn.Module):
|
|
||||||
def __init__(self, dim, eps=1e-5):
|
|
||||||
super().__init__()
|
|
||||||
self.scale = dim ** -0.5
|
|
||||||
self.eps = eps
|
|
||||||
self.g = nn.Parameter(torch.ones(1))
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
norm = torch.norm(x, dim=-1, keepdim=True) * self.scale
|
|
||||||
return x / norm.clamp(min=self.eps) * self.g
|
|
||||||
|
|
||||||
|
|
||||||
class RMSNorm(nn.Module):
|
|
||||||
def __init__(self, dim, eps=1e-8):
|
|
||||||
super().__init__()
|
|
||||||
self.scale = dim ** -0.5
|
|
||||||
self.eps = eps
|
|
||||||
self.g = nn.Parameter(torch.ones(dim))
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
norm = torch.norm(x, dim=-1, keepdim=True) * self.scale
|
|
||||||
return x / norm.clamp(min=self.eps) * self.g
|
|
||||||
|
|
||||||
|
|
||||||
class Residual(nn.Module):
|
|
||||||
def forward(self, x, residual):
|
|
||||||
return x + residual
|
|
||||||
|
|
||||||
|
|
||||||
class GRUGating(nn.Module):
|
|
||||||
def __init__(self, dim):
|
|
||||||
super().__init__()
|
|
||||||
self.gru = nn.GRUCell(dim, dim)
|
|
||||||
|
|
||||||
def forward(self, x, residual):
|
|
||||||
gated_output = self.gru(
|
|
||||||
rearrange(x, 'b n d -> (b n) d'),
|
|
||||||
rearrange(residual, 'b n d -> (b n) d')
|
|
||||||
)
|
|
||||||
|
|
||||||
return gated_output.reshape_as(x)
|
|
||||||
|
|
||||||
|
|
||||||
# feedforward
|
|
||||||
|
|
||||||
class GEGLU(nn.Module):
|
|
||||||
def __init__(self, dim_in, dim_out):
|
|
||||||
super().__init__()
|
|
||||||
self.proj = nn.Linear(dim_in, dim_out * 2)
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
x, gate = self.proj(x).chunk(2, dim=-1)
|
|
||||||
return x * F.gelu(gate)
|
|
||||||
|
|
||||||
|
|
||||||
class FeedForward(nn.Module):
|
|
||||||
def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.):
|
|
||||||
super().__init__()
|
|
||||||
inner_dim = int(dim * mult)
|
|
||||||
dim_out = default(dim_out, dim)
|
|
||||||
project_in = nn.Sequential(
|
|
||||||
nn.Linear(dim, inner_dim),
|
|
||||||
nn.GELU()
|
|
||||||
) if not glu else GEGLU(dim, inner_dim)
|
|
||||||
|
|
||||||
self.net = nn.Sequential(
|
|
||||||
project_in,
|
|
||||||
nn.Dropout(dropout),
|
|
||||||
nn.Linear(inner_dim, dim_out)
|
|
||||||
)
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
return self.net(x)
|
|
||||||
|
|
||||||
|
|
||||||
# attention.
|
|
||||||
class Attention(nn.Module):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
dim,
|
|
||||||
dim_head=DEFAULT_DIM_HEAD,
|
|
||||||
heads=8,
|
|
||||||
causal=False,
|
|
||||||
mask=None,
|
|
||||||
talking_heads=False,
|
|
||||||
sparse_topk=None,
|
|
||||||
use_entmax15=False,
|
|
||||||
num_mem_kv=0,
|
|
||||||
dropout=0.,
|
|
||||||
on_attn=False
|
|
||||||
):
|
|
||||||
super().__init__()
|
|
||||||
if use_entmax15:
|
|
||||||
raise NotImplementedError("Check out entmax activation instead of softmax activation!")
|
|
||||||
self.scale = dim_head ** -0.5
|
|
||||||
self.heads = heads
|
|
||||||
self.causal = causal
|
|
||||||
self.mask = mask
|
|
||||||
|
|
||||||
inner_dim = dim_head * heads
|
|
||||||
|
|
||||||
self.to_q = nn.Linear(dim, inner_dim, bias=False)
|
|
||||||
self.to_k = nn.Linear(dim, inner_dim, bias=False)
|
|
||||||
self.to_v = nn.Linear(dim, inner_dim, bias=False)
|
|
||||||
self.dropout = nn.Dropout(dropout)
|
|
||||||
|
|
||||||
# talking heads
|
|
||||||
self.talking_heads = talking_heads
|
|
||||||
if talking_heads:
|
|
||||||
self.pre_softmax_proj = nn.Parameter(torch.randn(heads, heads))
|
|
||||||
self.post_softmax_proj = nn.Parameter(torch.randn(heads, heads))
|
|
||||||
|
|
||||||
# explicit topk sparse attention
|
|
||||||
self.sparse_topk = sparse_topk
|
|
||||||
|
|
||||||
# entmax
|
|
||||||
#self.attn_fn = entmax15 if use_entmax15 else F.softmax
|
|
||||||
self.attn_fn = F.softmax
|
|
||||||
|
|
||||||
# add memory key / values
|
|
||||||
self.num_mem_kv = num_mem_kv
|
|
||||||
if num_mem_kv > 0:
|
|
||||||
self.mem_k = nn.Parameter(torch.randn(heads, num_mem_kv, dim_head))
|
|
||||||
self.mem_v = nn.Parameter(torch.randn(heads, num_mem_kv, dim_head))
|
|
||||||
|
|
||||||
# attention on attention
|
|
||||||
self.attn_on_attn = on_attn
|
|
||||||
self.to_out = nn.Sequential(nn.Linear(inner_dim, dim * 2), nn.GLU()) if on_attn else nn.Linear(inner_dim, dim)
|
|
||||||
|
|
||||||
def forward(
|
|
||||||
self,
|
|
||||||
x,
|
|
||||||
context=None,
|
|
||||||
mask=None,
|
|
||||||
context_mask=None,
|
|
||||||
rel_pos=None,
|
|
||||||
sinusoidal_emb=None,
|
|
||||||
prev_attn=None,
|
|
||||||
mem=None
|
|
||||||
):
|
|
||||||
b, n, _, h, talking_heads, device = *x.shape, self.heads, self.talking_heads, x.device
|
|
||||||
kv_input = default(context, x)
|
|
||||||
|
|
||||||
q_input = x
|
|
||||||
k_input = kv_input
|
|
||||||
v_input = kv_input
|
|
||||||
|
|
||||||
if exists(mem):
|
|
||||||
k_input = torch.cat((mem, k_input), dim=-2)
|
|
||||||
v_input = torch.cat((mem, v_input), dim=-2)
|
|
||||||
|
|
||||||
if exists(sinusoidal_emb):
|
|
||||||
# in shortformer, the query would start at a position offset depending on the past cached memory
|
|
||||||
offset = k_input.shape[-2] - q_input.shape[-2]
|
|
||||||
q_input = q_input + sinusoidal_emb(q_input, offset=offset)
|
|
||||||
k_input = k_input + sinusoidal_emb(k_input)
|
|
||||||
|
|
||||||
q = self.to_q(q_input)
|
|
||||||
k = self.to_k(k_input)
|
|
||||||
v = self.to_v(v_input)
|
|
||||||
|
|
||||||
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h=h), (q, k, v))
|
|
||||||
|
|
||||||
input_mask = None
|
|
||||||
if any(map(exists, (mask, context_mask))):
|
|
||||||
q_mask = default(mask, lambda: torch.ones((b, n), device=device).bool())
|
|
||||||
k_mask = q_mask if not exists(context) else context_mask
|
|
||||||
k_mask = default(k_mask, lambda: torch.ones((b, k.shape[-2]), device=device).bool())
|
|
||||||
q_mask = rearrange(q_mask, 'b i -> b () i ()')
|
|
||||||
k_mask = rearrange(k_mask, 'b j -> b () () j')
|
|
||||||
input_mask = q_mask * k_mask
|
|
||||||
|
|
||||||
if self.num_mem_kv > 0:
|
|
||||||
mem_k, mem_v = map(lambda t: repeat(t, 'h n d -> b h n d', b=b), (self.mem_k, self.mem_v))
|
|
||||||
k = torch.cat((mem_k, k), dim=-2)
|
|
||||||
v = torch.cat((mem_v, v), dim=-2)
|
|
||||||
if exists(input_mask):
|
|
||||||
input_mask = F.pad(input_mask, (self.num_mem_kv, 0), value=True)
|
|
||||||
|
|
||||||
dots = einsum('b h i d, b h j d -> b h i j', q, k) * self.scale
|
|
||||||
mask_value = max_neg_value(dots)
|
|
||||||
|
|
||||||
if exists(prev_attn):
|
|
||||||
dots = dots + prev_attn
|
|
||||||
|
|
||||||
pre_softmax_attn = dots
|
|
||||||
|
|
||||||
if talking_heads:
|
|
||||||
dots = einsum('b h i j, h k -> b k i j', dots, self.pre_softmax_proj).contiguous()
|
|
||||||
|
|
||||||
if exists(rel_pos):
|
|
||||||
dots = rel_pos(dots)
|
|
||||||
|
|
||||||
if exists(input_mask):
|
|
||||||
dots.masked_fill_(~input_mask, mask_value)
|
|
||||||
del input_mask
|
|
||||||
|
|
||||||
if self.causal:
|
|
||||||
i, j = dots.shape[-2:]
|
|
||||||
r = torch.arange(i, device=device)
|
|
||||||
mask = rearrange(r, 'i -> () () i ()') < rearrange(r, 'j -> () () () j')
|
|
||||||
mask = F.pad(mask, (j - i, 0), value=False)
|
|
||||||
dots.masked_fill_(mask, mask_value)
|
|
||||||
del mask
|
|
||||||
|
|
||||||
if exists(self.sparse_topk) and self.sparse_topk < dots.shape[-1]:
|
|
||||||
top, _ = dots.topk(self.sparse_topk, dim=-1)
|
|
||||||
vk = top[..., -1].unsqueeze(-1).expand_as(dots)
|
|
||||||
mask = dots < vk
|
|
||||||
dots.masked_fill_(mask, mask_value)
|
|
||||||
del mask
|
|
||||||
|
|
||||||
attn = self.attn_fn(dots, dim=-1)
|
|
||||||
post_softmax_attn = attn
|
|
||||||
|
|
||||||
attn = self.dropout(attn)
|
|
||||||
|
|
||||||
if talking_heads:
|
|
||||||
attn = einsum('b h i j, h k -> b k i j', attn, self.post_softmax_proj).contiguous()
|
|
||||||
|
|
||||||
out = einsum('b h i j, b h j d -> b h i d', attn, v)
|
|
||||||
out = rearrange(out, 'b h n d -> b n (h d)')
|
|
||||||
|
|
||||||
intermediates = Intermediates(
|
|
||||||
pre_softmax_attn=pre_softmax_attn,
|
|
||||||
post_softmax_attn=post_softmax_attn
|
|
||||||
)
|
|
||||||
|
|
||||||
return self.to_out(out), intermediates
|
|
||||||
|
|
||||||
|
|
||||||
class AttentionLayers(nn.Module):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
dim,
|
|
||||||
depth,
|
|
||||||
heads=8,
|
|
||||||
causal=False,
|
|
||||||
cross_attend=False,
|
|
||||||
only_cross=False,
|
|
||||||
use_scalenorm=False,
|
|
||||||
use_rmsnorm=False,
|
|
||||||
use_rezero=False,
|
|
||||||
rel_pos_num_buckets=32,
|
|
||||||
rel_pos_max_distance=128,
|
|
||||||
position_infused_attn=False,
|
|
||||||
custom_layers=None,
|
|
||||||
sandwich_coef=None,
|
|
||||||
par_ratio=None,
|
|
||||||
residual_attn=False,
|
|
||||||
cross_residual_attn=False,
|
|
||||||
macaron=False,
|
|
||||||
pre_norm=True,
|
|
||||||
gate_residual=False,
|
|
||||||
**kwargs
|
|
||||||
):
|
|
||||||
super().__init__()
|
|
||||||
ff_kwargs, kwargs = groupby_prefix_and_trim('ff_', kwargs)
|
|
||||||
attn_kwargs, _ = groupby_prefix_and_trim('attn_', kwargs)
|
|
||||||
|
|
||||||
dim_head = attn_kwargs.get('dim_head', DEFAULT_DIM_HEAD)
|
|
||||||
|
|
||||||
self.dim = dim
|
|
||||||
self.depth = depth
|
|
||||||
self.layers = nn.ModuleList([])
|
|
||||||
|
|
||||||
self.has_pos_emb = position_infused_attn
|
|
||||||
self.pia_pos_emb = FixedPositionalEmbedding(dim) if position_infused_attn else None
|
|
||||||
self.rotary_pos_emb = always(None)
|
|
||||||
|
|
||||||
assert rel_pos_num_buckets <= rel_pos_max_distance, 'number of relative position buckets must be less than the relative position max distance'
|
|
||||||
self.rel_pos = None
|
|
||||||
|
|
||||||
self.pre_norm = pre_norm
|
|
||||||
|
|
||||||
self.residual_attn = residual_attn
|
|
||||||
self.cross_residual_attn = cross_residual_attn
|
|
||||||
|
|
||||||
norm_class = ScaleNorm if use_scalenorm else nn.LayerNorm
|
|
||||||
norm_class = RMSNorm if use_rmsnorm else norm_class
|
|
||||||
norm_fn = partial(norm_class, dim)
|
|
||||||
|
|
||||||
norm_fn = nn.Identity if use_rezero else norm_fn
|
|
||||||
branch_fn = Rezero if use_rezero else None
|
|
||||||
|
|
||||||
if cross_attend and not only_cross:
|
|
||||||
default_block = ('a', 'c', 'f')
|
|
||||||
elif cross_attend and only_cross:
|
|
||||||
default_block = ('c', 'f')
|
|
||||||
else:
|
|
||||||
default_block = ('a', 'f')
|
|
||||||
|
|
||||||
if macaron:
|
|
||||||
default_block = ('f',) + default_block
|
|
||||||
|
|
||||||
if exists(custom_layers):
|
|
||||||
layer_types = custom_layers
|
|
||||||
elif exists(par_ratio):
|
|
||||||
par_depth = depth * len(default_block)
|
|
||||||
assert 1 < par_ratio <= par_depth, 'par ratio out of range'
|
|
||||||
default_block = tuple(filter(not_equals('f'), default_block))
|
|
||||||
par_attn = par_depth // par_ratio
|
|
||||||
depth_cut = par_depth * 2 // 3 # 2 / 3 attention layer cutoff suggested by PAR paper
|
|
||||||
par_width = (depth_cut + depth_cut // par_attn) // par_attn
|
|
||||||
assert len(default_block) <= par_width, 'default block is too large for par_ratio'
|
|
||||||
par_block = default_block + ('f',) * (par_width - len(default_block))
|
|
||||||
par_head = par_block * par_attn
|
|
||||||
layer_types = par_head + ('f',) * (par_depth - len(par_head))
|
|
||||||
elif exists(sandwich_coef):
|
|
||||||
assert sandwich_coef > 0 and sandwich_coef <= depth, 'sandwich coefficient should be less than the depth'
|
|
||||||
layer_types = ('a',) * sandwich_coef + default_block * (depth - sandwich_coef) + ('f',) * sandwich_coef
|
|
||||||
else:
|
|
||||||
layer_types = default_block * depth
|
|
||||||
|
|
||||||
self.layer_types = layer_types
|
|
||||||
self.num_attn_layers = len(list(filter(equals('a'), layer_types)))
|
|
||||||
|
|
||||||
for layer_type in self.layer_types:
|
|
||||||
if layer_type == 'a':
|
|
||||||
layer = Attention(dim, heads=heads, causal=causal, **attn_kwargs)
|
|
||||||
elif layer_type == 'c':
|
|
||||||
layer = Attention(dim, heads=heads, **attn_kwargs)
|
|
||||||
elif layer_type == 'f':
|
|
||||||
layer = FeedForward(dim, **ff_kwargs)
|
|
||||||
layer = layer if not macaron else Scale(0.5, layer)
|
|
||||||
else:
|
|
||||||
raise Exception(f'invalid layer type {layer_type}')
|
|
||||||
|
|
||||||
if isinstance(layer, Attention) and exists(branch_fn):
|
|
||||||
layer = branch_fn(layer)
|
|
||||||
|
|
||||||
if gate_residual:
|
|
||||||
residual_fn = GRUGating(dim)
|
|
||||||
else:
|
|
||||||
residual_fn = Residual()
|
|
||||||
|
|
||||||
self.layers.append(nn.ModuleList([
|
|
||||||
norm_fn(),
|
|
||||||
layer,
|
|
||||||
residual_fn
|
|
||||||
]))
|
|
||||||
|
|
||||||
def forward(
|
|
||||||
self,
|
|
||||||
x,
|
|
||||||
context=None,
|
|
||||||
mask=None,
|
|
||||||
context_mask=None,
|
|
||||||
mems=None,
|
|
||||||
return_hiddens=False
|
|
||||||
):
|
|
||||||
hiddens = []
|
|
||||||
intermediates = []
|
|
||||||
prev_attn = None
|
|
||||||
prev_cross_attn = None
|
|
||||||
|
|
||||||
mems = mems.copy() if exists(mems) else [None] * self.num_attn_layers
|
|
||||||
|
|
||||||
for ind, (layer_type, (norm, block, residual_fn)) in enumerate(zip(self.layer_types, self.layers)):
|
|
||||||
is_last = ind == (len(self.layers) - 1)
|
|
||||||
|
|
||||||
if layer_type == 'a':
|
|
||||||
hiddens.append(x)
|
|
||||||
layer_mem = mems.pop(0)
|
|
||||||
|
|
||||||
residual = x
|
|
||||||
|
|
||||||
if self.pre_norm:
|
|
||||||
x = norm(x)
|
|
||||||
|
|
||||||
if layer_type == 'a':
|
|
||||||
out, inter = block(x, mask=mask, sinusoidal_emb=self.pia_pos_emb, rel_pos=self.rel_pos,
|
|
||||||
prev_attn=prev_attn, mem=layer_mem)
|
|
||||||
elif layer_type == 'c':
|
|
||||||
out, inter = block(x, context=context, mask=mask, context_mask=context_mask, prev_attn=prev_cross_attn)
|
|
||||||
elif layer_type == 'f':
|
|
||||||
out = block(x)
|
|
||||||
|
|
||||||
x = residual_fn(out, residual)
|
|
||||||
|
|
||||||
if layer_type in ('a', 'c'):
|
|
||||||
intermediates.append(inter)
|
|
||||||
|
|
||||||
if layer_type == 'a' and self.residual_attn:
|
|
||||||
prev_attn = inter.pre_softmax_attn
|
|
||||||
elif layer_type == 'c' and self.cross_residual_attn:
|
|
||||||
prev_cross_attn = inter.pre_softmax_attn
|
|
||||||
|
|
||||||
if not self.pre_norm and not is_last:
|
|
||||||
x = norm(x)
|
|
||||||
|
|
||||||
if return_hiddens:
|
|
||||||
intermediates = LayerIntermediates(
|
|
||||||
hiddens=hiddens,
|
|
||||||
attn_intermediates=intermediates
|
|
||||||
)
|
|
||||||
|
|
||||||
return x, intermediates
|
|
||||||
|
|
||||||
return x
|
|
||||||
|
|
||||||
|
|
||||||
class Encoder(AttentionLayers):
|
|
||||||
def __init__(self, **kwargs):
|
|
||||||
assert 'causal' not in kwargs, 'cannot set causality on encoder'
|
|
||||||
super().__init__(causal=False, **kwargs)
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class TransformerWrapper(nn.Module):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
*,
|
|
||||||
num_tokens,
|
|
||||||
max_seq_len,
|
|
||||||
attn_layers,
|
|
||||||
emb_dim=None,
|
|
||||||
max_mem_len=0.,
|
|
||||||
emb_dropout=0.,
|
|
||||||
num_memory_tokens=None,
|
|
||||||
tie_embedding=False,
|
|
||||||
use_pos_emb=True
|
|
||||||
):
|
|
||||||
super().__init__()
|
|
||||||
assert isinstance(attn_layers, AttentionLayers), 'attention layers must be one of Encoder or Decoder'
|
|
||||||
|
|
||||||
dim = attn_layers.dim
|
|
||||||
emb_dim = default(emb_dim, dim)
|
|
||||||
|
|
||||||
self.max_seq_len = max_seq_len
|
|
||||||
self.max_mem_len = max_mem_len
|
|
||||||
self.num_tokens = num_tokens
|
|
||||||
|
|
||||||
self.token_emb = nn.Embedding(num_tokens, emb_dim)
|
|
||||||
self.pos_emb = AbsolutePositionalEmbedding(emb_dim, max_seq_len) if (
|
|
||||||
use_pos_emb and not attn_layers.has_pos_emb) else always(0)
|
|
||||||
self.emb_dropout = nn.Dropout(emb_dropout)
|
|
||||||
|
|
||||||
self.project_emb = nn.Linear(emb_dim, dim) if emb_dim != dim else nn.Identity()
|
|
||||||
self.attn_layers = attn_layers
|
|
||||||
self.norm = nn.LayerNorm(dim)
|
|
||||||
|
|
||||||
self.init_()
|
|
||||||
|
|
||||||
self.to_logits = nn.Linear(dim, num_tokens) if not tie_embedding else lambda t: t @ self.token_emb.weight.t()
|
|
||||||
|
|
||||||
# memory tokens (like [cls]) from Memory Transformers paper
|
|
||||||
num_memory_tokens = default(num_memory_tokens, 0)
|
|
||||||
self.num_memory_tokens = num_memory_tokens
|
|
||||||
if num_memory_tokens > 0:
|
|
||||||
self.memory_tokens = nn.Parameter(torch.randn(num_memory_tokens, dim))
|
|
||||||
|
|
||||||
# let funnel encoder know number of memory tokens, if specified
|
|
||||||
if hasattr(attn_layers, 'num_memory_tokens'):
|
|
||||||
attn_layers.num_memory_tokens = num_memory_tokens
|
|
||||||
|
|
||||||
def init_(self):
|
|
||||||
nn.init.normal_(self.token_emb.weight, std=0.02)
|
|
||||||
|
|
||||||
def forward(
|
|
||||||
self,
|
|
||||||
x,
|
|
||||||
return_embeddings=False,
|
|
||||||
mask=None,
|
|
||||||
return_mems=False,
|
|
||||||
return_attn=False,
|
|
||||||
mems=None,
|
|
||||||
**kwargs
|
|
||||||
):
|
|
||||||
b, n, device, num_mem = *x.shape, x.device, self.num_memory_tokens
|
|
||||||
x = self.token_emb(x)
|
|
||||||
x += self.pos_emb(x)
|
|
||||||
x = self.emb_dropout(x)
|
|
||||||
|
|
||||||
x = self.project_emb(x)
|
|
||||||
|
|
||||||
if num_mem > 0:
|
|
||||||
mem = repeat(self.memory_tokens, 'n d -> b n d', b=b)
|
|
||||||
x = torch.cat((mem, x), dim=1)
|
|
||||||
|
|
||||||
# auto-handle masking after appending memory tokens
|
|
||||||
if exists(mask):
|
|
||||||
mask = F.pad(mask, (num_mem, 0), value=True)
|
|
||||||
|
|
||||||
x, intermediates = self.attn_layers(x, mask=mask, mems=mems, return_hiddens=True, **kwargs)
|
|
||||||
x = self.norm(x)
|
|
||||||
|
|
||||||
mem, x = x[:, :num_mem], x[:, num_mem:]
|
|
||||||
|
|
||||||
out = self.to_logits(x) if not return_embeddings else x
|
|
||||||
|
|
||||||
if return_mems:
|
|
||||||
hiddens = intermediates.hiddens
|
|
||||||
new_mems = list(map(lambda pair: torch.cat(pair, dim=-2), zip(mems, hiddens))) if exists(mems) else hiddens
|
|
||||||
new_mems = list(map(lambda t: t[..., -self.max_mem_len:, :].detach(), new_mems))
|
|
||||||
return out, new_mems
|
|
||||||
|
|
||||||
if return_attn:
|
|
||||||
attn_maps = list(map(lambda t: t.post_softmax_attn, intermediates.attn_intermediates))
|
|
||||||
return out, attn_maps
|
|
||||||
|
|
||||||
return out
|
|
||||||
|
|
|
@ -1,14 +1,8 @@
|
||||||
import importlib
|
import importlib
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
from torch import optim
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from collections import abc
|
|
||||||
from einops import rearrange
|
|
||||||
from functools import partial
|
|
||||||
|
|
||||||
import multiprocessing as mp
|
|
||||||
from threading import Thread
|
|
||||||
from queue import Queue
|
|
||||||
|
|
||||||
from inspect import isfunction
|
from inspect import isfunction
|
||||||
from PIL import Image, ImageDraw, ImageFont
|
from PIL import Image, ImageDraw, ImageFont
|
||||||
|
@ -45,7 +39,7 @@ def ismap(x):
|
||||||
|
|
||||||
|
|
||||||
def isimage(x):
|
def isimage(x):
|
||||||
if not isinstance(x, torch.Tensor):
|
if not isinstance(x,torch.Tensor):
|
||||||
return False
|
return False
|
||||||
return (len(x.shape) == 4) and (x.shape[1] == 3 or x.shape[1] == 1)
|
return (len(x.shape) == 4) and (x.shape[1] == 3 or x.shape[1] == 1)
|
||||||
|
|
||||||
|
@ -71,7 +65,7 @@ def mean_flat(tensor):
|
||||||
def count_params(model, verbose=False):
|
def count_params(model, verbose=False):
|
||||||
total_params = sum(p.numel() for p in model.parameters())
|
total_params = sum(p.numel() for p in model.parameters())
|
||||||
if verbose:
|
if verbose:
|
||||||
print(f"{model.__class__.__name__} has {total_params * 1.e-6:.2f} M params.")
|
print(f"{model.__class__.__name__} has {total_params*1.e-6:.2f} M params.")
|
||||||
return total_params
|
return total_params
|
||||||
|
|
||||||
|
|
||||||
|
@ -93,111 +87,111 @@ def get_obj_from_str(string, reload=False):
|
||||||
return getattr(importlib.import_module(module, package=None), cls)
|
return getattr(importlib.import_module(module, package=None), cls)
|
||||||
|
|
||||||
|
|
||||||
def _do_parallel_data_prefetch(func, Q, data, idx, idx_to_fn=False):
|
class AdamWwithEMAandWings(optim.Optimizer):
|
||||||
# create dummy dataset instance
|
# credit to https://gist.github.com/crowsonkb/65f7265353f403714fce3b2595e0b298
|
||||||
|
def __init__(self, params, lr=1.e-3, betas=(0.9, 0.999), eps=1.e-8, # TODO: check hyperparameters before using
|
||||||
|
weight_decay=1.e-2, amsgrad=False, ema_decay=0.9999, # ema decay to match previous code
|
||||||
|
ema_power=1., param_names=()):
|
||||||
|
"""AdamW that saves EMA versions of the parameters."""
|
||||||
|
if not 0.0 <= lr:
|
||||||
|
raise ValueError("Invalid learning rate: {}".format(lr))
|
||||||
|
if not 0.0 <= eps:
|
||||||
|
raise ValueError("Invalid epsilon value: {}".format(eps))
|
||||||
|
if not 0.0 <= betas[0] < 1.0:
|
||||||
|
raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0]))
|
||||||
|
if not 0.0 <= betas[1] < 1.0:
|
||||||
|
raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1]))
|
||||||
|
if not 0.0 <= weight_decay:
|
||||||
|
raise ValueError("Invalid weight_decay value: {}".format(weight_decay))
|
||||||
|
if not 0.0 <= ema_decay <= 1.0:
|
||||||
|
raise ValueError("Invalid ema_decay value: {}".format(ema_decay))
|
||||||
|
defaults = dict(lr=lr, betas=betas, eps=eps,
|
||||||
|
weight_decay=weight_decay, amsgrad=amsgrad, ema_decay=ema_decay,
|
||||||
|
ema_power=ema_power, param_names=param_names)
|
||||||
|
super().__init__(params, defaults)
|
||||||
|
|
||||||
# run prefetching
|
def __setstate__(self, state):
|
||||||
if idx_to_fn:
|
super().__setstate__(state)
|
||||||
res = func(data, worker_id=idx)
|
for group in self.param_groups:
|
||||||
else:
|
group.setdefault('amsgrad', False)
|
||||||
res = func(data)
|
|
||||||
Q.put([idx, res])
|
|
||||||
Q.put("Done")
|
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def step(self, closure=None):
|
||||||
|
"""Performs a single optimization step.
|
||||||
|
Args:
|
||||||
|
closure (callable, optional): A closure that reevaluates the model
|
||||||
|
and returns the loss.
|
||||||
|
"""
|
||||||
|
loss = None
|
||||||
|
if closure is not None:
|
||||||
|
with torch.enable_grad():
|
||||||
|
loss = closure()
|
||||||
|
|
||||||
def parallel_data_prefetch(
|
for group in self.param_groups:
|
||||||
func: callable, data, n_proc, target_data_type="ndarray", cpu_intensive=True, use_worker_id=False
|
params_with_grad = []
|
||||||
):
|
grads = []
|
||||||
# if target_data_type not in ["ndarray", "list"]:
|
exp_avgs = []
|
||||||
# raise ValueError(
|
exp_avg_sqs = []
|
||||||
# "Data, which is passed to parallel_data_prefetch has to be either of type list or ndarray."
|
ema_params_with_grad = []
|
||||||
# )
|
state_sums = []
|
||||||
if isinstance(data, np.ndarray) and target_data_type == "list":
|
max_exp_avg_sqs = []
|
||||||
raise ValueError("list expected but function got ndarray.")
|
state_steps = []
|
||||||
elif isinstance(data, abc.Iterable):
|
amsgrad = group['amsgrad']
|
||||||
if isinstance(data, dict):
|
beta1, beta2 = group['betas']
|
||||||
print(
|
ema_decay = group['ema_decay']
|
||||||
f'WARNING:"data" argument passed to parallel_data_prefetch is a dict: Using only its values and disregarding keys.'
|
ema_power = group['ema_power']
|
||||||
)
|
|
||||||
data = list(data.values())
|
|
||||||
if target_data_type == "ndarray":
|
|
||||||
data = np.asarray(data)
|
|
||||||
else:
|
|
||||||
data = list(data)
|
|
||||||
else:
|
|
||||||
raise TypeError(
|
|
||||||
f"The data, that shall be processed parallel has to be either an np.ndarray or an Iterable, but is actually {type(data)}."
|
|
||||||
)
|
|
||||||
|
|
||||||
if cpu_intensive:
|
for p in group['params']:
|
||||||
Q = mp.Queue(1000)
|
if p.grad is None:
|
||||||
proc = mp.Process
|
continue
|
||||||
else:
|
params_with_grad.append(p)
|
||||||
Q = Queue(1000)
|
if p.grad.is_sparse:
|
||||||
proc = Thread
|
raise RuntimeError('AdamW does not support sparse gradients')
|
||||||
# spawn processes
|
grads.append(p.grad)
|
||||||
if target_data_type == "ndarray":
|
|
||||||
arguments = [
|
|
||||||
[func, Q, part, i, use_worker_id]
|
|
||||||
for i, part in enumerate(np.array_split(data, n_proc))
|
|
||||||
]
|
|
||||||
else:
|
|
||||||
step = (
|
|
||||||
int(len(data) / n_proc + 1)
|
|
||||||
if len(data) % n_proc != 0
|
|
||||||
else int(len(data) / n_proc)
|
|
||||||
)
|
|
||||||
arguments = [
|
|
||||||
[func, Q, part, i, use_worker_id]
|
|
||||||
for i, part in enumerate(
|
|
||||||
[data[i: i + step] for i in range(0, len(data), step)]
|
|
||||||
)
|
|
||||||
]
|
|
||||||
processes = []
|
|
||||||
for i in range(n_proc):
|
|
||||||
p = proc(target=_do_parallel_data_prefetch, args=arguments[i])
|
|
||||||
processes += [p]
|
|
||||||
|
|
||||||
# start processes
|
state = self.state[p]
|
||||||
print(f"Start prefetching...")
|
|
||||||
import time
|
|
||||||
|
|
||||||
start = time.time()
|
# State initialization
|
||||||
gather_res = [[] for _ in range(n_proc)]
|
if len(state) == 0:
|
||||||
try:
|
state['step'] = 0
|
||||||
for p in processes:
|
# Exponential moving average of gradient values
|
||||||
p.start()
|
state['exp_avg'] = torch.zeros_like(p, memory_format=torch.preserve_format)
|
||||||
|
# Exponential moving average of squared gradient values
|
||||||
|
state['exp_avg_sq'] = torch.zeros_like(p, memory_format=torch.preserve_format)
|
||||||
|
if amsgrad:
|
||||||
|
# Maintains max of all exp. moving avg. of sq. grad. values
|
||||||
|
state['max_exp_avg_sq'] = torch.zeros_like(p, memory_format=torch.preserve_format)
|
||||||
|
# Exponential moving average of parameter values
|
||||||
|
state['param_exp_avg'] = p.detach().float().clone()
|
||||||
|
|
||||||
k = 0
|
exp_avgs.append(state['exp_avg'])
|
||||||
while k < n_proc:
|
exp_avg_sqs.append(state['exp_avg_sq'])
|
||||||
# get result
|
ema_params_with_grad.append(state['param_exp_avg'])
|
||||||
res = Q.get()
|
|
||||||
if res == "Done":
|
|
||||||
k += 1
|
|
||||||
else:
|
|
||||||
gather_res[res[0]] = res[1]
|
|
||||||
|
|
||||||
except Exception as e:
|
if amsgrad:
|
||||||
print("Exception: ", e)
|
max_exp_avg_sqs.append(state['max_exp_avg_sq'])
|
||||||
for p in processes:
|
|
||||||
p.terminate()
|
|
||||||
|
|
||||||
raise e
|
# update the steps for each param group update
|
||||||
finally:
|
state['step'] += 1
|
||||||
for p in processes:
|
# record the step after step update
|
||||||
p.join()
|
state_steps.append(state['step'])
|
||||||
print(f"Prefetching complete. [{time.time() - start} sec.]")
|
|
||||||
|
|
||||||
if target_data_type == 'ndarray':
|
optim._functional.adamw(params_with_grad,
|
||||||
if not isinstance(gather_res[0], np.ndarray):
|
grads,
|
||||||
return np.concatenate([np.asarray(r) for r in gather_res], axis=0)
|
exp_avgs,
|
||||||
|
exp_avg_sqs,
|
||||||
|
max_exp_avg_sqs,
|
||||||
|
state_steps,
|
||||||
|
amsgrad=amsgrad,
|
||||||
|
beta1=beta1,
|
||||||
|
beta2=beta2,
|
||||||
|
lr=group['lr'],
|
||||||
|
weight_decay=group['weight_decay'],
|
||||||
|
eps=group['eps'],
|
||||||
|
maximize=False)
|
||||||
|
|
||||||
# order outputs
|
cur_ema_decay = min(ema_decay, 1 - state['step'] ** -ema_power)
|
||||||
return np.concatenate(gather_res, axis=0)
|
for param, ema_param in zip(params_with_grad, ema_params_with_grad):
|
||||||
elif target_data_type == 'list':
|
ema_param.mul_(cur_ema_decay).add_(param.float(), alpha=1 - cur_ema_decay)
|
||||||
out = []
|
|
||||||
for r in gather_res:
|
return loss
|
||||||
out.extend(r)
|
|
||||||
return out
|
|
||||||
else:
|
|
||||||
return gather_res
|
|
|
@ -1,54 +1,48 @@
|
||||||
import argparse, os, sys, datetime, glob, importlib, csv
|
import argparse
|
||||||
import numpy as np
|
import csv
|
||||||
|
import datetime
|
||||||
|
import glob
|
||||||
|
import importlib
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
import time
|
import time
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
import torchvision
|
import torchvision
|
||||||
import lightning.pytorch as pl
|
|
||||||
|
|
||||||
from packaging import version
|
try:
|
||||||
from omegaconf import OmegaConf
|
import lightning.pytorch as pl
|
||||||
from torch.utils.data import random_split, DataLoader, Dataset, Subset
|
except:
|
||||||
|
import pytorch_lightning as pl
|
||||||
|
|
||||||
from functools import partial
|
from functools import partial
|
||||||
|
|
||||||
|
from omegaconf import OmegaConf
|
||||||
|
from packaging import version
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
# from lightning.pytorch.strategies.colossalai import ColossalAIStrategy
|
|
||||||
# from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR
|
|
||||||
from colossalai.nn.optimizer import HybridAdam
|
|
||||||
from prefetch_generator import BackgroundGenerator
|
from prefetch_generator import BackgroundGenerator
|
||||||
|
from torch.utils.data import DataLoader, Dataset, Subset, random_split
|
||||||
|
|
||||||
from lightning.pytorch import seed_everything
|
try:
|
||||||
from lightning.pytorch.trainer import Trainer
|
from lightning.pytorch import seed_everything
|
||||||
from lightning.pytorch.callbacks import ModelCheckpoint, Callback, LearningRateMonitor
|
from lightning.pytorch.callbacks import Callback, LearningRateMonitor, ModelCheckpoint
|
||||||
from lightning.pytorch.utilities.rank_zero import rank_zero_only
|
from lightning.pytorch.trainer import Trainer
|
||||||
from lightning.pytorch.utilities import rank_zero_info
|
from lightning.pytorch.utilities import rank_zero_info, rank_zero_only
|
||||||
from diffusers.models.unet_2d import UNet2DModel
|
LIGHTNING_PACK_NAME = "lightning.pytorch."
|
||||||
|
except:
|
||||||
from clip.model import Bottleneck
|
from pytorch_lightning import seed_everything
|
||||||
from transformers.models.clip.modeling_clip import CLIPTextTransformer
|
from pytorch_lightning.callbacks import Callback, LearningRateMonitor, ModelCheckpoint
|
||||||
|
from pytorch_lightning.trainer import Trainer
|
||||||
|
from pytorch_lightning.utilities import rank_zero_info, rank_zero_only
|
||||||
|
LIGHTNING_PACK_NAME = "pytorch_lightning."
|
||||||
|
|
||||||
from ldm.data.base import Txt2ImgIterableBaseDataset
|
from ldm.data.base import Txt2ImgIterableBaseDataset
|
||||||
from ldm.util import instantiate_from_config
|
from ldm.util import instantiate_from_config
|
||||||
import clip
|
|
||||||
from einops import rearrange, repeat
|
|
||||||
from transformers import CLIPTokenizer, CLIPTextModel
|
|
||||||
import kornia
|
|
||||||
|
|
||||||
from ldm.modules.x_transformer import *
|
# from ldm.modules.attention import enable_flash_attentions
|
||||||
from ldm.modules.encoders.modules import *
|
|
||||||
from taming.modules.diffusionmodules.model import ResnetBlock
|
|
||||||
from taming.modules.transformer.mingpt import *
|
|
||||||
from taming.modules.transformer.permuter import *
|
|
||||||
|
|
||||||
|
|
||||||
from ldm.modules.ema import LitEma
|
|
||||||
from ldm.modules.distributions.distributions import normal_kl, DiagonalGaussianDistribution
|
|
||||||
from ldm.models.autoencoder import AutoencoderKL
|
|
||||||
from ldm.models.autoencoder import *
|
|
||||||
from ldm.models.diffusion.ddim import *
|
|
||||||
from ldm.modules.diffusionmodules.openaimodel import *
|
|
||||||
from ldm.modules.diffusionmodules.model import *
|
|
||||||
from ldm.modules.diffusionmodules.model import Decoder, Encoder, Up_module, Down_module, Mid_module, temb_module
|
|
||||||
from ldm.modules.attention import enable_flash_attention
|
|
||||||
|
|
||||||
class DataLoaderX(DataLoader):
|
class DataLoaderX(DataLoader):
|
||||||
|
|
||||||
def __iter__(self):
|
def __iter__(self):
|
||||||
|
@ -56,6 +50,7 @@ class DataLoaderX(DataLoader):
|
||||||
|
|
||||||
|
|
||||||
def get_parser(**parser_kwargs):
|
def get_parser(**parser_kwargs):
|
||||||
|
|
||||||
def str2bool(v):
|
def str2bool(v):
|
||||||
if isinstance(v, bool):
|
if isinstance(v, bool):
|
||||||
return v
|
return v
|
||||||
|
@ -91,7 +86,7 @@ def get_parser(**parser_kwargs):
|
||||||
nargs="*",
|
nargs="*",
|
||||||
metavar="base_config.yaml",
|
metavar="base_config.yaml",
|
||||||
help="paths to base configs. Loaded from left-to-right. "
|
help="paths to base configs. Loaded from left-to-right. "
|
||||||
"Parameters can be overwritten or added with command-line options of the form `--key value`.",
|
"Parameters can be overwritten or added with command-line options of the form `--key value`.",
|
||||||
default=list(),
|
default=list(),
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
|
@ -111,11 +106,7 @@ def get_parser(**parser_kwargs):
|
||||||
nargs="?",
|
nargs="?",
|
||||||
help="disable test",
|
help="disable test",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument("-p", "--project", help="name of new or path to existing project")
|
||||||
"-p",
|
|
||||||
"--project",
|
|
||||||
help="name of new or path to existing project"
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"-d",
|
"-d",
|
||||||
"--debug",
|
"--debug",
|
||||||
|
@ -210,8 +201,17 @@ def worker_init_fn(_):
|
||||||
|
|
||||||
|
|
||||||
class DataModuleFromConfig(pl.LightningDataModule):
|
class DataModuleFromConfig(pl.LightningDataModule):
|
||||||
def __init__(self, batch_size, train=None, validation=None, test=None, predict=None,
|
|
||||||
wrap=False, num_workers=None, shuffle_test_loader=False, use_worker_init_fn=False,
|
def __init__(self,
|
||||||
|
batch_size,
|
||||||
|
train=None,
|
||||||
|
validation=None,
|
||||||
|
test=None,
|
||||||
|
predict=None,
|
||||||
|
wrap=False,
|
||||||
|
num_workers=None,
|
||||||
|
shuffle_test_loader=False,
|
||||||
|
use_worker_init_fn=False,
|
||||||
shuffle_val_dataloader=False):
|
shuffle_val_dataloader=False):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.batch_size = batch_size
|
self.batch_size = batch_size
|
||||||
|
@ -237,9 +237,7 @@ class DataModuleFromConfig(pl.LightningDataModule):
|
||||||
instantiate_from_config(data_cfg)
|
instantiate_from_config(data_cfg)
|
||||||
|
|
||||||
def setup(self, stage=None):
|
def setup(self, stage=None):
|
||||||
self.datasets = dict(
|
self.datasets = dict((k, instantiate_from_config(self.dataset_configs[k])) for k in self.dataset_configs)
|
||||||
(k, instantiate_from_config(self.dataset_configs[k]))
|
|
||||||
for k in self.dataset_configs)
|
|
||||||
if self.wrap:
|
if self.wrap:
|
||||||
for k in self.datasets:
|
for k in self.datasets:
|
||||||
self.datasets[k] = WrappedDataset(self.datasets[k])
|
self.datasets[k] = WrappedDataset(self.datasets[k])
|
||||||
|
@ -250,9 +248,11 @@ class DataModuleFromConfig(pl.LightningDataModule):
|
||||||
init_fn = worker_init_fn
|
init_fn = worker_init_fn
|
||||||
else:
|
else:
|
||||||
init_fn = None
|
init_fn = None
|
||||||
return DataLoaderX(self.datasets["train"], batch_size=self.batch_size,
|
return DataLoaderX(self.datasets["train"],
|
||||||
num_workers=self.num_workers, shuffle=False if is_iterable_dataset else True,
|
batch_size=self.batch_size,
|
||||||
worker_init_fn=init_fn)
|
num_workers=self.num_workers,
|
||||||
|
shuffle=False if is_iterable_dataset else True,
|
||||||
|
worker_init_fn=init_fn)
|
||||||
|
|
||||||
def _val_dataloader(self, shuffle=False):
|
def _val_dataloader(self, shuffle=False):
|
||||||
if isinstance(self.datasets['validation'], Txt2ImgIterableBaseDataset) or self.use_worker_init_fn:
|
if isinstance(self.datasets['validation'], Txt2ImgIterableBaseDataset) or self.use_worker_init_fn:
|
||||||
|
@ -260,10 +260,10 @@ class DataModuleFromConfig(pl.LightningDataModule):
|
||||||
else:
|
else:
|
||||||
init_fn = None
|
init_fn = None
|
||||||
return DataLoaderX(self.datasets["validation"],
|
return DataLoaderX(self.datasets["validation"],
|
||||||
batch_size=self.batch_size,
|
batch_size=self.batch_size,
|
||||||
num_workers=self.num_workers,
|
num_workers=self.num_workers,
|
||||||
worker_init_fn=init_fn,
|
worker_init_fn=init_fn,
|
||||||
shuffle=shuffle)
|
shuffle=shuffle)
|
||||||
|
|
||||||
def _test_dataloader(self, shuffle=False):
|
def _test_dataloader(self, shuffle=False):
|
||||||
is_iterable_dataset = isinstance(self.datasets['train'], Txt2ImgIterableBaseDataset)
|
is_iterable_dataset = isinstance(self.datasets['train'], Txt2ImgIterableBaseDataset)
|
||||||
|
@ -275,19 +275,25 @@ class DataModuleFromConfig(pl.LightningDataModule):
|
||||||
# do not shuffle dataloader for iterable dataset
|
# do not shuffle dataloader for iterable dataset
|
||||||
shuffle = shuffle and (not is_iterable_dataset)
|
shuffle = shuffle and (not is_iterable_dataset)
|
||||||
|
|
||||||
return DataLoaderX(self.datasets["test"], batch_size=self.batch_size,
|
return DataLoaderX(self.datasets["test"],
|
||||||
num_workers=self.num_workers, worker_init_fn=init_fn, shuffle=shuffle)
|
batch_size=self.batch_size,
|
||||||
|
num_workers=self.num_workers,
|
||||||
|
worker_init_fn=init_fn,
|
||||||
|
shuffle=shuffle)
|
||||||
|
|
||||||
def _predict_dataloader(self, shuffle=False):
|
def _predict_dataloader(self, shuffle=False):
|
||||||
if isinstance(self.datasets['predict'], Txt2ImgIterableBaseDataset) or self.use_worker_init_fn:
|
if isinstance(self.datasets['predict'], Txt2ImgIterableBaseDataset) or self.use_worker_init_fn:
|
||||||
init_fn = worker_init_fn
|
init_fn = worker_init_fn
|
||||||
else:
|
else:
|
||||||
init_fn = None
|
init_fn = None
|
||||||
return DataLoaderX(self.datasets["predict"], batch_size=self.batch_size,
|
return DataLoaderX(self.datasets["predict"],
|
||||||
num_workers=self.num_workers, worker_init_fn=init_fn)
|
batch_size=self.batch_size,
|
||||||
|
num_workers=self.num_workers,
|
||||||
|
worker_init_fn=init_fn)
|
||||||
|
|
||||||
|
|
||||||
class SetupCallback(Callback):
|
class SetupCallback(Callback):
|
||||||
|
|
||||||
def __init__(self, resume, now, logdir, ckptdir, cfgdir, config, lightning_config):
|
def __init__(self, resume, now, logdir, ckptdir, cfgdir, config, lightning_config):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.resume = resume
|
self.resume = resume
|
||||||
|
@ -317,8 +323,7 @@ class SetupCallback(Callback):
|
||||||
os.makedirs(os.path.join(self.ckptdir, 'trainstep_checkpoints'), exist_ok=True)
|
os.makedirs(os.path.join(self.ckptdir, 'trainstep_checkpoints'), exist_ok=True)
|
||||||
print("Project config")
|
print("Project config")
|
||||||
print(OmegaConf.to_yaml(self.config))
|
print(OmegaConf.to_yaml(self.config))
|
||||||
OmegaConf.save(self.config,
|
OmegaConf.save(self.config, os.path.join(self.cfgdir, "{}-project.yaml".format(self.now)))
|
||||||
os.path.join(self.cfgdir, "{}-project.yaml".format(self.now)))
|
|
||||||
|
|
||||||
print("Lightning config")
|
print("Lightning config")
|
||||||
print(OmegaConf.to_yaml(self.lightning_config))
|
print(OmegaConf.to_yaml(self.lightning_config))
|
||||||
|
@ -338,8 +343,16 @@ class SetupCallback(Callback):
|
||||||
|
|
||||||
|
|
||||||
class ImageLogger(Callback):
|
class ImageLogger(Callback):
|
||||||
def __init__(self, batch_frequency, max_images, clamp=True, increase_log_steps=True,
|
|
||||||
rescale=True, disabled=False, log_on_batch_idx=False, log_first_step=False,
|
def __init__(self,
|
||||||
|
batch_frequency,
|
||||||
|
max_images,
|
||||||
|
clamp=True,
|
||||||
|
increase_log_steps=True,
|
||||||
|
rescale=True,
|
||||||
|
disabled=False,
|
||||||
|
log_on_batch_idx=False,
|
||||||
|
log_first_step=False,
|
||||||
log_images_kwargs=None):
|
log_images_kwargs=None):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.rescale = rescale
|
self.rescale = rescale
|
||||||
|
@ -348,7 +361,7 @@ class ImageLogger(Callback):
|
||||||
self.logger_log_images = {
|
self.logger_log_images = {
|
||||||
pl.loggers.CSVLogger: self._testtube,
|
pl.loggers.CSVLogger: self._testtube,
|
||||||
}
|
}
|
||||||
self.log_steps = [2 ** n for n in range(int(np.log2(self.batch_freq)) + 1)]
|
self.log_steps = [2**n for n in range(int(np.log2(self.batch_freq)) + 1)]
|
||||||
if not increase_log_steps:
|
if not increase_log_steps:
|
||||||
self.log_steps = [self.batch_freq]
|
self.log_steps = [self.batch_freq]
|
||||||
self.clamp = clamp
|
self.clamp = clamp
|
||||||
|
@ -361,39 +374,30 @@ class ImageLogger(Callback):
|
||||||
def _testtube(self, pl_module, images, batch_idx, split):
|
def _testtube(self, pl_module, images, batch_idx, split):
|
||||||
for k in images:
|
for k in images:
|
||||||
grid = torchvision.utils.make_grid(images[k])
|
grid = torchvision.utils.make_grid(images[k])
|
||||||
grid = (grid + 1.0) / 2.0 # -1,1 -> 0,1; c,h,w
|
grid = (grid + 1.0) / 2.0 # -1,1 -> 0,1; c,h,w
|
||||||
|
|
||||||
tag = f"{split}/{k}"
|
tag = f"{split}/{k}"
|
||||||
pl_module.logger.experiment.add_image(
|
pl_module.logger.experiment.add_image(tag, grid, global_step=pl_module.global_step)
|
||||||
tag, grid,
|
|
||||||
global_step=pl_module.global_step)
|
|
||||||
|
|
||||||
@rank_zero_only
|
@rank_zero_only
|
||||||
def log_local(self, save_dir, split, images,
|
def log_local(self, save_dir, split, images, global_step, current_epoch, batch_idx):
|
||||||
global_step, current_epoch, batch_idx):
|
|
||||||
root = os.path.join(save_dir, "images", split)
|
root = os.path.join(save_dir, "images", split)
|
||||||
for k in images:
|
for k in images:
|
||||||
grid = torchvision.utils.make_grid(images[k], nrow=4)
|
grid = torchvision.utils.make_grid(images[k], nrow=4)
|
||||||
if self.rescale:
|
if self.rescale:
|
||||||
grid = (grid + 1.0) / 2.0 # -1,1 -> 0,1; c,h,w
|
grid = (grid + 1.0) / 2.0 # -1,1 -> 0,1; c,h,w
|
||||||
grid = grid.transpose(0, 1).transpose(1, 2).squeeze(-1)
|
grid = grid.transpose(0, 1).transpose(1, 2).squeeze(-1)
|
||||||
grid = grid.numpy()
|
grid = grid.numpy()
|
||||||
grid = (grid * 255).astype(np.uint8)
|
grid = (grid * 255).astype(np.uint8)
|
||||||
filename = "{}_gs-{:06}_e-{:06}_b-{:06}.png".format(
|
filename = "{}_gs-{:06}_e-{:06}_b-{:06}.png".format(k, global_step, current_epoch, batch_idx)
|
||||||
k,
|
|
||||||
global_step,
|
|
||||||
current_epoch,
|
|
||||||
batch_idx)
|
|
||||||
path = os.path.join(root, filename)
|
path = os.path.join(root, filename)
|
||||||
os.makedirs(os.path.split(path)[0], exist_ok=True)
|
os.makedirs(os.path.split(path)[0], exist_ok=True)
|
||||||
Image.fromarray(grid).save(path)
|
Image.fromarray(grid).save(path)
|
||||||
|
|
||||||
def log_img(self, pl_module, batch, batch_idx, split="train"):
|
def log_img(self, pl_module, batch, batch_idx, split="train"):
|
||||||
check_idx = batch_idx if self.log_on_batch_idx else pl_module.global_step
|
check_idx = batch_idx if self.log_on_batch_idx else pl_module.global_step
|
||||||
if (self.check_frequency(check_idx) and # batch_idx % self.batch_freq == 0
|
if (self.check_frequency(check_idx) and # batch_idx % self.batch_freq == 0
|
||||||
hasattr(pl_module, "log_images") and
|
hasattr(pl_module, "log_images") and callable(pl_module.log_images) and self.max_images > 0):
|
||||||
callable(pl_module.log_images) and
|
|
||||||
self.max_images > 0):
|
|
||||||
logger = type(pl_module.logger)
|
logger = type(pl_module.logger)
|
||||||
|
|
||||||
is_train = pl_module.training
|
is_train = pl_module.training
|
||||||
|
@ -411,8 +415,8 @@ class ImageLogger(Callback):
|
||||||
if self.clamp:
|
if self.clamp:
|
||||||
images[k] = torch.clamp(images[k], -1., 1.)
|
images[k] = torch.clamp(images[k], -1., 1.)
|
||||||
|
|
||||||
self.log_local(pl_module.logger.save_dir, split, images,
|
self.log_local(pl_module.logger.save_dir, split, images, pl_module.global_step, pl_module.current_epoch,
|
||||||
pl_module.global_step, pl_module.current_epoch, batch_idx)
|
batch_idx)
|
||||||
|
|
||||||
logger_log_images = self.logger_log_images.get(logger, lambda *args, **kwargs: None)
|
logger_log_images = self.logger_log_images.get(logger, lambda *args, **kwargs: None)
|
||||||
logger_log_images(pl_module, images, pl_module.global_step, split)
|
logger_log_images(pl_module, images, pl_module.global_step, split)
|
||||||
|
@ -421,8 +425,8 @@ class ImageLogger(Callback):
|
||||||
pl_module.train()
|
pl_module.train()
|
||||||
|
|
||||||
def check_frequency(self, check_idx):
|
def check_frequency(self, check_idx):
|
||||||
if ((check_idx % self.batch_freq) == 0 or (check_idx in self.log_steps)) and (
|
if ((check_idx % self.batch_freq) == 0 or
|
||||||
check_idx > 0 or self.log_first_step):
|
(check_idx in self.log_steps)) and (check_idx > 0 or self.log_first_step):
|
||||||
try:
|
try:
|
||||||
self.log_steps.pop(0)
|
self.log_steps.pop(0)
|
||||||
except IndexError as e:
|
except IndexError as e:
|
||||||
|
@ -461,7 +465,7 @@ class CUDACallback(Callback):
|
||||||
|
|
||||||
def on_train_epoch_end(self, trainer, pl_module):
|
def on_train_epoch_end(self, trainer, pl_module):
|
||||||
torch.cuda.synchronize(trainer.strategy.root_device.index)
|
torch.cuda.synchronize(trainer.strategy.root_device.index)
|
||||||
max_memory = torch.cuda.max_memory_allocated(trainer.strategy.root_device.index) / 2 ** 20
|
max_memory = torch.cuda.max_memory_allocated(trainer.strategy.root_device.index) / 2**20
|
||||||
epoch_time = time.time() - self.start_time
|
epoch_time = time.time() - self.start_time
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
@ -528,13 +532,9 @@ if __name__ == "__main__":
|
||||||
|
|
||||||
opt, unknown = parser.parse_known_args()
|
opt, unknown = parser.parse_known_args()
|
||||||
if opt.name and opt.resume:
|
if opt.name and opt.resume:
|
||||||
raise ValueError(
|
raise ValueError("-n/--name and -r/--resume cannot be specified both."
|
||||||
"-n/--name and -r/--resume cannot be specified both."
|
"If you want to resume training in a new log folder, "
|
||||||
"If you want to resume training in a new log folder, "
|
"use -n/--name in combination with --resume_from_checkpoint")
|
||||||
"use -n/--name in combination with --resume_from_checkpoint"
|
|
||||||
)
|
|
||||||
if opt.flash:
|
|
||||||
enable_flash_attention()
|
|
||||||
if opt.resume:
|
if opt.resume:
|
||||||
if not os.path.exists(opt.resume):
|
if not os.path.exists(opt.resume):
|
||||||
raise ValueError("Cannot find {}".format(opt.resume))
|
raise ValueError("Cannot find {}".format(opt.resume))
|
||||||
|
@ -578,7 +578,7 @@ if __name__ == "__main__":
|
||||||
lightning_config = config.pop("lightning", OmegaConf.create())
|
lightning_config = config.pop("lightning", OmegaConf.create())
|
||||||
# merge trainer cli with config
|
# merge trainer cli with config
|
||||||
trainer_config = lightning_config.get("trainer", OmegaConf.create())
|
trainer_config = lightning_config.get("trainer", OmegaConf.create())
|
||||||
|
|
||||||
for k in nondefault_trainer_args(opt):
|
for k in nondefault_trainer_args(opt):
|
||||||
trainer_config[k] = getattr(opt, k)
|
trainer_config[k] = getattr(opt, k)
|
||||||
|
|
||||||
|
@ -601,7 +601,7 @@ if __name__ == "__main__":
|
||||||
else:
|
else:
|
||||||
config.model["params"].update({"use_fp16": False})
|
config.model["params"].update({"use_fp16": False})
|
||||||
print("Using FP16 = {}".format(config.model["params"]["use_fp16"]))
|
print("Using FP16 = {}".format(config.model["params"]["use_fp16"]))
|
||||||
|
|
||||||
model = instantiate_from_config(config.model)
|
model = instantiate_from_config(config.model)
|
||||||
# trainer and callbacks
|
# trainer and callbacks
|
||||||
trainer_kwargs = dict()
|
trainer_kwargs = dict()
|
||||||
|
@ -610,7 +610,7 @@ if __name__ == "__main__":
|
||||||
# default logger configs
|
# default logger configs
|
||||||
default_logger_cfgs = {
|
default_logger_cfgs = {
|
||||||
"wandb": {
|
"wandb": {
|
||||||
"target": "lightning.pytorch.loggers.WandbLogger",
|
"target": LIGHTNING_PACK_NAME + "loggers.WandbLogger",
|
||||||
"params": {
|
"params": {
|
||||||
"name": nowname,
|
"name": nowname,
|
||||||
"save_dir": logdir,
|
"save_dir": logdir,
|
||||||
|
@ -618,9 +618,9 @@ if __name__ == "__main__":
|
||||||
"id": nowname,
|
"id": nowname,
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"tensorboard":{
|
"tensorboard": {
|
||||||
"target": "lightning.pytorch.loggers.TensorBoardLogger",
|
"target": LIGHTNING_PACK_NAME + "loggers.TensorBoardLogger",
|
||||||
"params":{
|
"params": {
|
||||||
"save_dir": logdir,
|
"save_dir": logdir,
|
||||||
"name": "diff_tb",
|
"name": "diff_tb",
|
||||||
"log_graph": True
|
"log_graph": True
|
||||||
|
@ -640,9 +640,10 @@ if __name__ == "__main__":
|
||||||
if "strategy" in trainer_config:
|
if "strategy" in trainer_config:
|
||||||
strategy_cfg = trainer_config["strategy"]
|
strategy_cfg = trainer_config["strategy"]
|
||||||
print("Using strategy: {}".format(strategy_cfg["target"]))
|
print("Using strategy: {}".format(strategy_cfg["target"]))
|
||||||
|
strategy_cfg["target"] = LIGHTNING_PACK_NAME + strategy_cfg["target"]
|
||||||
else:
|
else:
|
||||||
strategy_cfg = {
|
strategy_cfg = {
|
||||||
"target": "lightning.pytorch.strategies.DDPStrategy",
|
"target": LIGHTNING_PACK_NAME + "strategies.DDPStrategy",
|
||||||
"params": {
|
"params": {
|
||||||
"find_unused_parameters": False
|
"find_unused_parameters": False
|
||||||
}
|
}
|
||||||
|
@ -654,7 +655,7 @@ if __name__ == "__main__":
|
||||||
# modelcheckpoint - use TrainResult/EvalResult(checkpoint_on=metric) to
|
# modelcheckpoint - use TrainResult/EvalResult(checkpoint_on=metric) to
|
||||||
# specify which metric is used to determine best models
|
# specify which metric is used to determine best models
|
||||||
default_modelckpt_cfg = {
|
default_modelckpt_cfg = {
|
||||||
"target": "lightning.pytorch.callbacks.ModelCheckpoint",
|
"target": LIGHTNING_PACK_NAME + "callbacks.ModelCheckpoint",
|
||||||
"params": {
|
"params": {
|
||||||
"dirpath": ckptdir,
|
"dirpath": ckptdir,
|
||||||
"filename": "{epoch:06}",
|
"filename": "{epoch:06}",
|
||||||
|
@ -670,7 +671,7 @@ if __name__ == "__main__":
|
||||||
if "modelcheckpoint" in lightning_config:
|
if "modelcheckpoint" in lightning_config:
|
||||||
modelckpt_cfg = lightning_config.modelcheckpoint
|
modelckpt_cfg = lightning_config.modelcheckpoint
|
||||||
else:
|
else:
|
||||||
modelckpt_cfg = OmegaConf.create()
|
modelckpt_cfg = OmegaConf.create()
|
||||||
modelckpt_cfg = OmegaConf.merge(default_modelckpt_cfg, modelckpt_cfg)
|
modelckpt_cfg = OmegaConf.merge(default_modelckpt_cfg, modelckpt_cfg)
|
||||||
print(f"Merged modelckpt-cfg: \n{modelckpt_cfg}")
|
print(f"Merged modelckpt-cfg: \n{modelckpt_cfg}")
|
||||||
if version.parse(pl.__version__) < version.parse('1.4.0'):
|
if version.parse(pl.__version__) < version.parse('1.4.0'):
|
||||||
|
@ -702,7 +703,7 @@ if __name__ == "__main__":
|
||||||
"target": "main.LearningRateMonitor",
|
"target": "main.LearningRateMonitor",
|
||||||
"params": {
|
"params": {
|
||||||
"logging_interval": "step",
|
"logging_interval": "step",
|
||||||
# "log_momentum": True
|
# "log_momentum": True
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"cuda_callback": {
|
"cuda_callback": {
|
||||||
|
@ -721,17 +722,17 @@ if __name__ == "__main__":
|
||||||
print(
|
print(
|
||||||
'Caution: Saving checkpoints every n train steps without deleting. This might require some free space.')
|
'Caution: Saving checkpoints every n train steps without deleting. This might require some free space.')
|
||||||
default_metrics_over_trainsteps_ckpt_dict = {
|
default_metrics_over_trainsteps_ckpt_dict = {
|
||||||
'metrics_over_trainsteps_checkpoint':
|
'metrics_over_trainsteps_checkpoint': {
|
||||||
{"target": 'lightning.pytorch.callbacks.ModelCheckpoint',
|
"target": LIGHTNING_PACK_NAME + 'callbacks.ModelCheckpoint',
|
||||||
'params': {
|
'params': {
|
||||||
"dirpath": os.path.join(ckptdir, 'trainstep_checkpoints'),
|
"dirpath": os.path.join(ckptdir, 'trainstep_checkpoints'),
|
||||||
"filename": "{epoch:06}-{step:09}",
|
"filename": "{epoch:06}-{step:09}",
|
||||||
"verbose": True,
|
"verbose": True,
|
||||||
'save_top_k': -1,
|
'save_top_k': -1,
|
||||||
'every_n_train_steps': 10000,
|
'every_n_train_steps': 10000,
|
||||||
'save_weights_only': True
|
'save_weights_only': True
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
default_callbacks_cfg.update(default_metrics_over_trainsteps_ckpt_dict)
|
default_callbacks_cfg.update(default_metrics_over_trainsteps_ckpt_dict)
|
||||||
|
|
||||||
|
@ -744,7 +745,7 @@ if __name__ == "__main__":
|
||||||
trainer_kwargs["callbacks"] = [instantiate_from_config(callbacks_cfg[k]) for k in callbacks_cfg]
|
trainer_kwargs["callbacks"] = [instantiate_from_config(callbacks_cfg[k]) for k in callbacks_cfg]
|
||||||
|
|
||||||
trainer = Trainer.from_argparse_args(trainer_opt, **trainer_kwargs)
|
trainer = Trainer.from_argparse_args(trainer_opt, **trainer_kwargs)
|
||||||
trainer.logdir = logdir ###
|
trainer.logdir = logdir ###
|
||||||
|
|
||||||
# data
|
# data
|
||||||
data = instantiate_from_config(config.data)
|
data = instantiate_from_config(config.data)
|
||||||
|
@ -772,14 +773,13 @@ if __name__ == "__main__":
|
||||||
if opt.scale_lr:
|
if opt.scale_lr:
|
||||||
model.learning_rate = accumulate_grad_batches * ngpu * bs * base_lr
|
model.learning_rate = accumulate_grad_batches * ngpu * bs * base_lr
|
||||||
print(
|
print(
|
||||||
"Setting learning rate to {:.2e} = {} (accumulate_grad_batches) * {} (num_gpus) * {} (batchsize) * {:.2e} (base_lr)".format(
|
"Setting learning rate to {:.2e} = {} (accumulate_grad_batches) * {} (num_gpus) * {} (batchsize) * {:.2e} (base_lr)"
|
||||||
model.learning_rate, accumulate_grad_batches, ngpu, bs, base_lr))
|
.format(model.learning_rate, accumulate_grad_batches, ngpu, bs, base_lr))
|
||||||
else:
|
else:
|
||||||
model.learning_rate = base_lr
|
model.learning_rate = base_lr
|
||||||
print("++++ NOT USING LR SCALING ++++")
|
print("++++ NOT USING LR SCALING ++++")
|
||||||
print(f"Setting learning rate to {model.learning_rate:.2e}")
|
print(f"Setting learning rate to {model.learning_rate:.2e}")
|
||||||
|
|
||||||
|
|
||||||
# allow checkpointing via USR1
|
# allow checkpointing via USR1
|
||||||
def melk(*args, **kwargs):
|
def melk(*args, **kwargs):
|
||||||
# run all checkpoint hooks
|
# run all checkpoint hooks
|
||||||
|
@ -788,13 +788,11 @@ if __name__ == "__main__":
|
||||||
ckpt_path = os.path.join(ckptdir, "last.ckpt")
|
ckpt_path = os.path.join(ckptdir, "last.ckpt")
|
||||||
trainer.save_checkpoint(ckpt_path)
|
trainer.save_checkpoint(ckpt_path)
|
||||||
|
|
||||||
|
|
||||||
def divein(*args, **kwargs):
|
def divein(*args, **kwargs):
|
||||||
if trainer.global_rank == 0:
|
if trainer.global_rank == 0:
|
||||||
import pudb;
|
import pudb
|
||||||
pudb.set_trace()
|
pudb.set_trace()
|
||||||
|
|
||||||
|
|
||||||
import signal
|
import signal
|
||||||
|
|
||||||
signal.signal(signal.SIGUSR1, melk)
|
signal.signal(signal.SIGUSR1, melk)
|
||||||
|
@ -803,8 +801,6 @@ if __name__ == "__main__":
|
||||||
# run
|
# run
|
||||||
if opt.train:
|
if opt.train:
|
||||||
try:
|
try:
|
||||||
for name, m in model.named_parameters():
|
|
||||||
print(name)
|
|
||||||
trainer.fit(model, data)
|
trainer.fit(model, data)
|
||||||
except Exception:
|
except Exception:
|
||||||
melk()
|
melk()
|
||||||
|
|
|
@ -1,22 +1,17 @@
|
||||||
albumentations==0.4.3
|
albumentations==1.3.0
|
||||||
diffusers
|
opencv-python
|
||||||
pudb==2019.2
|
pudb==2019.2
|
||||||
datasets
|
prefetch_generator
|
||||||
invisible-watermark
|
|
||||||
imageio==2.9.0
|
imageio==2.9.0
|
||||||
imageio-ffmpeg==0.4.2
|
imageio-ffmpeg==0.4.2
|
||||||
|
torchmetrics==0.6
|
||||||
omegaconf==2.1.1
|
omegaconf==2.1.1
|
||||||
multiprocess
|
|
||||||
lightning==1.8.1
|
|
||||||
test-tube>=0.7.5
|
test-tube>=0.7.5
|
||||||
streamlit>=0.73.1
|
streamlit>=0.73.1
|
||||||
einops==0.3.0
|
einops==0.3.0
|
||||||
torch-fidelity==0.3.0
|
|
||||||
transformers==4.19.2
|
transformers==4.19.2
|
||||||
torchmetrics==0.6.0
|
webdataset==0.2.5
|
||||||
kornia==0.6
|
open-clip-torch==2.7.0
|
||||||
opencv-python==4.6.0.66
|
gradio==3.11
|
||||||
prefetch_generator
|
datasets
|
||||||
-e git+https://github.com/CompVis/taming-transformers.git@master#egg=taming-transformers
|
|
||||||
-e git+https://github.com/openai/CLIP.git@main#egg=clip
|
|
||||||
-e .
|
-e .
|
||||||
|
|
|
@ -1,6 +1,6 @@
|
||||||
"""make variations of input image"""
|
"""make variations of input image"""
|
||||||
|
|
||||||
import argparse, os, sys, glob
|
import argparse, os
|
||||||
import PIL
|
import PIL
|
||||||
import torch
|
import torch
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
@ -12,12 +12,16 @@ from einops import rearrange, repeat
|
||||||
from torchvision.utils import make_grid
|
from torchvision.utils import make_grid
|
||||||
from torch import autocast
|
from torch import autocast
|
||||||
from contextlib import nullcontext
|
from contextlib import nullcontext
|
||||||
import time
|
try:
|
||||||
from lightning.pytorch import seed_everything
|
from lightning.pytorch import seed_everything
|
||||||
|
except:
|
||||||
|
from pytorch_lightning import seed_everything
|
||||||
|
from imwatermark import WatermarkEncoder
|
||||||
|
|
||||||
|
|
||||||
|
from scripts.txt2img import put_watermark
|
||||||
from ldm.util import instantiate_from_config
|
from ldm.util import instantiate_from_config
|
||||||
from ldm.models.diffusion.ddim import DDIMSampler
|
from ldm.models.diffusion.ddim import DDIMSampler
|
||||||
from ldm.models.diffusion.plms import PLMSSampler
|
|
||||||
|
|
||||||
|
|
||||||
def chunk(it, size):
|
def chunk(it, size):
|
||||||
|
@ -49,12 +53,12 @@ def load_img(path):
|
||||||
image = Image.open(path).convert("RGB")
|
image = Image.open(path).convert("RGB")
|
||||||
w, h = image.size
|
w, h = image.size
|
||||||
print(f"loaded input image of size ({w}, {h}) from {path}")
|
print(f"loaded input image of size ({w}, {h}) from {path}")
|
||||||
w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32
|
w, h = map(lambda x: x - x % 64, (w, h)) # resize to integer multiple of 64
|
||||||
image = image.resize((w, h), resample=PIL.Image.LANCZOS)
|
image = image.resize((w, h), resample=PIL.Image.LANCZOS)
|
||||||
image = np.array(image).astype(np.float32) / 255.0
|
image = np.array(image).astype(np.float32) / 255.0
|
||||||
image = image[None].transpose(0, 3, 1, 2)
|
image = image[None].transpose(0, 3, 1, 2)
|
||||||
image = torch.from_numpy(image)
|
image = torch.from_numpy(image)
|
||||||
return 2.*image - 1.
|
return 2. * image - 1.
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
|
@ -83,18 +87,6 @@ def main():
|
||||||
default="outputs/img2img-samples"
|
default="outputs/img2img-samples"
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
"--skip_grid",
|
|
||||||
action='store_true',
|
|
||||||
help="do not save a grid, only individual samples. Helpful when evaluating lots of samples",
|
|
||||||
)
|
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
"--skip_save",
|
|
||||||
action='store_true',
|
|
||||||
help="do not save indiviual samples. For speed measurements.",
|
|
||||||
)
|
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--ddim_steps",
|
"--ddim_steps",
|
||||||
type=int,
|
type=int,
|
||||||
|
@ -102,11 +94,6 @@ def main():
|
||||||
help="number of ddim sampling steps",
|
help="number of ddim sampling steps",
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
"--plms",
|
|
||||||
action='store_true',
|
|
||||||
help="use plms sampling",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--fixed_code",
|
"--fixed_code",
|
||||||
action='store_true',
|
action='store_true',
|
||||||
|
@ -125,6 +112,7 @@ def main():
|
||||||
default=1,
|
default=1,
|
||||||
help="sample this often",
|
help="sample this often",
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--C",
|
"--C",
|
||||||
type=int,
|
type=int,
|
||||||
|
@ -137,31 +125,35 @@ def main():
|
||||||
default=8,
|
default=8,
|
||||||
help="downsampling factor, most often 8 or 16",
|
help="downsampling factor, most often 8 or 16",
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--n_samples",
|
"--n_samples",
|
||||||
type=int,
|
type=int,
|
||||||
default=2,
|
default=2,
|
||||||
help="how many samples to produce for each given prompt. A.k.a batch size",
|
help="how many samples to produce for each given prompt. A.k.a batch size",
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--n_rows",
|
"--n_rows",
|
||||||
type=int,
|
type=int,
|
||||||
default=0,
|
default=0,
|
||||||
help="rows in the grid (default: n_samples)",
|
help="rows in the grid (default: n_samples)",
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--scale",
|
"--scale",
|
||||||
type=float,
|
type=float,
|
||||||
default=5.0,
|
default=9.0,
|
||||||
help="unconditional guidance scale: eps = eps(x, empty) + scale * (eps(x, cond) - eps(x, empty))",
|
help="unconditional guidance scale: eps = eps(x, empty) + scale * (eps(x, cond) - eps(x, empty))",
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--strength",
|
"--strength",
|
||||||
type=float,
|
type=float,
|
||||||
default=0.75,
|
default=0.8,
|
||||||
help="strength for noising/unnoising. 1.0 corresponds to full destruction of information in init image",
|
help="strength for noising/unnoising. 1.0 corresponds to full destruction of information in init image",
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--from-file",
|
"--from-file",
|
||||||
type=str,
|
type=str,
|
||||||
|
@ -170,13 +162,12 @@ def main():
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--config",
|
"--config",
|
||||||
type=str,
|
type=str,
|
||||||
default="configs/stable-diffusion/v1-inference.yaml",
|
default="configs/stable-diffusion/v2-inference.yaml",
|
||||||
help="path to config which constructs model",
|
help="path to config which constructs model",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--ckpt",
|
"--ckpt",
|
||||||
type=str,
|
type=str,
|
||||||
default="models/ldm/stable-diffusion-v1/model.ckpt",
|
|
||||||
help="path to checkpoint of model",
|
help="path to checkpoint of model",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
|
@ -202,15 +193,16 @@ def main():
|
||||||
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
|
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
|
||||||
model = model.to(device)
|
model = model.to(device)
|
||||||
|
|
||||||
if opt.plms:
|
sampler = DDIMSampler(model)
|
||||||
raise NotImplementedError("PLMS sampler not (yet) supported")
|
|
||||||
sampler = PLMSSampler(model)
|
|
||||||
else:
|
|
||||||
sampler = DDIMSampler(model)
|
|
||||||
|
|
||||||
os.makedirs(opt.outdir, exist_ok=True)
|
os.makedirs(opt.outdir, exist_ok=True)
|
||||||
outpath = opt.outdir
|
outpath = opt.outdir
|
||||||
|
|
||||||
|
print("Creating invisible watermark encoder (see https://github.com/ShieldMnt/invisible-watermark)...")
|
||||||
|
wm = "SDV2"
|
||||||
|
wm_encoder = WatermarkEncoder()
|
||||||
|
wm_encoder.set_watermark('bytes', wm.encode('utf-8'))
|
||||||
|
|
||||||
batch_size = opt.n_samples
|
batch_size = opt.n_samples
|
||||||
n_rows = opt.n_rows if opt.n_rows > 0 else batch_size
|
n_rows = opt.n_rows if opt.n_rows > 0 else batch_size
|
||||||
if not opt.from_file:
|
if not opt.from_file:
|
||||||
|
@ -244,7 +236,6 @@ def main():
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
with precision_scope("cuda"):
|
with precision_scope("cuda"):
|
||||||
with model.ema_scope():
|
with model.ema_scope():
|
||||||
tic = time.time()
|
|
||||||
all_samples = list()
|
all_samples = list()
|
||||||
for n in trange(opt.n_iter, desc="Sampling"):
|
for n in trange(opt.n_iter, desc="Sampling"):
|
||||||
for prompts in tqdm(data, desc="data"):
|
for prompts in tqdm(data, desc="data"):
|
||||||
|
@ -256,37 +247,35 @@ def main():
|
||||||
c = model.get_learned_conditioning(prompts)
|
c = model.get_learned_conditioning(prompts)
|
||||||
|
|
||||||
# encode (scaled latent)
|
# encode (scaled latent)
|
||||||
z_enc = sampler.stochastic_encode(init_latent, torch.tensor([t_enc]*batch_size).to(device))
|
z_enc = sampler.stochastic_encode(init_latent, torch.tensor([t_enc] * batch_size).to(device))
|
||||||
# decode it
|
# decode it
|
||||||
samples = sampler.decode(z_enc, c, t_enc, unconditional_guidance_scale=opt.scale,
|
samples = sampler.decode(z_enc, c, t_enc, unconditional_guidance_scale=opt.scale,
|
||||||
unconditional_conditioning=uc,)
|
unconditional_conditioning=uc, )
|
||||||
|
|
||||||
x_samples = model.decode_first_stage(samples)
|
x_samples = model.decode_first_stage(samples)
|
||||||
x_samples = torch.clamp((x_samples + 1.0) / 2.0, min=0.0, max=1.0)
|
x_samples = torch.clamp((x_samples + 1.0) / 2.0, min=0.0, max=1.0)
|
||||||
|
|
||||||
if not opt.skip_save:
|
for x_sample in x_samples:
|
||||||
for x_sample in x_samples:
|
x_sample = 255. * rearrange(x_sample.cpu().numpy(), 'c h w -> h w c')
|
||||||
x_sample = 255. * rearrange(x_sample.cpu().numpy(), 'c h w -> h w c')
|
img = Image.fromarray(x_sample.astype(np.uint8))
|
||||||
Image.fromarray(x_sample.astype(np.uint8)).save(
|
img = put_watermark(img, wm_encoder)
|
||||||
os.path.join(sample_path, f"{base_count:05}.png"))
|
img.save(os.path.join(sample_path, f"{base_count:05}.png"))
|
||||||
base_count += 1
|
base_count += 1
|
||||||
all_samples.append(x_samples)
|
all_samples.append(x_samples)
|
||||||
|
|
||||||
if not opt.skip_grid:
|
# additionally, save as grid
|
||||||
# additionally, save as grid
|
grid = torch.stack(all_samples, 0)
|
||||||
grid = torch.stack(all_samples, 0)
|
grid = rearrange(grid, 'n b c h w -> (n b) c h w')
|
||||||
grid = rearrange(grid, 'n b c h w -> (n b) c h w')
|
grid = make_grid(grid, nrow=n_rows)
|
||||||
grid = make_grid(grid, nrow=n_rows)
|
|
||||||
|
|
||||||
# to image
|
# to image
|
||||||
grid = 255. * rearrange(grid, 'c h w -> h w c').cpu().numpy()
|
grid = 255. * rearrange(grid, 'c h w -> h w c').cpu().numpy()
|
||||||
Image.fromarray(grid.astype(np.uint8)).save(os.path.join(outpath, f'grid-{grid_count:04}.png'))
|
grid = Image.fromarray(grid.astype(np.uint8))
|
||||||
grid_count += 1
|
grid = put_watermark(grid, wm_encoder)
|
||||||
|
grid.save(os.path.join(outpath, f'grid-{grid_count:04}.png'))
|
||||||
|
grid_count += 1
|
||||||
|
|
||||||
toc = time.time()
|
print(f"Your samples are ready and waiting for you here: \n{outpath} \nEnjoy.")
|
||||||
|
|
||||||
print(f"Your samples are ready and waiting for you here: \n{outpath} \n"
|
|
||||||
f" \nEnjoy.")
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|
|
@ -1,50 +1,33 @@
|
||||||
import argparse, os, sys, glob
|
import argparse, os
|
||||||
import cv2
|
import cv2
|
||||||
import torch
|
import torch
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from omegaconf import OmegaConf
|
from omegaconf import OmegaConf
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
from tqdm import tqdm, trange
|
from tqdm import tqdm, trange
|
||||||
from imwatermark import WatermarkEncoder
|
|
||||||
from itertools import islice
|
from itertools import islice
|
||||||
from einops import rearrange
|
from einops import rearrange
|
||||||
from torchvision.utils import make_grid
|
from torchvision.utils import make_grid
|
||||||
import time
|
try:
|
||||||
from lightning.pytorch import seed_everything
|
from lightning.pytorch import seed_everything
|
||||||
|
except:
|
||||||
|
from pytorch_lightning import seed_everything
|
||||||
from torch import autocast
|
from torch import autocast
|
||||||
from contextlib import contextmanager, nullcontext
|
from contextlib import nullcontext
|
||||||
|
from imwatermark import WatermarkEncoder
|
||||||
|
|
||||||
from ldm.util import instantiate_from_config
|
from ldm.util import instantiate_from_config
|
||||||
from ldm.models.diffusion.ddim import DDIMSampler
|
from ldm.models.diffusion.ddim import DDIMSampler
|
||||||
from ldm.models.diffusion.plms import PLMSSampler
|
from ldm.models.diffusion.plms import PLMSSampler
|
||||||
|
from ldm.models.diffusion.dpm_solver import DPMSolverSampler
|
||||||
|
|
||||||
from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
|
torch.set_grad_enabled(False)
|
||||||
from transformers import AutoFeatureExtractor
|
|
||||||
|
|
||||||
|
|
||||||
# load safety model
|
|
||||||
safety_model_id = "CompVis/stable-diffusion-safety-checker"
|
|
||||||
safety_feature_extractor = AutoFeatureExtractor.from_pretrained(safety_model_id)
|
|
||||||
safety_checker = StableDiffusionSafetyChecker.from_pretrained(safety_model_id)
|
|
||||||
|
|
||||||
|
|
||||||
def chunk(it, size):
|
def chunk(it, size):
|
||||||
it = iter(it)
|
it = iter(it)
|
||||||
return iter(lambda: tuple(islice(it, size)), ())
|
return iter(lambda: tuple(islice(it, size)), ())
|
||||||
|
|
||||||
|
|
||||||
def numpy_to_pil(images):
|
|
||||||
"""
|
|
||||||
Convert a numpy image or a batch of images to a PIL image.
|
|
||||||
"""
|
|
||||||
if images.ndim == 3:
|
|
||||||
images = images[None, ...]
|
|
||||||
images = (images * 255).round().astype("uint8")
|
|
||||||
pil_images = [Image.fromarray(image) for image in images]
|
|
||||||
|
|
||||||
return pil_images
|
|
||||||
|
|
||||||
|
|
||||||
def load_model_from_config(config, ckpt, verbose=False):
|
def load_model_from_config(config, ckpt, verbose=False):
|
||||||
print(f"Loading model from {ckpt}")
|
print(f"Loading model from {ckpt}")
|
||||||
pl_sd = torch.load(ckpt, map_location="cpu")
|
pl_sd = torch.load(ckpt, map_location="cpu")
|
||||||
|
@ -65,43 +48,13 @@ def load_model_from_config(config, ckpt, verbose=False):
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
def put_watermark(img, wm_encoder=None):
|
def parse_args():
|
||||||
if wm_encoder is not None:
|
|
||||||
img = cv2.cvtColor(np.array(img), cv2.COLOR_RGB2BGR)
|
|
||||||
img = wm_encoder.encode(img, 'dwtDct')
|
|
||||||
img = Image.fromarray(img[:, :, ::-1])
|
|
||||||
return img
|
|
||||||
|
|
||||||
|
|
||||||
def load_replacement(x):
|
|
||||||
try:
|
|
||||||
hwc = x.shape
|
|
||||||
y = Image.open("assets/rick.jpeg").convert("RGB").resize((hwc[1], hwc[0]))
|
|
||||||
y = (np.array(y)/255.0).astype(x.dtype)
|
|
||||||
assert y.shape == x.shape
|
|
||||||
return y
|
|
||||||
except Exception:
|
|
||||||
return x
|
|
||||||
|
|
||||||
|
|
||||||
def check_safety(x_image):
|
|
||||||
safety_checker_input = safety_feature_extractor(numpy_to_pil(x_image), return_tensors="pt")
|
|
||||||
x_checked_image, has_nsfw_concept = safety_checker(images=x_image, clip_input=safety_checker_input.pixel_values)
|
|
||||||
assert x_checked_image.shape[0] == len(has_nsfw_concept)
|
|
||||||
for i in range(len(has_nsfw_concept)):
|
|
||||||
if has_nsfw_concept[i]:
|
|
||||||
x_checked_image[i] = load_replacement(x_checked_image[i])
|
|
||||||
return x_checked_image, has_nsfw_concept
|
|
||||||
|
|
||||||
|
|
||||||
def main():
|
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--prompt",
|
"--prompt",
|
||||||
type=str,
|
type=str,
|
||||||
nargs="?",
|
nargs="?",
|
||||||
default="a painting of a virus monster playing guitar",
|
default="a professional photograph of an astronaut riding a triceratops",
|
||||||
help="the prompt to render"
|
help="the prompt to render"
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
|
@ -112,17 +65,7 @@ def main():
|
||||||
default="outputs/txt2img-samples"
|
default="outputs/txt2img-samples"
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--skip_grid",
|
"--steps",
|
||||||
action='store_true',
|
|
||||||
help="do not save a grid, only individual samples. Helpful when evaluating lots of samples",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--skip_save",
|
|
||||||
action='store_true',
|
|
||||||
help="do not save individual samples. For speed measurements.",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--ddim_steps",
|
|
||||||
type=int,
|
type=int,
|
||||||
default=50,
|
default=50,
|
||||||
help="number of ddim sampling steps",
|
help="number of ddim sampling steps",
|
||||||
|
@ -133,14 +76,14 @@ def main():
|
||||||
help="use plms sampling",
|
help="use plms sampling",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--laion400m",
|
"--dpm",
|
||||||
action='store_true',
|
action='store_true',
|
||||||
help="uses the LAION400M model",
|
help="use DPM (2) sampler",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--fixed_code",
|
"--fixed_code",
|
||||||
action='store_true',
|
action='store_true',
|
||||||
help="if enabled, uses the same starting code across samples ",
|
help="if enabled, uses the same starting code across all samples ",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--ddim_eta",
|
"--ddim_eta",
|
||||||
|
@ -151,7 +94,7 @@ def main():
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--n_iter",
|
"--n_iter",
|
||||||
type=int,
|
type=int,
|
||||||
default=2,
|
default=3,
|
||||||
help="sample this often",
|
help="sample this often",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
|
@ -176,13 +119,13 @@ def main():
|
||||||
"--f",
|
"--f",
|
||||||
type=int,
|
type=int,
|
||||||
default=8,
|
default=8,
|
||||||
help="downsampling factor",
|
help="downsampling factor, most often 8 or 16",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--n_samples",
|
"--n_samples",
|
||||||
type=int,
|
type=int,
|
||||||
default=3,
|
default=3,
|
||||||
help="how many samples to produce for each given prompt. A.k.a. batch size",
|
help="how many samples to produce for each given prompt. A.k.a batch size",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--n_rows",
|
"--n_rows",
|
||||||
|
@ -193,24 +136,23 @@ def main():
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--scale",
|
"--scale",
|
||||||
type=float,
|
type=float,
|
||||||
default=7.5,
|
default=9.0,
|
||||||
help="unconditional guidance scale: eps = eps(x, empty) + scale * (eps(x, cond) - eps(x, empty))",
|
help="unconditional guidance scale: eps = eps(x, empty) + scale * (eps(x, cond) - eps(x, empty))",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--from-file",
|
"--from-file",
|
||||||
type=str,
|
type=str,
|
||||||
help="if specified, load prompts from this file",
|
help="if specified, load prompts from this file, separated by newlines",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--config",
|
"--config",
|
||||||
type=str,
|
type=str,
|
||||||
default="configs/stable-diffusion/v1-inference.yaml",
|
default="configs/stable-diffusion/v2-inference.yaml",
|
||||||
help="path to config which constructs model",
|
help="path to config which constructs model",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--ckpt",
|
"--ckpt",
|
||||||
type=str,
|
type=str,
|
||||||
default="models/ldm/stable-diffusion-v1/model.ckpt",
|
|
||||||
help="path to checkpoint of model",
|
help="path to checkpoint of model",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
|
@ -226,14 +168,25 @@ def main():
|
||||||
choices=["full", "autocast"],
|
choices=["full", "autocast"],
|
||||||
default="autocast"
|
default="autocast"
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--repeat",
|
||||||
|
type=int,
|
||||||
|
default=1,
|
||||||
|
help="repeat each prompt in file this often",
|
||||||
|
)
|
||||||
opt = parser.parse_args()
|
opt = parser.parse_args()
|
||||||
|
return opt
|
||||||
|
|
||||||
if opt.laion400m:
|
|
||||||
print("Falling back to LAION 400M model...")
|
|
||||||
opt.config = "configs/latent-diffusion/txt2img-1p4B-eval.yaml"
|
|
||||||
opt.ckpt = "models/ldm/text2img-large/model.ckpt"
|
|
||||||
opt.outdir = "outputs/txt2img-samples-laion400m"
|
|
||||||
|
|
||||||
|
def put_watermark(img, wm_encoder=None):
|
||||||
|
if wm_encoder is not None:
|
||||||
|
img = cv2.cvtColor(np.array(img), cv2.COLOR_RGB2BGR)
|
||||||
|
img = wm_encoder.encode(img, 'dwtDct')
|
||||||
|
img = Image.fromarray(img[:, :, ::-1])
|
||||||
|
return img
|
||||||
|
|
||||||
|
|
||||||
|
def main(opt):
|
||||||
seed_everything(opt.seed)
|
seed_everything(opt.seed)
|
||||||
|
|
||||||
config = OmegaConf.load(f"{opt.config}")
|
config = OmegaConf.load(f"{opt.config}")
|
||||||
|
@ -244,6 +197,8 @@ def main():
|
||||||
|
|
||||||
if opt.plms:
|
if opt.plms:
|
||||||
sampler = PLMSSampler(model)
|
sampler = PLMSSampler(model)
|
||||||
|
elif opt.dpm:
|
||||||
|
sampler = DPMSolverSampler(model)
|
||||||
else:
|
else:
|
||||||
sampler = DDIMSampler(model)
|
sampler = DDIMSampler(model)
|
||||||
|
|
||||||
|
@ -251,7 +206,7 @@ def main():
|
||||||
outpath = opt.outdir
|
outpath = opt.outdir
|
||||||
|
|
||||||
print("Creating invisible watermark encoder (see https://github.com/ShieldMnt/invisible-watermark)...")
|
print("Creating invisible watermark encoder (see https://github.com/ShieldMnt/invisible-watermark)...")
|
||||||
wm = "StableDiffusionV1"
|
wm = "SDV2"
|
||||||
wm_encoder = WatermarkEncoder()
|
wm_encoder = WatermarkEncoder()
|
||||||
wm_encoder.set_watermark('bytes', wm.encode('utf-8'))
|
wm_encoder.set_watermark('bytes', wm.encode('utf-8'))
|
||||||
|
|
||||||
|
@ -266,10 +221,12 @@ def main():
|
||||||
print(f"reading prompts from {opt.from_file}")
|
print(f"reading prompts from {opt.from_file}")
|
||||||
with open(opt.from_file, "r") as f:
|
with open(opt.from_file, "r") as f:
|
||||||
data = f.read().splitlines()
|
data = f.read().splitlines()
|
||||||
|
data = [p for p in data for i in range(opt.repeat)]
|
||||||
data = list(chunk(data, batch_size))
|
data = list(chunk(data, batch_size))
|
||||||
|
|
||||||
sample_path = os.path.join(outpath, "samples")
|
sample_path = os.path.join(outpath, "samples")
|
||||||
os.makedirs(sample_path, exist_ok=True)
|
os.makedirs(sample_path, exist_ok=True)
|
||||||
|
sample_count = 0
|
||||||
base_count = len(os.listdir(sample_path))
|
base_count = len(os.listdir(sample_path))
|
||||||
grid_count = len(os.listdir(outpath)) - 1
|
grid_count = len(os.listdir(outpath)) - 1
|
||||||
|
|
||||||
|
@ -277,68 +234,59 @@ def main():
|
||||||
if opt.fixed_code:
|
if opt.fixed_code:
|
||||||
start_code = torch.randn([opt.n_samples, opt.C, opt.H // opt.f, opt.W // opt.f], device=device)
|
start_code = torch.randn([opt.n_samples, opt.C, opt.H // opt.f, opt.W // opt.f], device=device)
|
||||||
|
|
||||||
precision_scope = autocast if opt.precision=="autocast" else nullcontext
|
precision_scope = autocast if opt.precision == "autocast" else nullcontext
|
||||||
with torch.no_grad():
|
with torch.no_grad(), \
|
||||||
with precision_scope("cuda"):
|
precision_scope("cuda"), \
|
||||||
with model.ema_scope():
|
model.ema_scope():
|
||||||
tic = time.time()
|
all_samples = list()
|
||||||
all_samples = list()
|
for n in trange(opt.n_iter, desc="Sampling"):
|
||||||
for n in trange(opt.n_iter, desc="Sampling"):
|
for prompts in tqdm(data, desc="data"):
|
||||||
for prompts in tqdm(data, desc="data"):
|
uc = None
|
||||||
uc = None
|
if opt.scale != 1.0:
|
||||||
if opt.scale != 1.0:
|
uc = model.get_learned_conditioning(batch_size * [""])
|
||||||
uc = model.get_learned_conditioning(batch_size * [""])
|
if isinstance(prompts, tuple):
|
||||||
if isinstance(prompts, tuple):
|
prompts = list(prompts)
|
||||||
prompts = list(prompts)
|
c = model.get_learned_conditioning(prompts)
|
||||||
c = model.get_learned_conditioning(prompts)
|
shape = [opt.C, opt.H // opt.f, opt.W // opt.f]
|
||||||
shape = [opt.C, opt.H // opt.f, opt.W // opt.f]
|
samples, _ = sampler.sample(S=opt.steps,
|
||||||
samples_ddim, _ = sampler.sample(S=opt.ddim_steps,
|
conditioning=c,
|
||||||
conditioning=c,
|
batch_size=opt.n_samples,
|
||||||
batch_size=opt.n_samples,
|
shape=shape,
|
||||||
shape=shape,
|
verbose=False,
|
||||||
verbose=False,
|
unconditional_guidance_scale=opt.scale,
|
||||||
unconditional_guidance_scale=opt.scale,
|
unconditional_conditioning=uc,
|
||||||
unconditional_conditioning=uc,
|
eta=opt.ddim_eta,
|
||||||
eta=opt.ddim_eta,
|
x_T=start_code)
|
||||||
x_T=start_code)
|
|
||||||
|
|
||||||
x_samples_ddim = model.decode_first_stage(samples_ddim)
|
x_samples = model.decode_first_stage(samples)
|
||||||
x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
|
x_samples = torch.clamp((x_samples + 1.0) / 2.0, min=0.0, max=1.0)
|
||||||
x_samples_ddim = x_samples_ddim.cpu().permute(0, 2, 3, 1).numpy()
|
|
||||||
|
|
||||||
x_checked_image, has_nsfw_concept = check_safety(x_samples_ddim)
|
for x_sample in x_samples:
|
||||||
|
x_sample = 255. * rearrange(x_sample.cpu().numpy(), 'c h w -> h w c')
|
||||||
|
img = Image.fromarray(x_sample.astype(np.uint8))
|
||||||
|
img = put_watermark(img, wm_encoder)
|
||||||
|
img.save(os.path.join(sample_path, f"{base_count:05}.png"))
|
||||||
|
base_count += 1
|
||||||
|
sample_count += 1
|
||||||
|
|
||||||
x_checked_image_torch = torch.from_numpy(x_checked_image).permute(0, 3, 1, 2)
|
all_samples.append(x_samples)
|
||||||
|
|
||||||
if not opt.skip_save:
|
# additionally, save as grid
|
||||||
for x_sample in x_checked_image_torch:
|
grid = torch.stack(all_samples, 0)
|
||||||
x_sample = 255. * rearrange(x_sample.cpu().numpy(), 'c h w -> h w c')
|
grid = rearrange(grid, 'n b c h w -> (n b) c h w')
|
||||||
img = Image.fromarray(x_sample.astype(np.uint8))
|
grid = make_grid(grid, nrow=n_rows)
|
||||||
img = put_watermark(img, wm_encoder)
|
|
||||||
img.save(os.path.join(sample_path, f"{base_count:05}.png"))
|
|
||||||
base_count += 1
|
|
||||||
|
|
||||||
if not opt.skip_grid:
|
# to image
|
||||||
all_samples.append(x_checked_image_torch)
|
grid = 255. * rearrange(grid, 'c h w -> h w c').cpu().numpy()
|
||||||
|
grid = Image.fromarray(grid.astype(np.uint8))
|
||||||
if not opt.skip_grid:
|
grid = put_watermark(grid, wm_encoder)
|
||||||
# additionally, save as grid
|
grid.save(os.path.join(outpath, f'grid-{grid_count:04}.png'))
|
||||||
grid = torch.stack(all_samples, 0)
|
grid_count += 1
|
||||||
grid = rearrange(grid, 'n b c h w -> (n b) c h w')
|
|
||||||
grid = make_grid(grid, nrow=n_rows)
|
|
||||||
|
|
||||||
# to image
|
|
||||||
grid = 255. * rearrange(grid, 'c h w -> h w c').cpu().numpy()
|
|
||||||
img = Image.fromarray(grid.astype(np.uint8))
|
|
||||||
img = put_watermark(img, wm_encoder)
|
|
||||||
img.save(os.path.join(outpath, f'grid-{grid_count:04}.png'))
|
|
||||||
grid_count += 1
|
|
||||||
|
|
||||||
toc = time.time()
|
|
||||||
|
|
||||||
print(f"Your samples are ready and waiting for you here: \n{outpath} \n"
|
print(f"Your samples are ready and waiting for you here: \n{outpath} \n"
|
||||||
f" \nEnjoy.")
|
f" \nEnjoy.")
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
main()
|
opt = parse_args()
|
||||||
|
main(opt)
|
||||||
|
|
|
@ -1,4 +1,5 @@
|
||||||
HF_DATASETS_OFFLINE=1
|
# HF_DATASETS_OFFLINE=1
|
||||||
TRANSFORMERS_OFFLINE=1
|
# TRANSFORMERS_OFFLINE=1
|
||||||
|
# DIFFUSERS_OFFLINE=1
|
||||||
|
|
||||||
python main.py --logdir /tmp -t --postfix test -b configs/train_colossalai.yaml
|
python main.py --logdir /tmp/ -t -b configs/Teyvat/train_colossalai_teyvat.yaml
|
||||||
|
|
Loading…
Reference in New Issue