mirror of https://github.com/hpcaitech/ColossalAI
[example] add stable diffuser (#1825)
parent
b1263d32ba
commit
6e9730d7ab
Binary file not shown.
After Width: | Height: | Size: 3.8 MiB |
|
@ -1,21 +1,21 @@
|
|||
# ColoDiffusion
|
||||
*ColoDiffusion is a Faster Train implementation of the model [stable-diffusion](https://github.com/CompVis/stable-diffusion) from [Stability AI](https://stability.ai/)*
|
||||
*[ColoDiffusion](https://github.com/hpcaitech/ColoDiffusion) is a Faster Train implementation of the model [stable-diffusion](https://github.com/CompVis/stable-diffusion) from [Stability AI](https://stability.ai/)*
|
||||
|
||||
We take advantage of Colosssal-AI to exploit multiple optimization strategies
|
||||
, e.g. data parallelism, tensor parallelism, mixed precision & ZeRO, to scale the training to multiple GPUs.
|
||||
|
||||
|
||||
![](./Merged-0001.png)
|
||||
|
||||
![txt2img-stable2](assets/stable-samples/txt2img/merged-0006.png)
|
||||
[Stable Diffusion](#stable-diffusion-v1) is a latent text-to-image diffusion
|
||||
model.
|
||||
Thanks to a generous compute donation from [Stability AI](https://stability.ai/) and support from [LAION](https://laion.ai/), we were able to train a Latent Diffusion Model on 512x512 images from a subset of the [LAION-5B](https://laion.ai/blog/laion-5b/) database.
|
||||
Similar to Google's [Imagen](https://arxiv.org/abs/2205.11487),
|
||||
Thanks to a generous compute donation from [Stability AI](https://stability.ai/) and support from [LAION](https://laion.ai/), we were able to train a Latent Diffusion Model on 512x512 images from a subset of the [LAION-5B](https://laion.ai/blog/laion-5b/) database.
|
||||
Similar to Google's [Imagen](https://arxiv.org/abs/2205.11487),
|
||||
this model uses a frozen CLIP ViT-L/14 text encoder to condition the model on text prompts.
|
||||
With its 860M UNet and 123M text encoder, the model is relatively lightweight and runs on a GPU with at least 10GB VRAM.
|
||||
See [this section](#stable-diffusion-v1) below and the [model card](https://huggingface.co/CompVis/stable-diffusion).
|
||||
|
||||
|
||||
|
||||
## Requirements
|
||||
A suitable [conda](https://conda.io/) environment named `ldm` can be created
|
||||
and activated with:
|
||||
|
@ -31,7 +31,7 @@ You can also update an existing [latent diffusion](https://github.com/CompVis/la
|
|||
conda install pytorch torchvision -c pytorch
|
||||
pip install transformers==4.19.2 diffusers invisible-watermark
|
||||
pip install -e .
|
||||
```
|
||||
```
|
||||
|
||||
### Install ColossalAI
|
||||
|
||||
|
@ -41,38 +41,51 @@ git checkout v0.1.10
|
|||
pip install .
|
||||
```
|
||||
|
||||
### Install colossalai lightning
|
||||
```
|
||||
git clone -b colossalai https://github.com/Fazziekey/lightning.git
|
||||
pip install .
|
||||
```
|
||||
|
||||
## Dataset
|
||||
The DataSet is from [LAION-5B](https://laion.ai/blog/laion-5b/), the subset of [LAION](https://laion.ai/),
|
||||
you should the change the `data.file_path` in the `config/train_colossalai.yaml`
|
||||
|
||||
## Training
|
||||
|
||||
we provide the script `train.sh` to run the training task , and three Stategy in `configs`:`train_colossalai.yaml`, `train_ddp.yaml`, `train_deepspeed.yaml`
|
||||
|
||||
for example, you can run the training from colossalai by
|
||||
```
|
||||
python main.py --logdir /tmp -t --postfix test -b config/train_colossalai.yaml
|
||||
python main.py --logdir /tmp -t --postfix test -b config/train_colossalai.yaml
|
||||
```
|
||||
|
||||
- you can change the `--logdir` the save the log information and the last checkpoint
|
||||
|
||||
### Training config
|
||||
you can change the trainging config in the yaml file
|
||||
|
||||
- accelerator: acceleratortype, default 'gpu'
|
||||
- accelerator: acceleratortype, default 'gpu'
|
||||
- devices: device number used for training, default 4
|
||||
- max_epochs: max training epochs
|
||||
- precision: usefp16 for training or not, default 16, you must use fp16 if you want to apply colossalai
|
||||
|
||||
|
||||
## Comments
|
||||
## Comments
|
||||
|
||||
- Our codebase for the diffusion models builds heavily on [OpenAI's ADM codebase](https://github.com/openai/guided-diffusion)
|
||||
and [https://github.com/lucidrains/denoising-diffusion-pytorch](https://github.com/lucidrains/denoising-diffusion-pytorch).
|
||||
and [https://github.com/lucidrains/denoising-diffusion-pytorch](https://github.com/lucidrains/denoising-diffusion-pytorch).
|
||||
Thanks for open-sourcing!
|
||||
|
||||
- The implementation of the transformer encoder is from [x-transformers](https://github.com/lucidrains/x-transformers) by [lucidrains](https://github.com/lucidrains?tab=repositories).
|
||||
- The implementation of the transformer encoder is from [x-transformers](https://github.com/lucidrains/x-transformers) by [lucidrains](https://github.com/lucidrains?tab=repositories).
|
||||
|
||||
- the implementation of [flash attention](https://github.com/HazyResearch/flash-attention) is from [HazyResearch](https://github.com/HazyResearch)
|
||||
- the implementation of [flash attention](https://github.com/HazyResearch/flash-attention) is from [HazyResearch](https://github.com/HazyResearch)
|
||||
|
||||
## BibTeX
|
||||
|
||||
```
|
||||
@misc{rombach2021highresolution,
|
||||
title={High-Resolution Image Synthesis with Latent Diffusion Models},
|
||||
title={High-Resolution Image Synthesis with Latent Diffusion Models},
|
||||
author={Robin Rombach and Andreas Blattmann and Dominik Lorenz and Patrick Esser and Björn Ommer},
|
||||
year={2021},
|
||||
eprint={2112.10752},
|
||||
|
@ -86,3 +99,5 @@ Thanks for open-sourcing!
|
|||
year={2022}
|
||||
}
|
||||
```
|
||||
|
||||
|
||||
|
|
|
@ -0,0 +1,144 @@
|
|||
# Stable Diffusion v1 Model Card
|
||||
This model card focuses on the model associated with the Stable Diffusion model, available [here](https://github.com/CompVis/stable-diffusion).
|
||||
|
||||
## Model Details
|
||||
- **Developed by:** Robin Rombach, Patrick Esser
|
||||
- **Model type:** Diffusion-based text-to-image generation model
|
||||
- **Language(s):** English
|
||||
- **License:** [Proprietary](LICENSE)
|
||||
- **Model Description:** This is a model that can be used to generate and modify images based on text prompts. It is a [Latent Diffusion Model](https://arxiv.org/abs/2112.10752) that uses a fixed, pretrained text encoder ([CLIP ViT-L/14](https://arxiv.org/abs/2103.00020)) as suggested in the [Imagen paper](https://arxiv.org/abs/2205.11487).
|
||||
- **Resources for more information:** [GitHub Repository](https://github.com/CompVis/stable-diffusion), [Paper](https://arxiv.org/abs/2112.10752).
|
||||
- **Cite as:**
|
||||
|
||||
@InProceedings{Rombach_2022_CVPR,
|
||||
author = {Rombach, Robin and Blattmann, Andreas and Lorenz, Dominik and Esser, Patrick and Ommer, Bj\"orn},
|
||||
title = {High-Resolution Image Synthesis With Latent Diffusion Models},
|
||||
booktitle = {Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)},
|
||||
month = {June},
|
||||
year = {2022},
|
||||
pages = {10684-10695}
|
||||
}
|
||||
|
||||
# Uses
|
||||
|
||||
## Direct Use
|
||||
The model is intended for research purposes only. Possible research areas and
|
||||
tasks include
|
||||
|
||||
- Safe deployment of models which have the potential to generate harmful content.
|
||||
- Probing and understanding the limitations and biases of generative models.
|
||||
- Generation of artworks and use in design and other artistic processes.
|
||||
- Applications in educational or creative tools.
|
||||
- Research on generative models.
|
||||
|
||||
Excluded uses are described below.
|
||||
|
||||
### Misuse, Malicious Use, and Out-of-Scope Use
|
||||
_Note: This section is taken from the [DALLE-MINI model card](https://huggingface.co/dalle-mini/dalle-mini), but applies in the same way to Stable Diffusion v1_.
|
||||
|
||||
The model should not be used to intentionally create or disseminate images that create hostile or alienating environments for people. This includes generating images that people would foreseeably find disturbing, distressing, or offensive; or content that propagates historical or current stereotypes.
|
||||
|
||||
#### Out-of-Scope Use
|
||||
The model was not trained to be factual or true representations of people or events, and therefore using the model to generate such content is out-of-scope for the abilities of this model.
|
||||
|
||||
#### Misuse and Malicious Use
|
||||
Using the model to generate content that is cruel to individuals is a misuse of this model. This includes, but is not limited to:
|
||||
|
||||
- Generating demeaning, dehumanizing, or otherwise harmful representations of people or their environments, cultures, religions, etc.
|
||||
- Intentionally promoting or propagating discriminatory content or harmful stereotypes.
|
||||
- Impersonating individuals without their consent.
|
||||
- Sexual content without consent of the people who might see it.
|
||||
- Mis- and disinformation
|
||||
- Representations of egregious violence and gore
|
||||
- Sharing of copyrighted or licensed material in violation of its terms of use.
|
||||
- Sharing content that is an alteration of copyrighted or licensed material in violation of its terms of use.
|
||||
|
||||
## Limitations and Bias
|
||||
|
||||
### Limitations
|
||||
|
||||
- The model does not achieve perfect photorealism
|
||||
- The model cannot render legible text
|
||||
- The model does not perform well on more difficult tasks which involve compositionality, such as rendering an image corresponding to “A red cube on top of a blue sphere”
|
||||
- Faces and people in general may not be generated properly.
|
||||
- The model was trained mainly with English captions and will not work as well in other languages.
|
||||
- The autoencoding part of the model is lossy
|
||||
- The model was trained on a large-scale dataset
|
||||
[LAION-5B](https://laion.ai/blog/laion-5b/) which contains adult material
|
||||
and is not fit for product use without additional safety mechanisms and
|
||||
considerations.
|
||||
- No additional measures were used to deduplicate the dataset. As a result, we observe some degree of memorization for images that are duplicated in the training data.
|
||||
The training data can be searched at [https://rom1504.github.io/clip-retrieval/](https://rom1504.github.io/clip-retrieval/) to possibly assist in the detection of memorized images.
|
||||
|
||||
### Bias
|
||||
While the capabilities of image generation models are impressive, they can also reinforce or exacerbate social biases.
|
||||
Stable Diffusion v1 was primarily trained on subsets of [LAION-2B(en)](https://laion.ai/blog/laion-5b/),
|
||||
which consists of images that are limited to English descriptions.
|
||||
Texts and images from communities and cultures that use other languages are likely to be insufficiently accounted for.
|
||||
This affects the overall output of the model, as white and western cultures are often set as the default. Further, the
|
||||
ability of the model to generate content with non-English prompts is significantly worse than with English-language prompts.
|
||||
Stable Diffusion v1 mirrors and exacerbates biases to such a degree that viewer discretion must be advised irrespective of the input or its intent.
|
||||
|
||||
|
||||
## Training
|
||||
|
||||
**Training Data**
|
||||
The model developers used the following dataset for training the model:
|
||||
|
||||
- LAION-5B and subsets thereof (see next section)
|
||||
|
||||
**Training Procedure**
|
||||
Stable Diffusion v1 is a latent diffusion model which combines an autoencoder with a diffusion model that is trained in the latent space of the autoencoder. During training,
|
||||
|
||||
- Images are encoded through an encoder, which turns images into latent representations. The autoencoder uses a relative downsampling factor of 8 and maps images of shape H x W x 3 to latents of shape H/f x W/f x 4
|
||||
- Text prompts are encoded through a ViT-L/14 text-encoder.
|
||||
- The non-pooled output of the text encoder is fed into the UNet backbone of the latent diffusion model via cross-attention.
|
||||
- The loss is a reconstruction objective between the noise that was added to the latent and the prediction made by the UNet.
|
||||
|
||||
We currently provide the following checkpoints:
|
||||
|
||||
- `sd-v1-1.ckpt`: 237k steps at resolution `256x256` on [laion2B-en](https://huggingface.co/datasets/laion/laion2B-en).
|
||||
194k steps at resolution `512x512` on [laion-high-resolution](https://huggingface.co/datasets/laion/laion-high-resolution) (170M examples from LAION-5B with resolution `>= 1024x1024`).
|
||||
- `sd-v1-2.ckpt`: Resumed from `sd-v1-1.ckpt`.
|
||||
515k steps at resolution `512x512` on [laion-aesthetics v2 5+](https://laion.ai/blog/laion-aesthetics/) (a subset of laion2B-en with estimated aesthetics score `> 5.0`, and additionally
|
||||
filtered to images with an original size `>= 512x512`, and an estimated watermark probability `< 0.5`. The watermark estimate is from the [LAION-5B](https://laion.ai/blog/laion-5b/) metadata, the aesthetics score is estimated using the [LAION-Aesthetics Predictor V2](https://github.com/christophschuhmann/improved-aesthetic-predictor)).
|
||||
- `sd-v1-3.ckpt`: Resumed from `sd-v1-2.ckpt`. 195k steps at resolution `512x512` on "laion-aesthetics v2 5+" and 10\% dropping of the text-conditioning to improve [classifier-free guidance sampling](https://arxiv.org/abs/2207.12598).
|
||||
- `sd-v1-4.ckpt`: Resumed from `sd-v1-2.ckpt`. 225k steps at resolution `512x512` on "laion-aesthetics v2 5+" and 10\% dropping of the text-conditioning to improve [classifier-free guidance sampling](https://arxiv.org/abs/2207.12598).
|
||||
|
||||
- **Hardware:** 32 x 8 x A100 GPUs
|
||||
- **Optimizer:** AdamW
|
||||
- **Gradient Accumulations**: 2
|
||||
- **Batch:** 32 x 8 x 2 x 4 = 2048
|
||||
- **Learning rate:** warmup to 0.0001 for 10,000 steps and then kept constant
|
||||
|
||||
## Evaluation Results
|
||||
Evaluations with different classifier-free guidance scales (1.5, 2.0, 3.0, 4.0,
|
||||
5.0, 6.0, 7.0, 8.0) and 50 PLMS sampling
|
||||
steps show the relative improvements of the checkpoints:
|
||||
|
||||
![pareto](assets/v1-variants-scores.jpg)
|
||||
|
||||
Evaluated using 50 PLMS steps and 10000 random prompts from the COCO2017 validation set, evaluated at 512x512 resolution. Not optimized for FID scores.
|
||||
|
||||
## Environmental Impact
|
||||
|
||||
**Stable Diffusion v1** **Estimated Emissions**
|
||||
Based on that information, we estimate the following CO2 emissions using the [Machine Learning Impact calculator](https://mlco2.github.io/impact#compute) presented in [Lacoste et al. (2019)](https://arxiv.org/abs/1910.09700). The hardware, runtime, cloud provider, and compute region were utilized to estimate the carbon impact.
|
||||
|
||||
- **Hardware Type:** A100 PCIe 40GB
|
||||
- **Hours used:** 150000
|
||||
- **Cloud Provider:** AWS
|
||||
- **Compute Region:** US-east
|
||||
- **Carbon Emitted (Power consumption x Time x Carbon produced based on location of power grid):** 11250 kg CO2 eq.
|
||||
|
||||
## Citation
|
||||
@InProceedings{Rombach_2022_CVPR,
|
||||
author = {Rombach, Robin and Blattmann, Andreas and Lorenz, Dominik and Esser, Patrick and Ommer, Bj\"orn},
|
||||
title = {High-Resolution Image Synthesis With Latent Diffusion Models},
|
||||
booktitle = {Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)},
|
||||
month = {June},
|
||||
year = {2022},
|
||||
pages = {10684-10695}
|
||||
}
|
||||
|
||||
*This model card was written by: Robin Rombach and Patrick Esser and is based on the [DALL-E Mini model card](https://huggingface.co/dalle-mini/dalle-mini).*
|
|
@ -0,0 +1,116 @@
|
|||
model:
|
||||
base_learning_rate: 1.0e-04
|
||||
target: ldm.models.diffusion.ddpm.LatentDiffusion
|
||||
params:
|
||||
linear_start: 0.00085
|
||||
linear_end: 0.0120
|
||||
num_timesteps_cond: 1
|
||||
log_every_t: 200
|
||||
timesteps: 1000
|
||||
first_stage_key: image
|
||||
cond_stage_key: caption
|
||||
image_size: 64
|
||||
channels: 4
|
||||
cond_stage_trainable: false # Note: different from the one we trained before
|
||||
conditioning_key: crossattn
|
||||
monitor: val/loss_simple_ema
|
||||
scale_factor: 0.18215
|
||||
use_ema: False
|
||||
|
||||
scheduler_config: # 10000 warmup steps
|
||||
target: ldm.lr_scheduler.LambdaLinearScheduler
|
||||
params:
|
||||
warm_up_steps: [ 1 ] # NOTE for resuming. use 10000 if starting from scratch
|
||||
cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases
|
||||
f_start: [ 1.e-6 ]
|
||||
f_max: [ 1.e-4 ]
|
||||
f_min: [ 1.e-10 ]
|
||||
|
||||
unet_config:
|
||||
target: ldm.modules.diffusionmodules.openaimodel.UNetModel
|
||||
params:
|
||||
image_size: 32 # unused
|
||||
from_pretrained: '/data/scratch/diffuser/stable-diffusion-v1-4/unet/diffusion_pytorch_model.bin'
|
||||
in_channels: 4
|
||||
out_channels: 4
|
||||
model_channels: 320
|
||||
attention_resolutions: [ 4, 2, 1 ]
|
||||
num_res_blocks: 2
|
||||
channel_mult: [ 1, 2, 4, 4 ]
|
||||
num_heads: 8
|
||||
use_spatial_transformer: True
|
||||
transformer_depth: 1
|
||||
context_dim: 768
|
||||
use_checkpoint: False
|
||||
legacy: False
|
||||
|
||||
first_stage_config:
|
||||
target: ldm.models.autoencoder.AutoencoderKL
|
||||
params:
|
||||
embed_dim: 4
|
||||
from_pretrained: '/data/scratch/diffuser/stable-diffusion-v1-4/vae/diffusion_pytorch_model.bin'
|
||||
monitor: val/rec_loss
|
||||
ddconfig:
|
||||
double_z: true
|
||||
z_channels: 4
|
||||
resolution: 256
|
||||
in_channels: 3
|
||||
out_ch: 3
|
||||
ch: 128
|
||||
ch_mult:
|
||||
- 1
|
||||
- 2
|
||||
- 4
|
||||
- 4
|
||||
num_res_blocks: 2
|
||||
attn_resolutions: []
|
||||
dropout: 0.0
|
||||
lossconfig:
|
||||
target: torch.nn.Identity
|
||||
|
||||
cond_stage_config:
|
||||
target: ldm.modules.encoders.modules.FrozenCLIPEmbedder
|
||||
params:
|
||||
use_fp16: True
|
||||
|
||||
data:
|
||||
target: main.DataModuleFromConfig
|
||||
params:
|
||||
batch_size: 64
|
||||
wrap: False
|
||||
train:
|
||||
target: ldm.data.base.Txt2ImgIterableBaseDataset
|
||||
params:
|
||||
file_path: "/data/scratch/diffuser/laion_part0/"
|
||||
world_size: 1
|
||||
rank: 0
|
||||
|
||||
lightning:
|
||||
trainer:
|
||||
accelerator: 'gpu'
|
||||
devices: 4
|
||||
log_gpu_memory: all
|
||||
max_epochs: 2
|
||||
precision: 16
|
||||
auto_select_gpus: False
|
||||
strategy:
|
||||
target: pytorch_lightning.strategies.ColossalAIStrategy
|
||||
params:
|
||||
use_chunk: False
|
||||
enable_distributed_storage: True,
|
||||
placement_policy: cuda
|
||||
force_outputs_fp32: False
|
||||
|
||||
log_every_n_steps: 2
|
||||
logger: True
|
||||
default_root_dir: "/tmp/diff_log/"
|
||||
profiler: pytorch
|
||||
|
||||
logger_config:
|
||||
wandb:
|
||||
target: pytorch_lightning.loggers.WandbLogger
|
||||
params:
|
||||
name: nowname
|
||||
save_dir: "/tmp/diff_log/"
|
||||
offline: opt.debug
|
||||
id: nowname
|
|
@ -0,0 +1,113 @@
|
|||
model:
|
||||
base_learning_rate: 1.0e-04
|
||||
target: ldm.models.diffusion.ddpm.LatentDiffusion
|
||||
params:
|
||||
linear_start: 0.00085
|
||||
linear_end: 0.0120
|
||||
num_timesteps_cond: 1
|
||||
log_every_t: 200
|
||||
timesteps: 1000
|
||||
first_stage_key: image
|
||||
cond_stage_key: caption
|
||||
image_size: 32
|
||||
channels: 4
|
||||
cond_stage_trainable: false # Note: different from the one we trained before
|
||||
conditioning_key: crossattn
|
||||
monitor: val/loss_simple_ema
|
||||
scale_factor: 0.18215
|
||||
use_ema: False
|
||||
|
||||
scheduler_config: # 10000 warmup steps
|
||||
target: ldm.lr_scheduler.LambdaLinearScheduler
|
||||
params:
|
||||
warm_up_steps: [ 100 ]
|
||||
cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases
|
||||
f_start: [ 1.e-6 ]
|
||||
f_max: [ 1.e-4 ]
|
||||
f_min: [ 1.e-10 ]
|
||||
|
||||
unet_config:
|
||||
target: ldm.modules.diffusionmodules.openaimodel.UNetModel
|
||||
params:
|
||||
image_size: 32 # unused
|
||||
from_pretrained: '/data/scratch/diffuser/stable-diffusion-v1-4/unet/diffusion_pytorch_model.bin'
|
||||
in_channels: 4
|
||||
out_channels: 4
|
||||
model_channels: 320
|
||||
attention_resolutions: [ 4, 2, 1 ]
|
||||
num_res_blocks: 2
|
||||
channel_mult: [ 1, 2, 4, 4 ]
|
||||
num_heads: 8
|
||||
use_spatial_transformer: True
|
||||
transformer_depth: 1
|
||||
context_dim: 768
|
||||
use_checkpoint: False
|
||||
legacy: False
|
||||
|
||||
first_stage_config:
|
||||
target: ldm.models.autoencoder.AutoencoderKL
|
||||
params:
|
||||
embed_dim: 4
|
||||
from_pretrained: '/data/scratch/diffuser/stable-diffusion-v1-4/vae/diffusion_pytorch_model.bin'
|
||||
monitor: val/rec_loss
|
||||
ddconfig:
|
||||
double_z: true
|
||||
z_channels: 4
|
||||
resolution: 256
|
||||
in_channels: 3
|
||||
out_ch: 3
|
||||
ch: 128
|
||||
ch_mult:
|
||||
- 1
|
||||
- 2
|
||||
- 4
|
||||
- 4
|
||||
num_res_blocks: 2
|
||||
attn_resolutions: []
|
||||
dropout: 0.0
|
||||
lossconfig:
|
||||
target: torch.nn.Identity
|
||||
|
||||
cond_stage_config:
|
||||
target: ldm.modules.encoders.modules.FrozenCLIPEmbedder
|
||||
params:
|
||||
use_fp16: True
|
||||
|
||||
data:
|
||||
target: main.DataModuleFromConfig
|
||||
params:
|
||||
batch_size: 64
|
||||
wrap: False
|
||||
train:
|
||||
target: ldm.data.base.Txt2ImgIterableBaseDataset
|
||||
params:
|
||||
file_path: "/data/scratch/diffuser/laion_part0/"
|
||||
world_size: 1
|
||||
rank: 0
|
||||
|
||||
lightning:
|
||||
trainer:
|
||||
accelerator: 'gpu'
|
||||
devices: 4
|
||||
log_gpu_memory: all
|
||||
max_epochs: 2
|
||||
precision: 16
|
||||
auto_select_gpus: False
|
||||
strategy:
|
||||
target: pytorch_lightning.strategies.DDPStrategy
|
||||
params:
|
||||
find_unused_parameters: False
|
||||
log_every_n_steps: 2
|
||||
# max_steps: 6o
|
||||
logger: True
|
||||
default_root_dir: "/tmp/diff_log/"
|
||||
# profiler: pytorch
|
||||
|
||||
logger_config:
|
||||
wandb:
|
||||
target: pytorch_lightning.loggers.WandbLogger
|
||||
params:
|
||||
name: nowname
|
||||
save_dir: "/tmp/diff_log/"
|
||||
offline: opt.debug
|
||||
id: nowname
|
|
@ -0,0 +1,117 @@
|
|||
model:
|
||||
base_learning_rate: 1.0e-04
|
||||
target: ldm.models.diffusion.ddpm.LatentDiffusion
|
||||
params:
|
||||
linear_start: 0.00085
|
||||
linear_end: 0.0120
|
||||
num_timesteps_cond: 1
|
||||
log_every_t: 200
|
||||
timesteps: 1000
|
||||
first_stage_key: image
|
||||
cond_stage_key: caption
|
||||
image_size: 32
|
||||
channels: 4
|
||||
cond_stage_trainable: false # Note: different from the one we trained before
|
||||
conditioning_key: crossattn
|
||||
monitor: val/loss_simple_ema
|
||||
scale_factor: 0.18215
|
||||
use_ema: False
|
||||
|
||||
scheduler_config: # 10000 warmup steps
|
||||
target: ldm.lr_scheduler.LambdaLinearScheduler
|
||||
params:
|
||||
warm_up_steps: [ 10000 ]
|
||||
cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases
|
||||
f_start: [ 1.e-6 ]
|
||||
f_max: [ 1.e-4 ]
|
||||
f_min: [ 1.e-10 ]
|
||||
|
||||
unet_config:
|
||||
target: ldm.modules.diffusionmodules.openaimodel.UNetModel
|
||||
params:
|
||||
image_size: 32 # unused
|
||||
in_channels: 4
|
||||
out_channels: 4
|
||||
model_channels: 320
|
||||
attention_resolutions: [ 4, 2, 1 ]
|
||||
num_res_blocks: 2
|
||||
channel_mult: [ 1, 2, 4, 4 ]
|
||||
num_heads: 8
|
||||
use_spatial_transformer: True
|
||||
transformer_depth: 1
|
||||
context_dim: 768
|
||||
use_checkpoint: False
|
||||
legacy: False
|
||||
use_fp16: True
|
||||
|
||||
first_stage_config:
|
||||
target: ldm.models.autoencoder.AutoencoderKL
|
||||
params:
|
||||
embed_dim: 4
|
||||
monitor: val/rec_loss
|
||||
ddconfig:
|
||||
double_z: true
|
||||
z_channels: 4
|
||||
resolution: 256
|
||||
in_channels: 3
|
||||
out_ch: 3
|
||||
ch: 128
|
||||
ch_mult:
|
||||
- 1
|
||||
- 2
|
||||
- 4
|
||||
- 4
|
||||
num_res_blocks: 2
|
||||
attn_resolutions: []
|
||||
dropout: 0.0
|
||||
lossconfig:
|
||||
target: torch.nn.Identity
|
||||
|
||||
cond_stage_config:
|
||||
target: ldm.modules.encoders.modules.FrozenCLIPEmbedder
|
||||
params:
|
||||
use_fp16: True
|
||||
|
||||
data:
|
||||
target: main.DataModuleFromConfig
|
||||
params:
|
||||
batch_size: 4
|
||||
wrap: False
|
||||
train:
|
||||
target: ldm.data.base.Txt2ImgIterableBaseDataset
|
||||
params:
|
||||
file_path: "/data/scratch/diffuser/laion_part0/"
|
||||
world_size: 1
|
||||
rank: 0
|
||||
|
||||
lightning:
|
||||
trainer:
|
||||
accelerator: 'gpu'
|
||||
devices: 4
|
||||
log_gpu_memory: all
|
||||
max_epochs: 2
|
||||
precision: 16
|
||||
auto_select_gpus: False
|
||||
strategy:
|
||||
target: pytorch_lightning.strategies.DeepSpeedStrategy
|
||||
params:
|
||||
stage: 2
|
||||
zero_optimization: True
|
||||
offload_optimizer: False
|
||||
offload_parameters: False
|
||||
log_every_n_steps: 2
|
||||
# max_steps: 6o
|
||||
logger: True
|
||||
default_root_dir: "/tmp/diff_log/"
|
||||
profiler: pytorch
|
||||
|
||||
logger_config:
|
||||
wandb:
|
||||
target: pytorch_lightning.loggers.WandbLogger
|
||||
params:
|
||||
name: nowname
|
||||
save_dir: logdir
|
||||
offline: opt.debug
|
||||
id: nowname
|
||||
|
||||
|
|
@ -0,0 +1,121 @@
|
|||
model:
|
||||
base_learning_rate: 1.0e-04
|
||||
target: ldm.models.diffusion.ddpm.LatentDiffusion
|
||||
params:
|
||||
linear_start: 0.00085
|
||||
linear_end: 0.0120
|
||||
num_timesteps_cond: 1
|
||||
log_every_t: 200
|
||||
timesteps: 1000
|
||||
first_stage_key: image
|
||||
cond_stage_key: caption
|
||||
image_size: 32
|
||||
channels: 4
|
||||
cond_stage_trainable: false # Note: different from the one we trained before
|
||||
conditioning_key: crossattn
|
||||
monitor: val/loss_simple_ema
|
||||
scale_factor: 0.18215
|
||||
use_ema: False
|
||||
check_nan_inf: False
|
||||
|
||||
scheduler_config: # 10000 warmup steps
|
||||
target: ldm.lr_scheduler.LambdaLinearScheduler
|
||||
params:
|
||||
warm_up_steps: [ 10000 ]
|
||||
cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases
|
||||
f_start: [ 1.e-6 ]
|
||||
f_max: [ 1.e-4 ]
|
||||
f_min: [ 1.e-10 ]
|
||||
|
||||
unet_config:
|
||||
target: ldm.modules.diffusionmodules.openaimodel.UNetModel
|
||||
params:
|
||||
image_size: 32 # unused
|
||||
from_pretrained: '/data/scratch/diffuser/stable-diffusion-v1-4/unet/diffusion_pytorch_model.bin'
|
||||
in_channels: 4
|
||||
out_channels: 4
|
||||
model_channels: 320
|
||||
attention_resolutions: [ 4, 2, 1 ]
|
||||
num_res_blocks: 2
|
||||
channel_mult: [ 1, 2, 4, 4 ]
|
||||
num_heads: 8
|
||||
use_spatial_transformer: True
|
||||
transformer_depth: 1
|
||||
context_dim: 768
|
||||
use_checkpoint: False
|
||||
legacy: False
|
||||
|
||||
first_stage_config:
|
||||
target: ldm.models.autoencoder.AutoencoderKL
|
||||
params:
|
||||
embed_dim: 4
|
||||
from_pretrained: '/data/scratch/diffuser/stable-diffusion-v1-4/vae/diffusion_pytorch_model.bin'
|
||||
monitor: val/rec_loss
|
||||
ddconfig:
|
||||
double_z: true
|
||||
z_channels: 4
|
||||
resolution: 256
|
||||
in_channels: 3
|
||||
out_ch: 3
|
||||
ch: 128
|
||||
ch_mult:
|
||||
- 1
|
||||
- 2
|
||||
- 4
|
||||
- 4
|
||||
num_res_blocks: 2
|
||||
attn_resolutions: []
|
||||
dropout: 0.0
|
||||
lossconfig:
|
||||
target: torch.nn.Identity
|
||||
|
||||
cond_stage_config:
|
||||
target: ldm.modules.encoders.modules.FrozenCLIPEmbedder
|
||||
params:
|
||||
use_fp16: True
|
||||
|
||||
data:
|
||||
target: main.DataModuleFromConfig
|
||||
params:
|
||||
batch_size: 32
|
||||
wrap: False
|
||||
train:
|
||||
target: ldm.data.pokemon.PokemonDataset
|
||||
# params:
|
||||
# file_path: "/data/scratch/diffuser/laion_part0/"
|
||||
# world_size: 1
|
||||
# rank: 0
|
||||
|
||||
lightning:
|
||||
trainer:
|
||||
accelerator: 'gpu'
|
||||
devices: 4
|
||||
log_gpu_memory: all
|
||||
max_epochs: 2
|
||||
precision: 16
|
||||
auto_select_gpus: False
|
||||
strategy:
|
||||
target: pytorch_lightning.strategies.ColossalAIStrategy
|
||||
params:
|
||||
use_chunk: False
|
||||
enable_distributed_storage: True,
|
||||
placement_policy: cuda
|
||||
force_outputs_fp32: False
|
||||
initial_scale: 65536
|
||||
min_scale: 1
|
||||
max_scale: 65536
|
||||
# max_scale: 4294967296
|
||||
|
||||
log_every_n_steps: 2
|
||||
logger: True
|
||||
default_root_dir: "/tmp/diff_log/"
|
||||
profiler: pytorch
|
||||
|
||||
logger_config:
|
||||
wandb:
|
||||
target: pytorch_lightning.loggers.WandbLogger
|
||||
params:
|
||||
name: nowname
|
||||
save_dir: "/tmp/diff_log/"
|
||||
offline: opt.debug
|
||||
id: nowname
|
|
@ -0,0 +1,33 @@
|
|||
name: ldm
|
||||
channels:
|
||||
- pytorch
|
||||
- defaults
|
||||
dependencies:
|
||||
- python=3.9.12
|
||||
- pip=20.3
|
||||
- cudatoolkit=11.3
|
||||
- pytorch=1.11.0
|
||||
- torchvision=0.12.0
|
||||
- numpy=1.19.2
|
||||
- pip:
|
||||
- albumentations==0.4.3
|
||||
- diffusers
|
||||
- opencv-python==4.6.0.66
|
||||
- pudb==2019.2
|
||||
- invisible-watermark
|
||||
- imageio==2.9.0
|
||||
- imageio-ffmpeg==0.4.2
|
||||
- pytorch-lightning==1.4.2
|
||||
- omegaconf==2.1.1
|
||||
- test-tube>=0.7.5
|
||||
- streamlit>=0.73.1
|
||||
- einops==0.3.0
|
||||
- torch-fidelity==0.3.0
|
||||
- transformers==4.19.2
|
||||
- torchmetrics==0.6.0
|
||||
- kornia==0.6
|
||||
- deepspeed==0.7.4
|
||||
- prefetch_generator
|
||||
- -e git+https://github.com/CompVis/taming-transformers.git@master#egg=taming-transformers
|
||||
- -e git+https://github.com/openai/CLIP.git@main#egg=clip
|
||||
- -e .
|
|
@ -0,0 +1,98 @@
|
|||
import numpy as np
|
||||
|
||||
|
||||
class LambdaWarmUpCosineScheduler:
|
||||
"""
|
||||
note: use with a base_lr of 1.0
|
||||
"""
|
||||
def __init__(self, warm_up_steps, lr_min, lr_max, lr_start, max_decay_steps, verbosity_interval=0):
|
||||
self.lr_warm_up_steps = warm_up_steps
|
||||
self.lr_start = lr_start
|
||||
self.lr_min = lr_min
|
||||
self.lr_max = lr_max
|
||||
self.lr_max_decay_steps = max_decay_steps
|
||||
self.last_lr = 0.
|
||||
self.verbosity_interval = verbosity_interval
|
||||
|
||||
def schedule(self, n, **kwargs):
|
||||
if self.verbosity_interval > 0:
|
||||
if n % self.verbosity_interval == 0: print(f"current step: {n}, recent lr-multiplier: {self.last_lr}")
|
||||
if n < self.lr_warm_up_steps:
|
||||
lr = (self.lr_max - self.lr_start) / self.lr_warm_up_steps * n + self.lr_start
|
||||
self.last_lr = lr
|
||||
return lr
|
||||
else:
|
||||
t = (n - self.lr_warm_up_steps) / (self.lr_max_decay_steps - self.lr_warm_up_steps)
|
||||
t = min(t, 1.0)
|
||||
lr = self.lr_min + 0.5 * (self.lr_max - self.lr_min) * (
|
||||
1 + np.cos(t * np.pi))
|
||||
self.last_lr = lr
|
||||
return lr
|
||||
|
||||
def __call__(self, n, **kwargs):
|
||||
return self.schedule(n,**kwargs)
|
||||
|
||||
|
||||
class LambdaWarmUpCosineScheduler2:
|
||||
"""
|
||||
supports repeated iterations, configurable via lists
|
||||
note: use with a base_lr of 1.0.
|
||||
"""
|
||||
def __init__(self, warm_up_steps, f_min, f_max, f_start, cycle_lengths, verbosity_interval=0):
|
||||
assert len(warm_up_steps) == len(f_min) == len(f_max) == len(f_start) == len(cycle_lengths)
|
||||
self.lr_warm_up_steps = warm_up_steps
|
||||
self.f_start = f_start
|
||||
self.f_min = f_min
|
||||
self.f_max = f_max
|
||||
self.cycle_lengths = cycle_lengths
|
||||
self.cum_cycles = np.cumsum([0] + list(self.cycle_lengths))
|
||||
self.last_f = 0.
|
||||
self.verbosity_interval = verbosity_interval
|
||||
|
||||
def find_in_interval(self, n):
|
||||
interval = 0
|
||||
for cl in self.cum_cycles[1:]:
|
||||
if n <= cl:
|
||||
return interval
|
||||
interval += 1
|
||||
|
||||
def schedule(self, n, **kwargs):
|
||||
cycle = self.find_in_interval(n)
|
||||
n = n - self.cum_cycles[cycle]
|
||||
if self.verbosity_interval > 0:
|
||||
if n % self.verbosity_interval == 0: print(f"current step: {n}, recent lr-multiplier: {self.last_f}, "
|
||||
f"current cycle {cycle}")
|
||||
if n < self.lr_warm_up_steps[cycle]:
|
||||
f = (self.f_max[cycle] - self.f_start[cycle]) / self.lr_warm_up_steps[cycle] * n + self.f_start[cycle]
|
||||
self.last_f = f
|
||||
return f
|
||||
else:
|
||||
t = (n - self.lr_warm_up_steps[cycle]) / (self.cycle_lengths[cycle] - self.lr_warm_up_steps[cycle])
|
||||
t = min(t, 1.0)
|
||||
f = self.f_min[cycle] + 0.5 * (self.f_max[cycle] - self.f_min[cycle]) * (
|
||||
1 + np.cos(t * np.pi))
|
||||
self.last_f = f
|
||||
return f
|
||||
|
||||
def __call__(self, n, **kwargs):
|
||||
return self.schedule(n, **kwargs)
|
||||
|
||||
|
||||
class LambdaLinearScheduler(LambdaWarmUpCosineScheduler2):
|
||||
|
||||
def schedule(self, n, **kwargs):
|
||||
cycle = self.find_in_interval(n)
|
||||
n = n - self.cum_cycles[cycle]
|
||||
if self.verbosity_interval > 0:
|
||||
if n % self.verbosity_interval == 0: print(f"current step: {n}, recent lr-multiplier: {self.last_f}, "
|
||||
f"current cycle {cycle}")
|
||||
|
||||
if n < self.lr_warm_up_steps[cycle]:
|
||||
f = (self.f_max[cycle] - self.f_start[cycle]) / self.lr_warm_up_steps[cycle] * n + self.f_start[cycle]
|
||||
self.last_f = f
|
||||
return f
|
||||
else:
|
||||
f = self.f_min[cycle] + (self.f_max[cycle] - self.f_min[cycle]) * (self.cycle_lengths[cycle] - n) / (self.cycle_lengths[cycle])
|
||||
self.last_f = f
|
||||
return f
|
||||
|
|
@ -0,0 +1,544 @@
|
|||
import torch
|
||||
import pytorch_lightning as pl
|
||||
import torch.nn.functional as F
|
||||
from contextlib import contextmanager
|
||||
|
||||
from taming.modules.vqvae.quantize import VectorQuantizer2 as VectorQuantizer
|
||||
|
||||
from ldm.modules.diffusionmodules.model import Encoder, Decoder
|
||||
from ldm.modules.distributions.distributions import DiagonalGaussianDistribution
|
||||
|
||||
from ldm.util import instantiate_from_config
|
||||
|
||||
|
||||
class VQModel(pl.LightningModule):
|
||||
def __init__(self,
|
||||
ddconfig,
|
||||
lossconfig,
|
||||
n_embed,
|
||||
embed_dim,
|
||||
ckpt_path=None,
|
||||
ignore_keys=[],
|
||||
image_key="image",
|
||||
colorize_nlabels=None,
|
||||
monitor=None,
|
||||
batch_resize_range=None,
|
||||
scheduler_config=None,
|
||||
lr_g_factor=1.0,
|
||||
remap=None,
|
||||
sane_index_shape=False, # tell vector quantizer to return indices as bhw
|
||||
use_ema=False
|
||||
):
|
||||
super().__init__()
|
||||
self.embed_dim = embed_dim
|
||||
self.n_embed = n_embed
|
||||
self.image_key = image_key
|
||||
self.encoder = Encoder(**ddconfig)
|
||||
self.decoder = Decoder(**ddconfig)
|
||||
self.loss = instantiate_from_config(lossconfig)
|
||||
self.quantize = VectorQuantizer(n_embed, embed_dim, beta=0.25,
|
||||
remap=remap,
|
||||
sane_index_shape=sane_index_shape)
|
||||
self.quant_conv = torch.nn.Conv2d(ddconfig["z_channels"], embed_dim, 1)
|
||||
self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1)
|
||||
if colorize_nlabels is not None:
|
||||
assert type(colorize_nlabels)==int
|
||||
self.register_buffer("colorize", torch.randn(3, colorize_nlabels, 1, 1))
|
||||
if monitor is not None:
|
||||
self.monitor = monitor
|
||||
self.batch_resize_range = batch_resize_range
|
||||
if self.batch_resize_range is not None:
|
||||
print(f"{self.__class__.__name__}: Using per-batch resizing in range {batch_resize_range}.")
|
||||
|
||||
self.use_ema = use_ema
|
||||
if self.use_ema:
|
||||
self.model_ema = LitEma(self)
|
||||
print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.")
|
||||
|
||||
if ckpt_path is not None:
|
||||
self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
|
||||
self.scheduler_config = scheduler_config
|
||||
self.lr_g_factor = lr_g_factor
|
||||
|
||||
@contextmanager
|
||||
def ema_scope(self, context=None):
|
||||
if self.use_ema:
|
||||
self.model_ema.store(self.parameters())
|
||||
self.model_ema.copy_to(self)
|
||||
if context is not None:
|
||||
print(f"{context}: Switched to EMA weights")
|
||||
try:
|
||||
yield None
|
||||
finally:
|
||||
if self.use_ema:
|
||||
self.model_ema.restore(self.parameters())
|
||||
if context is not None:
|
||||
print(f"{context}: Restored training weights")
|
||||
|
||||
def init_from_ckpt(self, path, ignore_keys=list()):
|
||||
sd = torch.load(path, map_location="cpu")["state_dict"]
|
||||
keys = list(sd.keys())
|
||||
for k in keys:
|
||||
for ik in ignore_keys:
|
||||
if k.startswith(ik):
|
||||
print("Deleting key {} from state_dict.".format(k))
|
||||
del sd[k]
|
||||
missing, unexpected = self.load_state_dict(sd, strict=False)
|
||||
print(f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys")
|
||||
if len(missing) > 0:
|
||||
print(f"Missing Keys: {missing}")
|
||||
print(f"Unexpected Keys: {unexpected}")
|
||||
|
||||
def on_train_batch_end(self, *args, **kwargs):
|
||||
if self.use_ema:
|
||||
self.model_ema(self)
|
||||
|
||||
def encode(self, x):
|
||||
h = self.encoder(x)
|
||||
h = self.quant_conv(h)
|
||||
quant, emb_loss, info = self.quantize(h)
|
||||
return quant, emb_loss, info
|
||||
|
||||
def encode_to_prequant(self, x):
|
||||
h = self.encoder(x)
|
||||
h = self.quant_conv(h)
|
||||
return h
|
||||
|
||||
def decode(self, quant):
|
||||
quant = self.post_quant_conv(quant)
|
||||
dec = self.decoder(quant)
|
||||
return dec
|
||||
|
||||
def decode_code(self, code_b):
|
||||
quant_b = self.quantize.embed_code(code_b)
|
||||
dec = self.decode(quant_b)
|
||||
return dec
|
||||
|
||||
def forward(self, input, return_pred_indices=False):
|
||||
quant, diff, (_,_,ind) = self.encode(input)
|
||||
dec = self.decode(quant)
|
||||
if return_pred_indices:
|
||||
return dec, diff, ind
|
||||
return dec, diff
|
||||
|
||||
def get_input(self, batch, k):
|
||||
x = batch[k]
|
||||
if len(x.shape) == 3:
|
||||
x = x[..., None]
|
||||
x = x.permute(0, 3, 1, 2).to(memory_format=torch.contiguous_format).float()
|
||||
if self.batch_resize_range is not None:
|
||||
lower_size = self.batch_resize_range[0]
|
||||
upper_size = self.batch_resize_range[1]
|
||||
if self.global_step <= 4:
|
||||
# do the first few batches with max size to avoid later oom
|
||||
new_resize = upper_size
|
||||
else:
|
||||
new_resize = np.random.choice(np.arange(lower_size, upper_size+16, 16))
|
||||
if new_resize != x.shape[2]:
|
||||
x = F.interpolate(x, size=new_resize, mode="bicubic")
|
||||
x = x.detach()
|
||||
return x
|
||||
|
||||
def training_step(self, batch, batch_idx, optimizer_idx):
|
||||
# https://github.com/pytorch/pytorch/issues/37142
|
||||
# try not to fool the heuristics
|
||||
x = self.get_input(batch, self.image_key)
|
||||
xrec, qloss, ind = self(x, return_pred_indices=True)
|
||||
|
||||
if optimizer_idx == 0:
|
||||
# autoencode
|
||||
aeloss, log_dict_ae = self.loss(qloss, x, xrec, optimizer_idx, self.global_step,
|
||||
last_layer=self.get_last_layer(), split="train",
|
||||
predicted_indices=ind)
|
||||
|
||||
self.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=True)
|
||||
return aeloss
|
||||
|
||||
if optimizer_idx == 1:
|
||||
# discriminator
|
||||
discloss, log_dict_disc = self.loss(qloss, x, xrec, optimizer_idx, self.global_step,
|
||||
last_layer=self.get_last_layer(), split="train")
|
||||
self.log_dict(log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=True)
|
||||
return discloss
|
||||
|
||||
def validation_step(self, batch, batch_idx):
|
||||
log_dict = self._validation_step(batch, batch_idx)
|
||||
with self.ema_scope():
|
||||
log_dict_ema = self._validation_step(batch, batch_idx, suffix="_ema")
|
||||
return log_dict
|
||||
|
||||
def _validation_step(self, batch, batch_idx, suffix=""):
|
||||
x = self.get_input(batch, self.image_key)
|
||||
xrec, qloss, ind = self(x, return_pred_indices=True)
|
||||
aeloss, log_dict_ae = self.loss(qloss, x, xrec, 0,
|
||||
self.global_step,
|
||||
last_layer=self.get_last_layer(),
|
||||
split="val"+suffix,
|
||||
predicted_indices=ind
|
||||
)
|
||||
|
||||
discloss, log_dict_disc = self.loss(qloss, x, xrec, 1,
|
||||
self.global_step,
|
||||
last_layer=self.get_last_layer(),
|
||||
split="val"+suffix,
|
||||
predicted_indices=ind
|
||||
)
|
||||
rec_loss = log_dict_ae[f"val{suffix}/rec_loss"]
|
||||
self.log(f"val{suffix}/rec_loss", rec_loss,
|
||||
prog_bar=True, logger=True, on_step=False, on_epoch=True, sync_dist=True)
|
||||
self.log(f"val{suffix}/aeloss", aeloss,
|
||||
prog_bar=True, logger=True, on_step=False, on_epoch=True, sync_dist=True)
|
||||
if version.parse(pl.__version__) >= version.parse('1.4.0'):
|
||||
del log_dict_ae[f"val{suffix}/rec_loss"]
|
||||
self.log_dict(log_dict_ae)
|
||||
self.log_dict(log_dict_disc)
|
||||
return self.log_dict
|
||||
|
||||
def configure_optimizers(self):
|
||||
lr_d = self.learning_rate
|
||||
lr_g = self.lr_g_factor*self.learning_rate
|
||||
print("lr_d", lr_d)
|
||||
print("lr_g", lr_g)
|
||||
opt_ae = torch.optim.Adam(list(self.encoder.parameters())+
|
||||
list(self.decoder.parameters())+
|
||||
list(self.quantize.parameters())+
|
||||
list(self.quant_conv.parameters())+
|
||||
list(self.post_quant_conv.parameters()),
|
||||
lr=lr_g, betas=(0.5, 0.9))
|
||||
opt_disc = torch.optim.Adam(self.loss.discriminator.parameters(),
|
||||
lr=lr_d, betas=(0.5, 0.9))
|
||||
|
||||
if self.scheduler_config is not None:
|
||||
scheduler = instantiate_from_config(self.scheduler_config)
|
||||
|
||||
print("Setting up LambdaLR scheduler...")
|
||||
scheduler = [
|
||||
{
|
||||
'scheduler': LambdaLR(opt_ae, lr_lambda=scheduler.schedule),
|
||||
'interval': 'step',
|
||||
'frequency': 1
|
||||
},
|
||||
{
|
||||
'scheduler': LambdaLR(opt_disc, lr_lambda=scheduler.schedule),
|
||||
'interval': 'step',
|
||||
'frequency': 1
|
||||
},
|
||||
]
|
||||
return [opt_ae, opt_disc], scheduler
|
||||
return [opt_ae, opt_disc], []
|
||||
|
||||
def get_last_layer(self):
|
||||
return self.decoder.conv_out.weight
|
||||
|
||||
def log_images(self, batch, only_inputs=False, plot_ema=False, **kwargs):
|
||||
log = dict()
|
||||
x = self.get_input(batch, self.image_key)
|
||||
x = x.to(self.device)
|
||||
if only_inputs:
|
||||
log["inputs"] = x
|
||||
return log
|
||||
xrec, _ = self(x)
|
||||
if x.shape[1] > 3:
|
||||
# colorize with random projection
|
||||
assert xrec.shape[1] > 3
|
||||
x = self.to_rgb(x)
|
||||
xrec = self.to_rgb(xrec)
|
||||
log["inputs"] = x
|
||||
log["reconstructions"] = xrec
|
||||
if plot_ema:
|
||||
with self.ema_scope():
|
||||
xrec_ema, _ = self(x)
|
||||
if x.shape[1] > 3: xrec_ema = self.to_rgb(xrec_ema)
|
||||
log["reconstructions_ema"] = xrec_ema
|
||||
return log
|
||||
|
||||
def to_rgb(self, x):
|
||||
assert self.image_key == "segmentation"
|
||||
if not hasattr(self, "colorize"):
|
||||
self.register_buffer("colorize", torch.randn(3, x.shape[1], 1, 1).to(x))
|
||||
x = F.conv2d(x, weight=self.colorize)
|
||||
x = 2.*(x-x.min())/(x.max()-x.min()) - 1.
|
||||
return x
|
||||
|
||||
|
||||
class VQModelInterface(VQModel):
|
||||
def __init__(self, embed_dim, *args, **kwargs):
|
||||
super().__init__(embed_dim=embed_dim, *args, **kwargs)
|
||||
self.embed_dim = embed_dim
|
||||
|
||||
def encode(self, x):
|
||||
h = self.encoder(x)
|
||||
h = self.quant_conv(h)
|
||||
return h
|
||||
|
||||
def decode(self, h, force_not_quantize=False):
|
||||
# also go through quantization layer
|
||||
if not force_not_quantize:
|
||||
quant, emb_loss, info = self.quantize(h)
|
||||
else:
|
||||
quant = h
|
||||
quant = self.post_quant_conv(quant)
|
||||
dec = self.decoder(quant)
|
||||
return dec
|
||||
|
||||
|
||||
class AutoencoderKL(pl.LightningModule):
|
||||
def __init__(self,
|
||||
ddconfig,
|
||||
lossconfig,
|
||||
embed_dim,
|
||||
ckpt_path=None,
|
||||
ignore_keys=[],
|
||||
image_key="image",
|
||||
colorize_nlabels=None,
|
||||
monitor=None,
|
||||
from_pretrained: str=None
|
||||
):
|
||||
super().__init__()
|
||||
self.image_key = image_key
|
||||
self.encoder = Encoder(**ddconfig)
|
||||
self.decoder = Decoder(**ddconfig)
|
||||
self.loss = instantiate_from_config(lossconfig)
|
||||
assert ddconfig["double_z"]
|
||||
self.quant_conv = torch.nn.Conv2d(2*ddconfig["z_channels"], 2*embed_dim, 1)
|
||||
self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1)
|
||||
self.embed_dim = embed_dim
|
||||
if colorize_nlabels is not None:
|
||||
assert type(colorize_nlabels)==int
|
||||
self.register_buffer("colorize", torch.randn(3, colorize_nlabels, 1, 1))
|
||||
if monitor is not None:
|
||||
self.monitor = monitor
|
||||
if ckpt_path is not None:
|
||||
self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
|
||||
from diffusers.modeling_utils import load_state_dict
|
||||
if from_pretrained is not None:
|
||||
state_dict = load_state_dict(from_pretrained)
|
||||
self._load_pretrained_model(state_dict)
|
||||
|
||||
def _state_key_mapping(self, state_dict: dict):
|
||||
import re
|
||||
res_dict = {}
|
||||
key_list = state_dict.keys()
|
||||
key_str = " ".join(key_list)
|
||||
up_block_pattern = re.compile('upsamplers')
|
||||
p1 = re.compile('mid.block_[0-9]')
|
||||
p2 = re.compile('decoder.up.[0-9]')
|
||||
up_blocks_count = int(len(re.findall(up_block_pattern, key_str)) / 2 + 1)
|
||||
for key_, val_ in state_dict.items():
|
||||
key_ = key_.replace("up_blocks", "up").replace("down_blocks", "down").replace('resnets', 'block')\
|
||||
.replace('mid_block', 'mid').replace("mid.block.", "mid.block_")\
|
||||
.replace('mid.attentions.0.key', 'mid.attn_1.k')\
|
||||
.replace('mid.attentions.0.query', 'mid.attn_1.q') \
|
||||
.replace('mid.attentions.0.value', 'mid.attn_1.v') \
|
||||
.replace('mid.attentions.0.group_norm', 'mid.attn_1.norm') \
|
||||
.replace('mid.attentions.0.proj_attn', 'mid.attn_1.proj_out')\
|
||||
.replace('upsamplers.0', 'upsample')\
|
||||
.replace('downsamplers.0', 'downsample')\
|
||||
.replace('conv_shortcut', 'nin_shortcut')\
|
||||
.replace('conv_norm_out', 'norm_out')
|
||||
|
||||
mid_list = re.findall(p1, key_)
|
||||
if len(mid_list) != 0:
|
||||
mid_str = mid_list[0]
|
||||
mid_id = int(mid_str[-1]) + 1
|
||||
key_ = key_.replace(mid_str, mid_str[:-1] + str(mid_id))
|
||||
|
||||
up_list = re.findall(p2, key_)
|
||||
if len(up_list) != 0:
|
||||
up_str = up_list[0]
|
||||
up_id = up_blocks_count - 1 -int(up_str[-1])
|
||||
key_ = key_.replace(up_str, up_str[:-1] + str(up_id))
|
||||
res_dict[key_] = val_
|
||||
return res_dict
|
||||
|
||||
def _load_pretrained_model(self, state_dict, ignore_mismatched_sizes=False):
|
||||
state_dict = self._state_key_mapping(state_dict)
|
||||
model_state_dict = self.state_dict()
|
||||
loaded_keys = [k for k in state_dict.keys()]
|
||||
expected_keys = list(model_state_dict.keys())
|
||||
original_loaded_keys = loaded_keys
|
||||
missing_keys = list(set(expected_keys) - set(loaded_keys))
|
||||
unexpected_keys = list(set(loaded_keys) - set(expected_keys))
|
||||
|
||||
def _find_mismatched_keys(
|
||||
state_dict,
|
||||
model_state_dict,
|
||||
loaded_keys,
|
||||
ignore_mismatched_sizes,
|
||||
):
|
||||
mismatched_keys = []
|
||||
if ignore_mismatched_sizes:
|
||||
for checkpoint_key in loaded_keys:
|
||||
model_key = checkpoint_key
|
||||
|
||||
if (
|
||||
model_key in model_state_dict
|
||||
and state_dict[checkpoint_key].shape != model_state_dict[model_key].shape
|
||||
):
|
||||
mismatched_keys.append(
|
||||
(checkpoint_key, state_dict[checkpoint_key].shape, model_state_dict[model_key].shape)
|
||||
)
|
||||
del state_dict[checkpoint_key]
|
||||
return mismatched_keys
|
||||
if state_dict is not None:
|
||||
# Whole checkpoint
|
||||
mismatched_keys = _find_mismatched_keys(
|
||||
state_dict,
|
||||
model_state_dict,
|
||||
original_loaded_keys,
|
||||
ignore_mismatched_sizes,
|
||||
)
|
||||
error_msgs = self._load_state_dict_into_model(state_dict)
|
||||
return missing_keys, unexpected_keys, mismatched_keys, error_msgs
|
||||
|
||||
def _load_state_dict_into_model(self, state_dict):
|
||||
# Convert old format to new format if needed from a PyTorch state_dict
|
||||
# copy state_dict so _load_from_state_dict can modify it
|
||||
state_dict = state_dict.copy()
|
||||
error_msgs = []
|
||||
|
||||
# PyTorch's `_load_from_state_dict` does not copy parameters in a module's descendants
|
||||
# so we need to apply the function recursively.
|
||||
def load(module: torch.nn.Module, prefix=""):
|
||||
args = (state_dict, prefix, {}, True, [], [], error_msgs)
|
||||
module._load_from_state_dict(*args)
|
||||
|
||||
for name, child in module._modules.items():
|
||||
if child is not None:
|
||||
load(child, prefix + name + ".")
|
||||
|
||||
load(self)
|
||||
|
||||
return error_msgs
|
||||
|
||||
def init_from_ckpt(self, path, ignore_keys=list()):
|
||||
sd = torch.load(path, map_location="cpu")["state_dict"]
|
||||
keys = list(sd.keys())
|
||||
for k in keys:
|
||||
for ik in ignore_keys:
|
||||
if k.startswith(ik):
|
||||
print("Deleting key {} from state_dict.".format(k))
|
||||
del sd[k]
|
||||
self.load_state_dict(sd, strict=False)
|
||||
print(f"Restored from {path}")
|
||||
|
||||
def encode(self, x):
|
||||
h = self.encoder(x)
|
||||
moments = self.quant_conv(h)
|
||||
posterior = DiagonalGaussianDistribution(moments)
|
||||
return posterior
|
||||
|
||||
def decode(self, z):
|
||||
z = self.post_quant_conv(z)
|
||||
dec = self.decoder(z)
|
||||
return dec
|
||||
|
||||
def forward(self, input, sample_posterior=True):
|
||||
posterior = self.encode(input)
|
||||
if sample_posterior:
|
||||
z = posterior.sample()
|
||||
else:
|
||||
z = posterior.mode()
|
||||
dec = self.decode(z)
|
||||
return dec, posterior
|
||||
|
||||
def get_input(self, batch, k):
|
||||
x = batch[k]
|
||||
if len(x.shape) == 3:
|
||||
x = x[..., None]
|
||||
x = x.permute(0, 3, 1, 2).to(memory_format=torch.contiguous_format).float()
|
||||
return x
|
||||
|
||||
def training_step(self, batch, batch_idx, optimizer_idx):
|
||||
inputs = self.get_input(batch, self.image_key)
|
||||
reconstructions, posterior = self(inputs)
|
||||
|
||||
if optimizer_idx == 0:
|
||||
# train encoder+decoder+logvar
|
||||
aeloss, log_dict_ae = self.loss(inputs, reconstructions, posterior, optimizer_idx, self.global_step,
|
||||
last_layer=self.get_last_layer(), split="train")
|
||||
self.log("aeloss", aeloss, prog_bar=True, logger=True, on_step=True, on_epoch=True)
|
||||
self.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=False)
|
||||
return aeloss
|
||||
|
||||
if optimizer_idx == 1:
|
||||
# train the discriminator
|
||||
discloss, log_dict_disc = self.loss(inputs, reconstructions, posterior, optimizer_idx, self.global_step,
|
||||
last_layer=self.get_last_layer(), split="train")
|
||||
|
||||
self.log("discloss", discloss, prog_bar=True, logger=True, on_step=True, on_epoch=True)
|
||||
self.log_dict(log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=False)
|
||||
return discloss
|
||||
|
||||
def validation_step(self, batch, batch_idx):
|
||||
inputs = self.get_input(batch, self.image_key)
|
||||
reconstructions, posterior = self(inputs)
|
||||
aeloss, log_dict_ae = self.loss(inputs, reconstructions, posterior, 0, self.global_step,
|
||||
last_layer=self.get_last_layer(), split="val")
|
||||
|
||||
discloss, log_dict_disc = self.loss(inputs, reconstructions, posterior, 1, self.global_step,
|
||||
last_layer=self.get_last_layer(), split="val")
|
||||
|
||||
self.log("val/rec_loss", log_dict_ae["val/rec_loss"])
|
||||
self.log_dict(log_dict_ae)
|
||||
self.log_dict(log_dict_disc)
|
||||
return self.log_dict
|
||||
|
||||
def configure_optimizers(self):
|
||||
lr = self.learning_rate
|
||||
opt_ae = torch.optim.Adam(list(self.encoder.parameters())+
|
||||
list(self.decoder.parameters())+
|
||||
list(self.quant_conv.parameters())+
|
||||
list(self.post_quant_conv.parameters()),
|
||||
lr=lr, betas=(0.5, 0.9))
|
||||
opt_disc = torch.optim.Adam(self.loss.discriminator.parameters(),
|
||||
lr=lr, betas=(0.5, 0.9))
|
||||
return [opt_ae, opt_disc], []
|
||||
|
||||
def get_last_layer(self):
|
||||
return self.decoder.conv_out.weight
|
||||
|
||||
@torch.no_grad()
|
||||
def log_images(self, batch, only_inputs=False, **kwargs):
|
||||
log = dict()
|
||||
x = self.get_input(batch, self.image_key)
|
||||
x = x.to(self.device)
|
||||
if not only_inputs:
|
||||
xrec, posterior = self(x)
|
||||
if x.shape[1] > 3:
|
||||
# colorize with random projection
|
||||
assert xrec.shape[1] > 3
|
||||
x = self.to_rgb(x)
|
||||
xrec = self.to_rgb(xrec)
|
||||
log["samples"] = self.decode(torch.randn_like(posterior.sample()))
|
||||
log["reconstructions"] = xrec
|
||||
log["inputs"] = x
|
||||
return log
|
||||
|
||||
def to_rgb(self, x):
|
||||
assert self.image_key == "segmentation"
|
||||
if not hasattr(self, "colorize"):
|
||||
self.register_buffer("colorize", torch.randn(3, x.shape[1], 1, 1).to(x))
|
||||
x = F.conv2d(x, weight=self.colorize)
|
||||
x = 2.*(x-x.min())/(x.max()-x.min()) - 1.
|
||||
return x
|
||||
|
||||
|
||||
class IdentityFirstStage(torch.nn.Module):
|
||||
def __init__(self, *args, vq_interface=False, **kwargs):
|
||||
self.vq_interface = vq_interface # TODO: Should be true by default but check to not break older stuff
|
||||
super().__init__()
|
||||
|
||||
def encode(self, x, *args, **kwargs):
|
||||
return x
|
||||
|
||||
def decode(self, x, *args, **kwargs):
|
||||
return x
|
||||
|
||||
def quantize(self, x, *args, **kwargs):
|
||||
if self.vq_interface:
|
||||
return x, None, [None, None, None]
|
||||
return x
|
||||
|
||||
def forward(self, x, *args, **kwargs):
|
||||
return x
|
|
@ -0,0 +1,267 @@
|
|||
import os
|
||||
import torch
|
||||
import pytorch_lightning as pl
|
||||
from omegaconf import OmegaConf
|
||||
from torch.nn import functional as F
|
||||
from torch.optim import AdamW
|
||||
from torch.optim.lr_scheduler import LambdaLR
|
||||
from copy import deepcopy
|
||||
from einops import rearrange
|
||||
from glob import glob
|
||||
from natsort import natsorted
|
||||
|
||||
from ldm.modules.diffusionmodules.openaimodel import EncoderUNetModel, UNetModel
|
||||
from ldm.util import log_txt_as_img, default, ismap, instantiate_from_config
|
||||
|
||||
__models__ = {
|
||||
'class_label': EncoderUNetModel,
|
||||
'segmentation': UNetModel
|
||||
}
|
||||
|
||||
|
||||
def disabled_train(self, mode=True):
|
||||
"""Overwrite model.train with this function to make sure train/eval mode
|
||||
does not change anymore."""
|
||||
return self
|
||||
|
||||
|
||||
class NoisyLatentImageClassifier(pl.LightningModule):
|
||||
|
||||
def __init__(self,
|
||||
diffusion_path,
|
||||
num_classes,
|
||||
ckpt_path=None,
|
||||
pool='attention',
|
||||
label_key=None,
|
||||
diffusion_ckpt_path=None,
|
||||
scheduler_config=None,
|
||||
weight_decay=1.e-2,
|
||||
log_steps=10,
|
||||
monitor='val/loss',
|
||||
*args,
|
||||
**kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.num_classes = num_classes
|
||||
# get latest config of diffusion model
|
||||
diffusion_config = natsorted(glob(os.path.join(diffusion_path, 'configs', '*-project.yaml')))[-1]
|
||||
self.diffusion_config = OmegaConf.load(diffusion_config).model
|
||||
self.diffusion_config.params.ckpt_path = diffusion_ckpt_path
|
||||
self.load_diffusion()
|
||||
|
||||
self.monitor = monitor
|
||||
self.numd = self.diffusion_model.first_stage_model.encoder.num_resolutions - 1
|
||||
self.log_time_interval = self.diffusion_model.num_timesteps // log_steps
|
||||
self.log_steps = log_steps
|
||||
|
||||
self.label_key = label_key if not hasattr(self.diffusion_model, 'cond_stage_key') \
|
||||
else self.diffusion_model.cond_stage_key
|
||||
|
||||
assert self.label_key is not None, 'label_key neither in diffusion model nor in model.params'
|
||||
|
||||
if self.label_key not in __models__:
|
||||
raise NotImplementedError()
|
||||
|
||||
self.load_classifier(ckpt_path, pool)
|
||||
|
||||
self.scheduler_config = scheduler_config
|
||||
self.use_scheduler = self.scheduler_config is not None
|
||||
self.weight_decay = weight_decay
|
||||
|
||||
def init_from_ckpt(self, path, ignore_keys=list(), only_model=False):
|
||||
sd = torch.load(path, map_location="cpu")
|
||||
if "state_dict" in list(sd.keys()):
|
||||
sd = sd["state_dict"]
|
||||
keys = list(sd.keys())
|
||||
for k in keys:
|
||||
for ik in ignore_keys:
|
||||
if k.startswith(ik):
|
||||
print("Deleting key {} from state_dict.".format(k))
|
||||
del sd[k]
|
||||
missing, unexpected = self.load_state_dict(sd, strict=False) if not only_model else self.model.load_state_dict(
|
||||
sd, strict=False)
|
||||
print(f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys")
|
||||
if len(missing) > 0:
|
||||
print(f"Missing Keys: {missing}")
|
||||
if len(unexpected) > 0:
|
||||
print(f"Unexpected Keys: {unexpected}")
|
||||
|
||||
def load_diffusion(self):
|
||||
model = instantiate_from_config(self.diffusion_config)
|
||||
self.diffusion_model = model.eval()
|
||||
self.diffusion_model.train = disabled_train
|
||||
for param in self.diffusion_model.parameters():
|
||||
param.requires_grad = False
|
||||
|
||||
def load_classifier(self, ckpt_path, pool):
|
||||
model_config = deepcopy(self.diffusion_config.params.unet_config.params)
|
||||
model_config.in_channels = self.diffusion_config.params.unet_config.params.out_channels
|
||||
model_config.out_channels = self.num_classes
|
||||
if self.label_key == 'class_label':
|
||||
model_config.pool = pool
|
||||
|
||||
self.model = __models__[self.label_key](**model_config)
|
||||
if ckpt_path is not None:
|
||||
print('#####################################################################')
|
||||
print(f'load from ckpt "{ckpt_path}"')
|
||||
print('#####################################################################')
|
||||
self.init_from_ckpt(ckpt_path)
|
||||
|
||||
@torch.no_grad()
|
||||
def get_x_noisy(self, x, t, noise=None):
|
||||
noise = default(noise, lambda: torch.randn_like(x))
|
||||
continuous_sqrt_alpha_cumprod = None
|
||||
if self.diffusion_model.use_continuous_noise:
|
||||
continuous_sqrt_alpha_cumprod = self.diffusion_model.sample_continuous_noise_level(x.shape[0], t + 1)
|
||||
# todo: make sure t+1 is correct here
|
||||
|
||||
return self.diffusion_model.q_sample(x_start=x, t=t, noise=noise,
|
||||
continuous_sqrt_alpha_cumprod=continuous_sqrt_alpha_cumprod)
|
||||
|
||||
def forward(self, x_noisy, t, *args, **kwargs):
|
||||
return self.model(x_noisy, t)
|
||||
|
||||
@torch.no_grad()
|
||||
def get_input(self, batch, k):
|
||||
x = batch[k]
|
||||
if len(x.shape) == 3:
|
||||
x = x[..., None]
|
||||
x = rearrange(x, 'b h w c -> b c h w')
|
||||
x = x.to(memory_format=torch.contiguous_format).float()
|
||||
return x
|
||||
|
||||
@torch.no_grad()
|
||||
def get_conditioning(self, batch, k=None):
|
||||
if k is None:
|
||||
k = self.label_key
|
||||
assert k is not None, 'Needs to provide label key'
|
||||
|
||||
targets = batch[k].to(self.device)
|
||||
|
||||
if self.label_key == 'segmentation':
|
||||
targets = rearrange(targets, 'b h w c -> b c h w')
|
||||
for down in range(self.numd):
|
||||
h, w = targets.shape[-2:]
|
||||
targets = F.interpolate(targets, size=(h // 2, w // 2), mode='nearest')
|
||||
|
||||
# targets = rearrange(targets,'b c h w -> b h w c')
|
||||
|
||||
return targets
|
||||
|
||||
def compute_top_k(self, logits, labels, k, reduction="mean"):
|
||||
_, top_ks = torch.topk(logits, k, dim=1)
|
||||
if reduction == "mean":
|
||||
return (top_ks == labels[:, None]).float().sum(dim=-1).mean().item()
|
||||
elif reduction == "none":
|
||||
return (top_ks == labels[:, None]).float().sum(dim=-1)
|
||||
|
||||
def on_train_epoch_start(self):
|
||||
# save some memory
|
||||
self.diffusion_model.model.to('cpu')
|
||||
|
||||
@torch.no_grad()
|
||||
def write_logs(self, loss, logits, targets):
|
||||
log_prefix = 'train' if self.training else 'val'
|
||||
log = {}
|
||||
log[f"{log_prefix}/loss"] = loss.mean()
|
||||
log[f"{log_prefix}/acc@1"] = self.compute_top_k(
|
||||
logits, targets, k=1, reduction="mean"
|
||||
)
|
||||
log[f"{log_prefix}/acc@5"] = self.compute_top_k(
|
||||
logits, targets, k=5, reduction="mean"
|
||||
)
|
||||
|
||||
self.log_dict(log, prog_bar=False, logger=True, on_step=self.training, on_epoch=True)
|
||||
self.log('loss', log[f"{log_prefix}/loss"], prog_bar=True, logger=False)
|
||||
self.log('global_step', self.global_step, logger=False, on_epoch=False, prog_bar=True)
|
||||
lr = self.optimizers().param_groups[0]['lr']
|
||||
self.log('lr_abs', lr, on_step=True, logger=True, on_epoch=False, prog_bar=True)
|
||||
|
||||
def shared_step(self, batch, t=None):
|
||||
x, *_ = self.diffusion_model.get_input(batch, k=self.diffusion_model.first_stage_key)
|
||||
targets = self.get_conditioning(batch)
|
||||
if targets.dim() == 4:
|
||||
targets = targets.argmax(dim=1)
|
||||
if t is None:
|
||||
t = torch.randint(0, self.diffusion_model.num_timesteps, (x.shape[0],), device=self.device).long()
|
||||
else:
|
||||
t = torch.full(size=(x.shape[0],), fill_value=t, device=self.device).long()
|
||||
x_noisy = self.get_x_noisy(x, t)
|
||||
logits = self(x_noisy, t)
|
||||
|
||||
loss = F.cross_entropy(logits, targets, reduction='none')
|
||||
|
||||
self.write_logs(loss.detach(), logits.detach(), targets.detach())
|
||||
|
||||
loss = loss.mean()
|
||||
return loss, logits, x_noisy, targets
|
||||
|
||||
def training_step(self, batch, batch_idx):
|
||||
loss, *_ = self.shared_step(batch)
|
||||
return loss
|
||||
|
||||
def reset_noise_accs(self):
|
||||
self.noisy_acc = {t: {'acc@1': [], 'acc@5': []} for t in
|
||||
range(0, self.diffusion_model.num_timesteps, self.diffusion_model.log_every_t)}
|
||||
|
||||
def on_validation_start(self):
|
||||
self.reset_noise_accs()
|
||||
|
||||
@torch.no_grad()
|
||||
def validation_step(self, batch, batch_idx):
|
||||
loss, *_ = self.shared_step(batch)
|
||||
|
||||
for t in self.noisy_acc:
|
||||
_, logits, _, targets = self.shared_step(batch, t)
|
||||
self.noisy_acc[t]['acc@1'].append(self.compute_top_k(logits, targets, k=1, reduction='mean'))
|
||||
self.noisy_acc[t]['acc@5'].append(self.compute_top_k(logits, targets, k=5, reduction='mean'))
|
||||
|
||||
return loss
|
||||
|
||||
def configure_optimizers(self):
|
||||
optimizer = AdamW(self.model.parameters(), lr=self.learning_rate, weight_decay=self.weight_decay)
|
||||
|
||||
if self.use_scheduler:
|
||||
scheduler = instantiate_from_config(self.scheduler_config)
|
||||
|
||||
print("Setting up LambdaLR scheduler...")
|
||||
scheduler = [
|
||||
{
|
||||
'scheduler': LambdaLR(optimizer, lr_lambda=scheduler.schedule),
|
||||
'interval': 'step',
|
||||
'frequency': 1
|
||||
}]
|
||||
return [optimizer], scheduler
|
||||
|
||||
return optimizer
|
||||
|
||||
@torch.no_grad()
|
||||
def log_images(self, batch, N=8, *args, **kwargs):
|
||||
log = dict()
|
||||
x = self.get_input(batch, self.diffusion_model.first_stage_key)
|
||||
log['inputs'] = x
|
||||
|
||||
y = self.get_conditioning(batch)
|
||||
|
||||
if self.label_key == 'class_label':
|
||||
y = log_txt_as_img((x.shape[2], x.shape[3]), batch["human_label"])
|
||||
log['labels'] = y
|
||||
|
||||
if ismap(y):
|
||||
log['labels'] = self.diffusion_model.to_rgb(y)
|
||||
|
||||
for step in range(self.log_steps):
|
||||
current_time = step * self.log_time_interval
|
||||
|
||||
_, logits, x_noisy, _ = self.shared_step(batch, t=current_time)
|
||||
|
||||
log[f'inputs@t{current_time}'] = x_noisy
|
||||
|
||||
pred = F.one_hot(logits.argmax(dim=1), num_classes=self.num_classes)
|
||||
pred = rearrange(pred, 'b h w c -> b c h w')
|
||||
|
||||
log[f'pred@t{current_time}'] = self.diffusion_model.to_rgb(pred)
|
||||
|
||||
for key in log:
|
||||
log[key] = log[key][:N]
|
||||
|
||||
return log
|
|
@ -0,0 +1,240 @@
|
|||
"""SAMPLING ONLY."""
|
||||
|
||||
import torch
|
||||
import numpy as np
|
||||
from tqdm import tqdm
|
||||
from functools import partial
|
||||
|
||||
from ldm.modules.diffusionmodules.util import make_ddim_sampling_parameters, make_ddim_timesteps, noise_like, \
|
||||
extract_into_tensor
|
||||
|
||||
|
||||
class DDIMSampler(object):
|
||||
def __init__(self, model, schedule="linear", **kwargs):
|
||||
super().__init__()
|
||||
self.model = model
|
||||
self.ddpm_num_timesteps = model.num_timesteps
|
||||
self.schedule = schedule
|
||||
|
||||
def register_buffer(self, name, attr):
|
||||
if type(attr) == torch.Tensor:
|
||||
if attr.device != torch.device("cuda"):
|
||||
attr = attr.to(torch.device("cuda"))
|
||||
setattr(self, name, attr)
|
||||
|
||||
def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0., verbose=True):
|
||||
self.ddim_timesteps = make_ddim_timesteps(ddim_discr_method=ddim_discretize, num_ddim_timesteps=ddim_num_steps,
|
||||
num_ddpm_timesteps=self.ddpm_num_timesteps,verbose=verbose)
|
||||
alphas_cumprod = self.model.alphas_cumprod
|
||||
assert alphas_cumprod.shape[0] == self.ddpm_num_timesteps, 'alphas have to be defined for each timestep'
|
||||
to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.model.device)
|
||||
|
||||
self.register_buffer('betas', to_torch(self.model.betas))
|
||||
self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod))
|
||||
self.register_buffer('alphas_cumprod_prev', to_torch(self.model.alphas_cumprod_prev))
|
||||
|
||||
# calculations for diffusion q(x_t | x_{t-1}) and others
|
||||
self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod.cpu())))
|
||||
self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod.cpu())))
|
||||
self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod.cpu())))
|
||||
self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu())))
|
||||
self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu() - 1)))
|
||||
|
||||
# ddim sampling parameters
|
||||
ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters(alphacums=alphas_cumprod.cpu(),
|
||||
ddim_timesteps=self.ddim_timesteps,
|
||||
eta=ddim_eta,verbose=verbose)
|
||||
self.register_buffer('ddim_sigmas', ddim_sigmas)
|
||||
self.register_buffer('ddim_alphas', ddim_alphas)
|
||||
self.register_buffer('ddim_alphas_prev', ddim_alphas_prev)
|
||||
self.register_buffer('ddim_sqrt_one_minus_alphas', np.sqrt(1. - ddim_alphas))
|
||||
sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt(
|
||||
(1 - self.alphas_cumprod_prev) / (1 - self.alphas_cumprod) * (
|
||||
1 - self.alphas_cumprod / self.alphas_cumprod_prev))
|
||||
self.register_buffer('ddim_sigmas_for_original_num_steps', sigmas_for_original_sampling_steps)
|
||||
|
||||
@torch.no_grad()
|
||||
def sample(self,
|
||||
S,
|
||||
batch_size,
|
||||
shape,
|
||||
conditioning=None,
|
||||
callback=None,
|
||||
normals_sequence=None,
|
||||
img_callback=None,
|
||||
quantize_x0=False,
|
||||
eta=0.,
|
||||
mask=None,
|
||||
x0=None,
|
||||
temperature=1.,
|
||||
noise_dropout=0.,
|
||||
score_corrector=None,
|
||||
corrector_kwargs=None,
|
||||
verbose=True,
|
||||
x_T=None,
|
||||
log_every_t=100,
|
||||
unconditional_guidance_scale=1.,
|
||||
unconditional_conditioning=None,
|
||||
# this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
|
||||
**kwargs
|
||||
):
|
||||
if conditioning is not None:
|
||||
if isinstance(conditioning, dict):
|
||||
cbs = conditioning[list(conditioning.keys())[0]].shape[0]
|
||||
if cbs != batch_size:
|
||||
print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}")
|
||||
else:
|
||||
if conditioning.shape[0] != batch_size:
|
||||
print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}")
|
||||
|
||||
self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=verbose)
|
||||
# sampling
|
||||
C, H, W = shape
|
||||
size = (batch_size, C, H, W)
|
||||
print(f'Data shape for DDIM sampling is {size}, eta {eta}')
|
||||
|
||||
samples, intermediates = self.ddim_sampling(conditioning, size,
|
||||
callback=callback,
|
||||
img_callback=img_callback,
|
||||
quantize_denoised=quantize_x0,
|
||||
mask=mask, x0=x0,
|
||||
ddim_use_original_steps=False,
|
||||
noise_dropout=noise_dropout,
|
||||
temperature=temperature,
|
||||
score_corrector=score_corrector,
|
||||
corrector_kwargs=corrector_kwargs,
|
||||
x_T=x_T,
|
||||
log_every_t=log_every_t,
|
||||
unconditional_guidance_scale=unconditional_guidance_scale,
|
||||
unconditional_conditioning=unconditional_conditioning,
|
||||
)
|
||||
return samples, intermediates
|
||||
|
||||
@torch.no_grad()
|
||||
def ddim_sampling(self, cond, shape,
|
||||
x_T=None, ddim_use_original_steps=False,
|
||||
callback=None, timesteps=None, quantize_denoised=False,
|
||||
mask=None, x0=None, img_callback=None, log_every_t=100,
|
||||
temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
|
||||
unconditional_guidance_scale=1., unconditional_conditioning=None,):
|
||||
device = self.model.betas.device
|
||||
b = shape[0]
|
||||
if x_T is None:
|
||||
img = torch.randn(shape, device=device)
|
||||
else:
|
||||
img = x_T
|
||||
|
||||
if timesteps is None:
|
||||
timesteps = self.ddpm_num_timesteps if ddim_use_original_steps else self.ddim_timesteps
|
||||
elif timesteps is not None and not ddim_use_original_steps:
|
||||
subset_end = int(min(timesteps / self.ddim_timesteps.shape[0], 1) * self.ddim_timesteps.shape[0]) - 1
|
||||
timesteps = self.ddim_timesteps[:subset_end]
|
||||
|
||||
intermediates = {'x_inter': [img], 'pred_x0': [img]}
|
||||
time_range = reversed(range(0,timesteps)) if ddim_use_original_steps else np.flip(timesteps)
|
||||
total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0]
|
||||
print(f"Running DDIM Sampling with {total_steps} timesteps")
|
||||
|
||||
iterator = tqdm(time_range, desc='DDIM Sampler', total=total_steps)
|
||||
|
||||
for i, step in enumerate(iterator):
|
||||
index = total_steps - i - 1
|
||||
ts = torch.full((b,), step, device=device, dtype=torch.long)
|
||||
|
||||
if mask is not None:
|
||||
assert x0 is not None
|
||||
img_orig = self.model.q_sample(x0, ts) # TODO: deterministic forward pass?
|
||||
img = img_orig * mask + (1. - mask) * img
|
||||
outs = self.p_sample_ddim(img, cond, ts, index=index, use_original_steps=ddim_use_original_steps,
|
||||
quantize_denoised=quantize_denoised, temperature=temperature,
|
||||
noise_dropout=noise_dropout, score_corrector=score_corrector,
|
||||
corrector_kwargs=corrector_kwargs,
|
||||
unconditional_guidance_scale=unconditional_guidance_scale,
|
||||
unconditional_conditioning=unconditional_conditioning)
|
||||
img, pred_x0 = outs
|
||||
if callback: callback(i)
|
||||
if img_callback: img_callback(pred_x0, i)
|
||||
|
||||
if index % log_every_t == 0 or index == total_steps - 1:
|
||||
intermediates['x_inter'].append(img)
|
||||
intermediates['pred_x0'].append(pred_x0)
|
||||
|
||||
return img, intermediates
|
||||
|
||||
@torch.no_grad()
|
||||
def p_sample_ddim(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False,
|
||||
temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
|
||||
unconditional_guidance_scale=1., unconditional_conditioning=None):
|
||||
b, *_, device = *x.shape, x.device
|
||||
|
||||
if unconditional_conditioning is None or unconditional_guidance_scale == 1.:
|
||||
e_t = self.model.apply_model(x, t, c)
|
||||
else:
|
||||
x_in = torch.cat([x] * 2)
|
||||
t_in = torch.cat([t] * 2)
|
||||
c_in = torch.cat([unconditional_conditioning, c])
|
||||
e_t_uncond, e_t = self.model.apply_model(x_in, t_in, c_in).chunk(2)
|
||||
e_t = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond)
|
||||
|
||||
if score_corrector is not None:
|
||||
assert self.model.parameterization == "eps"
|
||||
e_t = score_corrector.modify_score(self.model, e_t, x, t, c, **corrector_kwargs)
|
||||
|
||||
alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas
|
||||
alphas_prev = self.model.alphas_cumprod_prev if use_original_steps else self.ddim_alphas_prev
|
||||
sqrt_one_minus_alphas = self.model.sqrt_one_minus_alphas_cumprod if use_original_steps else self.ddim_sqrt_one_minus_alphas
|
||||
sigmas = self.model.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas
|
||||
# select parameters corresponding to the currently considered timestep
|
||||
a_t = torch.full((b, 1, 1, 1), alphas[index], device=device)
|
||||
a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device)
|
||||
sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device)
|
||||
sqrt_one_minus_at = torch.full((b, 1, 1, 1), sqrt_one_minus_alphas[index],device=device)
|
||||
|
||||
# current prediction for x_0
|
||||
pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
|
||||
if quantize_denoised:
|
||||
pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0)
|
||||
# direction pointing to x_t
|
||||
dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t
|
||||
noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature
|
||||
if noise_dropout > 0.:
|
||||
noise = torch.nn.functional.dropout(noise, p=noise_dropout)
|
||||
x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise
|
||||
return x_prev, pred_x0
|
||||
|
||||
@torch.no_grad()
|
||||
def stochastic_encode(self, x0, t, use_original_steps=False, noise=None):
|
||||
# fast, but does not allow for exact reconstruction
|
||||
# t serves as an index to gather the correct alphas
|
||||
if use_original_steps:
|
||||
sqrt_alphas_cumprod = self.sqrt_alphas_cumprod
|
||||
sqrt_one_minus_alphas_cumprod = self.sqrt_one_minus_alphas_cumprod
|
||||
else:
|
||||
sqrt_alphas_cumprod = torch.sqrt(self.ddim_alphas)
|
||||
sqrt_one_minus_alphas_cumprod = self.ddim_sqrt_one_minus_alphas
|
||||
|
||||
if noise is None:
|
||||
noise = torch.randn_like(x0)
|
||||
return (extract_into_tensor(sqrt_alphas_cumprod, t, x0.shape) * x0 +
|
||||
extract_into_tensor(sqrt_one_minus_alphas_cumprod, t, x0.shape) * noise)
|
||||
|
||||
@torch.no_grad()
|
||||
def decode(self, x_latent, cond, t_start, unconditional_guidance_scale=1.0, unconditional_conditioning=None,
|
||||
use_original_steps=False):
|
||||
|
||||
timesteps = np.arange(self.ddpm_num_timesteps) if use_original_steps else self.ddim_timesteps
|
||||
timesteps = timesteps[:t_start]
|
||||
|
||||
time_range = np.flip(timesteps)
|
||||
total_steps = timesteps.shape[0]
|
||||
print(f"Running DDIM Sampling with {total_steps} timesteps")
|
||||
|
||||
iterator = tqdm(time_range, desc='Decoding image', total=total_steps)
|
||||
x_dec = x_latent
|
||||
for i, step in enumerate(iterator):
|
||||
index = total_steps - i - 1
|
||||
ts = torch.full((x_latent.shape[0],), step, device=x_latent.device, dtype=torch.long)
|
||||
x_dec, _ = self.p_sample_ddim(x_dec, cond, ts, index=index, use_original_steps=use_original_steps,
|
||||
unconditional_guidance_scale=unconditional_guidance_scale,
|
||||
unconditional_conditioning=unconditional_conditioning)
|
||||
return x_dec
|
File diff suppressed because it is too large
Load Diff
|
@ -0,0 +1,236 @@
|
|||
"""SAMPLING ONLY."""
|
||||
|
||||
import torch
|
||||
import numpy as np
|
||||
from tqdm import tqdm
|
||||
from functools import partial
|
||||
|
||||
from ldm.modules.diffusionmodules.util import make_ddim_sampling_parameters, make_ddim_timesteps, noise_like
|
||||
|
||||
|
||||
class PLMSSampler(object):
|
||||
def __init__(self, model, schedule="linear", **kwargs):
|
||||
super().__init__()
|
||||
self.model = model
|
||||
self.ddpm_num_timesteps = model.num_timesteps
|
||||
self.schedule = schedule
|
||||
|
||||
def register_buffer(self, name, attr):
|
||||
if type(attr) == torch.Tensor:
|
||||
if attr.device != torch.device("cuda"):
|
||||
attr = attr.to(torch.device("cuda"))
|
||||
setattr(self, name, attr)
|
||||
|
||||
def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0., verbose=True):
|
||||
if ddim_eta != 0:
|
||||
raise ValueError('ddim_eta must be 0 for PLMS')
|
||||
self.ddim_timesteps = make_ddim_timesteps(ddim_discr_method=ddim_discretize, num_ddim_timesteps=ddim_num_steps,
|
||||
num_ddpm_timesteps=self.ddpm_num_timesteps,verbose=verbose)
|
||||
alphas_cumprod = self.model.alphas_cumprod
|
||||
assert alphas_cumprod.shape[0] == self.ddpm_num_timesteps, 'alphas have to be defined for each timestep'
|
||||
to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.model.device)
|
||||
|
||||
self.register_buffer('betas', to_torch(self.model.betas))
|
||||
self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod))
|
||||
self.register_buffer('alphas_cumprod_prev', to_torch(self.model.alphas_cumprod_prev))
|
||||
|
||||
# calculations for diffusion q(x_t | x_{t-1}) and others
|
||||
self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod.cpu())))
|
||||
self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod.cpu())))
|
||||
self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod.cpu())))
|
||||
self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu())))
|
||||
self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu() - 1)))
|
||||
|
||||
# ddim sampling parameters
|
||||
ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters(alphacums=alphas_cumprod.cpu(),
|
||||
ddim_timesteps=self.ddim_timesteps,
|
||||
eta=ddim_eta,verbose=verbose)
|
||||
self.register_buffer('ddim_sigmas', ddim_sigmas)
|
||||
self.register_buffer('ddim_alphas', ddim_alphas)
|
||||
self.register_buffer('ddim_alphas_prev', ddim_alphas_prev)
|
||||
self.register_buffer('ddim_sqrt_one_minus_alphas', np.sqrt(1. - ddim_alphas))
|
||||
sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt(
|
||||
(1 - self.alphas_cumprod_prev) / (1 - self.alphas_cumprod) * (
|
||||
1 - self.alphas_cumprod / self.alphas_cumprod_prev))
|
||||
self.register_buffer('ddim_sigmas_for_original_num_steps', sigmas_for_original_sampling_steps)
|
||||
|
||||
@torch.no_grad()
|
||||
def sample(self,
|
||||
S,
|
||||
batch_size,
|
||||
shape,
|
||||
conditioning=None,
|
||||
callback=None,
|
||||
normals_sequence=None,
|
||||
img_callback=None,
|
||||
quantize_x0=False,
|
||||
eta=0.,
|
||||
mask=None,
|
||||
x0=None,
|
||||
temperature=1.,
|
||||
noise_dropout=0.,
|
||||
score_corrector=None,
|
||||
corrector_kwargs=None,
|
||||
verbose=True,
|
||||
x_T=None,
|
||||
log_every_t=100,
|
||||
unconditional_guidance_scale=1.,
|
||||
unconditional_conditioning=None,
|
||||
# this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
|
||||
**kwargs
|
||||
):
|
||||
if conditioning is not None:
|
||||
if isinstance(conditioning, dict):
|
||||
cbs = conditioning[list(conditioning.keys())[0]].shape[0]
|
||||
if cbs != batch_size:
|
||||
print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}")
|
||||
else:
|
||||
if conditioning.shape[0] != batch_size:
|
||||
print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}")
|
||||
|
||||
self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=verbose)
|
||||
# sampling
|
||||
C, H, W = shape
|
||||
size = (batch_size, C, H, W)
|
||||
print(f'Data shape for PLMS sampling is {size}')
|
||||
|
||||
samples, intermediates = self.plms_sampling(conditioning, size,
|
||||
callback=callback,
|
||||
img_callback=img_callback,
|
||||
quantize_denoised=quantize_x0,
|
||||
mask=mask, x0=x0,
|
||||
ddim_use_original_steps=False,
|
||||
noise_dropout=noise_dropout,
|
||||
temperature=temperature,
|
||||
score_corrector=score_corrector,
|
||||
corrector_kwargs=corrector_kwargs,
|
||||
x_T=x_T,
|
||||
log_every_t=log_every_t,
|
||||
unconditional_guidance_scale=unconditional_guidance_scale,
|
||||
unconditional_conditioning=unconditional_conditioning,
|
||||
)
|
||||
return samples, intermediates
|
||||
|
||||
@torch.no_grad()
|
||||
def plms_sampling(self, cond, shape,
|
||||
x_T=None, ddim_use_original_steps=False,
|
||||
callback=None, timesteps=None, quantize_denoised=False,
|
||||
mask=None, x0=None, img_callback=None, log_every_t=100,
|
||||
temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
|
||||
unconditional_guidance_scale=1., unconditional_conditioning=None,):
|
||||
device = self.model.betas.device
|
||||
b = shape[0]
|
||||
if x_T is None:
|
||||
img = torch.randn(shape, device=device)
|
||||
else:
|
||||
img = x_T
|
||||
|
||||
if timesteps is None:
|
||||
timesteps = self.ddpm_num_timesteps if ddim_use_original_steps else self.ddim_timesteps
|
||||
elif timesteps is not None and not ddim_use_original_steps:
|
||||
subset_end = int(min(timesteps / self.ddim_timesteps.shape[0], 1) * self.ddim_timesteps.shape[0]) - 1
|
||||
timesteps = self.ddim_timesteps[:subset_end]
|
||||
|
||||
intermediates = {'x_inter': [img], 'pred_x0': [img]}
|
||||
time_range = list(reversed(range(0,timesteps))) if ddim_use_original_steps else np.flip(timesteps)
|
||||
total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0]
|
||||
print(f"Running PLMS Sampling with {total_steps} timesteps")
|
||||
|
||||
iterator = tqdm(time_range, desc='PLMS Sampler', total=total_steps)
|
||||
old_eps = []
|
||||
|
||||
for i, step in enumerate(iterator):
|
||||
index = total_steps - i - 1
|
||||
ts = torch.full((b,), step, device=device, dtype=torch.long)
|
||||
ts_next = torch.full((b,), time_range[min(i + 1, len(time_range) - 1)], device=device, dtype=torch.long)
|
||||
|
||||
if mask is not None:
|
||||
assert x0 is not None
|
||||
img_orig = self.model.q_sample(x0, ts) # TODO: deterministic forward pass?
|
||||
img = img_orig * mask + (1. - mask) * img
|
||||
|
||||
outs = self.p_sample_plms(img, cond, ts, index=index, use_original_steps=ddim_use_original_steps,
|
||||
quantize_denoised=quantize_denoised, temperature=temperature,
|
||||
noise_dropout=noise_dropout, score_corrector=score_corrector,
|
||||
corrector_kwargs=corrector_kwargs,
|
||||
unconditional_guidance_scale=unconditional_guidance_scale,
|
||||
unconditional_conditioning=unconditional_conditioning,
|
||||
old_eps=old_eps, t_next=ts_next)
|
||||
img, pred_x0, e_t = outs
|
||||
old_eps.append(e_t)
|
||||
if len(old_eps) >= 4:
|
||||
old_eps.pop(0)
|
||||
if callback: callback(i)
|
||||
if img_callback: img_callback(pred_x0, i)
|
||||
|
||||
if index % log_every_t == 0 or index == total_steps - 1:
|
||||
intermediates['x_inter'].append(img)
|
||||
intermediates['pred_x0'].append(pred_x0)
|
||||
|
||||
return img, intermediates
|
||||
|
||||
@torch.no_grad()
|
||||
def p_sample_plms(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False,
|
||||
temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
|
||||
unconditional_guidance_scale=1., unconditional_conditioning=None, old_eps=None, t_next=None):
|
||||
b, *_, device = *x.shape, x.device
|
||||
|
||||
def get_model_output(x, t):
|
||||
if unconditional_conditioning is None or unconditional_guidance_scale == 1.:
|
||||
e_t = self.model.apply_model(x, t, c)
|
||||
else:
|
||||
x_in = torch.cat([x] * 2)
|
||||
t_in = torch.cat([t] * 2)
|
||||
c_in = torch.cat([unconditional_conditioning, c])
|
||||
e_t_uncond, e_t = self.model.apply_model(x_in, t_in, c_in).chunk(2)
|
||||
e_t = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond)
|
||||
|
||||
if score_corrector is not None:
|
||||
assert self.model.parameterization == "eps"
|
||||
e_t = score_corrector.modify_score(self.model, e_t, x, t, c, **corrector_kwargs)
|
||||
|
||||
return e_t
|
||||
|
||||
alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas
|
||||
alphas_prev = self.model.alphas_cumprod_prev if use_original_steps else self.ddim_alphas_prev
|
||||
sqrt_one_minus_alphas = self.model.sqrt_one_minus_alphas_cumprod if use_original_steps else self.ddim_sqrt_one_minus_alphas
|
||||
sigmas = self.model.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas
|
||||
|
||||
def get_x_prev_and_pred_x0(e_t, index):
|
||||
# select parameters corresponding to the currently considered timestep
|
||||
a_t = torch.full((b, 1, 1, 1), alphas[index], device=device)
|
||||
a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device)
|
||||
sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device)
|
||||
sqrt_one_minus_at = torch.full((b, 1, 1, 1), sqrt_one_minus_alphas[index],device=device)
|
||||
|
||||
# current prediction for x_0
|
||||
pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
|
||||
if quantize_denoised:
|
||||
pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0)
|
||||
# direction pointing to x_t
|
||||
dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t
|
||||
noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature
|
||||
if noise_dropout > 0.:
|
||||
noise = torch.nn.functional.dropout(noise, p=noise_dropout)
|
||||
x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise
|
||||
return x_prev, pred_x0
|
||||
|
||||
e_t = get_model_output(x, t)
|
||||
if len(old_eps) == 0:
|
||||
# Pseudo Improved Euler (2nd order)
|
||||
x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t, index)
|
||||
e_t_next = get_model_output(x_prev, t_next)
|
||||
e_t_prime = (e_t + e_t_next) / 2
|
||||
elif len(old_eps) == 1:
|
||||
# 2nd order Pseudo Linear Multistep (Adams-Bashforth)
|
||||
e_t_prime = (3 * e_t - old_eps[-1]) / 2
|
||||
elif len(old_eps) == 2:
|
||||
# 3nd order Pseudo Linear Multistep (Adams-Bashforth)
|
||||
e_t_prime = (23 * e_t - 16 * old_eps[-1] + 5 * old_eps[-2]) / 12
|
||||
elif len(old_eps) >= 3:
|
||||
# 4nd order Pseudo Linear Multistep (Adams-Bashforth)
|
||||
e_t_prime = (55 * e_t - 59 * old_eps[-1] + 37 * old_eps[-2] - 9 * old_eps[-3]) / 24
|
||||
|
||||
x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t_prime, index)
|
||||
|
||||
return x_prev, pred_x0, e_t
|
|
@ -0,0 +1,203 @@
|
|||
import importlib
|
||||
|
||||
import torch
|
||||
import numpy as np
|
||||
from collections import abc
|
||||
from einops import rearrange
|
||||
from functools import partial
|
||||
|
||||
import multiprocessing as mp
|
||||
from threading import Thread
|
||||
from queue import Queue
|
||||
|
||||
from inspect import isfunction
|
||||
from PIL import Image, ImageDraw, ImageFont
|
||||
|
||||
|
||||
def log_txt_as_img(wh, xc, size=10):
|
||||
# wh a tuple of (width, height)
|
||||
# xc a list of captions to plot
|
||||
b = len(xc)
|
||||
txts = list()
|
||||
for bi in range(b):
|
||||
txt = Image.new("RGB", wh, color="white")
|
||||
draw = ImageDraw.Draw(txt)
|
||||
font = ImageFont.truetype('data/DejaVuSans.ttf', size=size)
|
||||
nc = int(40 * (wh[0] / 256))
|
||||
lines = "\n".join(xc[bi][start:start + nc] for start in range(0, len(xc[bi]), nc))
|
||||
|
||||
try:
|
||||
draw.text((0, 0), lines, fill="black", font=font)
|
||||
except UnicodeEncodeError:
|
||||
print("Cant encode string for logging. Skipping.")
|
||||
|
||||
txt = np.array(txt).transpose(2, 0, 1) / 127.5 - 1.0
|
||||
txts.append(txt)
|
||||
txts = np.stack(txts)
|
||||
txts = torch.tensor(txts)
|
||||
return txts
|
||||
|
||||
|
||||
def ismap(x):
|
||||
if not isinstance(x, torch.Tensor):
|
||||
return False
|
||||
return (len(x.shape) == 4) and (x.shape[1] > 3)
|
||||
|
||||
|
||||
def isimage(x):
|
||||
if not isinstance(x, torch.Tensor):
|
||||
return False
|
||||
return (len(x.shape) == 4) and (x.shape[1] == 3 or x.shape[1] == 1)
|
||||
|
||||
|
||||
def exists(x):
|
||||
return x is not None
|
||||
|
||||
|
||||
def default(val, d):
|
||||
if exists(val):
|
||||
return val
|
||||
return d() if isfunction(d) else d
|
||||
|
||||
|
||||
def mean_flat(tensor):
|
||||
"""
|
||||
https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/nn.py#L86
|
||||
Take the mean over all non-batch dimensions.
|
||||
"""
|
||||
return tensor.mean(dim=list(range(1, len(tensor.shape))))
|
||||
|
||||
|
||||
def count_params(model, verbose=False):
|
||||
total_params = sum(p.numel() for p in model.parameters())
|
||||
if verbose:
|
||||
print(f"{model.__class__.__name__} has {total_params * 1.e-6:.2f} M params.")
|
||||
return total_params
|
||||
|
||||
|
||||
def instantiate_from_config(config):
|
||||
if not "target" in config:
|
||||
if config == '__is_first_stage__':
|
||||
return None
|
||||
elif config == "__is_unconditional__":
|
||||
return None
|
||||
raise KeyError("Expected key `target` to instantiate.")
|
||||
return get_obj_from_str(config["target"])(**config.get("params", dict()))
|
||||
|
||||
|
||||
def get_obj_from_str(string, reload=False):
|
||||
module, cls = string.rsplit(".", 1)
|
||||
if reload:
|
||||
module_imp = importlib.import_module(module)
|
||||
importlib.reload(module_imp)
|
||||
return getattr(importlib.import_module(module, package=None), cls)
|
||||
|
||||
|
||||
def _do_parallel_data_prefetch(func, Q, data, idx, idx_to_fn=False):
|
||||
# create dummy dataset instance
|
||||
|
||||
# run prefetching
|
||||
if idx_to_fn:
|
||||
res = func(data, worker_id=idx)
|
||||
else:
|
||||
res = func(data)
|
||||
Q.put([idx, res])
|
||||
Q.put("Done")
|
||||
|
||||
|
||||
def parallel_data_prefetch(
|
||||
func: callable, data, n_proc, target_data_type="ndarray", cpu_intensive=True, use_worker_id=False
|
||||
):
|
||||
# if target_data_type not in ["ndarray", "list"]:
|
||||
# raise ValueError(
|
||||
# "Data, which is passed to parallel_data_prefetch has to be either of type list or ndarray."
|
||||
# )
|
||||
if isinstance(data, np.ndarray) and target_data_type == "list":
|
||||
raise ValueError("list expected but function got ndarray.")
|
||||
elif isinstance(data, abc.Iterable):
|
||||
if isinstance(data, dict):
|
||||
print(
|
||||
f'WARNING:"data" argument passed to parallel_data_prefetch is a dict: Using only its values and disregarding keys.'
|
||||
)
|
||||
data = list(data.values())
|
||||
if target_data_type == "ndarray":
|
||||
data = np.asarray(data)
|
||||
else:
|
||||
data = list(data)
|
||||
else:
|
||||
raise TypeError(
|
||||
f"The data, that shall be processed parallel has to be either an np.ndarray or an Iterable, but is actually {type(data)}."
|
||||
)
|
||||
|
||||
if cpu_intensive:
|
||||
Q = mp.Queue(1000)
|
||||
proc = mp.Process
|
||||
else:
|
||||
Q = Queue(1000)
|
||||
proc = Thread
|
||||
# spawn processes
|
||||
if target_data_type == "ndarray":
|
||||
arguments = [
|
||||
[func, Q, part, i, use_worker_id]
|
||||
for i, part in enumerate(np.array_split(data, n_proc))
|
||||
]
|
||||
else:
|
||||
step = (
|
||||
int(len(data) / n_proc + 1)
|
||||
if len(data) % n_proc != 0
|
||||
else int(len(data) / n_proc)
|
||||
)
|
||||
arguments = [
|
||||
[func, Q, part, i, use_worker_id]
|
||||
for i, part in enumerate(
|
||||
[data[i: i + step] for i in range(0, len(data), step)]
|
||||
)
|
||||
]
|
||||
processes = []
|
||||
for i in range(n_proc):
|
||||
p = proc(target=_do_parallel_data_prefetch, args=arguments[i])
|
||||
processes += [p]
|
||||
|
||||
# start processes
|
||||
print(f"Start prefetching...")
|
||||
import time
|
||||
|
||||
start = time.time()
|
||||
gather_res = [[] for _ in range(n_proc)]
|
||||
try:
|
||||
for p in processes:
|
||||
p.start()
|
||||
|
||||
k = 0
|
||||
while k < n_proc:
|
||||
# get result
|
||||
res = Q.get()
|
||||
if res == "Done":
|
||||
k += 1
|
||||
else:
|
||||
gather_res[res[0]] = res[1]
|
||||
|
||||
except Exception as e:
|
||||
print("Exception: ", e)
|
||||
for p in processes:
|
||||
p.terminate()
|
||||
|
||||
raise e
|
||||
finally:
|
||||
for p in processes:
|
||||
p.join()
|
||||
print(f"Prefetching complete. [{time.time() - start} sec.]")
|
||||
|
||||
if target_data_type == 'ndarray':
|
||||
if not isinstance(gather_res[0], np.ndarray):
|
||||
return np.concatenate([np.asarray(r) for r in gather_res], axis=0)
|
||||
|
||||
# order outputs
|
||||
return np.concatenate(gather_res, axis=0)
|
||||
elif target_data_type == 'list':
|
||||
out = []
|
||||
for r in gather_res:
|
||||
out.extend(r)
|
||||
return out
|
||||
else:
|
||||
return gather_res
|
|
@ -0,0 +1,830 @@
|
|||
import argparse, os, sys, datetime, glob, importlib, csv
|
||||
import numpy as np
|
||||
import time
|
||||
import torch
|
||||
import torchvision
|
||||
import pytorch_lightning as pl
|
||||
|
||||
from packaging import version
|
||||
from omegaconf import OmegaConf
|
||||
from torch.utils.data import random_split, DataLoader, Dataset, Subset
|
||||
from functools import partial
|
||||
from PIL import Image
|
||||
# from pytorch_lightning.strategies.colossalai import ColossalAIStrategy
|
||||
# from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR
|
||||
from colossalai.nn.optimizer import HybridAdam
|
||||
from prefetch_generator import BackgroundGenerator
|
||||
|
||||
from pytorch_lightning import seed_everything
|
||||
from pytorch_lightning.trainer import Trainer
|
||||
from pytorch_lightning.callbacks import ModelCheckpoint, Callback, LearningRateMonitor
|
||||
from pytorch_lightning.utilities.rank_zero import rank_zero_only
|
||||
from pytorch_lightning.utilities import rank_zero_info
|
||||
from diffusers.models.unet_2d import UNet2DModel
|
||||
|
||||
from clip.model import Bottleneck
|
||||
from transformers.models.clip.modeling_clip import CLIPTextTransformer
|
||||
|
||||
from ldm.data.base import Txt2ImgIterableBaseDataset
|
||||
from ldm.util import instantiate_from_config
|
||||
import clip
|
||||
from einops import rearrange, repeat
|
||||
from transformers import CLIPTokenizer, CLIPTextModel
|
||||
import kornia
|
||||
|
||||
from ldm.modules.x_transformer import *
|
||||
from ldm.modules.encoders.modules import *
|
||||
from taming.modules.diffusionmodules.model import ResnetBlock
|
||||
from taming.modules.transformer.mingpt import *
|
||||
from taming.modules.transformer.permuter import *
|
||||
|
||||
|
||||
from ldm.modules.ema import LitEma
|
||||
from ldm.modules.distributions.distributions import normal_kl, DiagonalGaussianDistribution
|
||||
from ldm.models.autoencoder import AutoencoderKL
|
||||
from ldm.models.autoencoder import *
|
||||
from ldm.models.diffusion.ddim import *
|
||||
from ldm.modules.diffusionmodules.openaimodel import *
|
||||
from ldm.modules.diffusionmodules.model import *
|
||||
from ldm.modules.diffusionmodules.model import Decoder, Encoder, Up_module, Down_module, Mid_module, temb_module
|
||||
from ldm.modules.attention import enable_flash_attention
|
||||
|
||||
class DataLoaderX(DataLoader):
|
||||
|
||||
def __iter__(self):
|
||||
return BackgroundGenerator(super().__iter__())
|
||||
|
||||
|
||||
def get_parser(**parser_kwargs):
|
||||
def str2bool(v):
|
||||
if isinstance(v, bool):
|
||||
return v
|
||||
if v.lower() in ("yes", "true", "t", "y", "1"):
|
||||
return True
|
||||
elif v.lower() in ("no", "false", "f", "n", "0"):
|
||||
return False
|
||||
else:
|
||||
raise argparse.ArgumentTypeError("Boolean value expected.")
|
||||
|
||||
parser = argparse.ArgumentParser(**parser_kwargs)
|
||||
parser.add_argument(
|
||||
"-n",
|
||||
"--name",
|
||||
type=str,
|
||||
const=True,
|
||||
default="",
|
||||
nargs="?",
|
||||
help="postfix for logdir",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-r",
|
||||
"--resume",
|
||||
type=str,
|
||||
const=True,
|
||||
default="",
|
||||
nargs="?",
|
||||
help="resume from logdir or checkpoint in logdir",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-b",
|
||||
"--base",
|
||||
nargs="*",
|
||||
metavar="base_config.yaml",
|
||||
help="paths to base configs. Loaded from left-to-right. "
|
||||
"Parameters can be overwritten or added with command-line options of the form `--key value`.",
|
||||
default=list(),
|
||||
)
|
||||
parser.add_argument(
|
||||
"-t",
|
||||
"--train",
|
||||
type=str2bool,
|
||||
const=True,
|
||||
default=False,
|
||||
nargs="?",
|
||||
help="train",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--no-test",
|
||||
type=str2bool,
|
||||
const=True,
|
||||
default=False,
|
||||
nargs="?",
|
||||
help="disable test",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-p",
|
||||
"--project",
|
||||
help="name of new or path to existing project"
|
||||
)
|
||||
parser.add_argument(
|
||||
"-d",
|
||||
"--debug",
|
||||
type=str2bool,
|
||||
nargs="?",
|
||||
const=True,
|
||||
default=False,
|
||||
help="enable post-mortem debugging",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-s",
|
||||
"--seed",
|
||||
type=int,
|
||||
default=23,
|
||||
help="seed for seed_everything",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-f",
|
||||
"--postfix",
|
||||
type=str,
|
||||
default="",
|
||||
help="post-postfix for default name",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-l",
|
||||
"--logdir",
|
||||
type=str,
|
||||
default="logs",
|
||||
help="directory for logging dat shit",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--scale_lr",
|
||||
type=str2bool,
|
||||
nargs="?",
|
||||
const=True,
|
||||
default=True,
|
||||
help="scale base-lr by ngpu * batch_size * n_accumulate",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--use_fp16",
|
||||
type=str2bool,
|
||||
nargs="?",
|
||||
const=True,
|
||||
default=True,
|
||||
help="whether to use fp16",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--flash",
|
||||
type=str2bool,
|
||||
const=True,
|
||||
default=False,
|
||||
nargs="?",
|
||||
help="whether to use flash attention",
|
||||
)
|
||||
return parser
|
||||
|
||||
|
||||
def nondefault_trainer_args(opt):
|
||||
parser = argparse.ArgumentParser()
|
||||
parser = Trainer.add_argparse_args(parser)
|
||||
args = parser.parse_args([])
|
||||
return sorted(k for k in vars(args) if getattr(opt, k) != getattr(args, k))
|
||||
|
||||
|
||||
class WrappedDataset(Dataset):
|
||||
"""Wraps an arbitrary object with __len__ and __getitem__ into a pytorch dataset"""
|
||||
|
||||
def __init__(self, dataset):
|
||||
self.data = dataset
|
||||
|
||||
def __len__(self):
|
||||
return len(self.data)
|
||||
|
||||
def __getitem__(self, idx):
|
||||
return self.data[idx]
|
||||
|
||||
|
||||
def worker_init_fn(_):
|
||||
worker_info = torch.utils.data.get_worker_info()
|
||||
|
||||
dataset = worker_info.dataset
|
||||
worker_id = worker_info.id
|
||||
|
||||
if isinstance(dataset, Txt2ImgIterableBaseDataset):
|
||||
split_size = dataset.num_records // worker_info.num_workers
|
||||
# reset num_records to the true number to retain reliable length information
|
||||
dataset.sample_ids = dataset.valid_ids[worker_id * split_size:(worker_id + 1) * split_size]
|
||||
current_id = np.random.choice(len(np.random.get_state()[1]), 1)
|
||||
return np.random.seed(np.random.get_state()[1][current_id] + worker_id)
|
||||
else:
|
||||
return np.random.seed(np.random.get_state()[1][0] + worker_id)
|
||||
|
||||
|
||||
class DataModuleFromConfig(pl.LightningDataModule):
|
||||
def __init__(self, batch_size, train=None, validation=None, test=None, predict=None,
|
||||
wrap=False, num_workers=None, shuffle_test_loader=False, use_worker_init_fn=False,
|
||||
shuffle_val_dataloader=False):
|
||||
super().__init__()
|
||||
self.batch_size = batch_size
|
||||
self.dataset_configs = dict()
|
||||
self.num_workers = num_workers if num_workers is not None else batch_size * 2
|
||||
self.use_worker_init_fn = use_worker_init_fn
|
||||
if train is not None:
|
||||
self.dataset_configs["train"] = train
|
||||
self.train_dataloader = self._train_dataloader
|
||||
if validation is not None:
|
||||
self.dataset_configs["validation"] = validation
|
||||
self.val_dataloader = partial(self._val_dataloader, shuffle=shuffle_val_dataloader)
|
||||
if test is not None:
|
||||
self.dataset_configs["test"] = test
|
||||
self.test_dataloader = partial(self._test_dataloader, shuffle=shuffle_test_loader)
|
||||
if predict is not None:
|
||||
self.dataset_configs["predict"] = predict
|
||||
self.predict_dataloader = self._predict_dataloader
|
||||
self.wrap = wrap
|
||||
|
||||
def prepare_data(self):
|
||||
for data_cfg in self.dataset_configs.values():
|
||||
instantiate_from_config(data_cfg)
|
||||
|
||||
def setup(self, stage=None):
|
||||
self.datasets = dict(
|
||||
(k, instantiate_from_config(self.dataset_configs[k]))
|
||||
for k in self.dataset_configs)
|
||||
if self.wrap:
|
||||
for k in self.datasets:
|
||||
self.datasets[k] = WrappedDataset(self.datasets[k])
|
||||
|
||||
def _train_dataloader(self):
|
||||
is_iterable_dataset = isinstance(self.datasets['train'], Txt2ImgIterableBaseDataset)
|
||||
if is_iterable_dataset or self.use_worker_init_fn:
|
||||
init_fn = worker_init_fn
|
||||
else:
|
||||
init_fn = None
|
||||
return DataLoaderX(self.datasets["train"], batch_size=self.batch_size,
|
||||
num_workers=self.num_workers, shuffle=False if is_iterable_dataset else True,
|
||||
worker_init_fn=init_fn)
|
||||
|
||||
def _val_dataloader(self, shuffle=False):
|
||||
if isinstance(self.datasets['validation'], Txt2ImgIterableBaseDataset) or self.use_worker_init_fn:
|
||||
init_fn = worker_init_fn
|
||||
else:
|
||||
init_fn = None
|
||||
return DataLoaderX(self.datasets["validation"],
|
||||
batch_size=self.batch_size,
|
||||
num_workers=self.num_workers,
|
||||
worker_init_fn=init_fn,
|
||||
shuffle=shuffle)
|
||||
|
||||
def _test_dataloader(self, shuffle=False):
|
||||
is_iterable_dataset = isinstance(self.datasets['train'], Txt2ImgIterableBaseDataset)
|
||||
if is_iterable_dataset or self.use_worker_init_fn:
|
||||
init_fn = worker_init_fn
|
||||
else:
|
||||
init_fn = None
|
||||
|
||||
# do not shuffle dataloader for iterable dataset
|
||||
shuffle = shuffle and (not is_iterable_dataset)
|
||||
|
||||
return DataLoaderX(self.datasets["test"], batch_size=self.batch_size,
|
||||
num_workers=self.num_workers, worker_init_fn=init_fn, shuffle=shuffle)
|
||||
|
||||
def _predict_dataloader(self, shuffle=False):
|
||||
if isinstance(self.datasets['predict'], Txt2ImgIterableBaseDataset) or self.use_worker_init_fn:
|
||||
init_fn = worker_init_fn
|
||||
else:
|
||||
init_fn = None
|
||||
return DataLoaderX(self.datasets["predict"], batch_size=self.batch_size,
|
||||
num_workers=self.num_workers, worker_init_fn=init_fn)
|
||||
|
||||
|
||||
class SetupCallback(Callback):
|
||||
def __init__(self, resume, now, logdir, ckptdir, cfgdir, config, lightning_config):
|
||||
super().__init__()
|
||||
self.resume = resume
|
||||
self.now = now
|
||||
self.logdir = logdir
|
||||
self.ckptdir = ckptdir
|
||||
self.cfgdir = cfgdir
|
||||
self.config = config
|
||||
self.lightning_config = lightning_config
|
||||
|
||||
def on_keyboard_interrupt(self, trainer, pl_module):
|
||||
if trainer.global_rank == 0:
|
||||
print("Summoning checkpoint.")
|
||||
ckpt_path = os.path.join(self.ckptdir, "last.ckpt")
|
||||
trainer.save_checkpoint(ckpt_path)
|
||||
|
||||
# def on_pretrain_routine_start(self, trainer, pl_module):
|
||||
def on_fit_start(self, trainer, pl_module):
|
||||
if trainer.global_rank == 0:
|
||||
# Create logdirs and save configs
|
||||
os.makedirs(self.logdir, exist_ok=True)
|
||||
os.makedirs(self.ckptdir, exist_ok=True)
|
||||
os.makedirs(self.cfgdir, exist_ok=True)
|
||||
|
||||
if "callbacks" in self.lightning_config:
|
||||
if 'metrics_over_trainsteps_checkpoint' in self.lightning_config['callbacks']:
|
||||
os.makedirs(os.path.join(self.ckptdir, 'trainstep_checkpoints'), exist_ok=True)
|
||||
print("Project config")
|
||||
print(OmegaConf.to_yaml(self.config))
|
||||
OmegaConf.save(self.config,
|
||||
os.path.join(self.cfgdir, "{}-project.yaml".format(self.now)))
|
||||
|
||||
print("Lightning config")
|
||||
print(OmegaConf.to_yaml(self.lightning_config))
|
||||
OmegaConf.save(OmegaConf.create({"lightning": self.lightning_config}),
|
||||
os.path.join(self.cfgdir, "{}-lightning.yaml".format(self.now)))
|
||||
|
||||
else:
|
||||
# ModelCheckpoint callback created log directory --- remove it
|
||||
if not self.resume and os.path.exists(self.logdir):
|
||||
dst, name = os.path.split(self.logdir)
|
||||
dst = os.path.join(dst, "child_runs", name)
|
||||
os.makedirs(os.path.split(dst)[0], exist_ok=True)
|
||||
try:
|
||||
os.rename(self.logdir, dst)
|
||||
except FileNotFoundError:
|
||||
pass
|
||||
|
||||
|
||||
class ImageLogger(Callback):
|
||||
def __init__(self, batch_frequency, max_images, clamp=True, increase_log_steps=True,
|
||||
rescale=True, disabled=False, log_on_batch_idx=False, log_first_step=False,
|
||||
log_images_kwargs=None):
|
||||
super().__init__()
|
||||
self.rescale = rescale
|
||||
self.batch_freq = batch_frequency
|
||||
self.max_images = max_images
|
||||
self.logger_log_images = {
|
||||
pl.loggers.CSVLogger: self._testtube,
|
||||
}
|
||||
self.log_steps = [2 ** n for n in range(int(np.log2(self.batch_freq)) + 1)]
|
||||
if not increase_log_steps:
|
||||
self.log_steps = [self.batch_freq]
|
||||
self.clamp = clamp
|
||||
self.disabled = disabled
|
||||
self.log_on_batch_idx = log_on_batch_idx
|
||||
self.log_images_kwargs = log_images_kwargs if log_images_kwargs else {}
|
||||
self.log_first_step = log_first_step
|
||||
|
||||
@rank_zero_only
|
||||
def _testtube(self, pl_module, images, batch_idx, split):
|
||||
for k in images:
|
||||
grid = torchvision.utils.make_grid(images[k])
|
||||
grid = (grid + 1.0) / 2.0 # -1,1 -> 0,1; c,h,w
|
||||
|
||||
tag = f"{split}/{k}"
|
||||
pl_module.logger.experiment.add_image(
|
||||
tag, grid,
|
||||
global_step=pl_module.global_step)
|
||||
|
||||
@rank_zero_only
|
||||
def log_local(self, save_dir, split, images,
|
||||
global_step, current_epoch, batch_idx):
|
||||
root = os.path.join(save_dir, "images", split)
|
||||
for k in images:
|
||||
grid = torchvision.utils.make_grid(images[k], nrow=4)
|
||||
if self.rescale:
|
||||
grid = (grid + 1.0) / 2.0 # -1,1 -> 0,1; c,h,w
|
||||
grid = grid.transpose(0, 1).transpose(1, 2).squeeze(-1)
|
||||
grid = grid.numpy()
|
||||
grid = (grid * 255).astype(np.uint8)
|
||||
filename = "{}_gs-{:06}_e-{:06}_b-{:06}.png".format(
|
||||
k,
|
||||
global_step,
|
||||
current_epoch,
|
||||
batch_idx)
|
||||
path = os.path.join(root, filename)
|
||||
os.makedirs(os.path.split(path)[0], exist_ok=True)
|
||||
Image.fromarray(grid).save(path)
|
||||
|
||||
def log_img(self, pl_module, batch, batch_idx, split="train"):
|
||||
check_idx = batch_idx if self.log_on_batch_idx else pl_module.global_step
|
||||
if (self.check_frequency(check_idx) and # batch_idx % self.batch_freq == 0
|
||||
hasattr(pl_module, "log_images") and
|
||||
callable(pl_module.log_images) and
|
||||
self.max_images > 0):
|
||||
logger = type(pl_module.logger)
|
||||
|
||||
is_train = pl_module.training
|
||||
if is_train:
|
||||
pl_module.eval()
|
||||
|
||||
with torch.no_grad():
|
||||
images = pl_module.log_images(batch, split=split, **self.log_images_kwargs)
|
||||
|
||||
for k in images:
|
||||
N = min(images[k].shape[0], self.max_images)
|
||||
images[k] = images[k][:N]
|
||||
if isinstance(images[k], torch.Tensor):
|
||||
images[k] = images[k].detach().cpu()
|
||||
if self.clamp:
|
||||
images[k] = torch.clamp(images[k], -1., 1.)
|
||||
|
||||
self.log_local(pl_module.logger.save_dir, split, images,
|
||||
pl_module.global_step, pl_module.current_epoch, batch_idx)
|
||||
|
||||
logger_log_images = self.logger_log_images.get(logger, lambda *args, **kwargs: None)
|
||||
logger_log_images(pl_module, images, pl_module.global_step, split)
|
||||
|
||||
if is_train:
|
||||
pl_module.train()
|
||||
|
||||
def check_frequency(self, check_idx):
|
||||
if ((check_idx % self.batch_freq) == 0 or (check_idx in self.log_steps)) and (
|
||||
check_idx > 0 or self.log_first_step):
|
||||
try:
|
||||
self.log_steps.pop(0)
|
||||
except IndexError as e:
|
||||
print(e)
|
||||
pass
|
||||
return True
|
||||
return False
|
||||
|
||||
def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx):
|
||||
# if not self.disabled and (pl_module.global_step > 0 or self.log_first_step):
|
||||
# self.log_img(pl_module, batch, batch_idx, split="train")
|
||||
pass
|
||||
|
||||
def on_validation_batch_end(self, trainer, pl_module, outputs, batch, batch_idx):
|
||||
if not self.disabled and pl_module.global_step > 0:
|
||||
self.log_img(pl_module, batch, batch_idx, split="val")
|
||||
if hasattr(pl_module, 'calibrate_grad_norm'):
|
||||
if (pl_module.calibrate_grad_norm and batch_idx % 25 == 0) and batch_idx > 0:
|
||||
self.log_gradients(trainer, pl_module, batch_idx=batch_idx)
|
||||
|
||||
|
||||
class CUDACallback(Callback):
|
||||
# see https://github.com/SeanNaren/minGPT/blob/master/mingpt/callback.py
|
||||
|
||||
def on_train_start(self, trainer, pl_module):
|
||||
rank_zero_info("Training is starting")
|
||||
|
||||
def on_train_end(self, trainer, pl_module):
|
||||
rank_zero_info("Training is ending")
|
||||
|
||||
def on_train_epoch_start(self, trainer, pl_module):
|
||||
# Reset the memory use counter
|
||||
torch.cuda.reset_peak_memory_stats(trainer.strategy.root_device.index)
|
||||
torch.cuda.synchronize(trainer.strategy.root_device.index)
|
||||
self.start_time = time.time()
|
||||
|
||||
def on_train_epoch_end(self, trainer, pl_module):
|
||||
torch.cuda.synchronize(trainer.strategy.root_device.index)
|
||||
max_memory = torch.cuda.max_memory_allocated(trainer.strategy.root_device.index) / 2 ** 20
|
||||
epoch_time = time.time() - self.start_time
|
||||
|
||||
try:
|
||||
max_memory = trainer.strategy.reduce(max_memory)
|
||||
epoch_time = trainer.strategy.reduce(epoch_time)
|
||||
|
||||
rank_zero_info(f"Average Epoch time: {epoch_time:.2f} seconds")
|
||||
rank_zero_info(f"Average Peak memory {max_memory:.2f}MiB")
|
||||
except AttributeError:
|
||||
pass
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# custom parser to specify config files, train, test and debug mode,
|
||||
# postfix, resume.
|
||||
# `--key value` arguments are interpreted as arguments to the trainer.
|
||||
# `nested.key=value` arguments are interpreted as config parameters.
|
||||
# configs are merged from left-to-right followed by command line parameters.
|
||||
|
||||
# model:
|
||||
# base_learning_rate: float
|
||||
# target: path to lightning module
|
||||
# params:
|
||||
# key: value
|
||||
# data:
|
||||
# target: main.DataModuleFromConfig
|
||||
# params:
|
||||
# batch_size: int
|
||||
# wrap: bool
|
||||
# train:
|
||||
# target: path to train dataset
|
||||
# params:
|
||||
# key: value
|
||||
# validation:
|
||||
# target: path to validation dataset
|
||||
# params:
|
||||
# key: value
|
||||
# test:
|
||||
# target: path to test dataset
|
||||
# params:
|
||||
# key: value
|
||||
# lightning: (optional, has sane defaults and can be specified on cmdline)
|
||||
# trainer:
|
||||
# additional arguments to trainer
|
||||
# logger:
|
||||
# logger to instantiate
|
||||
# modelcheckpoint:
|
||||
# modelcheckpoint to instantiate
|
||||
# callbacks:
|
||||
# callback1:
|
||||
# target: importpath
|
||||
# params:
|
||||
# key: value
|
||||
|
||||
now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S")
|
||||
|
||||
# add cwd for convenience and to make classes in this file available when
|
||||
# running as `python main.py`
|
||||
# (in particular `main.DataModuleFromConfig`)
|
||||
sys.path.append(os.getcwd())
|
||||
|
||||
parser = get_parser()
|
||||
parser = Trainer.add_argparse_args(parser)
|
||||
|
||||
opt, unknown = parser.parse_known_args()
|
||||
if opt.name and opt.resume:
|
||||
raise ValueError(
|
||||
"-n/--name and -r/--resume cannot be specified both."
|
||||
"If you want to resume training in a new log folder, "
|
||||
"use -n/--name in combination with --resume_from_checkpoint"
|
||||
)
|
||||
if opt.flash:
|
||||
enable_flash_attention()
|
||||
if opt.resume:
|
||||
if not os.path.exists(opt.resume):
|
||||
raise ValueError("Cannot find {}".format(opt.resume))
|
||||
if os.path.isfile(opt.resume):
|
||||
paths = opt.resume.split("/")
|
||||
# idx = len(paths)-paths[::-1].index("logs")+1
|
||||
# logdir = "/".join(paths[:idx])
|
||||
logdir = "/".join(paths[:-2])
|
||||
ckpt = opt.resume
|
||||
else:
|
||||
assert os.path.isdir(opt.resume), opt.resume
|
||||
logdir = opt.resume.rstrip("/")
|
||||
ckpt = os.path.join(logdir, "checkpoints", "last.ckpt")
|
||||
|
||||
opt.resume_from_checkpoint = ckpt
|
||||
base_configs = sorted(glob.glob(os.path.join(logdir, "configs/*.yaml")))
|
||||
opt.base = base_configs + opt.base
|
||||
_tmp = logdir.split("/")
|
||||
nowname = _tmp[-1]
|
||||
else:
|
||||
if opt.name:
|
||||
name = "_" + opt.name
|
||||
elif opt.base:
|
||||
cfg_fname = os.path.split(opt.base[0])[-1]
|
||||
cfg_name = os.path.splitext(cfg_fname)[0]
|
||||
name = "_" + cfg_name
|
||||
else:
|
||||
name = ""
|
||||
nowname = now + name + opt.postfix
|
||||
logdir = os.path.join(opt.logdir, nowname)
|
||||
|
||||
ckptdir = os.path.join(logdir, "checkpoints")
|
||||
cfgdir = os.path.join(logdir, "configs")
|
||||
seed_everything(opt.seed)
|
||||
|
||||
try:
|
||||
# init and save configs
|
||||
configs = [OmegaConf.load(cfg) for cfg in opt.base]
|
||||
cli = OmegaConf.from_dotlist(unknown)
|
||||
config = OmegaConf.merge(*configs, cli)
|
||||
lightning_config = config.pop("lightning", OmegaConf.create())
|
||||
# merge trainer cli with config
|
||||
trainer_config = lightning_config.get("trainer", OmegaConf.create())
|
||||
|
||||
for k in nondefault_trainer_args(opt):
|
||||
trainer_config[k] = getattr(opt, k)
|
||||
|
||||
print(trainer_config)
|
||||
if not trainer_config["accelerator"] == "gpu":
|
||||
del trainer_config["accelerator"]
|
||||
cpu = True
|
||||
print("Running on CPU")
|
||||
else:
|
||||
cpu = False
|
||||
print("Running on GPU")
|
||||
trainer_opt = argparse.Namespace(**trainer_config)
|
||||
lightning_config.trainer = trainer_config
|
||||
|
||||
# model
|
||||
use_fp16 = trainer_config.get("precision", 32) == 16
|
||||
if use_fp16:
|
||||
config.model["params"].update({"use_fp16": True})
|
||||
print("Using FP16 = {}".format(config.model["params"]["use_fp16"]))
|
||||
else:
|
||||
config.model["params"].update({"use_fp16": False})
|
||||
print("Using FP16 = {}".format(config.model["params"]["use_fp16"]))
|
||||
|
||||
model = instantiate_from_config(config.model)
|
||||
# trainer and callbacks
|
||||
trainer_kwargs = dict()
|
||||
|
||||
# config the logger
|
||||
# default logger configs
|
||||
default_logger_cfgs = {
|
||||
"wandb": {
|
||||
"target": "pytorch_lightning.loggers.WandbLogger",
|
||||
"params": {
|
||||
"name": nowname,
|
||||
"save_dir": logdir,
|
||||
"offline": opt.debug,
|
||||
"id": nowname,
|
||||
}
|
||||
},
|
||||
"tensorboard":{
|
||||
"target": "pytorch_lightning.loggers.TensorBoardLogger",
|
||||
"params":{
|
||||
"save_dir": logdir,
|
||||
"name": "diff_tb",
|
||||
"log_graph": True
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
default_logger_cfg = default_logger_cfgs["tensorboard"]
|
||||
if "logger" in lightning_config:
|
||||
logger_cfg = lightning_config.logger
|
||||
else:
|
||||
logger_cfg = default_logger_cfg
|
||||
logger_cfg = OmegaConf.merge(default_logger_cfg, logger_cfg)
|
||||
trainer_kwargs["logger"] = instantiate_from_config(logger_cfg)
|
||||
|
||||
# config the strategy, defualt is ddp
|
||||
if "strategy" in trainer_config:
|
||||
strategy_cfg = trainer_config["strategy"]
|
||||
print("Using strategy: {}".format(strategy_cfg["target"]))
|
||||
else:
|
||||
strategy_cfg = {
|
||||
"target": "pytorch_lightning.strategies.DDPStrategy",
|
||||
"params": {
|
||||
"find_unused_parameters": False
|
||||
}
|
||||
}
|
||||
print("Using strategy: DDPStrategy")
|
||||
|
||||
trainer_kwargs["strategy"] = instantiate_from_config(strategy_cfg)
|
||||
|
||||
# modelcheckpoint - use TrainResult/EvalResult(checkpoint_on=metric) to
|
||||
# specify which metric is used to determine best models
|
||||
default_modelckpt_cfg = {
|
||||
"target": "pytorch_lightning.callbacks.ModelCheckpoint",
|
||||
"params": {
|
||||
"dirpath": ckptdir,
|
||||
"filename": "{epoch:06}",
|
||||
"verbose": True,
|
||||
"save_last": True,
|
||||
}
|
||||
}
|
||||
if hasattr(model, "monitor"):
|
||||
print(f"Monitoring {model.monitor} as checkpoint metric.")
|
||||
default_modelckpt_cfg["params"]["monitor"] = model.monitor
|
||||
default_modelckpt_cfg["params"]["save_top_k"] = 3
|
||||
|
||||
if "modelcheckpoint" in lightning_config:
|
||||
modelckpt_cfg = lightning_config.modelcheckpoint
|
||||
else:
|
||||
modelckpt_cfg = OmegaConf.create()
|
||||
modelckpt_cfg = OmegaConf.merge(default_modelckpt_cfg, modelckpt_cfg)
|
||||
print(f"Merged modelckpt-cfg: \n{modelckpt_cfg}")
|
||||
if version.parse(pl.__version__) < version.parse('1.4.0'):
|
||||
trainer_kwargs["checkpoint_callback"] = instantiate_from_config(modelckpt_cfg)
|
||||
|
||||
# add callback which sets up log directory
|
||||
default_callbacks_cfg = {
|
||||
"setup_callback": {
|
||||
"target": "main.SetupCallback",
|
||||
"params": {
|
||||
"resume": opt.resume,
|
||||
"now": now,
|
||||
"logdir": logdir,
|
||||
"ckptdir": ckptdir,
|
||||
"cfgdir": cfgdir,
|
||||
"config": config,
|
||||
"lightning_config": lightning_config,
|
||||
}
|
||||
},
|
||||
"image_logger": {
|
||||
"target": "main.ImageLogger",
|
||||
"params": {
|
||||
"batch_frequency": 750,
|
||||
"max_images": 4,
|
||||
"clamp": True
|
||||
}
|
||||
},
|
||||
"learning_rate_logger": {
|
||||
"target": "main.LearningRateMonitor",
|
||||
"params": {
|
||||
"logging_interval": "step",
|
||||
# "log_momentum": True
|
||||
}
|
||||
},
|
||||
"cuda_callback": {
|
||||
"target": "main.CUDACallback"
|
||||
},
|
||||
}
|
||||
if version.parse(pl.__version__) >= version.parse('1.4.0'):
|
||||
default_callbacks_cfg.update({'checkpoint_callback': modelckpt_cfg})
|
||||
|
||||
if "callbacks" in lightning_config:
|
||||
callbacks_cfg = lightning_config.callbacks
|
||||
else:
|
||||
callbacks_cfg = OmegaConf.create()
|
||||
|
||||
if 'metrics_over_trainsteps_checkpoint' in callbacks_cfg:
|
||||
print(
|
||||
'Caution: Saving checkpoints every n train steps without deleting. This might require some free space.')
|
||||
default_metrics_over_trainsteps_ckpt_dict = {
|
||||
'metrics_over_trainsteps_checkpoint':
|
||||
{"target": 'pytorch_lightning.callbacks.ModelCheckpoint',
|
||||
'params': {
|
||||
"dirpath": os.path.join(ckptdir, 'trainstep_checkpoints'),
|
||||
"filename": "{epoch:06}-{step:09}",
|
||||
"verbose": True,
|
||||
'save_top_k': -1,
|
||||
'every_n_train_steps': 10000,
|
||||
'save_weights_only': True
|
||||
}
|
||||
}
|
||||
}
|
||||
default_callbacks_cfg.update(default_metrics_over_trainsteps_ckpt_dict)
|
||||
|
||||
callbacks_cfg = OmegaConf.merge(default_callbacks_cfg, callbacks_cfg)
|
||||
if 'ignore_keys_callback' in callbacks_cfg and hasattr(trainer_opt, 'resume_from_checkpoint'):
|
||||
callbacks_cfg.ignore_keys_callback.params['ckpt_path'] = trainer_opt.resume_from_checkpoint
|
||||
elif 'ignore_keys_callback' in callbacks_cfg:
|
||||
del callbacks_cfg['ignore_keys_callback']
|
||||
|
||||
trainer_kwargs["callbacks"] = [instantiate_from_config(callbacks_cfg[k]) for k in callbacks_cfg]
|
||||
|
||||
trainer = Trainer.from_argparse_args(trainer_opt, **trainer_kwargs)
|
||||
trainer.logdir = logdir ###
|
||||
|
||||
# data
|
||||
data = instantiate_from_config(config.data)
|
||||
# NOTE according to https://pytorch-lightning.readthedocs.io/en/latest/datamodules.html
|
||||
# calling these ourselves should not be necessary but it is.
|
||||
# lightning still takes care of proper multiprocessing though
|
||||
data.prepare_data()
|
||||
data.setup()
|
||||
print("#### Data #####")
|
||||
for k in data.datasets:
|
||||
print(f"{k}, {data.datasets[k].__class__.__name__}, {len(data.datasets[k])}")
|
||||
|
||||
# configure learning rate
|
||||
bs, base_lr = config.data.params.batch_size, config.model.base_learning_rate
|
||||
if not cpu:
|
||||
ngpu = trainer_config["devices"]
|
||||
else:
|
||||
ngpu = 1
|
||||
if 'accumulate_grad_batches' in lightning_config.trainer:
|
||||
accumulate_grad_batches = lightning_config.trainer.accumulate_grad_batches
|
||||
else:
|
||||
accumulate_grad_batches = 1
|
||||
print(f"accumulate_grad_batches = {accumulate_grad_batches}")
|
||||
lightning_config.trainer.accumulate_grad_batches = accumulate_grad_batches
|
||||
if opt.scale_lr:
|
||||
model.learning_rate = accumulate_grad_batches * ngpu * bs * base_lr
|
||||
print(
|
||||
"Setting learning rate to {:.2e} = {} (accumulate_grad_batches) * {} (num_gpus) * {} (batchsize) * {:.2e} (base_lr)".format(
|
||||
model.learning_rate, accumulate_grad_batches, ngpu, bs, base_lr))
|
||||
else:
|
||||
model.learning_rate = base_lr
|
||||
print("++++ NOT USING LR SCALING ++++")
|
||||
print(f"Setting learning rate to {model.learning_rate:.2e}")
|
||||
|
||||
|
||||
# allow checkpointing via USR1
|
||||
def melk(*args, **kwargs):
|
||||
# run all checkpoint hooks
|
||||
if trainer.global_rank == 0:
|
||||
print("Summoning checkpoint.")
|
||||
ckpt_path = os.path.join(ckptdir, "last.ckpt")
|
||||
trainer.save_checkpoint(ckpt_path)
|
||||
|
||||
|
||||
def divein(*args, **kwargs):
|
||||
if trainer.global_rank == 0:
|
||||
import pudb;
|
||||
pudb.set_trace()
|
||||
|
||||
|
||||
import signal
|
||||
|
||||
signal.signal(signal.SIGUSR1, melk)
|
||||
signal.signal(signal.SIGUSR2, divein)
|
||||
|
||||
# run
|
||||
if opt.train:
|
||||
try:
|
||||
for name, m in model.named_parameters():
|
||||
print(name)
|
||||
trainer.fit(model, data)
|
||||
except Exception:
|
||||
melk()
|
||||
raise
|
||||
# if not opt.no_test and not trainer.interrupted:
|
||||
# trainer.test(model, data)
|
||||
except Exception:
|
||||
if opt.debug and trainer.global_rank == 0:
|
||||
try:
|
||||
import pudb as debugger
|
||||
except ImportError:
|
||||
import pdb as debugger
|
||||
debugger.post_mortem()
|
||||
raise
|
||||
finally:
|
||||
# move newly created debug project to debug_runs
|
||||
if opt.debug and not opt.resume and trainer.global_rank == 0:
|
||||
dst, name = os.path.split(logdir)
|
||||
dst = os.path.join(dst, "debug_runs", name)
|
||||
os.makedirs(os.path.split(dst)[0], exist_ok=True)
|
||||
os.rename(logdir, dst)
|
||||
if trainer.global_rank == 0:
|
||||
print(trainer.profiler.summary())
|
|
@ -0,0 +1,21 @@
|
|||
albumentations==0.4.3
|
||||
diffusers
|
||||
opencv-python==4.1.2.30
|
||||
pudb==2019.2
|
||||
invisible-watermark
|
||||
imageio==2.9.0
|
||||
imageio-ffmpeg==0.4.2
|
||||
omegaconf==2.1.1
|
||||
test-tube>=0.7.5
|
||||
streamlit>=0.73.1
|
||||
einops==0.3.0
|
||||
torch-fidelity==0.3.0
|
||||
transformers==4.19.2
|
||||
torchmetrics==0.6.0
|
||||
kornia==0.6
|
||||
deepspeed==0.7.4
|
||||
opencv-python==4.6.0.66
|
||||
prefetch_generator
|
||||
-e git+https://github.com/CompVis/taming-transformers.git@master#egg=taming-transformers
|
||||
-e git+https://github.com/openai/CLIP.git@main#egg=clip
|
||||
-e .
|
|
@ -0,0 +1,41 @@
|
|||
#!/bin/bash
|
||||
wget -O models/first_stage_models/kl-f4/model.zip https://ommer-lab.com/files/latent-diffusion/kl-f4.zip
|
||||
wget -O models/first_stage_models/kl-f8/model.zip https://ommer-lab.com/files/latent-diffusion/kl-f8.zip
|
||||
wget -O models/first_stage_models/kl-f16/model.zip https://ommer-lab.com/files/latent-diffusion/kl-f16.zip
|
||||
wget -O models/first_stage_models/kl-f32/model.zip https://ommer-lab.com/files/latent-diffusion/kl-f32.zip
|
||||
wget -O models/first_stage_models/vq-f4/model.zip https://ommer-lab.com/files/latent-diffusion/vq-f4.zip
|
||||
wget -O models/first_stage_models/vq-f4-noattn/model.zip https://ommer-lab.com/files/latent-diffusion/vq-f4-noattn.zip
|
||||
wget -O models/first_stage_models/vq-f8/model.zip https://ommer-lab.com/files/latent-diffusion/vq-f8.zip
|
||||
wget -O models/first_stage_models/vq-f8-n256/model.zip https://ommer-lab.com/files/latent-diffusion/vq-f8-n256.zip
|
||||
wget -O models/first_stage_models/vq-f16/model.zip https://ommer-lab.com/files/latent-diffusion/vq-f16.zip
|
||||
|
||||
|
||||
|
||||
cd models/first_stage_models/kl-f4
|
||||
unzip -o model.zip
|
||||
|
||||
cd ../kl-f8
|
||||
unzip -o model.zip
|
||||
|
||||
cd ../kl-f16
|
||||
unzip -o model.zip
|
||||
|
||||
cd ../kl-f32
|
||||
unzip -o model.zip
|
||||
|
||||
cd ../vq-f4
|
||||
unzip -o model.zip
|
||||
|
||||
cd ../vq-f4-noattn
|
||||
unzip -o model.zip
|
||||
|
||||
cd ../vq-f8
|
||||
unzip -o model.zip
|
||||
|
||||
cd ../vq-f8-n256
|
||||
unzip -o model.zip
|
||||
|
||||
cd ../vq-f16
|
||||
unzip -o model.zip
|
||||
|
||||
cd ../..
|
|
@ -0,0 +1,49 @@
|
|||
#!/bin/bash
|
||||
wget -O models/ldm/celeba256/celeba-256.zip https://ommer-lab.com/files/latent-diffusion/celeba.zip
|
||||
wget -O models/ldm/ffhq256/ffhq-256.zip https://ommer-lab.com/files/latent-diffusion/ffhq.zip
|
||||
wget -O models/ldm/lsun_churches256/lsun_churches-256.zip https://ommer-lab.com/files/latent-diffusion/lsun_churches.zip
|
||||
wget -O models/ldm/lsun_beds256/lsun_beds-256.zip https://ommer-lab.com/files/latent-diffusion/lsun_bedrooms.zip
|
||||
wget -O models/ldm/text2img256/model.zip https://ommer-lab.com/files/latent-diffusion/text2img.zip
|
||||
wget -O models/ldm/cin256/model.zip https://ommer-lab.com/files/latent-diffusion/cin.zip
|
||||
wget -O models/ldm/semantic_synthesis512/model.zip https://ommer-lab.com/files/latent-diffusion/semantic_synthesis.zip
|
||||
wget -O models/ldm/semantic_synthesis256/model.zip https://ommer-lab.com/files/latent-diffusion/semantic_synthesis256.zip
|
||||
wget -O models/ldm/bsr_sr/model.zip https://ommer-lab.com/files/latent-diffusion/sr_bsr.zip
|
||||
wget -O models/ldm/layout2img-openimages256/model.zip https://ommer-lab.com/files/latent-diffusion/layout2img_model.zip
|
||||
wget -O models/ldm/inpainting_big/model.zip https://ommer-lab.com/files/latent-diffusion/inpainting_big.zip
|
||||
|
||||
|
||||
|
||||
cd models/ldm/celeba256
|
||||
unzip -o celeba-256.zip
|
||||
|
||||
cd ../ffhq256
|
||||
unzip -o ffhq-256.zip
|
||||
|
||||
cd ../lsun_churches256
|
||||
unzip -o lsun_churches-256.zip
|
||||
|
||||
cd ../lsun_beds256
|
||||
unzip -o lsun_beds-256.zip
|
||||
|
||||
cd ../text2img256
|
||||
unzip -o model.zip
|
||||
|
||||
cd ../cin256
|
||||
unzip -o model.zip
|
||||
|
||||
cd ../semantic_synthesis512
|
||||
unzip -o model.zip
|
||||
|
||||
cd ../semantic_synthesis256
|
||||
unzip -o model.zip
|
||||
|
||||
cd ../bsr_sr
|
||||
unzip -o model.zip
|
||||
|
||||
cd ../layout2img-openimages256
|
||||
unzip -o model.zip
|
||||
|
||||
cd ../inpainting_big
|
||||
unzip -o model.zip
|
||||
|
||||
cd ../..
|
|
@ -0,0 +1,293 @@
|
|||
"""make variations of input image"""
|
||||
|
||||
import argparse, os, sys, glob
|
||||
import PIL
|
||||
import torch
|
||||
import numpy as np
|
||||
from omegaconf import OmegaConf
|
||||
from PIL import Image
|
||||
from tqdm import tqdm, trange
|
||||
from itertools import islice
|
||||
from einops import rearrange, repeat
|
||||
from torchvision.utils import make_grid
|
||||
from torch import autocast
|
||||
from contextlib import nullcontext
|
||||
import time
|
||||
from pytorch_lightning import seed_everything
|
||||
|
||||
from ldm.util import instantiate_from_config
|
||||
from ldm.models.diffusion.ddim import DDIMSampler
|
||||
from ldm.models.diffusion.plms import PLMSSampler
|
||||
|
||||
|
||||
def chunk(it, size):
|
||||
it = iter(it)
|
||||
return iter(lambda: tuple(islice(it, size)), ())
|
||||
|
||||
|
||||
def load_model_from_config(config, ckpt, verbose=False):
|
||||
print(f"Loading model from {ckpt}")
|
||||
pl_sd = torch.load(ckpt, map_location="cpu")
|
||||
if "global_step" in pl_sd:
|
||||
print(f"Global Step: {pl_sd['global_step']}")
|
||||
sd = pl_sd["state_dict"]
|
||||
model = instantiate_from_config(config.model)
|
||||
m, u = model.load_state_dict(sd, strict=False)
|
||||
if len(m) > 0 and verbose:
|
||||
print("missing keys:")
|
||||
print(m)
|
||||
if len(u) > 0 and verbose:
|
||||
print("unexpected keys:")
|
||||
print(u)
|
||||
|
||||
model.cuda()
|
||||
model.eval()
|
||||
return model
|
||||
|
||||
|
||||
def load_img(path):
|
||||
image = Image.open(path).convert("RGB")
|
||||
w, h = image.size
|
||||
print(f"loaded input image of size ({w}, {h}) from {path}")
|
||||
w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32
|
||||
image = image.resize((w, h), resample=PIL.Image.LANCZOS)
|
||||
image = np.array(image).astype(np.float32) / 255.0
|
||||
image = image[None].transpose(0, 3, 1, 2)
|
||||
image = torch.from_numpy(image)
|
||||
return 2.*image - 1.
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
parser.add_argument(
|
||||
"--prompt",
|
||||
type=str,
|
||||
nargs="?",
|
||||
default="a painting of a virus monster playing guitar",
|
||||
help="the prompt to render"
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--init-img",
|
||||
type=str,
|
||||
nargs="?",
|
||||
help="path to the input image"
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--outdir",
|
||||
type=str,
|
||||
nargs="?",
|
||||
help="dir to write results to",
|
||||
default="outputs/img2img-samples"
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--skip_grid",
|
||||
action='store_true',
|
||||
help="do not save a grid, only individual samples. Helpful when evaluating lots of samples",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--skip_save",
|
||||
action='store_true',
|
||||
help="do not save indiviual samples. For speed measurements.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--ddim_steps",
|
||||
type=int,
|
||||
default=50,
|
||||
help="number of ddim sampling steps",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--plms",
|
||||
action='store_true',
|
||||
help="use plms sampling",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--fixed_code",
|
||||
action='store_true',
|
||||
help="if enabled, uses the same starting code across all samples ",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--ddim_eta",
|
||||
type=float,
|
||||
default=0.0,
|
||||
help="ddim eta (eta=0.0 corresponds to deterministic sampling",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--n_iter",
|
||||
type=int,
|
||||
default=1,
|
||||
help="sample this often",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--C",
|
||||
type=int,
|
||||
default=4,
|
||||
help="latent channels",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--f",
|
||||
type=int,
|
||||
default=8,
|
||||
help="downsampling factor, most often 8 or 16",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--n_samples",
|
||||
type=int,
|
||||
default=2,
|
||||
help="how many samples to produce for each given prompt. A.k.a batch size",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--n_rows",
|
||||
type=int,
|
||||
default=0,
|
||||
help="rows in the grid (default: n_samples)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--scale",
|
||||
type=float,
|
||||
default=5.0,
|
||||
help="unconditional guidance scale: eps = eps(x, empty) + scale * (eps(x, cond) - eps(x, empty))",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--strength",
|
||||
type=float,
|
||||
default=0.75,
|
||||
help="strength for noising/unnoising. 1.0 corresponds to full destruction of information in init image",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--from-file",
|
||||
type=str,
|
||||
help="if specified, load prompts from this file",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--config",
|
||||
type=str,
|
||||
default="configs/stable-diffusion/v1-inference.yaml",
|
||||
help="path to config which constructs model",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--ckpt",
|
||||
type=str,
|
||||
default="models/ldm/stable-diffusion-v1/model.ckpt",
|
||||
help="path to checkpoint of model",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--seed",
|
||||
type=int,
|
||||
default=42,
|
||||
help="the seed (for reproducible sampling)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--precision",
|
||||
type=str,
|
||||
help="evaluate at this precision",
|
||||
choices=["full", "autocast"],
|
||||
default="autocast"
|
||||
)
|
||||
|
||||
opt = parser.parse_args()
|
||||
seed_everything(opt.seed)
|
||||
|
||||
config = OmegaConf.load(f"{opt.config}")
|
||||
model = load_model_from_config(config, f"{opt.ckpt}")
|
||||
|
||||
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
|
||||
model = model.to(device)
|
||||
|
||||
if opt.plms:
|
||||
raise NotImplementedError("PLMS sampler not (yet) supported")
|
||||
sampler = PLMSSampler(model)
|
||||
else:
|
||||
sampler = DDIMSampler(model)
|
||||
|
||||
os.makedirs(opt.outdir, exist_ok=True)
|
||||
outpath = opt.outdir
|
||||
|
||||
batch_size = opt.n_samples
|
||||
n_rows = opt.n_rows if opt.n_rows > 0 else batch_size
|
||||
if not opt.from_file:
|
||||
prompt = opt.prompt
|
||||
assert prompt is not None
|
||||
data = [batch_size * [prompt]]
|
||||
|
||||
else:
|
||||
print(f"reading prompts from {opt.from_file}")
|
||||
with open(opt.from_file, "r") as f:
|
||||
data = f.read().splitlines()
|
||||
data = list(chunk(data, batch_size))
|
||||
|
||||
sample_path = os.path.join(outpath, "samples")
|
||||
os.makedirs(sample_path, exist_ok=True)
|
||||
base_count = len(os.listdir(sample_path))
|
||||
grid_count = len(os.listdir(outpath)) - 1
|
||||
|
||||
assert os.path.isfile(opt.init_img)
|
||||
init_image = load_img(opt.init_img).to(device)
|
||||
init_image = repeat(init_image, '1 ... -> b ...', b=batch_size)
|
||||
init_latent = model.get_first_stage_encoding(model.encode_first_stage(init_image)) # move to latent space
|
||||
|
||||
sampler.make_schedule(ddim_num_steps=opt.ddim_steps, ddim_eta=opt.ddim_eta, verbose=False)
|
||||
|
||||
assert 0. <= opt.strength <= 1., 'can only work with strength in [0.0, 1.0]'
|
||||
t_enc = int(opt.strength * opt.ddim_steps)
|
||||
print(f"target t_enc is {t_enc} steps")
|
||||
|
||||
precision_scope = autocast if opt.precision == "autocast" else nullcontext
|
||||
with torch.no_grad():
|
||||
with precision_scope("cuda"):
|
||||
with model.ema_scope():
|
||||
tic = time.time()
|
||||
all_samples = list()
|
||||
for n in trange(opt.n_iter, desc="Sampling"):
|
||||
for prompts in tqdm(data, desc="data"):
|
||||
uc = None
|
||||
if opt.scale != 1.0:
|
||||
uc = model.get_learned_conditioning(batch_size * [""])
|
||||
if isinstance(prompts, tuple):
|
||||
prompts = list(prompts)
|
||||
c = model.get_learned_conditioning(prompts)
|
||||
|
||||
# encode (scaled latent)
|
||||
z_enc = sampler.stochastic_encode(init_latent, torch.tensor([t_enc]*batch_size).to(device))
|
||||
# decode it
|
||||
samples = sampler.decode(z_enc, c, t_enc, unconditional_guidance_scale=opt.scale,
|
||||
unconditional_conditioning=uc,)
|
||||
|
||||
x_samples = model.decode_first_stage(samples)
|
||||
x_samples = torch.clamp((x_samples + 1.0) / 2.0, min=0.0, max=1.0)
|
||||
|
||||
if not opt.skip_save:
|
||||
for x_sample in x_samples:
|
||||
x_sample = 255. * rearrange(x_sample.cpu().numpy(), 'c h w -> h w c')
|
||||
Image.fromarray(x_sample.astype(np.uint8)).save(
|
||||
os.path.join(sample_path, f"{base_count:05}.png"))
|
||||
base_count += 1
|
||||
all_samples.append(x_samples)
|
||||
|
||||
if not opt.skip_grid:
|
||||
# additionally, save as grid
|
||||
grid = torch.stack(all_samples, 0)
|
||||
grid = rearrange(grid, 'n b c h w -> (n b) c h w')
|
||||
grid = make_grid(grid, nrow=n_rows)
|
||||
|
||||
# to image
|
||||
grid = 255. * rearrange(grid, 'c h w -> h w c').cpu().numpy()
|
||||
Image.fromarray(grid.astype(np.uint8)).save(os.path.join(outpath, f'grid-{grid_count:04}.png'))
|
||||
grid_count += 1
|
||||
|
||||
toc = time.time()
|
||||
|
||||
print(f"Your samples are ready and waiting for you here: \n{outpath} \n"
|
||||
f" \nEnjoy.")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
|
@ -0,0 +1,98 @@
|
|||
import argparse, os, sys, glob
|
||||
from omegaconf import OmegaConf
|
||||
from PIL import Image
|
||||
from tqdm import tqdm
|
||||
import numpy as np
|
||||
import torch
|
||||
from main import instantiate_from_config
|
||||
from ldm.models.diffusion.ddim import DDIMSampler
|
||||
|
||||
|
||||
def make_batch(image, mask, device):
|
||||
image = np.array(Image.open(image).convert("RGB"))
|
||||
image = image.astype(np.float32)/255.0
|
||||
image = image[None].transpose(0,3,1,2)
|
||||
image = torch.from_numpy(image)
|
||||
|
||||
mask = np.array(Image.open(mask).convert("L"))
|
||||
mask = mask.astype(np.float32)/255.0
|
||||
mask = mask[None,None]
|
||||
mask[mask < 0.5] = 0
|
||||
mask[mask >= 0.5] = 1
|
||||
mask = torch.from_numpy(mask)
|
||||
|
||||
masked_image = (1-mask)*image
|
||||
|
||||
batch = {"image": image, "mask": mask, "masked_image": masked_image}
|
||||
for k in batch:
|
||||
batch[k] = batch[k].to(device=device)
|
||||
batch[k] = batch[k]*2.0-1.0
|
||||
return batch
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--indir",
|
||||
type=str,
|
||||
nargs="?",
|
||||
help="dir containing image-mask pairs (`example.png` and `example_mask.png`)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--outdir",
|
||||
type=str,
|
||||
nargs="?",
|
||||
help="dir to write results to",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--steps",
|
||||
type=int,
|
||||
default=50,
|
||||
help="number of ddim sampling steps",
|
||||
)
|
||||
opt = parser.parse_args()
|
||||
|
||||
masks = sorted(glob.glob(os.path.join(opt.indir, "*_mask.png")))
|
||||
images = [x.replace("_mask.png", ".png") for x in masks]
|
||||
print(f"Found {len(masks)} inputs.")
|
||||
|
||||
config = OmegaConf.load("models/ldm/inpainting_big/config.yaml")
|
||||
model = instantiate_from_config(config.model)
|
||||
model.load_state_dict(torch.load("models/ldm/inpainting_big/last.ckpt")["state_dict"],
|
||||
strict=False)
|
||||
|
||||
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
|
||||
model = model.to(device)
|
||||
sampler = DDIMSampler(model)
|
||||
|
||||
os.makedirs(opt.outdir, exist_ok=True)
|
||||
with torch.no_grad():
|
||||
with model.ema_scope():
|
||||
for image, mask in tqdm(zip(images, masks)):
|
||||
outpath = os.path.join(opt.outdir, os.path.split(image)[1])
|
||||
batch = make_batch(image, mask, device=device)
|
||||
|
||||
# encode masked image and concat downsampled mask
|
||||
c = model.cond_stage_model.encode(batch["masked_image"])
|
||||
cc = torch.nn.functional.interpolate(batch["mask"],
|
||||
size=c.shape[-2:])
|
||||
c = torch.cat((c, cc), dim=1)
|
||||
|
||||
shape = (c.shape[1]-1,)+c.shape[2:]
|
||||
samples_ddim, _ = sampler.sample(S=opt.steps,
|
||||
conditioning=c,
|
||||
batch_size=c.shape[0],
|
||||
shape=shape,
|
||||
verbose=False)
|
||||
x_samples_ddim = model.decode_first_stage(samples_ddim)
|
||||
|
||||
image = torch.clamp((batch["image"]+1.0)/2.0,
|
||||
min=0.0, max=1.0)
|
||||
mask = torch.clamp((batch["mask"]+1.0)/2.0,
|
||||
min=0.0, max=1.0)
|
||||
predicted_image = torch.clamp((x_samples_ddim+1.0)/2.0,
|
||||
min=0.0, max=1.0)
|
||||
|
||||
inpainted = (1-mask)*image+mask*predicted_image
|
||||
inpainted = inpainted.cpu().numpy().transpose(0,2,3,1)[0]*255
|
||||
Image.fromarray(inpainted.astype(np.uint8)).save(outpath)
|
|
@ -0,0 +1,398 @@
|
|||
import argparse, os, sys, glob
|
||||
import clip
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import numpy as np
|
||||
from omegaconf import OmegaConf
|
||||
from PIL import Image
|
||||
from tqdm import tqdm, trange
|
||||
from itertools import islice
|
||||
from einops import rearrange, repeat
|
||||
from torchvision.utils import make_grid
|
||||
import scann
|
||||
import time
|
||||
from multiprocessing import cpu_count
|
||||
|
||||
from ldm.util import instantiate_from_config, parallel_data_prefetch
|
||||
from ldm.models.diffusion.ddim import DDIMSampler
|
||||
from ldm.models.diffusion.plms import PLMSSampler
|
||||
from ldm.modules.encoders.modules import FrozenClipImageEmbedder, FrozenCLIPTextEmbedder
|
||||
|
||||
DATABASES = [
|
||||
"openimages",
|
||||
"artbench-art_nouveau",
|
||||
"artbench-baroque",
|
||||
"artbench-expressionism",
|
||||
"artbench-impressionism",
|
||||
"artbench-post_impressionism",
|
||||
"artbench-realism",
|
||||
"artbench-romanticism",
|
||||
"artbench-renaissance",
|
||||
"artbench-surrealism",
|
||||
"artbench-ukiyo_e",
|
||||
]
|
||||
|
||||
|
||||
def chunk(it, size):
|
||||
it = iter(it)
|
||||
return iter(lambda: tuple(islice(it, size)), ())
|
||||
|
||||
|
||||
def load_model_from_config(config, ckpt, verbose=False):
|
||||
print(f"Loading model from {ckpt}")
|
||||
pl_sd = torch.load(ckpt, map_location="cpu")
|
||||
if "global_step" in pl_sd:
|
||||
print(f"Global Step: {pl_sd['global_step']}")
|
||||
sd = pl_sd["state_dict"]
|
||||
model = instantiate_from_config(config.model)
|
||||
m, u = model.load_state_dict(sd, strict=False)
|
||||
if len(m) > 0 and verbose:
|
||||
print("missing keys:")
|
||||
print(m)
|
||||
if len(u) > 0 and verbose:
|
||||
print("unexpected keys:")
|
||||
print(u)
|
||||
|
||||
model.cuda()
|
||||
model.eval()
|
||||
return model
|
||||
|
||||
|
||||
class Searcher(object):
|
||||
def __init__(self, database, retriever_version='ViT-L/14'):
|
||||
assert database in DATABASES
|
||||
# self.database = self.load_database(database)
|
||||
self.database_name = database
|
||||
self.searcher_savedir = f'data/rdm/searchers/{self.database_name}'
|
||||
self.database_path = f'data/rdm/retrieval_databases/{self.database_name}'
|
||||
self.retriever = self.load_retriever(version=retriever_version)
|
||||
self.database = {'embedding': [],
|
||||
'img_id': [],
|
||||
'patch_coords': []}
|
||||
self.load_database()
|
||||
self.load_searcher()
|
||||
|
||||
def train_searcher(self, k,
|
||||
metric='dot_product',
|
||||
searcher_savedir=None):
|
||||
|
||||
print('Start training searcher')
|
||||
searcher = scann.scann_ops_pybind.builder(self.database['embedding'] /
|
||||
np.linalg.norm(self.database['embedding'], axis=1)[:, np.newaxis],
|
||||
k, metric)
|
||||
self.searcher = searcher.score_brute_force().build()
|
||||
print('Finish training searcher')
|
||||
|
||||
if searcher_savedir is not None:
|
||||
print(f'Save trained searcher under "{searcher_savedir}"')
|
||||
os.makedirs(searcher_savedir, exist_ok=True)
|
||||
self.searcher.serialize(searcher_savedir)
|
||||
|
||||
def load_single_file(self, saved_embeddings):
|
||||
compressed = np.load(saved_embeddings)
|
||||
self.database = {key: compressed[key] for key in compressed.files}
|
||||
print('Finished loading of clip embeddings.')
|
||||
|
||||
def load_multi_files(self, data_archive):
|
||||
out_data = {key: [] for key in self.database}
|
||||
for d in tqdm(data_archive, desc=f'Loading datapool from {len(data_archive)} individual files.'):
|
||||
for key in d.files:
|
||||
out_data[key].append(d[key])
|
||||
|
||||
return out_data
|
||||
|
||||
def load_database(self):
|
||||
|
||||
print(f'Load saved patch embedding from "{self.database_path}"')
|
||||
file_content = glob.glob(os.path.join(self.database_path, '*.npz'))
|
||||
|
||||
if len(file_content) == 1:
|
||||
self.load_single_file(file_content[0])
|
||||
elif len(file_content) > 1:
|
||||
data = [np.load(f) for f in file_content]
|
||||
prefetched_data = parallel_data_prefetch(self.load_multi_files, data,
|
||||
n_proc=min(len(data), cpu_count()), target_data_type='dict')
|
||||
|
||||
self.database = {key: np.concatenate([od[key] for od in prefetched_data], axis=1)[0] for key in
|
||||
self.database}
|
||||
else:
|
||||
raise ValueError(f'No npz-files in specified path "{self.database_path}" is this directory existing?')
|
||||
|
||||
print(f'Finished loading of retrieval database of length {self.database["embedding"].shape[0]}.')
|
||||
|
||||
def load_retriever(self, version='ViT-L/14', ):
|
||||
model = FrozenClipImageEmbedder(model=version)
|
||||
if torch.cuda.is_available():
|
||||
model.cuda()
|
||||
model.eval()
|
||||
return model
|
||||
|
||||
def load_searcher(self):
|
||||
print(f'load searcher for database {self.database_name} from {self.searcher_savedir}')
|
||||
self.searcher = scann.scann_ops_pybind.load_searcher(self.searcher_savedir)
|
||||
print('Finished loading searcher.')
|
||||
|
||||
def search(self, x, k):
|
||||
if self.searcher is None and self.database['embedding'].shape[0] < 2e4:
|
||||
self.train_searcher(k) # quickly fit searcher on the fly for small databases
|
||||
assert self.searcher is not None, 'Cannot search with uninitialized searcher'
|
||||
if isinstance(x, torch.Tensor):
|
||||
x = x.detach().cpu().numpy()
|
||||
if len(x.shape) == 3:
|
||||
x = x[:, 0]
|
||||
query_embeddings = x / np.linalg.norm(x, axis=1)[:, np.newaxis]
|
||||
|
||||
start = time.time()
|
||||
nns, distances = self.searcher.search_batched(query_embeddings, final_num_neighbors=k)
|
||||
end = time.time()
|
||||
|
||||
out_embeddings = self.database['embedding'][nns]
|
||||
out_img_ids = self.database['img_id'][nns]
|
||||
out_pc = self.database['patch_coords'][nns]
|
||||
|
||||
out = {'nn_embeddings': out_embeddings / np.linalg.norm(out_embeddings, axis=-1)[..., np.newaxis],
|
||||
'img_ids': out_img_ids,
|
||||
'patch_coords': out_pc,
|
||||
'queries': x,
|
||||
'exec_time': end - start,
|
||||
'nns': nns,
|
||||
'q_embeddings': query_embeddings}
|
||||
|
||||
return out
|
||||
|
||||
def __call__(self, x, n):
|
||||
return self.search(x, n)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
# TODO: add n_neighbors and modes (text-only, text-image-retrieval, image-image retrieval etc)
|
||||
# TODO: add 'image variation' mode when knn=0 but a single image is given instead of a text prompt?
|
||||
parser.add_argument(
|
||||
"--prompt",
|
||||
type=str,
|
||||
nargs="?",
|
||||
default="a painting of a virus monster playing guitar",
|
||||
help="the prompt to render"
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--outdir",
|
||||
type=str,
|
||||
nargs="?",
|
||||
help="dir to write results to",
|
||||
default="outputs/txt2img-samples"
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--skip_grid",
|
||||
action='store_true',
|
||||
help="do not save a grid, only individual samples. Helpful when evaluating lots of samples",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--ddim_steps",
|
||||
type=int,
|
||||
default=50,
|
||||
help="number of ddim sampling steps",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--n_repeat",
|
||||
type=int,
|
||||
default=1,
|
||||
help="number of repeats in CLIP latent space",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--plms",
|
||||
action='store_true',
|
||||
help="use plms sampling",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--ddim_eta",
|
||||
type=float,
|
||||
default=0.0,
|
||||
help="ddim eta (eta=0.0 corresponds to deterministic sampling",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--n_iter",
|
||||
type=int,
|
||||
default=1,
|
||||
help="sample this often",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--H",
|
||||
type=int,
|
||||
default=768,
|
||||
help="image height, in pixel space",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--W",
|
||||
type=int,
|
||||
default=768,
|
||||
help="image width, in pixel space",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--n_samples",
|
||||
type=int,
|
||||
default=3,
|
||||
help="how many samples to produce for each given prompt. A.k.a batch size",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--n_rows",
|
||||
type=int,
|
||||
default=0,
|
||||
help="rows in the grid (default: n_samples)",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--scale",
|
||||
type=float,
|
||||
default=5.0,
|
||||
help="unconditional guidance scale: eps = eps(x, empty) + scale * (eps(x, cond) - eps(x, empty))",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--from-file",
|
||||
type=str,
|
||||
help="if specified, load prompts from this file",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--config",
|
||||
type=str,
|
||||
default="configs/retrieval-augmented-diffusion/768x768.yaml",
|
||||
help="path to config which constructs model",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--ckpt",
|
||||
type=str,
|
||||
default="models/rdm/rdm768x768/model.ckpt",
|
||||
help="path to checkpoint of model",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--clip_type",
|
||||
type=str,
|
||||
default="ViT-L/14",
|
||||
help="which CLIP model to use for retrieval and NN encoding",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--database",
|
||||
type=str,
|
||||
default='artbench-surrealism',
|
||||
choices=DATABASES,
|
||||
help="The database used for the search, only applied when --use_neighbors=True",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--use_neighbors",
|
||||
default=False,
|
||||
action='store_true',
|
||||
help="Include neighbors in addition to text prompt for conditioning",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--knn",
|
||||
default=10,
|
||||
type=int,
|
||||
help="The number of included neighbors, only applied when --use_neighbors=True",
|
||||
)
|
||||
|
||||
opt = parser.parse_args()
|
||||
|
||||
config = OmegaConf.load(f"{opt.config}")
|
||||
model = load_model_from_config(config, f"{opt.ckpt}")
|
||||
|
||||
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
|
||||
model = model.to(device)
|
||||
|
||||
clip_text_encoder = FrozenCLIPTextEmbedder(opt.clip_type).to(device)
|
||||
|
||||
if opt.plms:
|
||||
sampler = PLMSSampler(model)
|
||||
else:
|
||||
sampler = DDIMSampler(model)
|
||||
|
||||
os.makedirs(opt.outdir, exist_ok=True)
|
||||
outpath = opt.outdir
|
||||
|
||||
batch_size = opt.n_samples
|
||||
n_rows = opt.n_rows if opt.n_rows > 0 else batch_size
|
||||
if not opt.from_file:
|
||||
prompt = opt.prompt
|
||||
assert prompt is not None
|
||||
data = [batch_size * [prompt]]
|
||||
|
||||
else:
|
||||
print(f"reading prompts from {opt.from_file}")
|
||||
with open(opt.from_file, "r") as f:
|
||||
data = f.read().splitlines()
|
||||
data = list(chunk(data, batch_size))
|
||||
|
||||
sample_path = os.path.join(outpath, "samples")
|
||||
os.makedirs(sample_path, exist_ok=True)
|
||||
base_count = len(os.listdir(sample_path))
|
||||
grid_count = len(os.listdir(outpath)) - 1
|
||||
|
||||
print(f"sampling scale for cfg is {opt.scale:.2f}")
|
||||
|
||||
searcher = None
|
||||
if opt.use_neighbors:
|
||||
searcher = Searcher(opt.database)
|
||||
|
||||
with torch.no_grad():
|
||||
with model.ema_scope():
|
||||
for n in trange(opt.n_iter, desc="Sampling"):
|
||||
all_samples = list()
|
||||
for prompts in tqdm(data, desc="data"):
|
||||
print("sampling prompts:", prompts)
|
||||
if isinstance(prompts, tuple):
|
||||
prompts = list(prompts)
|
||||
c = clip_text_encoder.encode(prompts)
|
||||
uc = None
|
||||
if searcher is not None:
|
||||
nn_dict = searcher(c, opt.knn)
|
||||
c = torch.cat([c, torch.from_numpy(nn_dict['nn_embeddings']).cuda()], dim=1)
|
||||
if opt.scale != 1.0:
|
||||
uc = torch.zeros_like(c)
|
||||
if isinstance(prompts, tuple):
|
||||
prompts = list(prompts)
|
||||
shape = [16, opt.H // 16, opt.W // 16] # note: currently hardcoded for f16 model
|
||||
samples_ddim, _ = sampler.sample(S=opt.ddim_steps,
|
||||
conditioning=c,
|
||||
batch_size=c.shape[0],
|
||||
shape=shape,
|
||||
verbose=False,
|
||||
unconditional_guidance_scale=opt.scale,
|
||||
unconditional_conditioning=uc,
|
||||
eta=opt.ddim_eta,
|
||||
)
|
||||
|
||||
x_samples_ddim = model.decode_first_stage(samples_ddim)
|
||||
x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
|
||||
|
||||
for x_sample in x_samples_ddim:
|
||||
x_sample = 255. * rearrange(x_sample.cpu().numpy(), 'c h w -> h w c')
|
||||
Image.fromarray(x_sample.astype(np.uint8)).save(
|
||||
os.path.join(sample_path, f"{base_count:05}.png"))
|
||||
base_count += 1
|
||||
all_samples.append(x_samples_ddim)
|
||||
|
||||
if not opt.skip_grid:
|
||||
# additionally, save as grid
|
||||
grid = torch.stack(all_samples, 0)
|
||||
grid = rearrange(grid, 'n b c h w -> (n b) c h w')
|
||||
grid = make_grid(grid, nrow=n_rows)
|
||||
|
||||
# to image
|
||||
grid = 255. * rearrange(grid, 'c h w -> h w c').cpu().numpy()
|
||||
Image.fromarray(grid.astype(np.uint8)).save(os.path.join(outpath, f'grid-{grid_count:04}.png'))
|
||||
grid_count += 1
|
||||
|
||||
print(f"Your samples are ready and waiting for you here: \n{outpath} \nEnjoy.")
|
|
@ -0,0 +1,313 @@
|
|||
import argparse, os, sys, glob, datetime, yaml
|
||||
import torch
|
||||
import time
|
||||
import numpy as np
|
||||
from tqdm import trange
|
||||
|
||||
from omegaconf import OmegaConf
|
||||
from PIL import Image
|
||||
|
||||
from ldm.models.diffusion.ddim import DDIMSampler
|
||||
from ldm.util import instantiate_from_config
|
||||
|
||||
rescale = lambda x: (x + 1.) / 2.
|
||||
|
||||
def custom_to_pil(x):
|
||||
x = x.detach().cpu()
|
||||
x = torch.clamp(x, -1., 1.)
|
||||
x = (x + 1.) / 2.
|
||||
x = x.permute(1, 2, 0).numpy()
|
||||
x = (255 * x).astype(np.uint8)
|
||||
x = Image.fromarray(x)
|
||||
if not x.mode == "RGB":
|
||||
x = x.convert("RGB")
|
||||
return x
|
||||
|
||||
|
||||
def custom_to_np(x):
|
||||
# saves the batch in adm style as in https://github.com/openai/guided-diffusion/blob/main/scripts/image_sample.py
|
||||
sample = x.detach().cpu()
|
||||
sample = ((sample + 1) * 127.5).clamp(0, 255).to(torch.uint8)
|
||||
sample = sample.permute(0, 2, 3, 1)
|
||||
sample = sample.contiguous()
|
||||
return sample
|
||||
|
||||
|
||||
def logs2pil(logs, keys=["sample"]):
|
||||
imgs = dict()
|
||||
for k in logs:
|
||||
try:
|
||||
if len(logs[k].shape) == 4:
|
||||
img = custom_to_pil(logs[k][0, ...])
|
||||
elif len(logs[k].shape) == 3:
|
||||
img = custom_to_pil(logs[k])
|
||||
else:
|
||||
print(f"Unknown format for key {k}. ")
|
||||
img = None
|
||||
except:
|
||||
img = None
|
||||
imgs[k] = img
|
||||
return imgs
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def convsample(model, shape, return_intermediates=True,
|
||||
verbose=True,
|
||||
make_prog_row=False):
|
||||
|
||||
|
||||
if not make_prog_row:
|
||||
return model.p_sample_loop(None, shape,
|
||||
return_intermediates=return_intermediates, verbose=verbose)
|
||||
else:
|
||||
return model.progressive_denoising(
|
||||
None, shape, verbose=True
|
||||
)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def convsample_ddim(model, steps, shape, eta=1.0
|
||||
):
|
||||
ddim = DDIMSampler(model)
|
||||
bs = shape[0]
|
||||
shape = shape[1:]
|
||||
samples, intermediates = ddim.sample(steps, batch_size=bs, shape=shape, eta=eta, verbose=False,)
|
||||
return samples, intermediates
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def make_convolutional_sample(model, batch_size, vanilla=False, custom_steps=None, eta=1.0,):
|
||||
|
||||
|
||||
log = dict()
|
||||
|
||||
shape = [batch_size,
|
||||
model.model.diffusion_model.in_channels,
|
||||
model.model.diffusion_model.image_size,
|
||||
model.model.diffusion_model.image_size]
|
||||
|
||||
with model.ema_scope("Plotting"):
|
||||
t0 = time.time()
|
||||
if vanilla:
|
||||
sample, progrow = convsample(model, shape,
|
||||
make_prog_row=True)
|
||||
else:
|
||||
sample, intermediates = convsample_ddim(model, steps=custom_steps, shape=shape,
|
||||
eta=eta)
|
||||
|
||||
t1 = time.time()
|
||||
|
||||
x_sample = model.decode_first_stage(sample)
|
||||
|
||||
log["sample"] = x_sample
|
||||
log["time"] = t1 - t0
|
||||
log['throughput'] = sample.shape[0] / (t1 - t0)
|
||||
print(f'Throughput for this batch: {log["throughput"]}')
|
||||
return log
|
||||
|
||||
def run(model, logdir, batch_size=50, vanilla=False, custom_steps=None, eta=None, n_samples=50000, nplog=None):
|
||||
if vanilla:
|
||||
print(f'Using Vanilla DDPM sampling with {model.num_timesteps} sampling steps.')
|
||||
else:
|
||||
print(f'Using DDIM sampling with {custom_steps} sampling steps and eta={eta}')
|
||||
|
||||
|
||||
tstart = time.time()
|
||||
n_saved = len(glob.glob(os.path.join(logdir,'*.png')))-1
|
||||
# path = logdir
|
||||
if model.cond_stage_model is None:
|
||||
all_images = []
|
||||
|
||||
print(f"Running unconditional sampling for {n_samples} samples")
|
||||
for _ in trange(n_samples // batch_size, desc="Sampling Batches (unconditional)"):
|
||||
logs = make_convolutional_sample(model, batch_size=batch_size,
|
||||
vanilla=vanilla, custom_steps=custom_steps,
|
||||
eta=eta)
|
||||
n_saved = save_logs(logs, logdir, n_saved=n_saved, key="sample")
|
||||
all_images.extend([custom_to_np(logs["sample"])])
|
||||
if n_saved >= n_samples:
|
||||
print(f'Finish after generating {n_saved} samples')
|
||||
break
|
||||
all_img = np.concatenate(all_images, axis=0)
|
||||
all_img = all_img[:n_samples]
|
||||
shape_str = "x".join([str(x) for x in all_img.shape])
|
||||
nppath = os.path.join(nplog, f"{shape_str}-samples.npz")
|
||||
np.savez(nppath, all_img)
|
||||
|
||||
else:
|
||||
raise NotImplementedError('Currently only sampling for unconditional models supported.')
|
||||
|
||||
print(f"sampling of {n_saved} images finished in {(time.time() - tstart) / 60.:.2f} minutes.")
|
||||
|
||||
|
||||
def save_logs(logs, path, n_saved=0, key="sample", np_path=None):
|
||||
for k in logs:
|
||||
if k == key:
|
||||
batch = logs[key]
|
||||
if np_path is None:
|
||||
for x in batch:
|
||||
img = custom_to_pil(x)
|
||||
imgpath = os.path.join(path, f"{key}_{n_saved:06}.png")
|
||||
img.save(imgpath)
|
||||
n_saved += 1
|
||||
else:
|
||||
npbatch = custom_to_np(batch)
|
||||
shape_str = "x".join([str(x) for x in npbatch.shape])
|
||||
nppath = os.path.join(np_path, f"{n_saved}-{shape_str}-samples.npz")
|
||||
np.savez(nppath, npbatch)
|
||||
n_saved += npbatch.shape[0]
|
||||
return n_saved
|
||||
|
||||
|
||||
def get_parser():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"-r",
|
||||
"--resume",
|
||||
type=str,
|
||||
nargs="?",
|
||||
help="load from logdir or checkpoint in logdir",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-n",
|
||||
"--n_samples",
|
||||
type=int,
|
||||
nargs="?",
|
||||
help="number of samples to draw",
|
||||
default=50000
|
||||
)
|
||||
parser.add_argument(
|
||||
"-e",
|
||||
"--eta",
|
||||
type=float,
|
||||
nargs="?",
|
||||
help="eta for ddim sampling (0.0 yields deterministic sampling)",
|
||||
default=1.0
|
||||
)
|
||||
parser.add_argument(
|
||||
"-v",
|
||||
"--vanilla_sample",
|
||||
default=False,
|
||||
action='store_true',
|
||||
help="vanilla sampling (default option is DDIM sampling)?",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-l",
|
||||
"--logdir",
|
||||
type=str,
|
||||
nargs="?",
|
||||
help="extra logdir",
|
||||
default="none"
|
||||
)
|
||||
parser.add_argument(
|
||||
"-c",
|
||||
"--custom_steps",
|
||||
type=int,
|
||||
nargs="?",
|
||||
help="number of steps for ddim and fastdpm sampling",
|
||||
default=50
|
||||
)
|
||||
parser.add_argument(
|
||||
"--batch_size",
|
||||
type=int,
|
||||
nargs="?",
|
||||
help="the bs",
|
||||
default=10
|
||||
)
|
||||
return parser
|
||||
|
||||
|
||||
def load_model_from_config(config, sd):
|
||||
model = instantiate_from_config(config)
|
||||
model.load_state_dict(sd,strict=False)
|
||||
model.cuda()
|
||||
model.eval()
|
||||
return model
|
||||
|
||||
|
||||
def load_model(config, ckpt, gpu, eval_mode):
|
||||
if ckpt:
|
||||
print(f"Loading model from {ckpt}")
|
||||
pl_sd = torch.load(ckpt, map_location="cpu")
|
||||
global_step = pl_sd["global_step"]
|
||||
else:
|
||||
pl_sd = {"state_dict": None}
|
||||
global_step = None
|
||||
model = load_model_from_config(config.model,
|
||||
pl_sd["state_dict"])
|
||||
|
||||
return model, global_step
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
now = datetime.datetime.now().strftime("%Y-%m-%d-%H-%M-%S")
|
||||
sys.path.append(os.getcwd())
|
||||
command = " ".join(sys.argv)
|
||||
|
||||
parser = get_parser()
|
||||
opt, unknown = parser.parse_known_args()
|
||||
ckpt = None
|
||||
|
||||
if not os.path.exists(opt.resume):
|
||||
raise ValueError("Cannot find {}".format(opt.resume))
|
||||
if os.path.isfile(opt.resume):
|
||||
# paths = opt.resume.split("/")
|
||||
try:
|
||||
logdir = '/'.join(opt.resume.split('/')[:-1])
|
||||
# idx = len(paths)-paths[::-1].index("logs")+1
|
||||
print(f'Logdir is {logdir}')
|
||||
except ValueError:
|
||||
paths = opt.resume.split("/")
|
||||
idx = -2 # take a guess: path/to/logdir/checkpoints/model.ckpt
|
||||
logdir = "/".join(paths[:idx])
|
||||
ckpt = opt.resume
|
||||
else:
|
||||
assert os.path.isdir(opt.resume), f"{opt.resume} is not a directory"
|
||||
logdir = opt.resume.rstrip("/")
|
||||
ckpt = os.path.join(logdir, "model.ckpt")
|
||||
|
||||
base_configs = sorted(glob.glob(os.path.join(logdir, "config.yaml")))
|
||||
opt.base = base_configs
|
||||
|
||||
configs = [OmegaConf.load(cfg) for cfg in opt.base]
|
||||
cli = OmegaConf.from_dotlist(unknown)
|
||||
config = OmegaConf.merge(*configs, cli)
|
||||
|
||||
gpu = True
|
||||
eval_mode = True
|
||||
|
||||
if opt.logdir != "none":
|
||||
locallog = logdir.split(os.sep)[-1]
|
||||
if locallog == "": locallog = logdir.split(os.sep)[-2]
|
||||
print(f"Switching logdir from '{logdir}' to '{os.path.join(opt.logdir, locallog)}'")
|
||||
logdir = os.path.join(opt.logdir, locallog)
|
||||
|
||||
print(config)
|
||||
|
||||
model, global_step = load_model(config, ckpt, gpu, eval_mode)
|
||||
print(f"global step: {global_step}")
|
||||
print(75 * "=")
|
||||
print("logging to:")
|
||||
logdir = os.path.join(logdir, "samples", f"{global_step:08}", now)
|
||||
imglogdir = os.path.join(logdir, "img")
|
||||
numpylogdir = os.path.join(logdir, "numpy")
|
||||
|
||||
os.makedirs(imglogdir)
|
||||
os.makedirs(numpylogdir)
|
||||
print(logdir)
|
||||
print(75 * "=")
|
||||
|
||||
# write config out
|
||||
sampling_file = os.path.join(logdir, "sampling_config.yaml")
|
||||
sampling_conf = vars(opt)
|
||||
|
||||
with open(sampling_file, 'w') as f:
|
||||
yaml.dump(sampling_conf, f, default_flow_style=False)
|
||||
print(sampling_conf)
|
||||
|
||||
|
||||
run(model, imglogdir, eta=opt.eta,
|
||||
vanilla=opt.vanilla_sample, n_samples=opt.n_samples, custom_steps=opt.custom_steps,
|
||||
batch_size=opt.batch_size, nplog=numpylogdir)
|
||||
|
||||
print("done.")
|
|
@ -0,0 +1,147 @@
|
|||
import os, sys
|
||||
import numpy as np
|
||||
import scann
|
||||
import argparse
|
||||
import glob
|
||||
from multiprocessing import cpu_count
|
||||
from tqdm import tqdm
|
||||
|
||||
from ldm.util import parallel_data_prefetch
|
||||
|
||||
|
||||
def search_bruteforce(searcher):
|
||||
return searcher.score_brute_force().build()
|
||||
|
||||
|
||||
def search_partioned_ah(searcher, dims_per_block, aiq_threshold, reorder_k,
|
||||
partioning_trainsize, num_leaves, num_leaves_to_search):
|
||||
return searcher.tree(num_leaves=num_leaves,
|
||||
num_leaves_to_search=num_leaves_to_search,
|
||||
training_sample_size=partioning_trainsize). \
|
||||
score_ah(dims_per_block, anisotropic_quantization_threshold=aiq_threshold).reorder(reorder_k).build()
|
||||
|
||||
|
||||
def search_ah(searcher, dims_per_block, aiq_threshold, reorder_k):
|
||||
return searcher.score_ah(dims_per_block, anisotropic_quantization_threshold=aiq_threshold).reorder(
|
||||
reorder_k).build()
|
||||
|
||||
def load_datapool(dpath):
|
||||
|
||||
|
||||
def load_single_file(saved_embeddings):
|
||||
compressed = np.load(saved_embeddings)
|
||||
database = {key: compressed[key] for key in compressed.files}
|
||||
return database
|
||||
|
||||
def load_multi_files(data_archive):
|
||||
database = {key: [] for key in data_archive[0].files}
|
||||
for d in tqdm(data_archive, desc=f'Loading datapool from {len(data_archive)} individual files.'):
|
||||
for key in d.files:
|
||||
database[key].append(d[key])
|
||||
|
||||
return database
|
||||
|
||||
print(f'Load saved patch embedding from "{dpath}"')
|
||||
file_content = glob.glob(os.path.join(dpath, '*.npz'))
|
||||
|
||||
if len(file_content) == 1:
|
||||
data_pool = load_single_file(file_content[0])
|
||||
elif len(file_content) > 1:
|
||||
data = [np.load(f) for f in file_content]
|
||||
prefetched_data = parallel_data_prefetch(load_multi_files, data,
|
||||
n_proc=min(len(data), cpu_count()), target_data_type='dict')
|
||||
|
||||
data_pool = {key: np.concatenate([od[key] for od in prefetched_data], axis=1)[0] for key in prefetched_data[0].keys()}
|
||||
else:
|
||||
raise ValueError(f'No npz-files in specified path "{dpath}" is this directory existing?')
|
||||
|
||||
print(f'Finished loading of retrieval database of length {data_pool["embedding"].shape[0]}.')
|
||||
return data_pool
|
||||
|
||||
|
||||
def train_searcher(opt,
|
||||
metric='dot_product',
|
||||
partioning_trainsize=None,
|
||||
reorder_k=None,
|
||||
# todo tune
|
||||
aiq_thld=0.2,
|
||||
dims_per_block=2,
|
||||
num_leaves=None,
|
||||
num_leaves_to_search=None,):
|
||||
|
||||
data_pool = load_datapool(opt.database)
|
||||
k = opt.knn
|
||||
|
||||
if not reorder_k:
|
||||
reorder_k = 2 * k
|
||||
|
||||
# normalize
|
||||
# embeddings =
|
||||
searcher = scann.scann_ops_pybind.builder(data_pool['embedding'] / np.linalg.norm(data_pool['embedding'], axis=1)[:, np.newaxis], k, metric)
|
||||
pool_size = data_pool['embedding'].shape[0]
|
||||
|
||||
print(*(['#'] * 100))
|
||||
print('Initializing scaNN searcher with the following values:')
|
||||
print(f'k: {k}')
|
||||
print(f'metric: {metric}')
|
||||
print(f'reorder_k: {reorder_k}')
|
||||
print(f'anisotropic_quantization_threshold: {aiq_thld}')
|
||||
print(f'dims_per_block: {dims_per_block}')
|
||||
print(*(['#'] * 100))
|
||||
print('Start training searcher....')
|
||||
print(f'N samples in pool is {pool_size}')
|
||||
|
||||
# this reflects the recommended design choices proposed at
|
||||
# https://github.com/google-research/google-research/blob/aca5f2e44e301af172590bb8e65711f0c9ee0cfd/scann/docs/algorithms.md
|
||||
if pool_size < 2e4:
|
||||
print('Using brute force search.')
|
||||
searcher = search_bruteforce(searcher)
|
||||
elif 2e4 <= pool_size and pool_size < 1e5:
|
||||
print('Using asymmetric hashing search and reordering.')
|
||||
searcher = search_ah(searcher, dims_per_block, aiq_thld, reorder_k)
|
||||
else:
|
||||
print('Using using partioning, asymmetric hashing search and reordering.')
|
||||
|
||||
if not partioning_trainsize:
|
||||
partioning_trainsize = data_pool['embedding'].shape[0] // 10
|
||||
if not num_leaves:
|
||||
num_leaves = int(np.sqrt(pool_size))
|
||||
|
||||
if not num_leaves_to_search:
|
||||
num_leaves_to_search = max(num_leaves // 20, 1)
|
||||
|
||||
print('Partitioning params:')
|
||||
print(f'num_leaves: {num_leaves}')
|
||||
print(f'num_leaves_to_search: {num_leaves_to_search}')
|
||||
# self.searcher = self.search_ah(searcher, dims_per_block, aiq_thld, reorder_k)
|
||||
searcher = search_partioned_ah(searcher, dims_per_block, aiq_thld, reorder_k,
|
||||
partioning_trainsize, num_leaves, num_leaves_to_search)
|
||||
|
||||
print('Finish training searcher')
|
||||
searcher_savedir = opt.target_path
|
||||
os.makedirs(searcher_savedir, exist_ok=True)
|
||||
searcher.serialize(searcher_savedir)
|
||||
print(f'Saved trained searcher under "{searcher_savedir}"')
|
||||
|
||||
if __name__ == '__main__':
|
||||
sys.path.append(os.getcwd())
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--database',
|
||||
'-d',
|
||||
default='data/rdm/retrieval_databases/openimages',
|
||||
type=str,
|
||||
help='path to folder containing the clip feature of the database')
|
||||
parser.add_argument('--target_path',
|
||||
'-t',
|
||||
default='data/rdm/searchers/openimages',
|
||||
type=str,
|
||||
help='path to the target folder where the searcher shall be stored.')
|
||||
parser.add_argument('--knn',
|
||||
'-k',
|
||||
default=20,
|
||||
type=int,
|
||||
help='number of nearest neighbors, for which the searcher shall be optimized')
|
||||
|
||||
opt, _ = parser.parse_known_args()
|
||||
|
||||
train_searcher(opt,)
|
|
@ -0,0 +1,344 @@
|
|||
import argparse, os, sys, glob
|
||||
import cv2
|
||||
import torch
|
||||
import numpy as np
|
||||
from omegaconf import OmegaConf
|
||||
from PIL import Image
|
||||
from tqdm import tqdm, trange
|
||||
from imwatermark import WatermarkEncoder
|
||||
from itertools import islice
|
||||
from einops import rearrange
|
||||
from torchvision.utils import make_grid
|
||||
import time
|
||||
from pytorch_lightning import seed_everything
|
||||
from torch import autocast
|
||||
from contextlib import contextmanager, nullcontext
|
||||
|
||||
from ldm.util import instantiate_from_config
|
||||
from ldm.models.diffusion.ddim import DDIMSampler
|
||||
from ldm.models.diffusion.plms import PLMSSampler
|
||||
|
||||
from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
|
||||
from transformers import AutoFeatureExtractor
|
||||
|
||||
|
||||
# load safety model
|
||||
safety_model_id = "CompVis/stable-diffusion-safety-checker"
|
||||
safety_feature_extractor = AutoFeatureExtractor.from_pretrained(safety_model_id)
|
||||
safety_checker = StableDiffusionSafetyChecker.from_pretrained(safety_model_id)
|
||||
|
||||
|
||||
def chunk(it, size):
|
||||
it = iter(it)
|
||||
return iter(lambda: tuple(islice(it, size)), ())
|
||||
|
||||
|
||||
def numpy_to_pil(images):
|
||||
"""
|
||||
Convert a numpy image or a batch of images to a PIL image.
|
||||
"""
|
||||
if images.ndim == 3:
|
||||
images = images[None, ...]
|
||||
images = (images * 255).round().astype("uint8")
|
||||
pil_images = [Image.fromarray(image) for image in images]
|
||||
|
||||
return pil_images
|
||||
|
||||
|
||||
def load_model_from_config(config, ckpt, verbose=False):
|
||||
print(f"Loading model from {ckpt}")
|
||||
pl_sd = torch.load(ckpt, map_location="cpu")
|
||||
if "global_step" in pl_sd:
|
||||
print(f"Global Step: {pl_sd['global_step']}")
|
||||
sd = pl_sd["state_dict"]
|
||||
model = instantiate_from_config(config.model)
|
||||
m, u = model.load_state_dict(sd, strict=False)
|
||||
if len(m) > 0 and verbose:
|
||||
print("missing keys:")
|
||||
print(m)
|
||||
if len(u) > 0 and verbose:
|
||||
print("unexpected keys:")
|
||||
print(u)
|
||||
|
||||
model.cuda()
|
||||
model.eval()
|
||||
return model
|
||||
|
||||
|
||||
def put_watermark(img, wm_encoder=None):
|
||||
if wm_encoder is not None:
|
||||
img = cv2.cvtColor(np.array(img), cv2.COLOR_RGB2BGR)
|
||||
img = wm_encoder.encode(img, 'dwtDct')
|
||||
img = Image.fromarray(img[:, :, ::-1])
|
||||
return img
|
||||
|
||||
|
||||
def load_replacement(x):
|
||||
try:
|
||||
hwc = x.shape
|
||||
y = Image.open("assets/rick.jpeg").convert("RGB").resize((hwc[1], hwc[0]))
|
||||
y = (np.array(y)/255.0).astype(x.dtype)
|
||||
assert y.shape == x.shape
|
||||
return y
|
||||
except Exception:
|
||||
return x
|
||||
|
||||
|
||||
def check_safety(x_image):
|
||||
safety_checker_input = safety_feature_extractor(numpy_to_pil(x_image), return_tensors="pt")
|
||||
x_checked_image, has_nsfw_concept = safety_checker(images=x_image, clip_input=safety_checker_input.pixel_values)
|
||||
assert x_checked_image.shape[0] == len(has_nsfw_concept)
|
||||
for i in range(len(has_nsfw_concept)):
|
||||
if has_nsfw_concept[i]:
|
||||
x_checked_image[i] = load_replacement(x_checked_image[i])
|
||||
return x_checked_image, has_nsfw_concept
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
parser.add_argument(
|
||||
"--prompt",
|
||||
type=str,
|
||||
nargs="?",
|
||||
default="a painting of a virus monster playing guitar",
|
||||
help="the prompt to render"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--outdir",
|
||||
type=str,
|
||||
nargs="?",
|
||||
help="dir to write results to",
|
||||
default="outputs/txt2img-samples"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--skip_grid",
|
||||
action='store_true',
|
||||
help="do not save a grid, only individual samples. Helpful when evaluating lots of samples",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--skip_save",
|
||||
action='store_true',
|
||||
help="do not save individual samples. For speed measurements.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--ddim_steps",
|
||||
type=int,
|
||||
default=50,
|
||||
help="number of ddim sampling steps",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--plms",
|
||||
action='store_true',
|
||||
help="use plms sampling",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--laion400m",
|
||||
action='store_true',
|
||||
help="uses the LAION400M model",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--fixed_code",
|
||||
action='store_true',
|
||||
help="if enabled, uses the same starting code across samples ",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--ddim_eta",
|
||||
type=float,
|
||||
default=0.0,
|
||||
help="ddim eta (eta=0.0 corresponds to deterministic sampling",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--n_iter",
|
||||
type=int,
|
||||
default=2,
|
||||
help="sample this often",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--H",
|
||||
type=int,
|
||||
default=512,
|
||||
help="image height, in pixel space",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--W",
|
||||
type=int,
|
||||
default=512,
|
||||
help="image width, in pixel space",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--C",
|
||||
type=int,
|
||||
default=4,
|
||||
help="latent channels",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--f",
|
||||
type=int,
|
||||
default=8,
|
||||
help="downsampling factor",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--n_samples",
|
||||
type=int,
|
||||
default=3,
|
||||
help="how many samples to produce for each given prompt. A.k.a. batch size",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--n_rows",
|
||||
type=int,
|
||||
default=0,
|
||||
help="rows in the grid (default: n_samples)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--scale",
|
||||
type=float,
|
||||
default=7.5,
|
||||
help="unconditional guidance scale: eps = eps(x, empty) + scale * (eps(x, cond) - eps(x, empty))",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--from-file",
|
||||
type=str,
|
||||
help="if specified, load prompts from this file",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--config",
|
||||
type=str,
|
||||
default="configs/stable-diffusion/v1-inference.yaml",
|
||||
help="path to config which constructs model",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--ckpt",
|
||||
type=str,
|
||||
default="models/ldm/stable-diffusion-v1/model.ckpt",
|
||||
help="path to checkpoint of model",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--seed",
|
||||
type=int,
|
||||
default=42,
|
||||
help="the seed (for reproducible sampling)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--precision",
|
||||
type=str,
|
||||
help="evaluate at this precision",
|
||||
choices=["full", "autocast"],
|
||||
default="autocast"
|
||||
)
|
||||
opt = parser.parse_args()
|
||||
|
||||
if opt.laion400m:
|
||||
print("Falling back to LAION 400M model...")
|
||||
opt.config = "configs/latent-diffusion/txt2img-1p4B-eval.yaml"
|
||||
opt.ckpt = "models/ldm/text2img-large/model.ckpt"
|
||||
opt.outdir = "outputs/txt2img-samples-laion400m"
|
||||
|
||||
seed_everything(opt.seed)
|
||||
|
||||
config = OmegaConf.load(f"{opt.config}")
|
||||
model = load_model_from_config(config, f"{opt.ckpt}")
|
||||
|
||||
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
|
||||
model = model.to(device)
|
||||
|
||||
if opt.plms:
|
||||
sampler = PLMSSampler(model)
|
||||
else:
|
||||
sampler = DDIMSampler(model)
|
||||
|
||||
os.makedirs(opt.outdir, exist_ok=True)
|
||||
outpath = opt.outdir
|
||||
|
||||
print("Creating invisible watermark encoder (see https://github.com/ShieldMnt/invisible-watermark)...")
|
||||
wm = "StableDiffusionV1"
|
||||
wm_encoder = WatermarkEncoder()
|
||||
wm_encoder.set_watermark('bytes', wm.encode('utf-8'))
|
||||
|
||||
batch_size = opt.n_samples
|
||||
n_rows = opt.n_rows if opt.n_rows > 0 else batch_size
|
||||
if not opt.from_file:
|
||||
prompt = opt.prompt
|
||||
assert prompt is not None
|
||||
data = [batch_size * [prompt]]
|
||||
|
||||
else:
|
||||
print(f"reading prompts from {opt.from_file}")
|
||||
with open(opt.from_file, "r") as f:
|
||||
data = f.read().splitlines()
|
||||
data = list(chunk(data, batch_size))
|
||||
|
||||
sample_path = os.path.join(outpath, "samples")
|
||||
os.makedirs(sample_path, exist_ok=True)
|
||||
base_count = len(os.listdir(sample_path))
|
||||
grid_count = len(os.listdir(outpath)) - 1
|
||||
|
||||
start_code = None
|
||||
if opt.fixed_code:
|
||||
start_code = torch.randn([opt.n_samples, opt.C, opt.H // opt.f, opt.W // opt.f], device=device)
|
||||
|
||||
precision_scope = autocast if opt.precision=="autocast" else nullcontext
|
||||
with torch.no_grad():
|
||||
with precision_scope("cuda"):
|
||||
with model.ema_scope():
|
||||
tic = time.time()
|
||||
all_samples = list()
|
||||
for n in trange(opt.n_iter, desc="Sampling"):
|
||||
for prompts in tqdm(data, desc="data"):
|
||||
uc = None
|
||||
if opt.scale != 1.0:
|
||||
uc = model.get_learned_conditioning(batch_size * [""])
|
||||
if isinstance(prompts, tuple):
|
||||
prompts = list(prompts)
|
||||
c = model.get_learned_conditioning(prompts)
|
||||
shape = [opt.C, opt.H // opt.f, opt.W // opt.f]
|
||||
samples_ddim, _ = sampler.sample(S=opt.ddim_steps,
|
||||
conditioning=c,
|
||||
batch_size=opt.n_samples,
|
||||
shape=shape,
|
||||
verbose=False,
|
||||
unconditional_guidance_scale=opt.scale,
|
||||
unconditional_conditioning=uc,
|
||||
eta=opt.ddim_eta,
|
||||
x_T=start_code)
|
||||
|
||||
x_samples_ddim = model.decode_first_stage(samples_ddim)
|
||||
x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
|
||||
x_samples_ddim = x_samples_ddim.cpu().permute(0, 2, 3, 1).numpy()
|
||||
|
||||
x_checked_image, has_nsfw_concept = check_safety(x_samples_ddim)
|
||||
|
||||
x_checked_image_torch = torch.from_numpy(x_checked_image).permute(0, 3, 1, 2)
|
||||
|
||||
if not opt.skip_save:
|
||||
for x_sample in x_checked_image_torch:
|
||||
x_sample = 255. * rearrange(x_sample.cpu().numpy(), 'c h w -> h w c')
|
||||
img = Image.fromarray(x_sample.astype(np.uint8))
|
||||
img = put_watermark(img, wm_encoder)
|
||||
img.save(os.path.join(sample_path, f"{base_count:05}.png"))
|
||||
base_count += 1
|
||||
|
||||
if not opt.skip_grid:
|
||||
all_samples.append(x_checked_image_torch)
|
||||
|
||||
if not opt.skip_grid:
|
||||
# additionally, save as grid
|
||||
grid = torch.stack(all_samples, 0)
|
||||
grid = rearrange(grid, 'n b c h w -> (n b) c h w')
|
||||
grid = make_grid(grid, nrow=n_rows)
|
||||
|
||||
# to image
|
||||
grid = 255. * rearrange(grid, 'c h w -> h w c').cpu().numpy()
|
||||
img = Image.fromarray(grid.astype(np.uint8))
|
||||
img = put_watermark(img, wm_encoder)
|
||||
img.save(os.path.join(outpath, f'grid-{grid_count:04}.png'))
|
||||
grid_count += 1
|
||||
|
||||
toc = time.time()
|
||||
|
||||
print(f"Your samples are ready and waiting for you here: \n{outpath} \n"
|
||||
f" \nEnjoy.")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
|
@ -0,0 +1,13 @@
|
|||
from setuptools import setup, find_packages
|
||||
|
||||
setup(
|
||||
name='latent-diffusion',
|
||||
version='0.0.1',
|
||||
description='',
|
||||
packages=find_packages(),
|
||||
install_requires=[
|
||||
'torch',
|
||||
'numpy',
|
||||
'tqdm',
|
||||
],
|
||||
)
|
|
@ -0,0 +1,4 @@
|
|||
HF_DATASETS_OFFLINE=1
|
||||
TRANSFORMERS_OFFLINE=1
|
||||
|
||||
python main.py --logdir /tmp -t --postfix test -b configs/train_colossalai.yaml
|
Loading…
Reference in New Issue