mirror of https://github.com/hpcaitech/ColossalAI
commit
79079a9d0c
|
@ -53,27 +53,33 @@ You can also update an existing [latent diffusion](https://github.com/CompVis/la
|
||||||
```
|
```
|
||||||
conda install pytorch==1.12.1 torchvision==0.13.1 torchaudio==0.12.1 cudatoolkit=11.3 -c pytorch
|
conda install pytorch==1.12.1 torchvision==0.13.1 torchaudio==0.12.1 cudatoolkit=11.3 -c pytorch
|
||||||
pip install transformers==4.19.2 diffusers invisible-watermark
|
pip install transformers==4.19.2 diffusers invisible-watermark
|
||||||
pip install -e .
|
|
||||||
```
|
```
|
||||||
|
|
||||||
#### Step 2: install lightning
|
#### Step 2: install lightning
|
||||||
|
|
||||||
Install Lightning version later than 2022.01.04. We suggest you install lightning from source.
|
Install Lightning version later than 2022.01.04. We suggest you install lightning from source.
|
||||||
|
|
||||||
|
##### From Source
|
||||||
```
|
```
|
||||||
git clone https://github.com/Lightning-AI/lightning.git
|
git clone https://github.com/Lightning-AI/lightning.git
|
||||||
pip install -r requirements.txt
|
pip install -r requirements.txt
|
||||||
python setup.py install
|
python setup.py install
|
||||||
```
|
```
|
||||||
|
|
||||||
|
##### From pip
|
||||||
|
|
||||||
|
```
|
||||||
|
pip install pytorch-lightning
|
||||||
|
```
|
||||||
|
|
||||||
#### Step 3:Install [Colossal-AI](https://colossalai.org/download/) From Our Official Website
|
#### Step 3:Install [Colossal-AI](https://colossalai.org/download/) From Our Official Website
|
||||||
|
|
||||||
##### From pip
|
##### From pip
|
||||||
|
|
||||||
For example, you can install v0.1.12 from our official website.
|
For example, you can install v0.2.0 from our official website.
|
||||||
|
|
||||||
```
|
```
|
||||||
pip install colossalai==0.1.12+torch1.12cu11.3 -f https://release.colossalai.org
|
pip install colossalai==0.2.0+torch1.12cu11.3 -f https://release.colossalai.org
|
||||||
```
|
```
|
||||||
|
|
||||||
##### From source
|
##### From source
|
||||||
|
@ -133,10 +139,9 @@ It is important for you to configure your volume mapping in order to get the bes
|
||||||
3. **Optional**, if you encounter any problem stating that shared memory is insufficient inside container, please add `-v /dev/shm:/dev/shm` to your `docker run` command.
|
3. **Optional**, if you encounter any problem stating that shared memory is insufficient inside container, please add `-v /dev/shm:/dev/shm` to your `docker run` command.
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
## Download the model checkpoint from pretrained
|
## Download the model checkpoint from pretrained
|
||||||
|
|
||||||
### stable-diffusion-v2-base
|
### stable-diffusion-v2-base(Recommand)
|
||||||
|
|
||||||
```
|
```
|
||||||
wget https://huggingface.co/stabilityai/stable-diffusion-2-base/resolve/main/512-base-ema.ckpt
|
wget https://huggingface.co/stabilityai/stable-diffusion-2-base/resolve/main/512-base-ema.ckpt
|
||||||
|
@ -144,8 +149,6 @@ wget https://huggingface.co/stabilityai/stable-diffusion-2-base/resolve/main/512
|
||||||
|
|
||||||
### stable-diffusion-v1-4
|
### stable-diffusion-v1-4
|
||||||
|
|
||||||
Our default model config use the weight from [CompVis/stable-diffusion-v1-4](https://huggingface.co/CompVis/stable-diffusion-v1-4?text=A+mecha+robot+in+a+favela+in+expressionist+style)
|
|
||||||
|
|
||||||
```
|
```
|
||||||
git lfs install
|
git lfs install
|
||||||
git clone https://huggingface.co/CompVis/stable-diffusion-v1-4
|
git clone https://huggingface.co/CompVis/stable-diffusion-v1-4
|
||||||
|
@ -153,8 +156,6 @@ git clone https://huggingface.co/CompVis/stable-diffusion-v1-4
|
||||||
|
|
||||||
### stable-diffusion-v1-5 from runway
|
### stable-diffusion-v1-5 from runway
|
||||||
|
|
||||||
If you want to useed the Last [stable-diffusion-v1-5](https://huggingface.co/runwayml/stable-diffusion-v1-5) weight from runwayml
|
|
||||||
|
|
||||||
```
|
```
|
||||||
git lfs install
|
git lfs install
|
||||||
git clone https://huggingface.co/runwayml/stable-diffusion-v1-5
|
git clone https://huggingface.co/runwayml/stable-diffusion-v1-5
|
||||||
|
@ -171,11 +172,16 @@ We provide the script `train_colossalai.sh` to run the training task with coloss
|
||||||
and can also use `train_ddp.sh` to run the training task with ddp to compare.
|
and can also use `train_ddp.sh` to run the training task with ddp to compare.
|
||||||
|
|
||||||
In `train_colossalai.sh` the main command is:
|
In `train_colossalai.sh` the main command is:
|
||||||
|
|
||||||
```
|
```
|
||||||
python main.py --logdir /tmp/ -t -b configs/train_colossalai.yaml
|
python main.py --logdir /tmp/ --train --base configs/train_colossalai.yaml --ckpt 512-base-ema.ckpt
|
||||||
```
|
```
|
||||||
|
|
||||||
- you can change the `--logdir` to decide where to save the log information and the last checkpoint.
|
- You can change the `--logdir` to decide where to save the log information and the last checkpoint.
|
||||||
|
- You will find your ckpt in `logdir/checkpoints` or `logdir/diff_tb/version_0/checkpoints`
|
||||||
|
- You will find your train config yaml in `logdir/configs`
|
||||||
|
- You can add the `--ckpt` if you want to load the pretrained model, for example `512-base-ema.ckpt`
|
||||||
|
- You can change the `--base` to specify the path of config yaml
|
||||||
|
|
||||||
### Training config
|
### Training config
|
||||||
|
|
||||||
|
@ -186,7 +192,8 @@ You can change the trainging config in the yaml file
|
||||||
- precision: the precision type used in training, default 16 (fp16), you must use fp16 if you want to apply colossalai
|
- precision: the precision type used in training, default 16 (fp16), you must use fp16 if you want to apply colossalai
|
||||||
- more information about the configuration of ColossalAIStrategy can be found [here](https://pytorch-lightning.readthedocs.io/en/latest/advanced/model_parallel.html#colossal-ai)
|
- more information about the configuration of ColossalAIStrategy can be found [here](https://pytorch-lightning.readthedocs.io/en/latest/advanced/model_parallel.html#colossal-ai)
|
||||||
|
|
||||||
## Finetune Example (Work In Progress)
|
|
||||||
|
## Finetune Example
|
||||||
### Training on Teyvat Datasets
|
### Training on Teyvat Datasets
|
||||||
|
|
||||||
We provide the finetuning example on [Teyvat](https://huggingface.co/datasets/Fazzie/Teyvat) dataset, which is create by BLIP generated captions.
|
We provide the finetuning example on [Teyvat](https://huggingface.co/datasets/Fazzie/Teyvat) dataset, which is create by BLIP generated captions.
|
||||||
|
@ -201,8 +208,8 @@ you can get yout training last.ckpt and train config.yaml in your `--logdir`, an
|
||||||
```
|
```
|
||||||
python scripts/txt2img.py --prompt "a photograph of an astronaut riding a horse" --plms
|
python scripts/txt2img.py --prompt "a photograph of an astronaut riding a horse" --plms
|
||||||
--outdir ./output \
|
--outdir ./output \
|
||||||
--config path/to/logdir/checkpoints/last.ckpt \
|
--ckpt path/to/logdir/checkpoints/last.ckpt \
|
||||||
--ckpt /path/to/logdir/configs/project.yaml \
|
--config /path/to/logdir/configs/project.yaml \
|
||||||
```
|
```
|
||||||
|
|
||||||
```commandline
|
```commandline
|
||||||
|
|
|
@ -6,6 +6,7 @@ model:
|
||||||
linear_start: 0.00085
|
linear_start: 0.00085
|
||||||
linear_end: 0.0120
|
linear_end: 0.0120
|
||||||
num_timesteps_cond: 1
|
num_timesteps_cond: 1
|
||||||
|
ckpt: None # use ckpt path
|
||||||
log_every_t: 200
|
log_every_t: 200
|
||||||
timesteps: 1000
|
timesteps: 1000
|
||||||
first_stage_key: image
|
first_stage_key: image
|
||||||
|
@ -16,7 +17,7 @@ model:
|
||||||
conditioning_key: crossattn
|
conditioning_key: crossattn
|
||||||
monitor: val/loss_simple_ema
|
monitor: val/loss_simple_ema
|
||||||
scale_factor: 0.18215
|
scale_factor: 0.18215
|
||||||
use_ema: False # we set this to false because this is an inference only config
|
use_ema: False
|
||||||
|
|
||||||
scheduler_config: # 10000 warmup steps
|
scheduler_config: # 10000 warmup steps
|
||||||
target: ldm.lr_scheduler.LambdaLinearScheduler
|
target: ldm.lr_scheduler.LambdaLinearScheduler
|
||||||
|
|
File diff suppressed because it is too large
Load Diff
|
@ -1,10 +1,11 @@
|
||||||
# pytorch_diffusion + derived encoder decoder
|
# pytorch_diffusion + derived encoder decoder
|
||||||
import math
|
import math
|
||||||
|
from typing import Any, Optional
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import numpy as np
|
|
||||||
from einops import rearrange
|
from einops import rearrange
|
||||||
from typing import Optional, Any
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from lightning.pytorch.utilities import rank_zero_info
|
from lightning.pytorch.utilities import rank_zero_info
|
||||||
|
@ -53,15 +54,12 @@ def Normalize(in_channels, num_groups=32):
|
||||||
|
|
||||||
|
|
||||||
class Upsample(nn.Module):
|
class Upsample(nn.Module):
|
||||||
|
|
||||||
def __init__(self, in_channels, with_conv):
|
def __init__(self, in_channels, with_conv):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.with_conv = with_conv
|
self.with_conv = with_conv
|
||||||
if self.with_conv:
|
if self.with_conv:
|
||||||
self.conv = torch.nn.Conv2d(in_channels,
|
self.conv = torch.nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1)
|
||||||
in_channels,
|
|
||||||
kernel_size=3,
|
|
||||||
stride=1,
|
|
||||||
padding=1)
|
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")
|
x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")
|
||||||
|
@ -71,16 +69,13 @@ class Upsample(nn.Module):
|
||||||
|
|
||||||
|
|
||||||
class Downsample(nn.Module):
|
class Downsample(nn.Module):
|
||||||
|
|
||||||
def __init__(self, in_channels, with_conv):
|
def __init__(self, in_channels, with_conv):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.with_conv = with_conv
|
self.with_conv = with_conv
|
||||||
if self.with_conv:
|
if self.with_conv:
|
||||||
# no asymmetric padding in torch conv, must do it ourselves
|
# no asymmetric padding in torch conv, must do it ourselves
|
||||||
self.conv = torch.nn.Conv2d(in_channels,
|
self.conv = torch.nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=2, padding=0)
|
||||||
in_channels,
|
|
||||||
kernel_size=3,
|
|
||||||
stride=2,
|
|
||||||
padding=0)
|
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
if self.with_conv:
|
if self.with_conv:
|
||||||
|
@ -93,8 +88,8 @@ class Downsample(nn.Module):
|
||||||
|
|
||||||
|
|
||||||
class ResnetBlock(nn.Module):
|
class ResnetBlock(nn.Module):
|
||||||
def __init__(self, *, in_channels, out_channels=None, conv_shortcut=False,
|
|
||||||
dropout, temb_channels=512):
|
def __init__(self, *, in_channels, out_channels=None, conv_shortcut=False, dropout, temb_channels=512):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.in_channels = in_channels
|
self.in_channels = in_channels
|
||||||
out_channels = in_channels if out_channels is None else out_channels
|
out_channels = in_channels if out_channels is None else out_channels
|
||||||
|
@ -102,34 +97,17 @@ class ResnetBlock(nn.Module):
|
||||||
self.use_conv_shortcut = conv_shortcut
|
self.use_conv_shortcut = conv_shortcut
|
||||||
|
|
||||||
self.norm1 = Normalize(in_channels)
|
self.norm1 = Normalize(in_channels)
|
||||||
self.conv1 = torch.nn.Conv2d(in_channels,
|
self.conv1 = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
|
||||||
out_channels,
|
|
||||||
kernel_size=3,
|
|
||||||
stride=1,
|
|
||||||
padding=1)
|
|
||||||
if temb_channels > 0:
|
if temb_channels > 0:
|
||||||
self.temb_proj = torch.nn.Linear(temb_channels,
|
self.temb_proj = torch.nn.Linear(temb_channels, out_channels)
|
||||||
out_channels)
|
|
||||||
self.norm2 = Normalize(out_channels)
|
self.norm2 = Normalize(out_channels)
|
||||||
self.dropout = torch.nn.Dropout(dropout)
|
self.dropout = torch.nn.Dropout(dropout)
|
||||||
self.conv2 = torch.nn.Conv2d(out_channels,
|
self.conv2 = torch.nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
|
||||||
out_channels,
|
|
||||||
kernel_size=3,
|
|
||||||
stride=1,
|
|
||||||
padding=1)
|
|
||||||
if self.in_channels != self.out_channels:
|
if self.in_channels != self.out_channels:
|
||||||
if self.use_conv_shortcut:
|
if self.use_conv_shortcut:
|
||||||
self.conv_shortcut = torch.nn.Conv2d(in_channels,
|
self.conv_shortcut = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
|
||||||
out_channels,
|
|
||||||
kernel_size=3,
|
|
||||||
stride=1,
|
|
||||||
padding=1)
|
|
||||||
else:
|
else:
|
||||||
self.nin_shortcut = torch.nn.Conv2d(in_channels,
|
self.nin_shortcut = torch.nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
|
||||||
out_channels,
|
|
||||||
kernel_size=1,
|
|
||||||
stride=1,
|
|
||||||
padding=0)
|
|
||||||
|
|
||||||
def forward(self, x, temb):
|
def forward(self, x, temb):
|
||||||
h = x
|
h = x
|
||||||
|
@ -155,31 +133,16 @@ class ResnetBlock(nn.Module):
|
||||||
|
|
||||||
|
|
||||||
class AttnBlock(nn.Module):
|
class AttnBlock(nn.Module):
|
||||||
|
|
||||||
def __init__(self, in_channels):
|
def __init__(self, in_channels):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.in_channels = in_channels
|
self.in_channels = in_channels
|
||||||
|
|
||||||
self.norm = Normalize(in_channels)
|
self.norm = Normalize(in_channels)
|
||||||
self.q = torch.nn.Conv2d(in_channels,
|
self.q = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
|
||||||
in_channels,
|
self.k = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
|
||||||
kernel_size=1,
|
self.v = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
|
||||||
stride=1,
|
self.proj_out = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
|
||||||
padding=0)
|
|
||||||
self.k = torch.nn.Conv2d(in_channels,
|
|
||||||
in_channels,
|
|
||||||
kernel_size=1,
|
|
||||||
stride=1,
|
|
||||||
padding=0)
|
|
||||||
self.v = torch.nn.Conv2d(in_channels,
|
|
||||||
in_channels,
|
|
||||||
kernel_size=1,
|
|
||||||
stride=1,
|
|
||||||
padding=0)
|
|
||||||
self.proj_out = torch.nn.Conv2d(in_channels,
|
|
||||||
in_channels,
|
|
||||||
kernel_size=1,
|
|
||||||
stride=1,
|
|
||||||
padding=0)
|
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
h_ = x
|
h_ = x
|
||||||
|
@ -207,38 +170,24 @@ class AttnBlock(nn.Module):
|
||||||
|
|
||||||
return x + h_
|
return x + h_
|
||||||
|
|
||||||
|
|
||||||
class MemoryEfficientAttnBlock(nn.Module):
|
class MemoryEfficientAttnBlock(nn.Module):
|
||||||
"""
|
"""
|
||||||
Uses xformers efficient implementation,
|
Uses xformers efficient implementation,
|
||||||
see https://github.com/MatthieuTPHR/diffusers/blob/d80b531ff8060ec1ea982b65a1b8df70f73aa67c/src/diffusers/models/attention.py#L223
|
see https://github.com/MatthieuTPHR/diffusers/blob/d80b531ff8060ec1ea982b65a1b8df70f73aa67c/src/diffusers/models/attention.py#L223
|
||||||
Note: this is a single-head self-attention operation
|
Note: this is a single-head self-attention operation
|
||||||
"""
|
"""
|
||||||
|
|
||||||
#
|
#
|
||||||
def __init__(self, in_channels):
|
def __init__(self, in_channels):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.in_channels = in_channels
|
self.in_channels = in_channels
|
||||||
|
|
||||||
self.norm = Normalize(in_channels)
|
self.norm = Normalize(in_channels)
|
||||||
self.q = torch.nn.Conv2d(in_channels,
|
self.q = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
|
||||||
in_channels,
|
self.k = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
|
||||||
kernel_size=1,
|
self.v = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
|
||||||
stride=1,
|
self.proj_out = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
|
||||||
padding=0)
|
|
||||||
self.k = torch.nn.Conv2d(in_channels,
|
|
||||||
in_channels,
|
|
||||||
kernel_size=1,
|
|
||||||
stride=1,
|
|
||||||
padding=0)
|
|
||||||
self.v = torch.nn.Conv2d(in_channels,
|
|
||||||
in_channels,
|
|
||||||
kernel_size=1,
|
|
||||||
stride=1,
|
|
||||||
padding=0)
|
|
||||||
self.proj_out = torch.nn.Conv2d(in_channels,
|
|
||||||
in_channels,
|
|
||||||
kernel_size=1,
|
|
||||||
stride=1,
|
|
||||||
padding=0)
|
|
||||||
self.attention_op: Optional[Any] = None
|
self.attention_op: Optional[Any] = None
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
|
@ -253,27 +202,20 @@ class MemoryEfficientAttnBlock(nn.Module):
|
||||||
q, k, v = map(lambda x: rearrange(x, 'b c h w -> b (h w) c'), (q, k, v))
|
q, k, v = map(lambda x: rearrange(x, 'b c h w -> b (h w) c'), (q, k, v))
|
||||||
|
|
||||||
q, k, v = map(
|
q, k, v = map(
|
||||||
lambda t: t.unsqueeze(3)
|
lambda t: t.unsqueeze(3).reshape(B, t.shape[1], 1, C).permute(0, 2, 1, 3).reshape(B * 1, t.shape[1], C).
|
||||||
.reshape(B, t.shape[1], 1, C)
|
contiguous(),
|
||||||
.permute(0, 2, 1, 3)
|
|
||||||
.reshape(B * 1, t.shape[1], C)
|
|
||||||
.contiguous(),
|
|
||||||
(q, k, v),
|
(q, k, v),
|
||||||
)
|
)
|
||||||
out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None, op=self.attention_op)
|
out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None, op=self.attention_op)
|
||||||
|
|
||||||
out = (
|
out = (out.unsqueeze(0).reshape(B, 1, out.shape[1], C).permute(0, 2, 1, 3).reshape(B, out.shape[1], C))
|
||||||
out.unsqueeze(0)
|
|
||||||
.reshape(B, 1, out.shape[1], C)
|
|
||||||
.permute(0, 2, 1, 3)
|
|
||||||
.reshape(B, out.shape[1], C)
|
|
||||||
)
|
|
||||||
out = rearrange(out, 'b (h w) c -> b c h w', b=B, h=H, w=W, c=C)
|
out = rearrange(out, 'b (h w) c -> b c h w', b=B, h=H, w=W, c=C)
|
||||||
out = self.proj_out(out)
|
out = self.proj_out(out)
|
||||||
return x + out
|
return x + out
|
||||||
|
|
||||||
|
|
||||||
class MemoryEfficientCrossAttentionWrapper(MemoryEfficientCrossAttention):
|
class MemoryEfficientCrossAttentionWrapper(MemoryEfficientCrossAttention):
|
||||||
|
|
||||||
def forward(self, x, context=None, mask=None):
|
def forward(self, x, context=None, mask=None):
|
||||||
b, c, h, w = x.shape
|
b, c, h, w = x.shape
|
||||||
x = rearrange(x, 'b c h w -> b (h w) c')
|
x = rearrange(x, 'b c h w -> b (h w) c')
|
||||||
|
@ -283,10 +225,10 @@ class MemoryEfficientCrossAttentionWrapper(MemoryEfficientCrossAttention):
|
||||||
|
|
||||||
|
|
||||||
def make_attn(in_channels, attn_type="vanilla", attn_kwargs=None):
|
def make_attn(in_channels, attn_type="vanilla", attn_kwargs=None):
|
||||||
assert attn_type in ["vanilla", "vanilla-xformers", "memory-efficient-cross-attn", "linear", "none"], f'attn_type {attn_type} unknown'
|
assert attn_type in ["vanilla", "vanilla-xformers", "memory-efficient-cross-attn", "linear",
|
||||||
|
"none"], f'attn_type {attn_type} unknown'
|
||||||
if XFORMERS_IS_AVAILBLE and attn_type == "vanilla":
|
if XFORMERS_IS_AVAILBLE and attn_type == "vanilla":
|
||||||
attn_type = "vanilla-xformers"
|
attn_type = "vanilla-xformers"
|
||||||
rank_zero_info(f"making attention of type '{attn_type}' with {in_channels} in_channels")
|
|
||||||
if attn_type == "vanilla":
|
if attn_type == "vanilla":
|
||||||
assert attn_kwargs is None
|
assert attn_kwargs is None
|
||||||
return AttnBlock(in_channels)
|
return AttnBlock(in_channels)
|
||||||
|
@ -303,11 +245,24 @@ def make_attn(in_channels, attn_type="vanilla", attn_kwargs=None):
|
||||||
|
|
||||||
|
|
||||||
class Model(nn.Module):
|
class Model(nn.Module):
|
||||||
def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks,
|
|
||||||
attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels,
|
def __init__(self,
|
||||||
resolution, use_timestep=True, use_linear_attn=False, attn_type="vanilla"):
|
*,
|
||||||
|
ch,
|
||||||
|
out_ch,
|
||||||
|
ch_mult=(1, 2, 4, 8),
|
||||||
|
num_res_blocks,
|
||||||
|
attn_resolutions,
|
||||||
|
dropout=0.0,
|
||||||
|
resamp_with_conv=True,
|
||||||
|
in_channels,
|
||||||
|
resolution,
|
||||||
|
use_timestep=True,
|
||||||
|
use_linear_attn=False,
|
||||||
|
attn_type="vanilla"):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
if use_linear_attn: attn_type = "linear"
|
if use_linear_attn:
|
||||||
|
attn_type = "linear"
|
||||||
self.ch = ch
|
self.ch = ch
|
||||||
self.temb_ch = self.ch * 4
|
self.temb_ch = self.ch * 4
|
||||||
self.num_resolutions = len(ch_mult)
|
self.num_resolutions = len(ch_mult)
|
||||||
|
@ -320,18 +275,12 @@ class Model(nn.Module):
|
||||||
# timestep embedding
|
# timestep embedding
|
||||||
self.temb = nn.Module()
|
self.temb = nn.Module()
|
||||||
self.temb.dense = nn.ModuleList([
|
self.temb.dense = nn.ModuleList([
|
||||||
torch.nn.Linear(self.ch,
|
torch.nn.Linear(self.ch, self.temb_ch),
|
||||||
self.temb_ch),
|
torch.nn.Linear(self.temb_ch, self.temb_ch),
|
||||||
torch.nn.Linear(self.temb_ch,
|
|
||||||
self.temb_ch),
|
|
||||||
])
|
])
|
||||||
|
|
||||||
# downsampling
|
# downsampling
|
||||||
self.conv_in = torch.nn.Conv2d(in_channels,
|
self.conv_in = torch.nn.Conv2d(in_channels, self.ch, kernel_size=3, stride=1, padding=1)
|
||||||
self.ch,
|
|
||||||
kernel_size=3,
|
|
||||||
stride=1,
|
|
||||||
padding=1)
|
|
||||||
|
|
||||||
curr_res = resolution
|
curr_res = resolution
|
||||||
in_ch_mult = (1,) + tuple(ch_mult)
|
in_ch_mult = (1,) + tuple(ch_mult)
|
||||||
|
@ -342,7 +291,8 @@ class Model(nn.Module):
|
||||||
block_in = ch * in_ch_mult[i_level]
|
block_in = ch * in_ch_mult[i_level]
|
||||||
block_out = ch * ch_mult[i_level]
|
block_out = ch * ch_mult[i_level]
|
||||||
for i_block in range(self.num_res_blocks):
|
for i_block in range(self.num_res_blocks):
|
||||||
block.append(ResnetBlock(in_channels=block_in,
|
block.append(
|
||||||
|
ResnetBlock(in_channels=block_in,
|
||||||
out_channels=block_out,
|
out_channels=block_out,
|
||||||
temb_channels=self.temb_ch,
|
temb_channels=self.temb_ch,
|
||||||
dropout=dropout))
|
dropout=dropout))
|
||||||
|
@ -379,7 +329,8 @@ class Model(nn.Module):
|
||||||
for i_block in range(self.num_res_blocks + 1):
|
for i_block in range(self.num_res_blocks + 1):
|
||||||
if i_block == self.num_res_blocks:
|
if i_block == self.num_res_blocks:
|
||||||
skip_in = ch * in_ch_mult[i_level]
|
skip_in = ch * in_ch_mult[i_level]
|
||||||
block.append(ResnetBlock(in_channels=block_in+skip_in,
|
block.append(
|
||||||
|
ResnetBlock(in_channels=block_in + skip_in,
|
||||||
out_channels=block_out,
|
out_channels=block_out,
|
||||||
temb_channels=self.temb_ch,
|
temb_channels=self.temb_ch,
|
||||||
dropout=dropout))
|
dropout=dropout))
|
||||||
|
@ -396,11 +347,7 @@ class Model(nn.Module):
|
||||||
|
|
||||||
# end
|
# end
|
||||||
self.norm_out = Normalize(block_in)
|
self.norm_out = Normalize(block_in)
|
||||||
self.conv_out = torch.nn.Conv2d(block_in,
|
self.conv_out = torch.nn.Conv2d(block_in, out_ch, kernel_size=3, stride=1, padding=1)
|
||||||
out_ch,
|
|
||||||
kernel_size=3,
|
|
||||||
stride=1,
|
|
||||||
padding=1)
|
|
||||||
|
|
||||||
def forward(self, x, t=None, context=None):
|
def forward(self, x, t=None, context=None):
|
||||||
#assert x.shape[2] == x.shape[3] == self.resolution
|
#assert x.shape[2] == x.shape[3] == self.resolution
|
||||||
|
@ -437,8 +384,7 @@ class Model(nn.Module):
|
||||||
# upsampling
|
# upsampling
|
||||||
for i_level in reversed(range(self.num_resolutions)):
|
for i_level in reversed(range(self.num_resolutions)):
|
||||||
for i_block in range(self.num_res_blocks + 1):
|
for i_block in range(self.num_res_blocks + 1):
|
||||||
h = self.up[i_level].block[i_block](
|
h = self.up[i_level].block[i_block](torch.cat([h, hs.pop()], dim=1), temb)
|
||||||
torch.cat([h, hs.pop()], dim=1), temb)
|
|
||||||
if len(self.up[i_level].attn) > 0:
|
if len(self.up[i_level].attn) > 0:
|
||||||
h = self.up[i_level].attn[i_block](h)
|
h = self.up[i_level].attn[i_block](h)
|
||||||
if i_level != 0:
|
if i_level != 0:
|
||||||
|
@ -455,12 +401,26 @@ class Model(nn.Module):
|
||||||
|
|
||||||
|
|
||||||
class Encoder(nn.Module):
|
class Encoder(nn.Module):
|
||||||
def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks,
|
|
||||||
attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels,
|
def __init__(self,
|
||||||
resolution, z_channels, double_z=True, use_linear_attn=False, attn_type="vanilla",
|
*,
|
||||||
|
ch,
|
||||||
|
out_ch,
|
||||||
|
ch_mult=(1, 2, 4, 8),
|
||||||
|
num_res_blocks,
|
||||||
|
attn_resolutions,
|
||||||
|
dropout=0.0,
|
||||||
|
resamp_with_conv=True,
|
||||||
|
in_channels,
|
||||||
|
resolution,
|
||||||
|
z_channels,
|
||||||
|
double_z=True,
|
||||||
|
use_linear_attn=False,
|
||||||
|
attn_type="vanilla",
|
||||||
**ignore_kwargs):
|
**ignore_kwargs):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
if use_linear_attn: attn_type = "linear"
|
if use_linear_attn:
|
||||||
|
attn_type = "linear"
|
||||||
self.ch = ch
|
self.ch = ch
|
||||||
self.temb_ch = 0
|
self.temb_ch = 0
|
||||||
self.num_resolutions = len(ch_mult)
|
self.num_resolutions = len(ch_mult)
|
||||||
|
@ -469,11 +429,7 @@ class Encoder(nn.Module):
|
||||||
self.in_channels = in_channels
|
self.in_channels = in_channels
|
||||||
|
|
||||||
# downsampling
|
# downsampling
|
||||||
self.conv_in = torch.nn.Conv2d(in_channels,
|
self.conv_in = torch.nn.Conv2d(in_channels, self.ch, kernel_size=3, stride=1, padding=1)
|
||||||
self.ch,
|
|
||||||
kernel_size=3,
|
|
||||||
stride=1,
|
|
||||||
padding=1)
|
|
||||||
|
|
||||||
curr_res = resolution
|
curr_res = resolution
|
||||||
in_ch_mult = (1,) + tuple(ch_mult)
|
in_ch_mult = (1,) + tuple(ch_mult)
|
||||||
|
@ -485,7 +441,8 @@ class Encoder(nn.Module):
|
||||||
block_in = ch * in_ch_mult[i_level]
|
block_in = ch * in_ch_mult[i_level]
|
||||||
block_out = ch * ch_mult[i_level]
|
block_out = ch * ch_mult[i_level]
|
||||||
for i_block in range(self.num_res_blocks):
|
for i_block in range(self.num_res_blocks):
|
||||||
block.append(ResnetBlock(in_channels=block_in,
|
block.append(
|
||||||
|
ResnetBlock(in_channels=block_in,
|
||||||
out_channels=block_out,
|
out_channels=block_out,
|
||||||
temb_channels=self.temb_ch,
|
temb_channels=self.temb_ch,
|
||||||
dropout=dropout))
|
dropout=dropout))
|
||||||
|
@ -549,12 +506,27 @@ class Encoder(nn.Module):
|
||||||
|
|
||||||
|
|
||||||
class Decoder(nn.Module):
|
class Decoder(nn.Module):
|
||||||
def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks,
|
|
||||||
attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels,
|
def __init__(self,
|
||||||
resolution, z_channels, give_pre_end=False, tanh_out=False, use_linear_attn=False,
|
*,
|
||||||
attn_type="vanilla", **ignorekwargs):
|
ch,
|
||||||
|
out_ch,
|
||||||
|
ch_mult=(1, 2, 4, 8),
|
||||||
|
num_res_blocks,
|
||||||
|
attn_resolutions,
|
||||||
|
dropout=0.0,
|
||||||
|
resamp_with_conv=True,
|
||||||
|
in_channels,
|
||||||
|
resolution,
|
||||||
|
z_channels,
|
||||||
|
give_pre_end=False,
|
||||||
|
tanh_out=False,
|
||||||
|
use_linear_attn=False,
|
||||||
|
attn_type="vanilla",
|
||||||
|
**ignorekwargs):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
if use_linear_attn: attn_type = "linear"
|
if use_linear_attn:
|
||||||
|
attn_type = "linear"
|
||||||
self.ch = ch
|
self.ch = ch
|
||||||
self.temb_ch = 0
|
self.temb_ch = 0
|
||||||
self.num_resolutions = len(ch_mult)
|
self.num_resolutions = len(ch_mult)
|
||||||
|
@ -569,15 +541,10 @@ class Decoder(nn.Module):
|
||||||
block_in = ch * ch_mult[self.num_resolutions - 1]
|
block_in = ch * ch_mult[self.num_resolutions - 1]
|
||||||
curr_res = resolution // 2**(self.num_resolutions - 1)
|
curr_res = resolution // 2**(self.num_resolutions - 1)
|
||||||
self.z_shape = (1, z_channels, curr_res, curr_res)
|
self.z_shape = (1, z_channels, curr_res, curr_res)
|
||||||
rank_zero_info("Working with z of shape {} = {} dimensions.".format(
|
rank_zero_info("Working with z of shape {} = {} dimensions.".format(self.z_shape, np.prod(self.z_shape)))
|
||||||
self.z_shape, np.prod(self.z_shape)))
|
|
||||||
|
|
||||||
# z to block_in
|
# z to block_in
|
||||||
self.conv_in = torch.nn.Conv2d(z_channels,
|
self.conv_in = torch.nn.Conv2d(z_channels, block_in, kernel_size=3, stride=1, padding=1)
|
||||||
block_in,
|
|
||||||
kernel_size=3,
|
|
||||||
stride=1,
|
|
||||||
padding=1)
|
|
||||||
|
|
||||||
# middle
|
# middle
|
||||||
self.mid = nn.Module()
|
self.mid = nn.Module()
|
||||||
|
@ -598,7 +565,8 @@ class Decoder(nn.Module):
|
||||||
attn = nn.ModuleList()
|
attn = nn.ModuleList()
|
||||||
block_out = ch * ch_mult[i_level]
|
block_out = ch * ch_mult[i_level]
|
||||||
for i_block in range(self.num_res_blocks + 1):
|
for i_block in range(self.num_res_blocks + 1):
|
||||||
block.append(ResnetBlock(in_channels=block_in,
|
block.append(
|
||||||
|
ResnetBlock(in_channels=block_in,
|
||||||
out_channels=block_out,
|
out_channels=block_out,
|
||||||
temb_channels=self.temb_ch,
|
temb_channels=self.temb_ch,
|
||||||
dropout=dropout))
|
dropout=dropout))
|
||||||
|
@ -615,11 +583,7 @@ class Decoder(nn.Module):
|
||||||
|
|
||||||
# end
|
# end
|
||||||
self.norm_out = Normalize(block_in)
|
self.norm_out = Normalize(block_in)
|
||||||
self.conv_out = torch.nn.Conv2d(block_in,
|
self.conv_out = torch.nn.Conv2d(block_in, out_ch, kernel_size=3, stride=1, padding=1)
|
||||||
out_ch,
|
|
||||||
kernel_size=3,
|
|
||||||
stride=1,
|
|
||||||
padding=1)
|
|
||||||
|
|
||||||
def forward(self, z):
|
def forward(self, z):
|
||||||
#assert z.shape[1:] == self.z_shape[1:]
|
#assert z.shape[1:] == self.z_shape[1:]
|
||||||
|
@ -658,27 +622,20 @@ class Decoder(nn.Module):
|
||||||
|
|
||||||
|
|
||||||
class SimpleDecoder(nn.Module):
|
class SimpleDecoder(nn.Module):
|
||||||
|
|
||||||
def __init__(self, in_channels, out_channels, *args, **kwargs):
|
def __init__(self, in_channels, out_channels, *args, **kwargs):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.model = nn.ModuleList([nn.Conv2d(in_channels, in_channels, 1),
|
self.model = nn.ModuleList([
|
||||||
ResnetBlock(in_channels=in_channels,
|
nn.Conv2d(in_channels, in_channels, 1),
|
||||||
out_channels=2 * in_channels,
|
ResnetBlock(in_channels=in_channels, out_channels=2 * in_channels, temb_channels=0, dropout=0.0),
|
||||||
temb_channels=0, dropout=0.0),
|
ResnetBlock(in_channels=2 * in_channels, out_channels=4 * in_channels, temb_channels=0, dropout=0.0),
|
||||||
ResnetBlock(in_channels=2 * in_channels,
|
ResnetBlock(in_channels=4 * in_channels, out_channels=2 * in_channels, temb_channels=0, dropout=0.0),
|
||||||
out_channels=4 * in_channels,
|
|
||||||
temb_channels=0, dropout=0.0),
|
|
||||||
ResnetBlock(in_channels=4 * in_channels,
|
|
||||||
out_channels=2 * in_channels,
|
|
||||||
temb_channels=0, dropout=0.0),
|
|
||||||
nn.Conv2d(2 * in_channels, in_channels, 1),
|
nn.Conv2d(2 * in_channels, in_channels, 1),
|
||||||
Upsample(in_channels, with_conv=True)])
|
Upsample(in_channels, with_conv=True)
|
||||||
|
])
|
||||||
# end
|
# end
|
||||||
self.norm_out = Normalize(in_channels)
|
self.norm_out = Normalize(in_channels)
|
||||||
self.conv_out = torch.nn.Conv2d(in_channels,
|
self.conv_out = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
|
||||||
out_channels,
|
|
||||||
kernel_size=3,
|
|
||||||
stride=1,
|
|
||||||
padding=1)
|
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
for i, layer in enumerate(self.model):
|
for i, layer in enumerate(self.model):
|
||||||
|
@ -694,8 +651,8 @@ class SimpleDecoder(nn.Module):
|
||||||
|
|
||||||
|
|
||||||
class UpsampleDecoder(nn.Module):
|
class UpsampleDecoder(nn.Module):
|
||||||
def __init__(self, in_channels, out_channels, ch, num_res_blocks, resolution,
|
|
||||||
ch_mult=(2,2), dropout=0.0):
|
def __init__(self, in_channels, out_channels, ch, num_res_blocks, resolution, ch_mult=(2, 2), dropout=0.0):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
# upsampling
|
# upsampling
|
||||||
self.temb_ch = 0
|
self.temb_ch = 0
|
||||||
|
@ -709,7 +666,8 @@ class UpsampleDecoder(nn.Module):
|
||||||
res_block = []
|
res_block = []
|
||||||
block_out = ch * ch_mult[i_level]
|
block_out = ch * ch_mult[i_level]
|
||||||
for i_block in range(self.num_res_blocks + 1):
|
for i_block in range(self.num_res_blocks + 1):
|
||||||
res_block.append(ResnetBlock(in_channels=block_in,
|
res_block.append(
|
||||||
|
ResnetBlock(in_channels=block_in,
|
||||||
out_channels=block_out,
|
out_channels=block_out,
|
||||||
temb_channels=self.temb_ch,
|
temb_channels=self.temb_ch,
|
||||||
dropout=dropout))
|
dropout=dropout))
|
||||||
|
@ -721,11 +679,7 @@ class UpsampleDecoder(nn.Module):
|
||||||
|
|
||||||
# end
|
# end
|
||||||
self.norm_out = Normalize(block_in)
|
self.norm_out = Normalize(block_in)
|
||||||
self.conv_out = torch.nn.Conv2d(block_in,
|
self.conv_out = torch.nn.Conv2d(block_in, out_channels, kernel_size=3, stride=1, padding=1)
|
||||||
out_channels,
|
|
||||||
kernel_size=3,
|
|
||||||
stride=1,
|
|
||||||
padding=1)
|
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
# upsampling
|
# upsampling
|
||||||
|
@ -742,26 +696,24 @@ class UpsampleDecoder(nn.Module):
|
||||||
|
|
||||||
|
|
||||||
class LatentRescaler(nn.Module):
|
class LatentRescaler(nn.Module):
|
||||||
|
|
||||||
def __init__(self, factor, in_channels, mid_channels, out_channels, depth=2):
|
def __init__(self, factor, in_channels, mid_channels, out_channels, depth=2):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
# residual block, interpolate, residual block
|
# residual block, interpolate, residual block
|
||||||
self.factor = factor
|
self.factor = factor
|
||||||
self.conv_in = nn.Conv2d(in_channels,
|
self.conv_in = nn.Conv2d(in_channels, mid_channels, kernel_size=3, stride=1, padding=1)
|
||||||
mid_channels,
|
self.res_block1 = nn.ModuleList([
|
||||||
kernel_size=3,
|
ResnetBlock(in_channels=mid_channels, out_channels=mid_channels, temb_channels=0, dropout=0.0)
|
||||||
stride=1,
|
for _ in range(depth)
|
||||||
padding=1)
|
])
|
||||||
self.res_block1 = nn.ModuleList([ResnetBlock(in_channels=mid_channels,
|
|
||||||
out_channels=mid_channels,
|
|
||||||
temb_channels=0,
|
|
||||||
dropout=0.0) for _ in range(depth)])
|
|
||||||
self.attn = AttnBlock(mid_channels)
|
self.attn = AttnBlock(mid_channels)
|
||||||
self.res_block2 = nn.ModuleList([ResnetBlock(in_channels=mid_channels,
|
self.res_block2 = nn.ModuleList([
|
||||||
out_channels=mid_channels,
|
ResnetBlock(in_channels=mid_channels, out_channels=mid_channels, temb_channels=0, dropout=0.0)
|
||||||
temb_channels=0,
|
for _ in range(depth)
|
||||||
dropout=0.0) for _ in range(depth)])
|
])
|
||||||
|
|
||||||
self.conv_out = nn.Conv2d(mid_channels,
|
self.conv_out = nn.Conv2d(
|
||||||
|
mid_channels,
|
||||||
out_channels,
|
out_channels,
|
||||||
kernel_size=1,
|
kernel_size=1,
|
||||||
)
|
)
|
||||||
|
@ -770,7 +722,9 @@ class LatentRescaler(nn.Module):
|
||||||
x = self.conv_in(x)
|
x = self.conv_in(x)
|
||||||
for block in self.res_block1:
|
for block in self.res_block1:
|
||||||
x = block(x, None)
|
x = block(x, None)
|
||||||
x = torch.nn.functional.interpolate(x, size=(int(round(x.shape[2]*self.factor)), int(round(x.shape[3]*self.factor))))
|
x = torch.nn.functional.interpolate(x,
|
||||||
|
size=(int(round(x.shape[2] * self.factor)),
|
||||||
|
int(round(x.shape[3] * self.factor))))
|
||||||
x = self.attn(x)
|
x = self.attn(x)
|
||||||
for block in self.res_block2:
|
for block in self.res_block2:
|
||||||
x = block(x, None)
|
x = block(x, None)
|
||||||
|
@ -779,17 +733,37 @@ class LatentRescaler(nn.Module):
|
||||||
|
|
||||||
|
|
||||||
class MergedRescaleEncoder(nn.Module):
|
class MergedRescaleEncoder(nn.Module):
|
||||||
def __init__(self, in_channels, ch, resolution, out_ch, num_res_blocks,
|
|
||||||
attn_resolutions, dropout=0.0, resamp_with_conv=True,
|
def __init__(self,
|
||||||
ch_mult=(1,2,4,8), rescale_factor=1.0, rescale_module_depth=1):
|
in_channels,
|
||||||
|
ch,
|
||||||
|
resolution,
|
||||||
|
out_ch,
|
||||||
|
num_res_blocks,
|
||||||
|
attn_resolutions,
|
||||||
|
dropout=0.0,
|
||||||
|
resamp_with_conv=True,
|
||||||
|
ch_mult=(1, 2, 4, 8),
|
||||||
|
rescale_factor=1.0,
|
||||||
|
rescale_module_depth=1):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
intermediate_chn = ch * ch_mult[-1]
|
intermediate_chn = ch * ch_mult[-1]
|
||||||
self.encoder = Encoder(in_channels=in_channels, num_res_blocks=num_res_blocks, ch=ch, ch_mult=ch_mult,
|
self.encoder = Encoder(in_channels=in_channels,
|
||||||
z_channels=intermediate_chn, double_z=False, resolution=resolution,
|
num_res_blocks=num_res_blocks,
|
||||||
attn_resolutions=attn_resolutions, dropout=dropout, resamp_with_conv=resamp_with_conv,
|
ch=ch,
|
||||||
|
ch_mult=ch_mult,
|
||||||
|
z_channels=intermediate_chn,
|
||||||
|
double_z=False,
|
||||||
|
resolution=resolution,
|
||||||
|
attn_resolutions=attn_resolutions,
|
||||||
|
dropout=dropout,
|
||||||
|
resamp_with_conv=resamp_with_conv,
|
||||||
out_ch=None)
|
out_ch=None)
|
||||||
self.rescaler = LatentRescaler(factor=rescale_factor, in_channels=intermediate_chn,
|
self.rescaler = LatentRescaler(factor=rescale_factor,
|
||||||
mid_channels=intermediate_chn, out_channels=out_ch, depth=rescale_module_depth)
|
in_channels=intermediate_chn,
|
||||||
|
mid_channels=intermediate_chn,
|
||||||
|
out_channels=out_ch,
|
||||||
|
depth=rescale_module_depth)
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
x = self.encoder(x)
|
x = self.encoder(x)
|
||||||
|
@ -798,15 +772,36 @@ class MergedRescaleEncoder(nn.Module):
|
||||||
|
|
||||||
|
|
||||||
class MergedRescaleDecoder(nn.Module):
|
class MergedRescaleDecoder(nn.Module):
|
||||||
def __init__(self, z_channels, out_ch, resolution, num_res_blocks, attn_resolutions, ch, ch_mult=(1,2,4,8),
|
|
||||||
dropout=0.0, resamp_with_conv=True, rescale_factor=1.0, rescale_module_depth=1):
|
def __init__(self,
|
||||||
|
z_channels,
|
||||||
|
out_ch,
|
||||||
|
resolution,
|
||||||
|
num_res_blocks,
|
||||||
|
attn_resolutions,
|
||||||
|
ch,
|
||||||
|
ch_mult=(1, 2, 4, 8),
|
||||||
|
dropout=0.0,
|
||||||
|
resamp_with_conv=True,
|
||||||
|
rescale_factor=1.0,
|
||||||
|
rescale_module_depth=1):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
tmp_chn = z_channels * ch_mult[-1]
|
tmp_chn = z_channels * ch_mult[-1]
|
||||||
self.decoder = Decoder(out_ch=out_ch, z_channels=tmp_chn, attn_resolutions=attn_resolutions, dropout=dropout,
|
self.decoder = Decoder(out_ch=out_ch,
|
||||||
resamp_with_conv=resamp_with_conv, in_channels=None, num_res_blocks=num_res_blocks,
|
z_channels=tmp_chn,
|
||||||
ch_mult=ch_mult, resolution=resolution, ch=ch)
|
attn_resolutions=attn_resolutions,
|
||||||
self.rescaler = LatentRescaler(factor=rescale_factor, in_channels=z_channels, mid_channels=tmp_chn,
|
dropout=dropout,
|
||||||
out_channels=tmp_chn, depth=rescale_module_depth)
|
resamp_with_conv=resamp_with_conv,
|
||||||
|
in_channels=None,
|
||||||
|
num_res_blocks=num_res_blocks,
|
||||||
|
ch_mult=ch_mult,
|
||||||
|
resolution=resolution,
|
||||||
|
ch=ch)
|
||||||
|
self.rescaler = LatentRescaler(factor=rescale_factor,
|
||||||
|
in_channels=z_channels,
|
||||||
|
mid_channels=tmp_chn,
|
||||||
|
out_channels=tmp_chn,
|
||||||
|
depth=rescale_module_depth)
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
x = self.rescaler(x)
|
x = self.rescaler(x)
|
||||||
|
@ -815,16 +810,26 @@ class MergedRescaleDecoder(nn.Module):
|
||||||
|
|
||||||
|
|
||||||
class Upsampler(nn.Module):
|
class Upsampler(nn.Module):
|
||||||
|
|
||||||
def __init__(self, in_size, out_size, in_channels, out_channels, ch_mult=2):
|
def __init__(self, in_size, out_size, in_channels, out_channels, ch_mult=2):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
assert out_size >= in_size
|
assert out_size >= in_size
|
||||||
num_blocks = int(np.log2(out_size // in_size)) + 1
|
num_blocks = int(np.log2(out_size // in_size)) + 1
|
||||||
factor_up = 1. + (out_size % in_size)
|
factor_up = 1. + (out_size % in_size)
|
||||||
rank_zero_info(f"Building {self.__class__.__name__} with in_size: {in_size} --> out_size {out_size} and factor {factor_up}")
|
rank_zero_info(
|
||||||
self.rescaler = LatentRescaler(factor=factor_up, in_channels=in_channels, mid_channels=2*in_channels,
|
f"Building {self.__class__.__name__} with in_size: {in_size} --> out_size {out_size} and factor {factor_up}"
|
||||||
|
)
|
||||||
|
self.rescaler = LatentRescaler(factor=factor_up,
|
||||||
|
in_channels=in_channels,
|
||||||
|
mid_channels=2 * in_channels,
|
||||||
out_channels=in_channels)
|
out_channels=in_channels)
|
||||||
self.decoder = Decoder(out_ch=out_channels, resolution=out_size, z_channels=in_channels, num_res_blocks=2,
|
self.decoder = Decoder(out_ch=out_channels,
|
||||||
attn_resolutions=[], in_channels=None, ch=in_channels,
|
resolution=out_size,
|
||||||
|
z_channels=in_channels,
|
||||||
|
num_res_blocks=2,
|
||||||
|
attn_resolutions=[],
|
||||||
|
in_channels=None,
|
||||||
|
ch=in_channels,
|
||||||
ch_mult=[ch_mult for _ in range(num_blocks)])
|
ch_mult=[ch_mult for _ in range(num_blocks)])
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
|
@ -834,20 +839,18 @@ class Upsampler(nn.Module):
|
||||||
|
|
||||||
|
|
||||||
class Resize(nn.Module):
|
class Resize(nn.Module):
|
||||||
|
|
||||||
def __init__(self, in_channels=None, learned=False, mode="bilinear"):
|
def __init__(self, in_channels=None, learned=False, mode="bilinear"):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.with_conv = learned
|
self.with_conv = learned
|
||||||
self.mode = mode
|
self.mode = mode
|
||||||
if self.with_conv:
|
if self.with_conv:
|
||||||
rank_zero_info(f"Note: {self.__class__.__name} uses learned downsampling and will ignore the fixed {mode} mode")
|
rank_zero_info(
|
||||||
|
f"Note: {self.__class__.__name} uses learned downsampling and will ignore the fixed {mode} mode")
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
assert in_channels is not None
|
assert in_channels is not None
|
||||||
# no asymmetric padding in torch conv, must do it ourselves
|
# no asymmetric padding in torch conv, must do it ourselves
|
||||||
self.conv = torch.nn.Conv2d(in_channels,
|
self.conv = torch.nn.Conv2d(in_channels, in_channels, kernel_size=4, stride=2, padding=1)
|
||||||
in_channels,
|
|
||||||
kernel_size=4,
|
|
||||||
stride=2,
|
|
||||||
padding=1)
|
|
||||||
|
|
||||||
def forward(self, x, scale_factor=1.0):
|
def forward(self, x, scale_factor=1.0):
|
||||||
if scale_factor == 1.0:
|
if scale_factor == 1.0:
|
||||||
|
|
|
@ -106,7 +106,20 @@ def get_parser(**parser_kwargs):
|
||||||
nargs="?",
|
nargs="?",
|
||||||
help="disable test",
|
help="disable test",
|
||||||
)
|
)
|
||||||
parser.add_argument("-p", "--project", help="name of new or path to existing project")
|
parser.add_argument(
|
||||||
|
"-p",
|
||||||
|
"--project",
|
||||||
|
help="name of new or path to existing project",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"-c",
|
||||||
|
"--ckpt",
|
||||||
|
type=str,
|
||||||
|
const=True,
|
||||||
|
default="",
|
||||||
|
nargs="?",
|
||||||
|
help="load pretrained checkpoint from stable AI",
|
||||||
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"-d",
|
"-d",
|
||||||
"--debug",
|
"--debug",
|
||||||
|
@ -145,22 +158,7 @@ def get_parser(**parser_kwargs):
|
||||||
default=True,
|
default=True,
|
||||||
help="scale base-lr by ngpu * batch_size * n_accumulate",
|
help="scale base-lr by ngpu * batch_size * n_accumulate",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
|
||||||
"--use_fp16",
|
|
||||||
type=str2bool,
|
|
||||||
nargs="?",
|
|
||||||
const=True,
|
|
||||||
default=True,
|
|
||||||
help="whether to use fp16",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--flash",
|
|
||||||
type=str2bool,
|
|
||||||
const=True,
|
|
||||||
default=False,
|
|
||||||
nargs="?",
|
|
||||||
help="whether to use flash attention",
|
|
||||||
)
|
|
||||||
return parser
|
return parser
|
||||||
|
|
||||||
|
|
||||||
|
@ -341,6 +339,12 @@ class SetupCallback(Callback):
|
||||||
except FileNotFoundError:
|
except FileNotFoundError:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
# def on_fit_end(self, trainer, pl_module):
|
||||||
|
# if trainer.global_rank == 0:
|
||||||
|
# ckpt_path = os.path.join(self.ckptdir, "last.ckpt")
|
||||||
|
# rank_zero_info(f"Saving final checkpoint in {ckpt_path}.")
|
||||||
|
# trainer.save_checkpoint(ckpt_path)
|
||||||
|
|
||||||
|
|
||||||
class ImageLogger(Callback):
|
class ImageLogger(Callback):
|
||||||
|
|
||||||
|
@ -536,6 +540,7 @@ if __name__ == "__main__":
|
||||||
"If you want to resume training in a new log folder, "
|
"If you want to resume training in a new log folder, "
|
||||||
"use -n/--name in combination with --resume_from_checkpoint")
|
"use -n/--name in combination with --resume_from_checkpoint")
|
||||||
if opt.resume:
|
if opt.resume:
|
||||||
|
rank_zero_info("Resuming from {}".format(opt.resume))
|
||||||
if not os.path.exists(opt.resume):
|
if not os.path.exists(opt.resume):
|
||||||
raise ValueError("Cannot find {}".format(opt.resume))
|
raise ValueError("Cannot find {}".format(opt.resume))
|
||||||
if os.path.isfile(opt.resume):
|
if os.path.isfile(opt.resume):
|
||||||
|
@ -543,13 +548,13 @@ if __name__ == "__main__":
|
||||||
# idx = len(paths)-paths[::-1].index("logs")+1
|
# idx = len(paths)-paths[::-1].index("logs")+1
|
||||||
# logdir = "/".join(paths[:idx])
|
# logdir = "/".join(paths[:idx])
|
||||||
logdir = "/".join(paths[:-2])
|
logdir = "/".join(paths[:-2])
|
||||||
|
rank_zero_info("logdir: {}".format(logdir))
|
||||||
ckpt = opt.resume
|
ckpt = opt.resume
|
||||||
else:
|
else:
|
||||||
assert os.path.isdir(opt.resume), opt.resume
|
assert os.path.isdir(opt.resume), opt.resume
|
||||||
logdir = opt.resume.rstrip("/")
|
logdir = opt.resume.rstrip("/")
|
||||||
ckpt = os.path.join(logdir, "checkpoints", "last.ckpt")
|
ckpt = os.path.join(logdir, "checkpoints", "last.ckpt")
|
||||||
|
|
||||||
opt.resume_from_checkpoint = ckpt
|
|
||||||
base_configs = sorted(glob.glob(os.path.join(logdir, "configs/*.yaml")))
|
base_configs = sorted(glob.glob(os.path.join(logdir, "configs/*.yaml")))
|
||||||
opt.base = base_configs + opt.base
|
opt.base = base_configs + opt.base
|
||||||
_tmp = logdir.split("/")
|
_tmp = logdir.split("/")
|
||||||
|
@ -558,6 +563,7 @@ if __name__ == "__main__":
|
||||||
if opt.name:
|
if opt.name:
|
||||||
name = "_" + opt.name
|
name = "_" + opt.name
|
||||||
elif opt.base:
|
elif opt.base:
|
||||||
|
rank_zero_info("Using base config {}".format(opt.base))
|
||||||
cfg_fname = os.path.split(opt.base[0])[-1]
|
cfg_fname = os.path.split(opt.base[0])[-1]
|
||||||
cfg_name = os.path.splitext(cfg_fname)[0]
|
cfg_name = os.path.splitext(cfg_fname)[0]
|
||||||
name = "_" + cfg_name
|
name = "_" + cfg_name
|
||||||
|
@ -566,6 +572,9 @@ if __name__ == "__main__":
|
||||||
nowname = now + name + opt.postfix
|
nowname = now + name + opt.postfix
|
||||||
logdir = os.path.join(opt.logdir, nowname)
|
logdir = os.path.join(opt.logdir, nowname)
|
||||||
|
|
||||||
|
if opt.ckpt:
|
||||||
|
ckpt = opt.ckpt
|
||||||
|
|
||||||
ckptdir = os.path.join(logdir, "checkpoints")
|
ckptdir = os.path.join(logdir, "checkpoints")
|
||||||
cfgdir = os.path.join(logdir, "configs")
|
cfgdir = os.path.join(logdir, "configs")
|
||||||
seed_everything(opt.seed)
|
seed_everything(opt.seed)
|
||||||
|
@ -582,14 +591,11 @@ if __name__ == "__main__":
|
||||||
for k in nondefault_trainer_args(opt):
|
for k in nondefault_trainer_args(opt):
|
||||||
trainer_config[k] = getattr(opt, k)
|
trainer_config[k] = getattr(opt, k)
|
||||||
|
|
||||||
print(trainer_config)
|
|
||||||
if not trainer_config["accelerator"] == "gpu":
|
if not trainer_config["accelerator"] == "gpu":
|
||||||
del trainer_config["accelerator"]
|
del trainer_config["accelerator"]
|
||||||
cpu = True
|
cpu = True
|
||||||
print("Running on CPU")
|
|
||||||
else:
|
else:
|
||||||
cpu = False
|
cpu = False
|
||||||
print("Running on GPU")
|
|
||||||
trainer_opt = argparse.Namespace(**trainer_config)
|
trainer_opt = argparse.Namespace(**trainer_config)
|
||||||
lightning_config.trainer = trainer_config
|
lightning_config.trainer = trainer_config
|
||||||
|
|
||||||
|
@ -597,10 +603,12 @@ if __name__ == "__main__":
|
||||||
use_fp16 = trainer_config.get("precision", 32) == 16
|
use_fp16 = trainer_config.get("precision", 32) == 16
|
||||||
if use_fp16:
|
if use_fp16:
|
||||||
config.model["params"].update({"use_fp16": True})
|
config.model["params"].update({"use_fp16": True})
|
||||||
print("Using FP16 = {}".format(config.model["params"]["use_fp16"]))
|
|
||||||
else:
|
else:
|
||||||
config.model["params"].update({"use_fp16": False})
|
config.model["params"].update({"use_fp16": False})
|
||||||
print("Using FP16 = {}".format(config.model["params"]["use_fp16"]))
|
|
||||||
|
if ckpt is not None:
|
||||||
|
config.model["params"].update({"ckpt": ckpt})
|
||||||
|
rank_zero_info("Using ckpt_path = {}".format(config.model["params"]["ckpt"]))
|
||||||
|
|
||||||
model = instantiate_from_config(config.model)
|
model = instantiate_from_config(config.model)
|
||||||
# trainer and callbacks
|
# trainer and callbacks
|
||||||
|
@ -639,7 +647,6 @@ if __name__ == "__main__":
|
||||||
# config the strategy, defualt is ddp
|
# config the strategy, defualt is ddp
|
||||||
if "strategy" in trainer_config:
|
if "strategy" in trainer_config:
|
||||||
strategy_cfg = trainer_config["strategy"]
|
strategy_cfg = trainer_config["strategy"]
|
||||||
print("Using strategy: {}".format(strategy_cfg["target"]))
|
|
||||||
strategy_cfg["target"] = LIGHTNING_PACK_NAME + strategy_cfg["target"]
|
strategy_cfg["target"] = LIGHTNING_PACK_NAME + strategy_cfg["target"]
|
||||||
else:
|
else:
|
||||||
strategy_cfg = {
|
strategy_cfg = {
|
||||||
|
@ -648,7 +655,6 @@ if __name__ == "__main__":
|
||||||
"find_unused_parameters": False
|
"find_unused_parameters": False
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
print("Using strategy: DDPStrategy")
|
|
||||||
|
|
||||||
trainer_kwargs["strategy"] = instantiate_from_config(strategy_cfg)
|
trainer_kwargs["strategy"] = instantiate_from_config(strategy_cfg)
|
||||||
|
|
||||||
|
@ -664,7 +670,6 @@ if __name__ == "__main__":
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if hasattr(model, "monitor"):
|
if hasattr(model, "monitor"):
|
||||||
print(f"Monitoring {model.monitor} as checkpoint metric.")
|
|
||||||
default_modelckpt_cfg["params"]["monitor"] = model.monitor
|
default_modelckpt_cfg["params"]["monitor"] = model.monitor
|
||||||
default_modelckpt_cfg["params"]["save_top_k"] = 3
|
default_modelckpt_cfg["params"]["save_top_k"] = 3
|
||||||
|
|
||||||
|
@ -673,7 +678,6 @@ if __name__ == "__main__":
|
||||||
else:
|
else:
|
||||||
modelckpt_cfg = OmegaConf.create()
|
modelckpt_cfg = OmegaConf.create()
|
||||||
modelckpt_cfg = OmegaConf.merge(default_modelckpt_cfg, modelckpt_cfg)
|
modelckpt_cfg = OmegaConf.merge(default_modelckpt_cfg, modelckpt_cfg)
|
||||||
print(f"Merged modelckpt-cfg: \n{modelckpt_cfg}")
|
|
||||||
if version.parse(pl.__version__) < version.parse('1.4.0'):
|
if version.parse(pl.__version__) < version.parse('1.4.0'):
|
||||||
trainer_kwargs["checkpoint_callback"] = instantiate_from_config(modelckpt_cfg)
|
trainer_kwargs["checkpoint_callback"] = instantiate_from_config(modelckpt_cfg)
|
||||||
|
|
||||||
|
@ -710,8 +714,6 @@ if __name__ == "__main__":
|
||||||
"target": "main.CUDACallback"
|
"target": "main.CUDACallback"
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
if version.parse(pl.__version__) >= version.parse('1.4.0'):
|
|
||||||
default_callbacks_cfg.update({'checkpoint_callback': modelckpt_cfg})
|
|
||||||
|
|
||||||
if "callbacks" in lightning_config:
|
if "callbacks" in lightning_config:
|
||||||
callbacks_cfg = lightning_config.callbacks
|
callbacks_cfg = lightning_config.callbacks
|
||||||
|
@ -737,15 +739,11 @@ if __name__ == "__main__":
|
||||||
default_callbacks_cfg.update(default_metrics_over_trainsteps_ckpt_dict)
|
default_callbacks_cfg.update(default_metrics_over_trainsteps_ckpt_dict)
|
||||||
|
|
||||||
callbacks_cfg = OmegaConf.merge(default_callbacks_cfg, callbacks_cfg)
|
callbacks_cfg = OmegaConf.merge(default_callbacks_cfg, callbacks_cfg)
|
||||||
if 'ignore_keys_callback' in callbacks_cfg and hasattr(trainer_opt, 'resume_from_checkpoint'):
|
|
||||||
callbacks_cfg.ignore_keys_callback.params['ckpt_path'] = trainer_opt.resume_from_checkpoint
|
|
||||||
elif 'ignore_keys_callback' in callbacks_cfg:
|
|
||||||
del callbacks_cfg['ignore_keys_callback']
|
|
||||||
|
|
||||||
trainer_kwargs["callbacks"] = [instantiate_from_config(callbacks_cfg[k]) for k in callbacks_cfg]
|
trainer_kwargs["callbacks"] = [instantiate_from_config(callbacks_cfg[k]) for k in callbacks_cfg]
|
||||||
|
|
||||||
trainer = Trainer.from_argparse_args(trainer_opt, **trainer_kwargs)
|
trainer = Trainer.from_argparse_args(trainer_opt, **trainer_kwargs)
|
||||||
trainer.logdir = logdir ###
|
trainer.logdir = logdir
|
||||||
|
|
||||||
# data
|
# data
|
||||||
data = instantiate_from_config(config.data)
|
data = instantiate_from_config(config.data)
|
||||||
|
@ -754,9 +752,9 @@ if __name__ == "__main__":
|
||||||
# lightning still takes care of proper multiprocessing though
|
# lightning still takes care of proper multiprocessing though
|
||||||
data.prepare_data()
|
data.prepare_data()
|
||||||
data.setup()
|
data.setup()
|
||||||
print("#### Data #####")
|
|
||||||
for k in data.datasets:
|
for k in data.datasets:
|
||||||
print(f"{k}, {data.datasets[k].__class__.__name__}, {len(data.datasets[k])}")
|
rank_zero_info(f"{k}, {data.datasets[k].__class__.__name__}, {len(data.datasets[k])}")
|
||||||
|
|
||||||
# configure learning rate
|
# configure learning rate
|
||||||
bs, base_lr = config.data.params.batch_size, config.model.base_learning_rate
|
bs, base_lr = config.data.params.batch_size, config.model.base_learning_rate
|
||||||
|
@ -768,17 +766,17 @@ if __name__ == "__main__":
|
||||||
accumulate_grad_batches = lightning_config.trainer.accumulate_grad_batches
|
accumulate_grad_batches = lightning_config.trainer.accumulate_grad_batches
|
||||||
else:
|
else:
|
||||||
accumulate_grad_batches = 1
|
accumulate_grad_batches = 1
|
||||||
print(f"accumulate_grad_batches = {accumulate_grad_batches}")
|
rank_zero_info(f"accumulate_grad_batches = {accumulate_grad_batches}")
|
||||||
lightning_config.trainer.accumulate_grad_batches = accumulate_grad_batches
|
lightning_config.trainer.accumulate_grad_batches = accumulate_grad_batches
|
||||||
if opt.scale_lr:
|
if opt.scale_lr:
|
||||||
model.learning_rate = accumulate_grad_batches * ngpu * bs * base_lr
|
model.learning_rate = accumulate_grad_batches * ngpu * bs * base_lr
|
||||||
print(
|
rank_zero_info(
|
||||||
"Setting learning rate to {:.2e} = {} (accumulate_grad_batches) * {} (num_gpus) * {} (batchsize) * {:.2e} (base_lr)"
|
"Setting learning rate to {:.2e} = {} (accumulate_grad_batches) * {} (num_gpus) * {} (batchsize) * {:.2e} (base_lr)"
|
||||||
.format(model.learning_rate, accumulate_grad_batches, ngpu, bs, base_lr))
|
.format(model.learning_rate, accumulate_grad_batches, ngpu, bs, base_lr))
|
||||||
else:
|
else:
|
||||||
model.learning_rate = base_lr
|
model.learning_rate = base_lr
|
||||||
print("++++ NOT USING LR SCALING ++++")
|
rank_zero_info("++++ NOT USING LR SCALING ++++")
|
||||||
print(f"Setting learning rate to {model.learning_rate:.2e}")
|
rank_zero_info(f"Setting learning rate to {model.learning_rate:.2e}")
|
||||||
|
|
||||||
# allow checkpointing via USR1
|
# allow checkpointing via USR1
|
||||||
def melk(*args, **kwargs):
|
def melk(*args, **kwargs):
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
python scripts/txt2img.py --prompt "Teyvat, Name:Layla, Element: Cryo, Weapon:Sword, Region:Sumeru, Model type:Medium Female, Description:a woman in a blue outfit holding a sword" --plms \
|
python scripts/txt2img.py --prompt "Teyvat, Medium Female, a woman in a blue outfit holding a sword" --plms \
|
||||||
--outdir ./output \
|
--outdir ./output \
|
||||||
--ckpt /tmp/2022-11-18T16-38-46_train_colossalai/checkpoints/last.ckpt \
|
--ckpt checkpoints/last.ckpt \
|
||||||
--config /tmp/2022-11-18T16-38-46_train_colossalai/configs/2022-11-18T16-38-46-project.yaml \
|
--config configs/2023-02-02T18-06-14-project.yaml \
|
||||||
--n_samples 4
|
--n_samples 4
|
||||||
|
|
|
@ -2,4 +2,4 @@ HF_DATASETS_OFFLINE=1
|
||||||
TRANSFORMERS_OFFLINE=1
|
TRANSFORMERS_OFFLINE=1
|
||||||
DIFFUSERS_OFFLINE=1
|
DIFFUSERS_OFFLINE=1
|
||||||
|
|
||||||
python main.py --logdir /tmp -t -b configs/train_colossalai.yaml
|
python main.py --logdir /tmp --train --base configs/Teyvat/train_colossalai_teyvat.yaml --ckpt diffuser_root_dir/512-base-ema.ckpt
|
||||||
|
|
Loading…
Reference in New Issue