mirror of https://github.com/hpcaitech/ColossalAI
commit
79079a9d0c
|
@ -53,27 +53,33 @@ You can also update an existing [latent diffusion](https://github.com/CompVis/la
|
|||
```
|
||||
conda install pytorch==1.12.1 torchvision==0.13.1 torchaudio==0.12.1 cudatoolkit=11.3 -c pytorch
|
||||
pip install transformers==4.19.2 diffusers invisible-watermark
|
||||
pip install -e .
|
||||
```
|
||||
|
||||
#### Step 2: install lightning
|
||||
|
||||
Install Lightning version later than 2022.01.04. We suggest you install lightning from source.
|
||||
|
||||
##### From Source
|
||||
```
|
||||
git clone https://github.com/Lightning-AI/lightning.git
|
||||
pip install -r requirements.txt
|
||||
python setup.py install
|
||||
```
|
||||
|
||||
##### From pip
|
||||
|
||||
```
|
||||
pip install pytorch-lightning
|
||||
```
|
||||
|
||||
#### Step 3:Install [Colossal-AI](https://colossalai.org/download/) From Our Official Website
|
||||
|
||||
##### From pip
|
||||
|
||||
For example, you can install v0.1.12 from our official website.
|
||||
For example, you can install v0.2.0 from our official website.
|
||||
|
||||
```
|
||||
pip install colossalai==0.1.12+torch1.12cu11.3 -f https://release.colossalai.org
|
||||
pip install colossalai==0.2.0+torch1.12cu11.3 -f https://release.colossalai.org
|
||||
```
|
||||
|
||||
##### From source
|
||||
|
@ -133,10 +139,9 @@ It is important for you to configure your volume mapping in order to get the bes
|
|||
3. **Optional**, if you encounter any problem stating that shared memory is insufficient inside container, please add `-v /dev/shm:/dev/shm` to your `docker run` command.
|
||||
|
||||
|
||||
|
||||
## Download the model checkpoint from pretrained
|
||||
|
||||
### stable-diffusion-v2-base
|
||||
### stable-diffusion-v2-base(Recommand)
|
||||
|
||||
```
|
||||
wget https://huggingface.co/stabilityai/stable-diffusion-2-base/resolve/main/512-base-ema.ckpt
|
||||
|
@ -144,8 +149,6 @@ wget https://huggingface.co/stabilityai/stable-diffusion-2-base/resolve/main/512
|
|||
|
||||
### stable-diffusion-v1-4
|
||||
|
||||
Our default model config use the weight from [CompVis/stable-diffusion-v1-4](https://huggingface.co/CompVis/stable-diffusion-v1-4?text=A+mecha+robot+in+a+favela+in+expressionist+style)
|
||||
|
||||
```
|
||||
git lfs install
|
||||
git clone https://huggingface.co/CompVis/stable-diffusion-v1-4
|
||||
|
@ -153,8 +156,6 @@ git clone https://huggingface.co/CompVis/stable-diffusion-v1-4
|
|||
|
||||
### stable-diffusion-v1-5 from runway
|
||||
|
||||
If you want to useed the Last [stable-diffusion-v1-5](https://huggingface.co/runwayml/stable-diffusion-v1-5) weight from runwayml
|
||||
|
||||
```
|
||||
git lfs install
|
||||
git clone https://huggingface.co/runwayml/stable-diffusion-v1-5
|
||||
|
@ -171,11 +172,16 @@ We provide the script `train_colossalai.sh` to run the training task with coloss
|
|||
and can also use `train_ddp.sh` to run the training task with ddp to compare.
|
||||
|
||||
In `train_colossalai.sh` the main command is:
|
||||
|
||||
```
|
||||
python main.py --logdir /tmp/ -t -b configs/train_colossalai.yaml
|
||||
python main.py --logdir /tmp/ --train --base configs/train_colossalai.yaml --ckpt 512-base-ema.ckpt
|
||||
```
|
||||
|
||||
- you can change the `--logdir` to decide where to save the log information and the last checkpoint.
|
||||
- You can change the `--logdir` to decide where to save the log information and the last checkpoint.
|
||||
- You will find your ckpt in `logdir/checkpoints` or `logdir/diff_tb/version_0/checkpoints`
|
||||
- You will find your train config yaml in `logdir/configs`
|
||||
- You can add the `--ckpt` if you want to load the pretrained model, for example `512-base-ema.ckpt`
|
||||
- You can change the `--base` to specify the path of config yaml
|
||||
|
||||
### Training config
|
||||
|
||||
|
@ -186,7 +192,8 @@ You can change the trainging config in the yaml file
|
|||
- precision: the precision type used in training, default 16 (fp16), you must use fp16 if you want to apply colossalai
|
||||
- more information about the configuration of ColossalAIStrategy can be found [here](https://pytorch-lightning.readthedocs.io/en/latest/advanced/model_parallel.html#colossal-ai)
|
||||
|
||||
## Finetune Example (Work In Progress)
|
||||
|
||||
## Finetune Example
|
||||
### Training on Teyvat Datasets
|
||||
|
||||
We provide the finetuning example on [Teyvat](https://huggingface.co/datasets/Fazzie/Teyvat) dataset, which is create by BLIP generated captions.
|
||||
|
@ -201,8 +208,8 @@ you can get yout training last.ckpt and train config.yaml in your `--logdir`, an
|
|||
```
|
||||
python scripts/txt2img.py --prompt "a photograph of an astronaut riding a horse" --plms
|
||||
--outdir ./output \
|
||||
--config path/to/logdir/checkpoints/last.ckpt \
|
||||
--ckpt /path/to/logdir/configs/project.yaml \
|
||||
--ckpt path/to/logdir/checkpoints/last.ckpt \
|
||||
--config /path/to/logdir/configs/project.yaml \
|
||||
```
|
||||
|
||||
```commandline
|
||||
|
|
|
@ -6,6 +6,7 @@ model:
|
|||
linear_start: 0.00085
|
||||
linear_end: 0.0120
|
||||
num_timesteps_cond: 1
|
||||
ckpt: None # use ckpt path
|
||||
log_every_t: 200
|
||||
timesteps: 1000
|
||||
first_stage_key: image
|
||||
|
@ -16,7 +17,7 @@ model:
|
|||
conditioning_key: crossattn
|
||||
monitor: val/loss_simple_ema
|
||||
scale_factor: 0.18215
|
||||
use_ema: False # we set this to false because this is an inference only config
|
||||
use_ema: False
|
||||
|
||||
scheduler_config: # 10000 warmup steps
|
||||
target: ldm.lr_scheduler.LambdaLinearScheduler
|
||||
|
|
File diff suppressed because it is too large
Load Diff
|
@ -1,10 +1,11 @@
|
|||
# pytorch_diffusion + derived encoder decoder
|
||||
import math
|
||||
from typing import Any, Optional
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import numpy as np
|
||||
from einops import rearrange
|
||||
from typing import Optional, Any
|
||||
|
||||
try:
|
||||
from lightning.pytorch.utilities import rank_zero_info
|
||||
|
@ -38,14 +39,14 @@ def get_timestep_embedding(timesteps, embedding_dim):
|
|||
emb = emb.to(device=timesteps.device)
|
||||
emb = timesteps.float()[:, None] * emb[None, :]
|
||||
emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
|
||||
if embedding_dim % 2 == 1: # zero pad
|
||||
emb = torch.nn.functional.pad(emb, (0,1,0,0))
|
||||
if embedding_dim % 2 == 1: # zero pad
|
||||
emb = torch.nn.functional.pad(emb, (0, 1, 0, 0))
|
||||
return emb
|
||||
|
||||
|
||||
def nonlinearity(x):
|
||||
# swish
|
||||
return x*torch.sigmoid(x)
|
||||
return x * torch.sigmoid(x)
|
||||
|
||||
|
||||
def Normalize(in_channels, num_groups=32):
|
||||
|
@ -53,15 +54,12 @@ def Normalize(in_channels, num_groups=32):
|
|||
|
||||
|
||||
class Upsample(nn.Module):
|
||||
|
||||
def __init__(self, in_channels, with_conv):
|
||||
super().__init__()
|
||||
self.with_conv = with_conv
|
||||
if self.with_conv:
|
||||
self.conv = torch.nn.Conv2d(in_channels,
|
||||
in_channels,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1)
|
||||
self.conv = torch.nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1)
|
||||
|
||||
def forward(self, x):
|
||||
x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")
|
||||
|
@ -71,20 +69,17 @@ class Upsample(nn.Module):
|
|||
|
||||
|
||||
class Downsample(nn.Module):
|
||||
|
||||
def __init__(self, in_channels, with_conv):
|
||||
super().__init__()
|
||||
self.with_conv = with_conv
|
||||
if self.with_conv:
|
||||
# no asymmetric padding in torch conv, must do it ourselves
|
||||
self.conv = torch.nn.Conv2d(in_channels,
|
||||
in_channels,
|
||||
kernel_size=3,
|
||||
stride=2,
|
||||
padding=0)
|
||||
self.conv = torch.nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=2, padding=0)
|
||||
|
||||
def forward(self, x):
|
||||
if self.with_conv:
|
||||
pad = (0,1,0,1)
|
||||
pad = (0, 1, 0, 1)
|
||||
x = torch.nn.functional.pad(x, pad, mode="constant", value=0)
|
||||
x = self.conv(x)
|
||||
else:
|
||||
|
@ -93,8 +88,8 @@ class Downsample(nn.Module):
|
|||
|
||||
|
||||
class ResnetBlock(nn.Module):
|
||||
def __init__(self, *, in_channels, out_channels=None, conv_shortcut=False,
|
||||
dropout, temb_channels=512):
|
||||
|
||||
def __init__(self, *, in_channels, out_channels=None, conv_shortcut=False, dropout, temb_channels=512):
|
||||
super().__init__()
|
||||
self.in_channels = in_channels
|
||||
out_channels = in_channels if out_channels is None else out_channels
|
||||
|
@ -102,34 +97,17 @@ class ResnetBlock(nn.Module):
|
|||
self.use_conv_shortcut = conv_shortcut
|
||||
|
||||
self.norm1 = Normalize(in_channels)
|
||||
self.conv1 = torch.nn.Conv2d(in_channels,
|
||||
out_channels,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1)
|
||||
self.conv1 = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
|
||||
if temb_channels > 0:
|
||||
self.temb_proj = torch.nn.Linear(temb_channels,
|
||||
out_channels)
|
||||
self.temb_proj = torch.nn.Linear(temb_channels, out_channels)
|
||||
self.norm2 = Normalize(out_channels)
|
||||
self.dropout = torch.nn.Dropout(dropout)
|
||||
self.conv2 = torch.nn.Conv2d(out_channels,
|
||||
out_channels,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1)
|
||||
self.conv2 = torch.nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
|
||||
if self.in_channels != self.out_channels:
|
||||
if self.use_conv_shortcut:
|
||||
self.conv_shortcut = torch.nn.Conv2d(in_channels,
|
||||
out_channels,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1)
|
||||
self.conv_shortcut = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
|
||||
else:
|
||||
self.nin_shortcut = torch.nn.Conv2d(in_channels,
|
||||
out_channels,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0)
|
||||
self.nin_shortcut = torch.nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
|
||||
|
||||
def forward(self, x, temb):
|
||||
h = x
|
||||
|
@ -138,7 +116,7 @@ class ResnetBlock(nn.Module):
|
|||
h = self.conv1(h)
|
||||
|
||||
if temb is not None:
|
||||
h = h + self.temb_proj(nonlinearity(temb))[:,:,None,None]
|
||||
h = h + self.temb_proj(nonlinearity(temb))[:, :, None, None]
|
||||
|
||||
h = self.norm2(h)
|
||||
h = nonlinearity(h)
|
||||
|
@ -151,35 +129,20 @@ class ResnetBlock(nn.Module):
|
|||
else:
|
||||
x = self.nin_shortcut(x)
|
||||
|
||||
return x+h
|
||||
return x + h
|
||||
|
||||
|
||||
class AttnBlock(nn.Module):
|
||||
|
||||
def __init__(self, in_channels):
|
||||
super().__init__()
|
||||
self.in_channels = in_channels
|
||||
|
||||
self.norm = Normalize(in_channels)
|
||||
self.q = torch.nn.Conv2d(in_channels,
|
||||
in_channels,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0)
|
||||
self.k = torch.nn.Conv2d(in_channels,
|
||||
in_channels,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0)
|
||||
self.v = torch.nn.Conv2d(in_channels,
|
||||
in_channels,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0)
|
||||
self.proj_out = torch.nn.Conv2d(in_channels,
|
||||
in_channels,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0)
|
||||
self.q = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
|
||||
self.k = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
|
||||
self.v = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
|
||||
self.proj_out = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
|
||||
|
||||
def forward(self, x):
|
||||
h_ = x
|
||||
|
@ -189,23 +152,24 @@ class AttnBlock(nn.Module):
|
|||
v = self.v(h_)
|
||||
|
||||
# compute attention
|
||||
b,c,h,w = q.shape
|
||||
q = q.reshape(b,c,h*w)
|
||||
q = q.permute(0,2,1) # b,hw,c
|
||||
k = k.reshape(b,c,h*w) # b,c,hw
|
||||
w_ = torch.bmm(q,k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j]
|
||||
b, c, h, w = q.shape
|
||||
q = q.reshape(b, c, h * w)
|
||||
q = q.permute(0, 2, 1) # b,hw,c
|
||||
k = k.reshape(b, c, h * w) # b,c,hw
|
||||
w_ = torch.bmm(q, k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j]
|
||||
w_ = w_ * (int(c)**(-0.5))
|
||||
w_ = torch.nn.functional.softmax(w_, dim=2)
|
||||
|
||||
# attend to values
|
||||
v = v.reshape(b,c,h*w)
|
||||
w_ = w_.permute(0,2,1) # b,hw,hw (first hw of k, second of q)
|
||||
h_ = torch.bmm(v,w_) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j]
|
||||
h_ = h_.reshape(b,c,h,w)
|
||||
v = v.reshape(b, c, h * w)
|
||||
w_ = w_.permute(0, 2, 1) # b,hw,hw (first hw of k, second of q)
|
||||
h_ = torch.bmm(v, w_) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j]
|
||||
h_ = h_.reshape(b, c, h, w)
|
||||
|
||||
h_ = self.proj_out(h_)
|
||||
|
||||
return x+h_
|
||||
return x + h_
|
||||
|
||||
|
||||
class MemoryEfficientAttnBlock(nn.Module):
|
||||
"""
|
||||
|
@ -213,32 +177,17 @@ class MemoryEfficientAttnBlock(nn.Module):
|
|||
see https://github.com/MatthieuTPHR/diffusers/blob/d80b531ff8060ec1ea982b65a1b8df70f73aa67c/src/diffusers/models/attention.py#L223
|
||||
Note: this is a single-head self-attention operation
|
||||
"""
|
||||
|
||||
#
|
||||
def __init__(self, in_channels):
|
||||
super().__init__()
|
||||
self.in_channels = in_channels
|
||||
|
||||
self.norm = Normalize(in_channels)
|
||||
self.q = torch.nn.Conv2d(in_channels,
|
||||
in_channels,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0)
|
||||
self.k = torch.nn.Conv2d(in_channels,
|
||||
in_channels,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0)
|
||||
self.v = torch.nn.Conv2d(in_channels,
|
||||
in_channels,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0)
|
||||
self.proj_out = torch.nn.Conv2d(in_channels,
|
||||
in_channels,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0)
|
||||
self.q = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
|
||||
self.k = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
|
||||
self.v = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
|
||||
self.proj_out = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
|
||||
self.attention_op: Optional[Any] = None
|
||||
|
||||
def forward(self, x):
|
||||
|
@ -253,27 +202,20 @@ class MemoryEfficientAttnBlock(nn.Module):
|
|||
q, k, v = map(lambda x: rearrange(x, 'b c h w -> b (h w) c'), (q, k, v))
|
||||
|
||||
q, k, v = map(
|
||||
lambda t: t.unsqueeze(3)
|
||||
.reshape(B, t.shape[1], 1, C)
|
||||
.permute(0, 2, 1, 3)
|
||||
.reshape(B * 1, t.shape[1], C)
|
||||
.contiguous(),
|
||||
lambda t: t.unsqueeze(3).reshape(B, t.shape[1], 1, C).permute(0, 2, 1, 3).reshape(B * 1, t.shape[1], C).
|
||||
contiguous(),
|
||||
(q, k, v),
|
||||
)
|
||||
out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None, op=self.attention_op)
|
||||
|
||||
out = (
|
||||
out.unsqueeze(0)
|
||||
.reshape(B, 1, out.shape[1], C)
|
||||
.permute(0, 2, 1, 3)
|
||||
.reshape(B, out.shape[1], C)
|
||||
)
|
||||
out = (out.unsqueeze(0).reshape(B, 1, out.shape[1], C).permute(0, 2, 1, 3).reshape(B, out.shape[1], C))
|
||||
out = rearrange(out, 'b (h w) c -> b c h w', b=B, h=H, w=W, c=C)
|
||||
out = self.proj_out(out)
|
||||
return x+out
|
||||
return x + out
|
||||
|
||||
|
||||
class MemoryEfficientCrossAttentionWrapper(MemoryEfficientCrossAttention):
|
||||
|
||||
def forward(self, x, context=None, mask=None):
|
||||
b, c, h, w = x.shape
|
||||
x = rearrange(x, 'b c h w -> b (h w) c')
|
||||
|
@ -283,10 +225,10 @@ class MemoryEfficientCrossAttentionWrapper(MemoryEfficientCrossAttention):
|
|||
|
||||
|
||||
def make_attn(in_channels, attn_type="vanilla", attn_kwargs=None):
|
||||
assert attn_type in ["vanilla", "vanilla-xformers", "memory-efficient-cross-attn", "linear", "none"], f'attn_type {attn_type} unknown'
|
||||
assert attn_type in ["vanilla", "vanilla-xformers", "memory-efficient-cross-attn", "linear",
|
||||
"none"], f'attn_type {attn_type} unknown'
|
||||
if XFORMERS_IS_AVAILBLE and attn_type == "vanilla":
|
||||
attn_type = "vanilla-xformers"
|
||||
rank_zero_info(f"making attention of type '{attn_type}' with {in_channels} in_channels")
|
||||
if attn_type == "vanilla":
|
||||
assert attn_kwargs is None
|
||||
return AttnBlock(in_channels)
|
||||
|
@ -303,13 +245,26 @@ def make_attn(in_channels, attn_type="vanilla", attn_kwargs=None):
|
|||
|
||||
|
||||
class Model(nn.Module):
|
||||
def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks,
|
||||
attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels,
|
||||
resolution, use_timestep=True, use_linear_attn=False, attn_type="vanilla"):
|
||||
|
||||
def __init__(self,
|
||||
*,
|
||||
ch,
|
||||
out_ch,
|
||||
ch_mult=(1, 2, 4, 8),
|
||||
num_res_blocks,
|
||||
attn_resolutions,
|
||||
dropout=0.0,
|
||||
resamp_with_conv=True,
|
||||
in_channels,
|
||||
resolution,
|
||||
use_timestep=True,
|
||||
use_linear_attn=False,
|
||||
attn_type="vanilla"):
|
||||
super().__init__()
|
||||
if use_linear_attn: attn_type = "linear"
|
||||
if use_linear_attn:
|
||||
attn_type = "linear"
|
||||
self.ch = ch
|
||||
self.temb_ch = self.ch*4
|
||||
self.temb_ch = self.ch * 4
|
||||
self.num_resolutions = len(ch_mult)
|
||||
self.num_res_blocks = num_res_blocks
|
||||
self.resolution = resolution
|
||||
|
@ -320,39 +275,34 @@ class Model(nn.Module):
|
|||
# timestep embedding
|
||||
self.temb = nn.Module()
|
||||
self.temb.dense = nn.ModuleList([
|
||||
torch.nn.Linear(self.ch,
|
||||
self.temb_ch),
|
||||
torch.nn.Linear(self.temb_ch,
|
||||
self.temb_ch),
|
||||
torch.nn.Linear(self.ch, self.temb_ch),
|
||||
torch.nn.Linear(self.temb_ch, self.temb_ch),
|
||||
])
|
||||
|
||||
# downsampling
|
||||
self.conv_in = torch.nn.Conv2d(in_channels,
|
||||
self.ch,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1)
|
||||
self.conv_in = torch.nn.Conv2d(in_channels, self.ch, kernel_size=3, stride=1, padding=1)
|
||||
|
||||
curr_res = resolution
|
||||
in_ch_mult = (1,)+tuple(ch_mult)
|
||||
in_ch_mult = (1,) + tuple(ch_mult)
|
||||
self.down = nn.ModuleList()
|
||||
for i_level in range(self.num_resolutions):
|
||||
block = nn.ModuleList()
|
||||
attn = nn.ModuleList()
|
||||
block_in = ch*in_ch_mult[i_level]
|
||||
block_out = ch*ch_mult[i_level]
|
||||
block_in = ch * in_ch_mult[i_level]
|
||||
block_out = ch * ch_mult[i_level]
|
||||
for i_block in range(self.num_res_blocks):
|
||||
block.append(ResnetBlock(in_channels=block_in,
|
||||
out_channels=block_out,
|
||||
temb_channels=self.temb_ch,
|
||||
dropout=dropout))
|
||||
block.append(
|
||||
ResnetBlock(in_channels=block_in,
|
||||
out_channels=block_out,
|
||||
temb_channels=self.temb_ch,
|
||||
dropout=dropout))
|
||||
block_in = block_out
|
||||
if curr_res in attn_resolutions:
|
||||
attn.append(make_attn(block_in, attn_type=attn_type))
|
||||
down = nn.Module()
|
||||
down.block = block
|
||||
down.attn = attn
|
||||
if i_level != self.num_resolutions-1:
|
||||
if i_level != self.num_resolutions - 1:
|
||||
down.downsample = Downsample(block_in, resamp_with_conv)
|
||||
curr_res = curr_res // 2
|
||||
self.down.append(down)
|
||||
|
@ -374,15 +324,16 @@ class Model(nn.Module):
|
|||
for i_level in reversed(range(self.num_resolutions)):
|
||||
block = nn.ModuleList()
|
||||
attn = nn.ModuleList()
|
||||
block_out = ch*ch_mult[i_level]
|
||||
skip_in = ch*ch_mult[i_level]
|
||||
for i_block in range(self.num_res_blocks+1):
|
||||
block_out = ch * ch_mult[i_level]
|
||||
skip_in = ch * ch_mult[i_level]
|
||||
for i_block in range(self.num_res_blocks + 1):
|
||||
if i_block == self.num_res_blocks:
|
||||
skip_in = ch*in_ch_mult[i_level]
|
||||
block.append(ResnetBlock(in_channels=block_in+skip_in,
|
||||
out_channels=block_out,
|
||||
temb_channels=self.temb_ch,
|
||||
dropout=dropout))
|
||||
skip_in = ch * in_ch_mult[i_level]
|
||||
block.append(
|
||||
ResnetBlock(in_channels=block_in + skip_in,
|
||||
out_channels=block_out,
|
||||
temb_channels=self.temb_ch,
|
||||
dropout=dropout))
|
||||
block_in = block_out
|
||||
if curr_res in attn_resolutions:
|
||||
attn.append(make_attn(block_in, attn_type=attn_type))
|
||||
|
@ -392,15 +343,11 @@ class Model(nn.Module):
|
|||
if i_level != 0:
|
||||
up.upsample = Upsample(block_in, resamp_with_conv)
|
||||
curr_res = curr_res * 2
|
||||
self.up.insert(0, up) # prepend to get consistent order
|
||||
self.up.insert(0, up) # prepend to get consistent order
|
||||
|
||||
# end
|
||||
self.norm_out = Normalize(block_in)
|
||||
self.conv_out = torch.nn.Conv2d(block_in,
|
||||
out_ch,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1)
|
||||
self.conv_out = torch.nn.Conv2d(block_in, out_ch, kernel_size=3, stride=1, padding=1)
|
||||
|
||||
def forward(self, x, t=None, context=None):
|
||||
#assert x.shape[2] == x.shape[3] == self.resolution
|
||||
|
@ -425,7 +372,7 @@ class Model(nn.Module):
|
|||
if len(self.down[i_level].attn) > 0:
|
||||
h = self.down[i_level].attn[i_block](h)
|
||||
hs.append(h)
|
||||
if i_level != self.num_resolutions-1:
|
||||
if i_level != self.num_resolutions - 1:
|
||||
hs.append(self.down[i_level].downsample(hs[-1]))
|
||||
|
||||
# middle
|
||||
|
@ -436,9 +383,8 @@ class Model(nn.Module):
|
|||
|
||||
# upsampling
|
||||
for i_level in reversed(range(self.num_resolutions)):
|
||||
for i_block in range(self.num_res_blocks+1):
|
||||
h = self.up[i_level].block[i_block](
|
||||
torch.cat([h, hs.pop()], dim=1), temb)
|
||||
for i_block in range(self.num_res_blocks + 1):
|
||||
h = self.up[i_level].block[i_block](torch.cat([h, hs.pop()], dim=1), temb)
|
||||
if len(self.up[i_level].attn) > 0:
|
||||
h = self.up[i_level].attn[i_block](h)
|
||||
if i_level != 0:
|
||||
|
@ -455,12 +401,26 @@ class Model(nn.Module):
|
|||
|
||||
|
||||
class Encoder(nn.Module):
|
||||
def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks,
|
||||
attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels,
|
||||
resolution, z_channels, double_z=True, use_linear_attn=False, attn_type="vanilla",
|
||||
|
||||
def __init__(self,
|
||||
*,
|
||||
ch,
|
||||
out_ch,
|
||||
ch_mult=(1, 2, 4, 8),
|
||||
num_res_blocks,
|
||||
attn_resolutions,
|
||||
dropout=0.0,
|
||||
resamp_with_conv=True,
|
||||
in_channels,
|
||||
resolution,
|
||||
z_channels,
|
||||
double_z=True,
|
||||
use_linear_attn=False,
|
||||
attn_type="vanilla",
|
||||
**ignore_kwargs):
|
||||
super().__init__()
|
||||
if use_linear_attn: attn_type = "linear"
|
||||
if use_linear_attn:
|
||||
attn_type = "linear"
|
||||
self.ch = ch
|
||||
self.temb_ch = 0
|
||||
self.num_resolutions = len(ch_mult)
|
||||
|
@ -469,33 +429,30 @@ class Encoder(nn.Module):
|
|||
self.in_channels = in_channels
|
||||
|
||||
# downsampling
|
||||
self.conv_in = torch.nn.Conv2d(in_channels,
|
||||
self.ch,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1)
|
||||
self.conv_in = torch.nn.Conv2d(in_channels, self.ch, kernel_size=3, stride=1, padding=1)
|
||||
|
||||
curr_res = resolution
|
||||
in_ch_mult = (1,)+tuple(ch_mult)
|
||||
in_ch_mult = (1,) + tuple(ch_mult)
|
||||
self.in_ch_mult = in_ch_mult
|
||||
self.down = nn.ModuleList()
|
||||
for i_level in range(self.num_resolutions):
|
||||
block = nn.ModuleList()
|
||||
attn = nn.ModuleList()
|
||||
block_in = ch*in_ch_mult[i_level]
|
||||
block_out = ch*ch_mult[i_level]
|
||||
block_in = ch * in_ch_mult[i_level]
|
||||
block_out = ch * ch_mult[i_level]
|
||||
for i_block in range(self.num_res_blocks):
|
||||
block.append(ResnetBlock(in_channels=block_in,
|
||||
out_channels=block_out,
|
||||
temb_channels=self.temb_ch,
|
||||
dropout=dropout))
|
||||
block.append(
|
||||
ResnetBlock(in_channels=block_in,
|
||||
out_channels=block_out,
|
||||
temb_channels=self.temb_ch,
|
||||
dropout=dropout))
|
||||
block_in = block_out
|
||||
if curr_res in attn_resolutions:
|
||||
attn.append(make_attn(block_in, attn_type=attn_type))
|
||||
down = nn.Module()
|
||||
down.block = block
|
||||
down.attn = attn
|
||||
if i_level != self.num_resolutions-1:
|
||||
if i_level != self.num_resolutions - 1:
|
||||
down.downsample = Downsample(block_in, resamp_with_conv)
|
||||
curr_res = curr_res // 2
|
||||
self.down.append(down)
|
||||
|
@ -515,7 +472,7 @@ class Encoder(nn.Module):
|
|||
# end
|
||||
self.norm_out = Normalize(block_in)
|
||||
self.conv_out = torch.nn.Conv2d(block_in,
|
||||
2*z_channels if double_z else z_channels,
|
||||
2 * z_channels if double_z else z_channels,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1)
|
||||
|
@ -532,7 +489,7 @@ class Encoder(nn.Module):
|
|||
if len(self.down[i_level].attn) > 0:
|
||||
h = self.down[i_level].attn[i_block](h)
|
||||
hs.append(h)
|
||||
if i_level != self.num_resolutions-1:
|
||||
if i_level != self.num_resolutions - 1:
|
||||
hs.append(self.down[i_level].downsample(hs[-1]))
|
||||
|
||||
# middle
|
||||
|
@ -549,12 +506,27 @@ class Encoder(nn.Module):
|
|||
|
||||
|
||||
class Decoder(nn.Module):
|
||||
def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks,
|
||||
attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels,
|
||||
resolution, z_channels, give_pre_end=False, tanh_out=False, use_linear_attn=False,
|
||||
attn_type="vanilla", **ignorekwargs):
|
||||
|
||||
def __init__(self,
|
||||
*,
|
||||
ch,
|
||||
out_ch,
|
||||
ch_mult=(1, 2, 4, 8),
|
||||
num_res_blocks,
|
||||
attn_resolutions,
|
||||
dropout=0.0,
|
||||
resamp_with_conv=True,
|
||||
in_channels,
|
||||
resolution,
|
||||
z_channels,
|
||||
give_pre_end=False,
|
||||
tanh_out=False,
|
||||
use_linear_attn=False,
|
||||
attn_type="vanilla",
|
||||
**ignorekwargs):
|
||||
super().__init__()
|
||||
if use_linear_attn: attn_type = "linear"
|
||||
if use_linear_attn:
|
||||
attn_type = "linear"
|
||||
self.ch = ch
|
||||
self.temb_ch = 0
|
||||
self.num_resolutions = len(ch_mult)
|
||||
|
@ -565,19 +537,14 @@ class Decoder(nn.Module):
|
|||
self.tanh_out = tanh_out
|
||||
|
||||
# compute in_ch_mult, block_in and curr_res at lowest res
|
||||
in_ch_mult = (1,)+tuple(ch_mult)
|
||||
block_in = ch*ch_mult[self.num_resolutions-1]
|
||||
curr_res = resolution // 2**(self.num_resolutions-1)
|
||||
self.z_shape = (1,z_channels,curr_res,curr_res)
|
||||
rank_zero_info("Working with z of shape {} = {} dimensions.".format(
|
||||
self.z_shape, np.prod(self.z_shape)))
|
||||
in_ch_mult = (1,) + tuple(ch_mult)
|
||||
block_in = ch * ch_mult[self.num_resolutions - 1]
|
||||
curr_res = resolution // 2**(self.num_resolutions - 1)
|
||||
self.z_shape = (1, z_channels, curr_res, curr_res)
|
||||
rank_zero_info("Working with z of shape {} = {} dimensions.".format(self.z_shape, np.prod(self.z_shape)))
|
||||
|
||||
# z to block_in
|
||||
self.conv_in = torch.nn.Conv2d(z_channels,
|
||||
block_in,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1)
|
||||
self.conv_in = torch.nn.Conv2d(z_channels, block_in, kernel_size=3, stride=1, padding=1)
|
||||
|
||||
# middle
|
||||
self.mid = nn.Module()
|
||||
|
@ -596,12 +563,13 @@ class Decoder(nn.Module):
|
|||
for i_level in reversed(range(self.num_resolutions)):
|
||||
block = nn.ModuleList()
|
||||
attn = nn.ModuleList()
|
||||
block_out = ch*ch_mult[i_level]
|
||||
for i_block in range(self.num_res_blocks+1):
|
||||
block.append(ResnetBlock(in_channels=block_in,
|
||||
out_channels=block_out,
|
||||
temb_channels=self.temb_ch,
|
||||
dropout=dropout))
|
||||
block_out = ch * ch_mult[i_level]
|
||||
for i_block in range(self.num_res_blocks + 1):
|
||||
block.append(
|
||||
ResnetBlock(in_channels=block_in,
|
||||
out_channels=block_out,
|
||||
temb_channels=self.temb_ch,
|
||||
dropout=dropout))
|
||||
block_in = block_out
|
||||
if curr_res in attn_resolutions:
|
||||
attn.append(make_attn(block_in, attn_type=attn_type))
|
||||
|
@ -611,15 +579,11 @@ class Decoder(nn.Module):
|
|||
if i_level != 0:
|
||||
up.upsample = Upsample(block_in, resamp_with_conv)
|
||||
curr_res = curr_res * 2
|
||||
self.up.insert(0, up) # prepend to get consistent order
|
||||
self.up.insert(0, up) # prepend to get consistent order
|
||||
|
||||
# end
|
||||
self.norm_out = Normalize(block_in)
|
||||
self.conv_out = torch.nn.Conv2d(block_in,
|
||||
out_ch,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1)
|
||||
self.conv_out = torch.nn.Conv2d(block_in, out_ch, kernel_size=3, stride=1, padding=1)
|
||||
|
||||
def forward(self, z):
|
||||
#assert z.shape[1:] == self.z_shape[1:]
|
||||
|
@ -638,7 +602,7 @@ class Decoder(nn.Module):
|
|||
|
||||
# upsampling
|
||||
for i_level in reversed(range(self.num_resolutions)):
|
||||
for i_block in range(self.num_res_blocks+1):
|
||||
for i_block in range(self.num_res_blocks + 1):
|
||||
h = self.up[i_level].block[i_block](h, temb)
|
||||
if len(self.up[i_level].attn) > 0:
|
||||
h = self.up[i_level].attn[i_block](h)
|
||||
|
@ -658,31 +622,24 @@ class Decoder(nn.Module):
|
|||
|
||||
|
||||
class SimpleDecoder(nn.Module):
|
||||
|
||||
def __init__(self, in_channels, out_channels, *args, **kwargs):
|
||||
super().__init__()
|
||||
self.model = nn.ModuleList([nn.Conv2d(in_channels, in_channels, 1),
|
||||
ResnetBlock(in_channels=in_channels,
|
||||
out_channels=2 * in_channels,
|
||||
temb_channels=0, dropout=0.0),
|
||||
ResnetBlock(in_channels=2 * in_channels,
|
||||
out_channels=4 * in_channels,
|
||||
temb_channels=0, dropout=0.0),
|
||||
ResnetBlock(in_channels=4 * in_channels,
|
||||
out_channels=2 * in_channels,
|
||||
temb_channels=0, dropout=0.0),
|
||||
nn.Conv2d(2*in_channels, in_channels, 1),
|
||||
Upsample(in_channels, with_conv=True)])
|
||||
self.model = nn.ModuleList([
|
||||
nn.Conv2d(in_channels, in_channels, 1),
|
||||
ResnetBlock(in_channels=in_channels, out_channels=2 * in_channels, temb_channels=0, dropout=0.0),
|
||||
ResnetBlock(in_channels=2 * in_channels, out_channels=4 * in_channels, temb_channels=0, dropout=0.0),
|
||||
ResnetBlock(in_channels=4 * in_channels, out_channels=2 * in_channels, temb_channels=0, dropout=0.0),
|
||||
nn.Conv2d(2 * in_channels, in_channels, 1),
|
||||
Upsample(in_channels, with_conv=True)
|
||||
])
|
||||
# end
|
||||
self.norm_out = Normalize(in_channels)
|
||||
self.conv_out = torch.nn.Conv2d(in_channels,
|
||||
out_channels,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1)
|
||||
self.conv_out = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
|
||||
|
||||
def forward(self, x):
|
||||
for i, layer in enumerate(self.model):
|
||||
if i in [1,2,3]:
|
||||
if i in [1, 2, 3]:
|
||||
x = layer(x, None)
|
||||
else:
|
||||
x = layer(x)
|
||||
|
@ -694,25 +651,26 @@ class SimpleDecoder(nn.Module):
|
|||
|
||||
|
||||
class UpsampleDecoder(nn.Module):
|
||||
def __init__(self, in_channels, out_channels, ch, num_res_blocks, resolution,
|
||||
ch_mult=(2,2), dropout=0.0):
|
||||
|
||||
def __init__(self, in_channels, out_channels, ch, num_res_blocks, resolution, ch_mult=(2, 2), dropout=0.0):
|
||||
super().__init__()
|
||||
# upsampling
|
||||
self.temb_ch = 0
|
||||
self.num_resolutions = len(ch_mult)
|
||||
self.num_res_blocks = num_res_blocks
|
||||
block_in = in_channels
|
||||
curr_res = resolution // 2 ** (self.num_resolutions - 1)
|
||||
curr_res = resolution // 2**(self.num_resolutions - 1)
|
||||
self.res_blocks = nn.ModuleList()
|
||||
self.upsample_blocks = nn.ModuleList()
|
||||
for i_level in range(self.num_resolutions):
|
||||
res_block = []
|
||||
block_out = ch * ch_mult[i_level]
|
||||
for i_block in range(self.num_res_blocks + 1):
|
||||
res_block.append(ResnetBlock(in_channels=block_in,
|
||||
out_channels=block_out,
|
||||
temb_channels=self.temb_ch,
|
||||
dropout=dropout))
|
||||
res_block.append(
|
||||
ResnetBlock(in_channels=block_in,
|
||||
out_channels=block_out,
|
||||
temb_channels=self.temb_ch,
|
||||
dropout=dropout))
|
||||
block_in = block_out
|
||||
self.res_blocks.append(nn.ModuleList(res_block))
|
||||
if i_level != self.num_resolutions - 1:
|
||||
|
@ -721,11 +679,7 @@ class UpsampleDecoder(nn.Module):
|
|||
|
||||
# end
|
||||
self.norm_out = Normalize(block_in)
|
||||
self.conv_out = torch.nn.Conv2d(block_in,
|
||||
out_channels,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1)
|
||||
self.conv_out = torch.nn.Conv2d(block_in, out_channels, kernel_size=3, stride=1, padding=1)
|
||||
|
||||
def forward(self, x):
|
||||
# upsampling
|
||||
|
@ -742,35 +696,35 @@ class UpsampleDecoder(nn.Module):
|
|||
|
||||
|
||||
class LatentRescaler(nn.Module):
|
||||
|
||||
def __init__(self, factor, in_channels, mid_channels, out_channels, depth=2):
|
||||
super().__init__()
|
||||
# residual block, interpolate, residual block
|
||||
self.factor = factor
|
||||
self.conv_in = nn.Conv2d(in_channels,
|
||||
mid_channels,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1)
|
||||
self.res_block1 = nn.ModuleList([ResnetBlock(in_channels=mid_channels,
|
||||
out_channels=mid_channels,
|
||||
temb_channels=0,
|
||||
dropout=0.0) for _ in range(depth)])
|
||||
self.conv_in = nn.Conv2d(in_channels, mid_channels, kernel_size=3, stride=1, padding=1)
|
||||
self.res_block1 = nn.ModuleList([
|
||||
ResnetBlock(in_channels=mid_channels, out_channels=mid_channels, temb_channels=0, dropout=0.0)
|
||||
for _ in range(depth)
|
||||
])
|
||||
self.attn = AttnBlock(mid_channels)
|
||||
self.res_block2 = nn.ModuleList([ResnetBlock(in_channels=mid_channels,
|
||||
out_channels=mid_channels,
|
||||
temb_channels=0,
|
||||
dropout=0.0) for _ in range(depth)])
|
||||
self.res_block2 = nn.ModuleList([
|
||||
ResnetBlock(in_channels=mid_channels, out_channels=mid_channels, temb_channels=0, dropout=0.0)
|
||||
for _ in range(depth)
|
||||
])
|
||||
|
||||
self.conv_out = nn.Conv2d(mid_channels,
|
||||
out_channels,
|
||||
kernel_size=1,
|
||||
)
|
||||
self.conv_out = nn.Conv2d(
|
||||
mid_channels,
|
||||
out_channels,
|
||||
kernel_size=1,
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.conv_in(x)
|
||||
for block in self.res_block1:
|
||||
x = block(x, None)
|
||||
x = torch.nn.functional.interpolate(x, size=(int(round(x.shape[2]*self.factor)), int(round(x.shape[3]*self.factor))))
|
||||
x = torch.nn.functional.interpolate(x,
|
||||
size=(int(round(x.shape[2] * self.factor)),
|
||||
int(round(x.shape[3] * self.factor))))
|
||||
x = self.attn(x)
|
||||
for block in self.res_block2:
|
||||
x = block(x, None)
|
||||
|
@ -779,17 +733,37 @@ class LatentRescaler(nn.Module):
|
|||
|
||||
|
||||
class MergedRescaleEncoder(nn.Module):
|
||||
def __init__(self, in_channels, ch, resolution, out_ch, num_res_blocks,
|
||||
attn_resolutions, dropout=0.0, resamp_with_conv=True,
|
||||
ch_mult=(1,2,4,8), rescale_factor=1.0, rescale_module_depth=1):
|
||||
|
||||
def __init__(self,
|
||||
in_channels,
|
||||
ch,
|
||||
resolution,
|
||||
out_ch,
|
||||
num_res_blocks,
|
||||
attn_resolutions,
|
||||
dropout=0.0,
|
||||
resamp_with_conv=True,
|
||||
ch_mult=(1, 2, 4, 8),
|
||||
rescale_factor=1.0,
|
||||
rescale_module_depth=1):
|
||||
super().__init__()
|
||||
intermediate_chn = ch * ch_mult[-1]
|
||||
self.encoder = Encoder(in_channels=in_channels, num_res_blocks=num_res_blocks, ch=ch, ch_mult=ch_mult,
|
||||
z_channels=intermediate_chn, double_z=False, resolution=resolution,
|
||||
attn_resolutions=attn_resolutions, dropout=dropout, resamp_with_conv=resamp_with_conv,
|
||||
self.encoder = Encoder(in_channels=in_channels,
|
||||
num_res_blocks=num_res_blocks,
|
||||
ch=ch,
|
||||
ch_mult=ch_mult,
|
||||
z_channels=intermediate_chn,
|
||||
double_z=False,
|
||||
resolution=resolution,
|
||||
attn_resolutions=attn_resolutions,
|
||||
dropout=dropout,
|
||||
resamp_with_conv=resamp_with_conv,
|
||||
out_ch=None)
|
||||
self.rescaler = LatentRescaler(factor=rescale_factor, in_channels=intermediate_chn,
|
||||
mid_channels=intermediate_chn, out_channels=out_ch, depth=rescale_module_depth)
|
||||
self.rescaler = LatentRescaler(factor=rescale_factor,
|
||||
in_channels=intermediate_chn,
|
||||
mid_channels=intermediate_chn,
|
||||
out_channels=out_ch,
|
||||
depth=rescale_module_depth)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.encoder(x)
|
||||
|
@ -798,15 +772,36 @@ class MergedRescaleEncoder(nn.Module):
|
|||
|
||||
|
||||
class MergedRescaleDecoder(nn.Module):
|
||||
def __init__(self, z_channels, out_ch, resolution, num_res_blocks, attn_resolutions, ch, ch_mult=(1,2,4,8),
|
||||
dropout=0.0, resamp_with_conv=True, rescale_factor=1.0, rescale_module_depth=1):
|
||||
|
||||
def __init__(self,
|
||||
z_channels,
|
||||
out_ch,
|
||||
resolution,
|
||||
num_res_blocks,
|
||||
attn_resolutions,
|
||||
ch,
|
||||
ch_mult=(1, 2, 4, 8),
|
||||
dropout=0.0,
|
||||
resamp_with_conv=True,
|
||||
rescale_factor=1.0,
|
||||
rescale_module_depth=1):
|
||||
super().__init__()
|
||||
tmp_chn = z_channels*ch_mult[-1]
|
||||
self.decoder = Decoder(out_ch=out_ch, z_channels=tmp_chn, attn_resolutions=attn_resolutions, dropout=dropout,
|
||||
resamp_with_conv=resamp_with_conv, in_channels=None, num_res_blocks=num_res_blocks,
|
||||
ch_mult=ch_mult, resolution=resolution, ch=ch)
|
||||
self.rescaler = LatentRescaler(factor=rescale_factor, in_channels=z_channels, mid_channels=tmp_chn,
|
||||
out_channels=tmp_chn, depth=rescale_module_depth)
|
||||
tmp_chn = z_channels * ch_mult[-1]
|
||||
self.decoder = Decoder(out_ch=out_ch,
|
||||
z_channels=tmp_chn,
|
||||
attn_resolutions=attn_resolutions,
|
||||
dropout=dropout,
|
||||
resamp_with_conv=resamp_with_conv,
|
||||
in_channels=None,
|
||||
num_res_blocks=num_res_blocks,
|
||||
ch_mult=ch_mult,
|
||||
resolution=resolution,
|
||||
ch=ch)
|
||||
self.rescaler = LatentRescaler(factor=rescale_factor,
|
||||
in_channels=z_channels,
|
||||
mid_channels=tmp_chn,
|
||||
out_channels=tmp_chn,
|
||||
depth=rescale_module_depth)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.rescaler(x)
|
||||
|
@ -815,16 +810,26 @@ class MergedRescaleDecoder(nn.Module):
|
|||
|
||||
|
||||
class Upsampler(nn.Module):
|
||||
|
||||
def __init__(self, in_size, out_size, in_channels, out_channels, ch_mult=2):
|
||||
super().__init__()
|
||||
assert out_size >= in_size
|
||||
num_blocks = int(np.log2(out_size//in_size))+1
|
||||
factor_up = 1.+ (out_size % in_size)
|
||||
rank_zero_info(f"Building {self.__class__.__name__} with in_size: {in_size} --> out_size {out_size} and factor {factor_up}")
|
||||
self.rescaler = LatentRescaler(factor=factor_up, in_channels=in_channels, mid_channels=2*in_channels,
|
||||
num_blocks = int(np.log2(out_size // in_size)) + 1
|
||||
factor_up = 1. + (out_size % in_size)
|
||||
rank_zero_info(
|
||||
f"Building {self.__class__.__name__} with in_size: {in_size} --> out_size {out_size} and factor {factor_up}"
|
||||
)
|
||||
self.rescaler = LatentRescaler(factor=factor_up,
|
||||
in_channels=in_channels,
|
||||
mid_channels=2 * in_channels,
|
||||
out_channels=in_channels)
|
||||
self.decoder = Decoder(out_ch=out_channels, resolution=out_size, z_channels=in_channels, num_res_blocks=2,
|
||||
attn_resolutions=[], in_channels=None, ch=in_channels,
|
||||
self.decoder = Decoder(out_ch=out_channels,
|
||||
resolution=out_size,
|
||||
z_channels=in_channels,
|
||||
num_res_blocks=2,
|
||||
attn_resolutions=[],
|
||||
in_channels=None,
|
||||
ch=in_channels,
|
||||
ch_mult=[ch_mult for _ in range(num_blocks)])
|
||||
|
||||
def forward(self, x):
|
||||
|
@ -834,23 +839,21 @@ class Upsampler(nn.Module):
|
|||
|
||||
|
||||
class Resize(nn.Module):
|
||||
|
||||
def __init__(self, in_channels=None, learned=False, mode="bilinear"):
|
||||
super().__init__()
|
||||
self.with_conv = learned
|
||||
self.mode = mode
|
||||
if self.with_conv:
|
||||
rank_zero_info(f"Note: {self.__class__.__name} uses learned downsampling and will ignore the fixed {mode} mode")
|
||||
rank_zero_info(
|
||||
f"Note: {self.__class__.__name} uses learned downsampling and will ignore the fixed {mode} mode")
|
||||
raise NotImplementedError()
|
||||
assert in_channels is not None
|
||||
# no asymmetric padding in torch conv, must do it ourselves
|
||||
self.conv = torch.nn.Conv2d(in_channels,
|
||||
in_channels,
|
||||
kernel_size=4,
|
||||
stride=2,
|
||||
padding=1)
|
||||
self.conv = torch.nn.Conv2d(in_channels, in_channels, kernel_size=4, stride=2, padding=1)
|
||||
|
||||
def forward(self, x, scale_factor=1.0):
|
||||
if scale_factor==1.0:
|
||||
if scale_factor == 1.0:
|
||||
return x
|
||||
else:
|
||||
x = torch.nn.functional.interpolate(x, mode=self.mode, align_corners=False, scale_factor=scale_factor)
|
||||
|
|
|
@ -106,7 +106,20 @@ def get_parser(**parser_kwargs):
|
|||
nargs="?",
|
||||
help="disable test",
|
||||
)
|
||||
parser.add_argument("-p", "--project", help="name of new or path to existing project")
|
||||
parser.add_argument(
|
||||
"-p",
|
||||
"--project",
|
||||
help="name of new or path to existing project",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-c",
|
||||
"--ckpt",
|
||||
type=str,
|
||||
const=True,
|
||||
default="",
|
||||
nargs="?",
|
||||
help="load pretrained checkpoint from stable AI",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-d",
|
||||
"--debug",
|
||||
|
@ -145,22 +158,7 @@ def get_parser(**parser_kwargs):
|
|||
default=True,
|
||||
help="scale base-lr by ngpu * batch_size * n_accumulate",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--use_fp16",
|
||||
type=str2bool,
|
||||
nargs="?",
|
||||
const=True,
|
||||
default=True,
|
||||
help="whether to use fp16",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--flash",
|
||||
type=str2bool,
|
||||
const=True,
|
||||
default=False,
|
||||
nargs="?",
|
||||
help="whether to use flash attention",
|
||||
)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
|
@ -341,6 +339,12 @@ class SetupCallback(Callback):
|
|||
except FileNotFoundError:
|
||||
pass
|
||||
|
||||
# def on_fit_end(self, trainer, pl_module):
|
||||
# if trainer.global_rank == 0:
|
||||
# ckpt_path = os.path.join(self.ckptdir, "last.ckpt")
|
||||
# rank_zero_info(f"Saving final checkpoint in {ckpt_path}.")
|
||||
# trainer.save_checkpoint(ckpt_path)
|
||||
|
||||
|
||||
class ImageLogger(Callback):
|
||||
|
||||
|
@ -536,6 +540,7 @@ if __name__ == "__main__":
|
|||
"If you want to resume training in a new log folder, "
|
||||
"use -n/--name in combination with --resume_from_checkpoint")
|
||||
if opt.resume:
|
||||
rank_zero_info("Resuming from {}".format(opt.resume))
|
||||
if not os.path.exists(opt.resume):
|
||||
raise ValueError("Cannot find {}".format(opt.resume))
|
||||
if os.path.isfile(opt.resume):
|
||||
|
@ -543,13 +548,13 @@ if __name__ == "__main__":
|
|||
# idx = len(paths)-paths[::-1].index("logs")+1
|
||||
# logdir = "/".join(paths[:idx])
|
||||
logdir = "/".join(paths[:-2])
|
||||
rank_zero_info("logdir: {}".format(logdir))
|
||||
ckpt = opt.resume
|
||||
else:
|
||||
assert os.path.isdir(opt.resume), opt.resume
|
||||
logdir = opt.resume.rstrip("/")
|
||||
ckpt = os.path.join(logdir, "checkpoints", "last.ckpt")
|
||||
|
||||
opt.resume_from_checkpoint = ckpt
|
||||
base_configs = sorted(glob.glob(os.path.join(logdir, "configs/*.yaml")))
|
||||
opt.base = base_configs + opt.base
|
||||
_tmp = logdir.split("/")
|
||||
|
@ -558,6 +563,7 @@ if __name__ == "__main__":
|
|||
if opt.name:
|
||||
name = "_" + opt.name
|
||||
elif opt.base:
|
||||
rank_zero_info("Using base config {}".format(opt.base))
|
||||
cfg_fname = os.path.split(opt.base[0])[-1]
|
||||
cfg_name = os.path.splitext(cfg_fname)[0]
|
||||
name = "_" + cfg_name
|
||||
|
@ -566,6 +572,9 @@ if __name__ == "__main__":
|
|||
nowname = now + name + opt.postfix
|
||||
logdir = os.path.join(opt.logdir, nowname)
|
||||
|
||||
if opt.ckpt:
|
||||
ckpt = opt.ckpt
|
||||
|
||||
ckptdir = os.path.join(logdir, "checkpoints")
|
||||
cfgdir = os.path.join(logdir, "configs")
|
||||
seed_everything(opt.seed)
|
||||
|
@ -582,14 +591,11 @@ if __name__ == "__main__":
|
|||
for k in nondefault_trainer_args(opt):
|
||||
trainer_config[k] = getattr(opt, k)
|
||||
|
||||
print(trainer_config)
|
||||
if not trainer_config["accelerator"] == "gpu":
|
||||
del trainer_config["accelerator"]
|
||||
cpu = True
|
||||
print("Running on CPU")
|
||||
else:
|
||||
cpu = False
|
||||
print("Running on GPU")
|
||||
trainer_opt = argparse.Namespace(**trainer_config)
|
||||
lightning_config.trainer = trainer_config
|
||||
|
||||
|
@ -597,10 +603,12 @@ if __name__ == "__main__":
|
|||
use_fp16 = trainer_config.get("precision", 32) == 16
|
||||
if use_fp16:
|
||||
config.model["params"].update({"use_fp16": True})
|
||||
print("Using FP16 = {}".format(config.model["params"]["use_fp16"]))
|
||||
else:
|
||||
config.model["params"].update({"use_fp16": False})
|
||||
print("Using FP16 = {}".format(config.model["params"]["use_fp16"]))
|
||||
|
||||
if ckpt is not None:
|
||||
config.model["params"].update({"ckpt": ckpt})
|
||||
rank_zero_info("Using ckpt_path = {}".format(config.model["params"]["ckpt"]))
|
||||
|
||||
model = instantiate_from_config(config.model)
|
||||
# trainer and callbacks
|
||||
|
@ -639,7 +647,6 @@ if __name__ == "__main__":
|
|||
# config the strategy, defualt is ddp
|
||||
if "strategy" in trainer_config:
|
||||
strategy_cfg = trainer_config["strategy"]
|
||||
print("Using strategy: {}".format(strategy_cfg["target"]))
|
||||
strategy_cfg["target"] = LIGHTNING_PACK_NAME + strategy_cfg["target"]
|
||||
else:
|
||||
strategy_cfg = {
|
||||
|
@ -648,7 +655,6 @@ if __name__ == "__main__":
|
|||
"find_unused_parameters": False
|
||||
}
|
||||
}
|
||||
print("Using strategy: DDPStrategy")
|
||||
|
||||
trainer_kwargs["strategy"] = instantiate_from_config(strategy_cfg)
|
||||
|
||||
|
@ -664,7 +670,6 @@ if __name__ == "__main__":
|
|||
}
|
||||
}
|
||||
if hasattr(model, "monitor"):
|
||||
print(f"Monitoring {model.monitor} as checkpoint metric.")
|
||||
default_modelckpt_cfg["params"]["monitor"] = model.monitor
|
||||
default_modelckpt_cfg["params"]["save_top_k"] = 3
|
||||
|
||||
|
@ -673,7 +678,6 @@ if __name__ == "__main__":
|
|||
else:
|
||||
modelckpt_cfg = OmegaConf.create()
|
||||
modelckpt_cfg = OmegaConf.merge(default_modelckpt_cfg, modelckpt_cfg)
|
||||
print(f"Merged modelckpt-cfg: \n{modelckpt_cfg}")
|
||||
if version.parse(pl.__version__) < version.parse('1.4.0'):
|
||||
trainer_kwargs["checkpoint_callback"] = instantiate_from_config(modelckpt_cfg)
|
||||
|
||||
|
@ -710,8 +714,6 @@ if __name__ == "__main__":
|
|||
"target": "main.CUDACallback"
|
||||
},
|
||||
}
|
||||
if version.parse(pl.__version__) >= version.parse('1.4.0'):
|
||||
default_callbacks_cfg.update({'checkpoint_callback': modelckpt_cfg})
|
||||
|
||||
if "callbacks" in lightning_config:
|
||||
callbacks_cfg = lightning_config.callbacks
|
||||
|
@ -737,15 +739,11 @@ if __name__ == "__main__":
|
|||
default_callbacks_cfg.update(default_metrics_over_trainsteps_ckpt_dict)
|
||||
|
||||
callbacks_cfg = OmegaConf.merge(default_callbacks_cfg, callbacks_cfg)
|
||||
if 'ignore_keys_callback' in callbacks_cfg and hasattr(trainer_opt, 'resume_from_checkpoint'):
|
||||
callbacks_cfg.ignore_keys_callback.params['ckpt_path'] = trainer_opt.resume_from_checkpoint
|
||||
elif 'ignore_keys_callback' in callbacks_cfg:
|
||||
del callbacks_cfg['ignore_keys_callback']
|
||||
|
||||
trainer_kwargs["callbacks"] = [instantiate_from_config(callbacks_cfg[k]) for k in callbacks_cfg]
|
||||
|
||||
trainer = Trainer.from_argparse_args(trainer_opt, **trainer_kwargs)
|
||||
trainer.logdir = logdir ###
|
||||
trainer.logdir = logdir
|
||||
|
||||
# data
|
||||
data = instantiate_from_config(config.data)
|
||||
|
@ -754,9 +752,9 @@ if __name__ == "__main__":
|
|||
# lightning still takes care of proper multiprocessing though
|
||||
data.prepare_data()
|
||||
data.setup()
|
||||
print("#### Data #####")
|
||||
|
||||
for k in data.datasets:
|
||||
print(f"{k}, {data.datasets[k].__class__.__name__}, {len(data.datasets[k])}")
|
||||
rank_zero_info(f"{k}, {data.datasets[k].__class__.__name__}, {len(data.datasets[k])}")
|
||||
|
||||
# configure learning rate
|
||||
bs, base_lr = config.data.params.batch_size, config.model.base_learning_rate
|
||||
|
@ -768,17 +766,17 @@ if __name__ == "__main__":
|
|||
accumulate_grad_batches = lightning_config.trainer.accumulate_grad_batches
|
||||
else:
|
||||
accumulate_grad_batches = 1
|
||||
print(f"accumulate_grad_batches = {accumulate_grad_batches}")
|
||||
rank_zero_info(f"accumulate_grad_batches = {accumulate_grad_batches}")
|
||||
lightning_config.trainer.accumulate_grad_batches = accumulate_grad_batches
|
||||
if opt.scale_lr:
|
||||
model.learning_rate = accumulate_grad_batches * ngpu * bs * base_lr
|
||||
print(
|
||||
rank_zero_info(
|
||||
"Setting learning rate to {:.2e} = {} (accumulate_grad_batches) * {} (num_gpus) * {} (batchsize) * {:.2e} (base_lr)"
|
||||
.format(model.learning_rate, accumulate_grad_batches, ngpu, bs, base_lr))
|
||||
else:
|
||||
model.learning_rate = base_lr
|
||||
print("++++ NOT USING LR SCALING ++++")
|
||||
print(f"Setting learning rate to {model.learning_rate:.2e}")
|
||||
rank_zero_info("++++ NOT USING LR SCALING ++++")
|
||||
rank_zero_info(f"Setting learning rate to {model.learning_rate:.2e}")
|
||||
|
||||
# allow checkpointing via USR1
|
||||
def melk(*args, **kwargs):
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
python scripts/txt2img.py --prompt "Teyvat, Name:Layla, Element: Cryo, Weapon:Sword, Region:Sumeru, Model type:Medium Female, Description:a woman in a blue outfit holding a sword" --plms \
|
||||
python scripts/txt2img.py --prompt "Teyvat, Medium Female, a woman in a blue outfit holding a sword" --plms \
|
||||
--outdir ./output \
|
||||
--ckpt /tmp/2022-11-18T16-38-46_train_colossalai/checkpoints/last.ckpt \
|
||||
--config /tmp/2022-11-18T16-38-46_train_colossalai/configs/2022-11-18T16-38-46-project.yaml \
|
||||
--ckpt checkpoints/last.ckpt \
|
||||
--config configs/2023-02-02T18-06-14-project.yaml \
|
||||
--n_samples 4
|
||||
|
|
|
@ -2,4 +2,4 @@ HF_DATASETS_OFFLINE=1
|
|||
TRANSFORMERS_OFFLINE=1
|
||||
DIFFUSERS_OFFLINE=1
|
||||
|
||||
python main.py --logdir /tmp -t -b configs/train_colossalai.yaml
|
||||
python main.py --logdir /tmp --train --base configs/Teyvat/train_colossalai_teyvat.yaml --ckpt diffuser_root_dir/512-base-ema.ckpt
|
||||
|
|
Loading…
Reference in New Issue