Merge pull request #2561 from Fazziekey/v2

bug/fix diffusion ckpt problem
pull/2567/head
Fazzie-Maqianli 2023-02-03 15:42:49 +08:00 committed by GitHub
commit 79079a9d0c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 831 additions and 658 deletions

View File

@ -53,27 +53,33 @@ You can also update an existing [latent diffusion](https://github.com/CompVis/la
```
conda install pytorch==1.12.1 torchvision==0.13.1 torchaudio==0.12.1 cudatoolkit=11.3 -c pytorch
pip install transformers==4.19.2 diffusers invisible-watermark
pip install -e .
```
#### Step 2: install lightning
Install Lightning version later than 2022.01.04. We suggest you install lightning from source.
##### From Source
```
git clone https://github.com/Lightning-AI/lightning.git
pip install -r requirements.txt
python setup.py install
```
##### From pip
```
pip install pytorch-lightning
```
#### Step 3:Install [Colossal-AI](https://colossalai.org/download/) From Our Official Website
##### From pip
For example, you can install v0.1.12 from our official website.
For example, you can install v0.2.0 from our official website.
```
pip install colossalai==0.1.12+torch1.12cu11.3 -f https://release.colossalai.org
pip install colossalai==0.2.0+torch1.12cu11.3 -f https://release.colossalai.org
```
##### From source
@ -133,10 +139,9 @@ It is important for you to configure your volume mapping in order to get the bes
3. **Optional**, if you encounter any problem stating that shared memory is insufficient inside container, please add `-v /dev/shm:/dev/shm` to your `docker run` command.
## Download the model checkpoint from pretrained
### stable-diffusion-v2-base
### stable-diffusion-v2-base(Recommand)
```
wget https://huggingface.co/stabilityai/stable-diffusion-2-base/resolve/main/512-base-ema.ckpt
@ -144,8 +149,6 @@ wget https://huggingface.co/stabilityai/stable-diffusion-2-base/resolve/main/512
### stable-diffusion-v1-4
Our default model config use the weight from [CompVis/stable-diffusion-v1-4](https://huggingface.co/CompVis/stable-diffusion-v1-4?text=A+mecha+robot+in+a+favela+in+expressionist+style)
```
git lfs install
git clone https://huggingface.co/CompVis/stable-diffusion-v1-4
@ -153,8 +156,6 @@ git clone https://huggingface.co/CompVis/stable-diffusion-v1-4
### stable-diffusion-v1-5 from runway
If you want to useed the Last [stable-diffusion-v1-5](https://huggingface.co/runwayml/stable-diffusion-v1-5) weight from runwayml
```
git lfs install
git clone https://huggingface.co/runwayml/stable-diffusion-v1-5
@ -171,11 +172,16 @@ We provide the script `train_colossalai.sh` to run the training task with coloss
and can also use `train_ddp.sh` to run the training task with ddp to compare.
In `train_colossalai.sh` the main command is:
```
python main.py --logdir /tmp/ -t -b configs/train_colossalai.yaml
python main.py --logdir /tmp/ --train --base configs/train_colossalai.yaml --ckpt 512-base-ema.ckpt
```
- you can change the `--logdir` to decide where to save the log information and the last checkpoint.
- You can change the `--logdir` to decide where to save the log information and the last checkpoint.
- You will find your ckpt in `logdir/checkpoints` or `logdir/diff_tb/version_0/checkpoints`
- You will find your train config yaml in `logdir/configs`
- You can add the `--ckpt` if you want to load the pretrained model, for example `512-base-ema.ckpt`
- You can change the `--base` to specify the path of config yaml
### Training config
@ -186,7 +192,8 @@ You can change the trainging config in the yaml file
- precision: the precision type used in training, default 16 (fp16), you must use fp16 if you want to apply colossalai
- more information about the configuration of ColossalAIStrategy can be found [here](https://pytorch-lightning.readthedocs.io/en/latest/advanced/model_parallel.html#colossal-ai)
## Finetune Example (Work In Progress)
## Finetune Example
### Training on Teyvat Datasets
We provide the finetuning example on [Teyvat](https://huggingface.co/datasets/Fazzie/Teyvat) dataset, which is create by BLIP generated captions.
@ -201,8 +208,8 @@ you can get yout training last.ckpt and train config.yaml in your `--logdir`, an
```
python scripts/txt2img.py --prompt "a photograph of an astronaut riding a horse" --plms
--outdir ./output \
--config path/to/logdir/checkpoints/last.ckpt \
--ckpt /path/to/logdir/configs/project.yaml \
--ckpt path/to/logdir/checkpoints/last.ckpt \
--config /path/to/logdir/configs/project.yaml \
```
```commandline

View File

@ -6,6 +6,7 @@ model:
linear_start: 0.00085
linear_end: 0.0120
num_timesteps_cond: 1
ckpt: None # use ckpt path
log_every_t: 200
timesteps: 1000
first_stage_key: image
@ -16,7 +17,7 @@ model:
conditioning_key: crossattn
monitor: val/loss_simple_ema
scale_factor: 0.18215
use_ema: False # we set this to false because this is an inference only config
use_ema: False
scheduler_config: # 10000 warmup steps
target: ldm.lr_scheduler.LambdaLinearScheduler

File diff suppressed because it is too large Load Diff

View File

@ -1,10 +1,11 @@
# pytorch_diffusion + derived encoder decoder
import math
from typing import Any, Optional
import numpy as np
import torch
import torch.nn as nn
import numpy as np
from einops import rearrange
from typing import Optional, Any
try:
from lightning.pytorch.utilities import rank_zero_info
@ -38,14 +39,14 @@ def get_timestep_embedding(timesteps, embedding_dim):
emb = emb.to(device=timesteps.device)
emb = timesteps.float()[:, None] * emb[None, :]
emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
if embedding_dim % 2 == 1: # zero pad
emb = torch.nn.functional.pad(emb, (0,1,0,0))
if embedding_dim % 2 == 1: # zero pad
emb = torch.nn.functional.pad(emb, (0, 1, 0, 0))
return emb
def nonlinearity(x):
# swish
return x*torch.sigmoid(x)
return x * torch.sigmoid(x)
def Normalize(in_channels, num_groups=32):
@ -53,15 +54,12 @@ def Normalize(in_channels, num_groups=32):
class Upsample(nn.Module):
def __init__(self, in_channels, with_conv):
super().__init__()
self.with_conv = with_conv
if self.with_conv:
self.conv = torch.nn.Conv2d(in_channels,
in_channels,
kernel_size=3,
stride=1,
padding=1)
self.conv = torch.nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1)
def forward(self, x):
x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")
@ -71,20 +69,17 @@ class Upsample(nn.Module):
class Downsample(nn.Module):
def __init__(self, in_channels, with_conv):
super().__init__()
self.with_conv = with_conv
if self.with_conv:
# no asymmetric padding in torch conv, must do it ourselves
self.conv = torch.nn.Conv2d(in_channels,
in_channels,
kernel_size=3,
stride=2,
padding=0)
self.conv = torch.nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=2, padding=0)
def forward(self, x):
if self.with_conv:
pad = (0,1,0,1)
pad = (0, 1, 0, 1)
x = torch.nn.functional.pad(x, pad, mode="constant", value=0)
x = self.conv(x)
else:
@ -93,8 +88,8 @@ class Downsample(nn.Module):
class ResnetBlock(nn.Module):
def __init__(self, *, in_channels, out_channels=None, conv_shortcut=False,
dropout, temb_channels=512):
def __init__(self, *, in_channels, out_channels=None, conv_shortcut=False, dropout, temb_channels=512):
super().__init__()
self.in_channels = in_channels
out_channels = in_channels if out_channels is None else out_channels
@ -102,34 +97,17 @@ class ResnetBlock(nn.Module):
self.use_conv_shortcut = conv_shortcut
self.norm1 = Normalize(in_channels)
self.conv1 = torch.nn.Conv2d(in_channels,
out_channels,
kernel_size=3,
stride=1,
padding=1)
self.conv1 = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
if temb_channels > 0:
self.temb_proj = torch.nn.Linear(temb_channels,
out_channels)
self.temb_proj = torch.nn.Linear(temb_channels, out_channels)
self.norm2 = Normalize(out_channels)
self.dropout = torch.nn.Dropout(dropout)
self.conv2 = torch.nn.Conv2d(out_channels,
out_channels,
kernel_size=3,
stride=1,
padding=1)
self.conv2 = torch.nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
if self.in_channels != self.out_channels:
if self.use_conv_shortcut:
self.conv_shortcut = torch.nn.Conv2d(in_channels,
out_channels,
kernel_size=3,
stride=1,
padding=1)
self.conv_shortcut = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
else:
self.nin_shortcut = torch.nn.Conv2d(in_channels,
out_channels,
kernel_size=1,
stride=1,
padding=0)
self.nin_shortcut = torch.nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
def forward(self, x, temb):
h = x
@ -138,7 +116,7 @@ class ResnetBlock(nn.Module):
h = self.conv1(h)
if temb is not None:
h = h + self.temb_proj(nonlinearity(temb))[:,:,None,None]
h = h + self.temb_proj(nonlinearity(temb))[:, :, None, None]
h = self.norm2(h)
h = nonlinearity(h)
@ -151,35 +129,20 @@ class ResnetBlock(nn.Module):
else:
x = self.nin_shortcut(x)
return x+h
return x + h
class AttnBlock(nn.Module):
def __init__(self, in_channels):
super().__init__()
self.in_channels = in_channels
self.norm = Normalize(in_channels)
self.q = torch.nn.Conv2d(in_channels,
in_channels,
kernel_size=1,
stride=1,
padding=0)
self.k = torch.nn.Conv2d(in_channels,
in_channels,
kernel_size=1,
stride=1,
padding=0)
self.v = torch.nn.Conv2d(in_channels,
in_channels,
kernel_size=1,
stride=1,
padding=0)
self.proj_out = torch.nn.Conv2d(in_channels,
in_channels,
kernel_size=1,
stride=1,
padding=0)
self.q = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
self.k = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
self.v = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
self.proj_out = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
def forward(self, x):
h_ = x
@ -189,23 +152,24 @@ class AttnBlock(nn.Module):
v = self.v(h_)
# compute attention
b,c,h,w = q.shape
q = q.reshape(b,c,h*w)
q = q.permute(0,2,1) # b,hw,c
k = k.reshape(b,c,h*w) # b,c,hw
w_ = torch.bmm(q,k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j]
b, c, h, w = q.shape
q = q.reshape(b, c, h * w)
q = q.permute(0, 2, 1) # b,hw,c
k = k.reshape(b, c, h * w) # b,c,hw
w_ = torch.bmm(q, k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j]
w_ = w_ * (int(c)**(-0.5))
w_ = torch.nn.functional.softmax(w_, dim=2)
# attend to values
v = v.reshape(b,c,h*w)
w_ = w_.permute(0,2,1) # b,hw,hw (first hw of k, second of q)
h_ = torch.bmm(v,w_) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j]
h_ = h_.reshape(b,c,h,w)
v = v.reshape(b, c, h * w)
w_ = w_.permute(0, 2, 1) # b,hw,hw (first hw of k, second of q)
h_ = torch.bmm(v, w_) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j]
h_ = h_.reshape(b, c, h, w)
h_ = self.proj_out(h_)
return x+h_
return x + h_
class MemoryEfficientAttnBlock(nn.Module):
"""
@ -213,32 +177,17 @@ class MemoryEfficientAttnBlock(nn.Module):
see https://github.com/MatthieuTPHR/diffusers/blob/d80b531ff8060ec1ea982b65a1b8df70f73aa67c/src/diffusers/models/attention.py#L223
Note: this is a single-head self-attention operation
"""
#
def __init__(self, in_channels):
super().__init__()
self.in_channels = in_channels
self.norm = Normalize(in_channels)
self.q = torch.nn.Conv2d(in_channels,
in_channels,
kernel_size=1,
stride=1,
padding=0)
self.k = torch.nn.Conv2d(in_channels,
in_channels,
kernel_size=1,
stride=1,
padding=0)
self.v = torch.nn.Conv2d(in_channels,
in_channels,
kernel_size=1,
stride=1,
padding=0)
self.proj_out = torch.nn.Conv2d(in_channels,
in_channels,
kernel_size=1,
stride=1,
padding=0)
self.q = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
self.k = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
self.v = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
self.proj_out = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
self.attention_op: Optional[Any] = None
def forward(self, x):
@ -253,27 +202,20 @@ class MemoryEfficientAttnBlock(nn.Module):
q, k, v = map(lambda x: rearrange(x, 'b c h w -> b (h w) c'), (q, k, v))
q, k, v = map(
lambda t: t.unsqueeze(3)
.reshape(B, t.shape[1], 1, C)
.permute(0, 2, 1, 3)
.reshape(B * 1, t.shape[1], C)
.contiguous(),
lambda t: t.unsqueeze(3).reshape(B, t.shape[1], 1, C).permute(0, 2, 1, 3).reshape(B * 1, t.shape[1], C).
contiguous(),
(q, k, v),
)
out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None, op=self.attention_op)
out = (
out.unsqueeze(0)
.reshape(B, 1, out.shape[1], C)
.permute(0, 2, 1, 3)
.reshape(B, out.shape[1], C)
)
out = (out.unsqueeze(0).reshape(B, 1, out.shape[1], C).permute(0, 2, 1, 3).reshape(B, out.shape[1], C))
out = rearrange(out, 'b (h w) c -> b c h w', b=B, h=H, w=W, c=C)
out = self.proj_out(out)
return x+out
return x + out
class MemoryEfficientCrossAttentionWrapper(MemoryEfficientCrossAttention):
def forward(self, x, context=None, mask=None):
b, c, h, w = x.shape
x = rearrange(x, 'b c h w -> b (h w) c')
@ -283,10 +225,10 @@ class MemoryEfficientCrossAttentionWrapper(MemoryEfficientCrossAttention):
def make_attn(in_channels, attn_type="vanilla", attn_kwargs=None):
assert attn_type in ["vanilla", "vanilla-xformers", "memory-efficient-cross-attn", "linear", "none"], f'attn_type {attn_type} unknown'
assert attn_type in ["vanilla", "vanilla-xformers", "memory-efficient-cross-attn", "linear",
"none"], f'attn_type {attn_type} unknown'
if XFORMERS_IS_AVAILBLE and attn_type == "vanilla":
attn_type = "vanilla-xformers"
rank_zero_info(f"making attention of type '{attn_type}' with {in_channels} in_channels")
if attn_type == "vanilla":
assert attn_kwargs is None
return AttnBlock(in_channels)
@ -303,13 +245,26 @@ def make_attn(in_channels, attn_type="vanilla", attn_kwargs=None):
class Model(nn.Module):
def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks,
attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels,
resolution, use_timestep=True, use_linear_attn=False, attn_type="vanilla"):
def __init__(self,
*,
ch,
out_ch,
ch_mult=(1, 2, 4, 8),
num_res_blocks,
attn_resolutions,
dropout=0.0,
resamp_with_conv=True,
in_channels,
resolution,
use_timestep=True,
use_linear_attn=False,
attn_type="vanilla"):
super().__init__()
if use_linear_attn: attn_type = "linear"
if use_linear_attn:
attn_type = "linear"
self.ch = ch
self.temb_ch = self.ch*4
self.temb_ch = self.ch * 4
self.num_resolutions = len(ch_mult)
self.num_res_blocks = num_res_blocks
self.resolution = resolution
@ -320,39 +275,34 @@ class Model(nn.Module):
# timestep embedding
self.temb = nn.Module()
self.temb.dense = nn.ModuleList([
torch.nn.Linear(self.ch,
self.temb_ch),
torch.nn.Linear(self.temb_ch,
self.temb_ch),
torch.nn.Linear(self.ch, self.temb_ch),
torch.nn.Linear(self.temb_ch, self.temb_ch),
])
# downsampling
self.conv_in = torch.nn.Conv2d(in_channels,
self.ch,
kernel_size=3,
stride=1,
padding=1)
self.conv_in = torch.nn.Conv2d(in_channels, self.ch, kernel_size=3, stride=1, padding=1)
curr_res = resolution
in_ch_mult = (1,)+tuple(ch_mult)
in_ch_mult = (1,) + tuple(ch_mult)
self.down = nn.ModuleList()
for i_level in range(self.num_resolutions):
block = nn.ModuleList()
attn = nn.ModuleList()
block_in = ch*in_ch_mult[i_level]
block_out = ch*ch_mult[i_level]
block_in = ch * in_ch_mult[i_level]
block_out = ch * ch_mult[i_level]
for i_block in range(self.num_res_blocks):
block.append(ResnetBlock(in_channels=block_in,
out_channels=block_out,
temb_channels=self.temb_ch,
dropout=dropout))
block.append(
ResnetBlock(in_channels=block_in,
out_channels=block_out,
temb_channels=self.temb_ch,
dropout=dropout))
block_in = block_out
if curr_res in attn_resolutions:
attn.append(make_attn(block_in, attn_type=attn_type))
down = nn.Module()
down.block = block
down.attn = attn
if i_level != self.num_resolutions-1:
if i_level != self.num_resolutions - 1:
down.downsample = Downsample(block_in, resamp_with_conv)
curr_res = curr_res // 2
self.down.append(down)
@ -374,15 +324,16 @@ class Model(nn.Module):
for i_level in reversed(range(self.num_resolutions)):
block = nn.ModuleList()
attn = nn.ModuleList()
block_out = ch*ch_mult[i_level]
skip_in = ch*ch_mult[i_level]
for i_block in range(self.num_res_blocks+1):
block_out = ch * ch_mult[i_level]
skip_in = ch * ch_mult[i_level]
for i_block in range(self.num_res_blocks + 1):
if i_block == self.num_res_blocks:
skip_in = ch*in_ch_mult[i_level]
block.append(ResnetBlock(in_channels=block_in+skip_in,
out_channels=block_out,
temb_channels=self.temb_ch,
dropout=dropout))
skip_in = ch * in_ch_mult[i_level]
block.append(
ResnetBlock(in_channels=block_in + skip_in,
out_channels=block_out,
temb_channels=self.temb_ch,
dropout=dropout))
block_in = block_out
if curr_res in attn_resolutions:
attn.append(make_attn(block_in, attn_type=attn_type))
@ -392,15 +343,11 @@ class Model(nn.Module):
if i_level != 0:
up.upsample = Upsample(block_in, resamp_with_conv)
curr_res = curr_res * 2
self.up.insert(0, up) # prepend to get consistent order
self.up.insert(0, up) # prepend to get consistent order
# end
self.norm_out = Normalize(block_in)
self.conv_out = torch.nn.Conv2d(block_in,
out_ch,
kernel_size=3,
stride=1,
padding=1)
self.conv_out = torch.nn.Conv2d(block_in, out_ch, kernel_size=3, stride=1, padding=1)
def forward(self, x, t=None, context=None):
#assert x.shape[2] == x.shape[3] == self.resolution
@ -425,7 +372,7 @@ class Model(nn.Module):
if len(self.down[i_level].attn) > 0:
h = self.down[i_level].attn[i_block](h)
hs.append(h)
if i_level != self.num_resolutions-1:
if i_level != self.num_resolutions - 1:
hs.append(self.down[i_level].downsample(hs[-1]))
# middle
@ -436,9 +383,8 @@ class Model(nn.Module):
# upsampling
for i_level in reversed(range(self.num_resolutions)):
for i_block in range(self.num_res_blocks+1):
h = self.up[i_level].block[i_block](
torch.cat([h, hs.pop()], dim=1), temb)
for i_block in range(self.num_res_blocks + 1):
h = self.up[i_level].block[i_block](torch.cat([h, hs.pop()], dim=1), temb)
if len(self.up[i_level].attn) > 0:
h = self.up[i_level].attn[i_block](h)
if i_level != 0:
@ -455,12 +401,26 @@ class Model(nn.Module):
class Encoder(nn.Module):
def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks,
attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels,
resolution, z_channels, double_z=True, use_linear_attn=False, attn_type="vanilla",
def __init__(self,
*,
ch,
out_ch,
ch_mult=(1, 2, 4, 8),
num_res_blocks,
attn_resolutions,
dropout=0.0,
resamp_with_conv=True,
in_channels,
resolution,
z_channels,
double_z=True,
use_linear_attn=False,
attn_type="vanilla",
**ignore_kwargs):
super().__init__()
if use_linear_attn: attn_type = "linear"
if use_linear_attn:
attn_type = "linear"
self.ch = ch
self.temb_ch = 0
self.num_resolutions = len(ch_mult)
@ -469,33 +429,30 @@ class Encoder(nn.Module):
self.in_channels = in_channels
# downsampling
self.conv_in = torch.nn.Conv2d(in_channels,
self.ch,
kernel_size=3,
stride=1,
padding=1)
self.conv_in = torch.nn.Conv2d(in_channels, self.ch, kernel_size=3, stride=1, padding=1)
curr_res = resolution
in_ch_mult = (1,)+tuple(ch_mult)
in_ch_mult = (1,) + tuple(ch_mult)
self.in_ch_mult = in_ch_mult
self.down = nn.ModuleList()
for i_level in range(self.num_resolutions):
block = nn.ModuleList()
attn = nn.ModuleList()
block_in = ch*in_ch_mult[i_level]
block_out = ch*ch_mult[i_level]
block_in = ch * in_ch_mult[i_level]
block_out = ch * ch_mult[i_level]
for i_block in range(self.num_res_blocks):
block.append(ResnetBlock(in_channels=block_in,
out_channels=block_out,
temb_channels=self.temb_ch,
dropout=dropout))
block.append(
ResnetBlock(in_channels=block_in,
out_channels=block_out,
temb_channels=self.temb_ch,
dropout=dropout))
block_in = block_out
if curr_res in attn_resolutions:
attn.append(make_attn(block_in, attn_type=attn_type))
down = nn.Module()
down.block = block
down.attn = attn
if i_level != self.num_resolutions-1:
if i_level != self.num_resolutions - 1:
down.downsample = Downsample(block_in, resamp_with_conv)
curr_res = curr_res // 2
self.down.append(down)
@ -515,7 +472,7 @@ class Encoder(nn.Module):
# end
self.norm_out = Normalize(block_in)
self.conv_out = torch.nn.Conv2d(block_in,
2*z_channels if double_z else z_channels,
2 * z_channels if double_z else z_channels,
kernel_size=3,
stride=1,
padding=1)
@ -532,7 +489,7 @@ class Encoder(nn.Module):
if len(self.down[i_level].attn) > 0:
h = self.down[i_level].attn[i_block](h)
hs.append(h)
if i_level != self.num_resolutions-1:
if i_level != self.num_resolutions - 1:
hs.append(self.down[i_level].downsample(hs[-1]))
# middle
@ -549,12 +506,27 @@ class Encoder(nn.Module):
class Decoder(nn.Module):
def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks,
attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels,
resolution, z_channels, give_pre_end=False, tanh_out=False, use_linear_attn=False,
attn_type="vanilla", **ignorekwargs):
def __init__(self,
*,
ch,
out_ch,
ch_mult=(1, 2, 4, 8),
num_res_blocks,
attn_resolutions,
dropout=0.0,
resamp_with_conv=True,
in_channels,
resolution,
z_channels,
give_pre_end=False,
tanh_out=False,
use_linear_attn=False,
attn_type="vanilla",
**ignorekwargs):
super().__init__()
if use_linear_attn: attn_type = "linear"
if use_linear_attn:
attn_type = "linear"
self.ch = ch
self.temb_ch = 0
self.num_resolutions = len(ch_mult)
@ -565,19 +537,14 @@ class Decoder(nn.Module):
self.tanh_out = tanh_out
# compute in_ch_mult, block_in and curr_res at lowest res
in_ch_mult = (1,)+tuple(ch_mult)
block_in = ch*ch_mult[self.num_resolutions-1]
curr_res = resolution // 2**(self.num_resolutions-1)
self.z_shape = (1,z_channels,curr_res,curr_res)
rank_zero_info("Working with z of shape {} = {} dimensions.".format(
self.z_shape, np.prod(self.z_shape)))
in_ch_mult = (1,) + tuple(ch_mult)
block_in = ch * ch_mult[self.num_resolutions - 1]
curr_res = resolution // 2**(self.num_resolutions - 1)
self.z_shape = (1, z_channels, curr_res, curr_res)
rank_zero_info("Working with z of shape {} = {} dimensions.".format(self.z_shape, np.prod(self.z_shape)))
# z to block_in
self.conv_in = torch.nn.Conv2d(z_channels,
block_in,
kernel_size=3,
stride=1,
padding=1)
self.conv_in = torch.nn.Conv2d(z_channels, block_in, kernel_size=3, stride=1, padding=1)
# middle
self.mid = nn.Module()
@ -596,12 +563,13 @@ class Decoder(nn.Module):
for i_level in reversed(range(self.num_resolutions)):
block = nn.ModuleList()
attn = nn.ModuleList()
block_out = ch*ch_mult[i_level]
for i_block in range(self.num_res_blocks+1):
block.append(ResnetBlock(in_channels=block_in,
out_channels=block_out,
temb_channels=self.temb_ch,
dropout=dropout))
block_out = ch * ch_mult[i_level]
for i_block in range(self.num_res_blocks + 1):
block.append(
ResnetBlock(in_channels=block_in,
out_channels=block_out,
temb_channels=self.temb_ch,
dropout=dropout))
block_in = block_out
if curr_res in attn_resolutions:
attn.append(make_attn(block_in, attn_type=attn_type))
@ -611,15 +579,11 @@ class Decoder(nn.Module):
if i_level != 0:
up.upsample = Upsample(block_in, resamp_with_conv)
curr_res = curr_res * 2
self.up.insert(0, up) # prepend to get consistent order
self.up.insert(0, up) # prepend to get consistent order
# end
self.norm_out = Normalize(block_in)
self.conv_out = torch.nn.Conv2d(block_in,
out_ch,
kernel_size=3,
stride=1,
padding=1)
self.conv_out = torch.nn.Conv2d(block_in, out_ch, kernel_size=3, stride=1, padding=1)
def forward(self, z):
#assert z.shape[1:] == self.z_shape[1:]
@ -638,7 +602,7 @@ class Decoder(nn.Module):
# upsampling
for i_level in reversed(range(self.num_resolutions)):
for i_block in range(self.num_res_blocks+1):
for i_block in range(self.num_res_blocks + 1):
h = self.up[i_level].block[i_block](h, temb)
if len(self.up[i_level].attn) > 0:
h = self.up[i_level].attn[i_block](h)
@ -658,31 +622,24 @@ class Decoder(nn.Module):
class SimpleDecoder(nn.Module):
def __init__(self, in_channels, out_channels, *args, **kwargs):
super().__init__()
self.model = nn.ModuleList([nn.Conv2d(in_channels, in_channels, 1),
ResnetBlock(in_channels=in_channels,
out_channels=2 * in_channels,
temb_channels=0, dropout=0.0),
ResnetBlock(in_channels=2 * in_channels,
out_channels=4 * in_channels,
temb_channels=0, dropout=0.0),
ResnetBlock(in_channels=4 * in_channels,
out_channels=2 * in_channels,
temb_channels=0, dropout=0.0),
nn.Conv2d(2*in_channels, in_channels, 1),
Upsample(in_channels, with_conv=True)])
self.model = nn.ModuleList([
nn.Conv2d(in_channels, in_channels, 1),
ResnetBlock(in_channels=in_channels, out_channels=2 * in_channels, temb_channels=0, dropout=0.0),
ResnetBlock(in_channels=2 * in_channels, out_channels=4 * in_channels, temb_channels=0, dropout=0.0),
ResnetBlock(in_channels=4 * in_channels, out_channels=2 * in_channels, temb_channels=0, dropout=0.0),
nn.Conv2d(2 * in_channels, in_channels, 1),
Upsample(in_channels, with_conv=True)
])
# end
self.norm_out = Normalize(in_channels)
self.conv_out = torch.nn.Conv2d(in_channels,
out_channels,
kernel_size=3,
stride=1,
padding=1)
self.conv_out = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
def forward(self, x):
for i, layer in enumerate(self.model):
if i in [1,2,3]:
if i in [1, 2, 3]:
x = layer(x, None)
else:
x = layer(x)
@ -694,25 +651,26 @@ class SimpleDecoder(nn.Module):
class UpsampleDecoder(nn.Module):
def __init__(self, in_channels, out_channels, ch, num_res_blocks, resolution,
ch_mult=(2,2), dropout=0.0):
def __init__(self, in_channels, out_channels, ch, num_res_blocks, resolution, ch_mult=(2, 2), dropout=0.0):
super().__init__()
# upsampling
self.temb_ch = 0
self.num_resolutions = len(ch_mult)
self.num_res_blocks = num_res_blocks
block_in = in_channels
curr_res = resolution // 2 ** (self.num_resolutions - 1)
curr_res = resolution // 2**(self.num_resolutions - 1)
self.res_blocks = nn.ModuleList()
self.upsample_blocks = nn.ModuleList()
for i_level in range(self.num_resolutions):
res_block = []
block_out = ch * ch_mult[i_level]
for i_block in range(self.num_res_blocks + 1):
res_block.append(ResnetBlock(in_channels=block_in,
out_channels=block_out,
temb_channels=self.temb_ch,
dropout=dropout))
res_block.append(
ResnetBlock(in_channels=block_in,
out_channels=block_out,
temb_channels=self.temb_ch,
dropout=dropout))
block_in = block_out
self.res_blocks.append(nn.ModuleList(res_block))
if i_level != self.num_resolutions - 1:
@ -721,11 +679,7 @@ class UpsampleDecoder(nn.Module):
# end
self.norm_out = Normalize(block_in)
self.conv_out = torch.nn.Conv2d(block_in,
out_channels,
kernel_size=3,
stride=1,
padding=1)
self.conv_out = torch.nn.Conv2d(block_in, out_channels, kernel_size=3, stride=1, padding=1)
def forward(self, x):
# upsampling
@ -742,35 +696,35 @@ class UpsampleDecoder(nn.Module):
class LatentRescaler(nn.Module):
def __init__(self, factor, in_channels, mid_channels, out_channels, depth=2):
super().__init__()
# residual block, interpolate, residual block
self.factor = factor
self.conv_in = nn.Conv2d(in_channels,
mid_channels,
kernel_size=3,
stride=1,
padding=1)
self.res_block1 = nn.ModuleList([ResnetBlock(in_channels=mid_channels,
out_channels=mid_channels,
temb_channels=0,
dropout=0.0) for _ in range(depth)])
self.conv_in = nn.Conv2d(in_channels, mid_channels, kernel_size=3, stride=1, padding=1)
self.res_block1 = nn.ModuleList([
ResnetBlock(in_channels=mid_channels, out_channels=mid_channels, temb_channels=0, dropout=0.0)
for _ in range(depth)
])
self.attn = AttnBlock(mid_channels)
self.res_block2 = nn.ModuleList([ResnetBlock(in_channels=mid_channels,
out_channels=mid_channels,
temb_channels=0,
dropout=0.0) for _ in range(depth)])
self.res_block2 = nn.ModuleList([
ResnetBlock(in_channels=mid_channels, out_channels=mid_channels, temb_channels=0, dropout=0.0)
for _ in range(depth)
])
self.conv_out = nn.Conv2d(mid_channels,
out_channels,
kernel_size=1,
)
self.conv_out = nn.Conv2d(
mid_channels,
out_channels,
kernel_size=1,
)
def forward(self, x):
x = self.conv_in(x)
for block in self.res_block1:
x = block(x, None)
x = torch.nn.functional.interpolate(x, size=(int(round(x.shape[2]*self.factor)), int(round(x.shape[3]*self.factor))))
x = torch.nn.functional.interpolate(x,
size=(int(round(x.shape[2] * self.factor)),
int(round(x.shape[3] * self.factor))))
x = self.attn(x)
for block in self.res_block2:
x = block(x, None)
@ -779,17 +733,37 @@ class LatentRescaler(nn.Module):
class MergedRescaleEncoder(nn.Module):
def __init__(self, in_channels, ch, resolution, out_ch, num_res_blocks,
attn_resolutions, dropout=0.0, resamp_with_conv=True,
ch_mult=(1,2,4,8), rescale_factor=1.0, rescale_module_depth=1):
def __init__(self,
in_channels,
ch,
resolution,
out_ch,
num_res_blocks,
attn_resolutions,
dropout=0.0,
resamp_with_conv=True,
ch_mult=(1, 2, 4, 8),
rescale_factor=1.0,
rescale_module_depth=1):
super().__init__()
intermediate_chn = ch * ch_mult[-1]
self.encoder = Encoder(in_channels=in_channels, num_res_blocks=num_res_blocks, ch=ch, ch_mult=ch_mult,
z_channels=intermediate_chn, double_z=False, resolution=resolution,
attn_resolutions=attn_resolutions, dropout=dropout, resamp_with_conv=resamp_with_conv,
self.encoder = Encoder(in_channels=in_channels,
num_res_blocks=num_res_blocks,
ch=ch,
ch_mult=ch_mult,
z_channels=intermediate_chn,
double_z=False,
resolution=resolution,
attn_resolutions=attn_resolutions,
dropout=dropout,
resamp_with_conv=resamp_with_conv,
out_ch=None)
self.rescaler = LatentRescaler(factor=rescale_factor, in_channels=intermediate_chn,
mid_channels=intermediate_chn, out_channels=out_ch, depth=rescale_module_depth)
self.rescaler = LatentRescaler(factor=rescale_factor,
in_channels=intermediate_chn,
mid_channels=intermediate_chn,
out_channels=out_ch,
depth=rescale_module_depth)
def forward(self, x):
x = self.encoder(x)
@ -798,15 +772,36 @@ class MergedRescaleEncoder(nn.Module):
class MergedRescaleDecoder(nn.Module):
def __init__(self, z_channels, out_ch, resolution, num_res_blocks, attn_resolutions, ch, ch_mult=(1,2,4,8),
dropout=0.0, resamp_with_conv=True, rescale_factor=1.0, rescale_module_depth=1):
def __init__(self,
z_channels,
out_ch,
resolution,
num_res_blocks,
attn_resolutions,
ch,
ch_mult=(1, 2, 4, 8),
dropout=0.0,
resamp_with_conv=True,
rescale_factor=1.0,
rescale_module_depth=1):
super().__init__()
tmp_chn = z_channels*ch_mult[-1]
self.decoder = Decoder(out_ch=out_ch, z_channels=tmp_chn, attn_resolutions=attn_resolutions, dropout=dropout,
resamp_with_conv=resamp_with_conv, in_channels=None, num_res_blocks=num_res_blocks,
ch_mult=ch_mult, resolution=resolution, ch=ch)
self.rescaler = LatentRescaler(factor=rescale_factor, in_channels=z_channels, mid_channels=tmp_chn,
out_channels=tmp_chn, depth=rescale_module_depth)
tmp_chn = z_channels * ch_mult[-1]
self.decoder = Decoder(out_ch=out_ch,
z_channels=tmp_chn,
attn_resolutions=attn_resolutions,
dropout=dropout,
resamp_with_conv=resamp_with_conv,
in_channels=None,
num_res_blocks=num_res_blocks,
ch_mult=ch_mult,
resolution=resolution,
ch=ch)
self.rescaler = LatentRescaler(factor=rescale_factor,
in_channels=z_channels,
mid_channels=tmp_chn,
out_channels=tmp_chn,
depth=rescale_module_depth)
def forward(self, x):
x = self.rescaler(x)
@ -815,16 +810,26 @@ class MergedRescaleDecoder(nn.Module):
class Upsampler(nn.Module):
def __init__(self, in_size, out_size, in_channels, out_channels, ch_mult=2):
super().__init__()
assert out_size >= in_size
num_blocks = int(np.log2(out_size//in_size))+1
factor_up = 1.+ (out_size % in_size)
rank_zero_info(f"Building {self.__class__.__name__} with in_size: {in_size} --> out_size {out_size} and factor {factor_up}")
self.rescaler = LatentRescaler(factor=factor_up, in_channels=in_channels, mid_channels=2*in_channels,
num_blocks = int(np.log2(out_size // in_size)) + 1
factor_up = 1. + (out_size % in_size)
rank_zero_info(
f"Building {self.__class__.__name__} with in_size: {in_size} --> out_size {out_size} and factor {factor_up}"
)
self.rescaler = LatentRescaler(factor=factor_up,
in_channels=in_channels,
mid_channels=2 * in_channels,
out_channels=in_channels)
self.decoder = Decoder(out_ch=out_channels, resolution=out_size, z_channels=in_channels, num_res_blocks=2,
attn_resolutions=[], in_channels=None, ch=in_channels,
self.decoder = Decoder(out_ch=out_channels,
resolution=out_size,
z_channels=in_channels,
num_res_blocks=2,
attn_resolutions=[],
in_channels=None,
ch=in_channels,
ch_mult=[ch_mult for _ in range(num_blocks)])
def forward(self, x):
@ -834,23 +839,21 @@ class Upsampler(nn.Module):
class Resize(nn.Module):
def __init__(self, in_channels=None, learned=False, mode="bilinear"):
super().__init__()
self.with_conv = learned
self.mode = mode
if self.with_conv:
rank_zero_info(f"Note: {self.__class__.__name} uses learned downsampling and will ignore the fixed {mode} mode")
rank_zero_info(
f"Note: {self.__class__.__name} uses learned downsampling and will ignore the fixed {mode} mode")
raise NotImplementedError()
assert in_channels is not None
# no asymmetric padding in torch conv, must do it ourselves
self.conv = torch.nn.Conv2d(in_channels,
in_channels,
kernel_size=4,
stride=2,
padding=1)
self.conv = torch.nn.Conv2d(in_channels, in_channels, kernel_size=4, stride=2, padding=1)
def forward(self, x, scale_factor=1.0):
if scale_factor==1.0:
if scale_factor == 1.0:
return x
else:
x = torch.nn.functional.interpolate(x, mode=self.mode, align_corners=False, scale_factor=scale_factor)

View File

@ -106,7 +106,20 @@ def get_parser(**parser_kwargs):
nargs="?",
help="disable test",
)
parser.add_argument("-p", "--project", help="name of new or path to existing project")
parser.add_argument(
"-p",
"--project",
help="name of new or path to existing project",
)
parser.add_argument(
"-c",
"--ckpt",
type=str,
const=True,
default="",
nargs="?",
help="load pretrained checkpoint from stable AI",
)
parser.add_argument(
"-d",
"--debug",
@ -145,22 +158,7 @@ def get_parser(**parser_kwargs):
default=True,
help="scale base-lr by ngpu * batch_size * n_accumulate",
)
parser.add_argument(
"--use_fp16",
type=str2bool,
nargs="?",
const=True,
default=True,
help="whether to use fp16",
)
parser.add_argument(
"--flash",
type=str2bool,
const=True,
default=False,
nargs="?",
help="whether to use flash attention",
)
return parser
@ -341,6 +339,12 @@ class SetupCallback(Callback):
except FileNotFoundError:
pass
# def on_fit_end(self, trainer, pl_module):
# if trainer.global_rank == 0:
# ckpt_path = os.path.join(self.ckptdir, "last.ckpt")
# rank_zero_info(f"Saving final checkpoint in {ckpt_path}.")
# trainer.save_checkpoint(ckpt_path)
class ImageLogger(Callback):
@ -536,6 +540,7 @@ if __name__ == "__main__":
"If you want to resume training in a new log folder, "
"use -n/--name in combination with --resume_from_checkpoint")
if opt.resume:
rank_zero_info("Resuming from {}".format(opt.resume))
if not os.path.exists(opt.resume):
raise ValueError("Cannot find {}".format(opt.resume))
if os.path.isfile(opt.resume):
@ -543,13 +548,13 @@ if __name__ == "__main__":
# idx = len(paths)-paths[::-1].index("logs")+1
# logdir = "/".join(paths[:idx])
logdir = "/".join(paths[:-2])
rank_zero_info("logdir: {}".format(logdir))
ckpt = opt.resume
else:
assert os.path.isdir(opt.resume), opt.resume
logdir = opt.resume.rstrip("/")
ckpt = os.path.join(logdir, "checkpoints", "last.ckpt")
opt.resume_from_checkpoint = ckpt
base_configs = sorted(glob.glob(os.path.join(logdir, "configs/*.yaml")))
opt.base = base_configs + opt.base
_tmp = logdir.split("/")
@ -558,6 +563,7 @@ if __name__ == "__main__":
if opt.name:
name = "_" + opt.name
elif opt.base:
rank_zero_info("Using base config {}".format(opt.base))
cfg_fname = os.path.split(opt.base[0])[-1]
cfg_name = os.path.splitext(cfg_fname)[0]
name = "_" + cfg_name
@ -566,6 +572,9 @@ if __name__ == "__main__":
nowname = now + name + opt.postfix
logdir = os.path.join(opt.logdir, nowname)
if opt.ckpt:
ckpt = opt.ckpt
ckptdir = os.path.join(logdir, "checkpoints")
cfgdir = os.path.join(logdir, "configs")
seed_everything(opt.seed)
@ -582,14 +591,11 @@ if __name__ == "__main__":
for k in nondefault_trainer_args(opt):
trainer_config[k] = getattr(opt, k)
print(trainer_config)
if not trainer_config["accelerator"] == "gpu":
del trainer_config["accelerator"]
cpu = True
print("Running on CPU")
else:
cpu = False
print("Running on GPU")
trainer_opt = argparse.Namespace(**trainer_config)
lightning_config.trainer = trainer_config
@ -597,10 +603,12 @@ if __name__ == "__main__":
use_fp16 = trainer_config.get("precision", 32) == 16
if use_fp16:
config.model["params"].update({"use_fp16": True})
print("Using FP16 = {}".format(config.model["params"]["use_fp16"]))
else:
config.model["params"].update({"use_fp16": False})
print("Using FP16 = {}".format(config.model["params"]["use_fp16"]))
if ckpt is not None:
config.model["params"].update({"ckpt": ckpt})
rank_zero_info("Using ckpt_path = {}".format(config.model["params"]["ckpt"]))
model = instantiate_from_config(config.model)
# trainer and callbacks
@ -639,7 +647,6 @@ if __name__ == "__main__":
# config the strategy, defualt is ddp
if "strategy" in trainer_config:
strategy_cfg = trainer_config["strategy"]
print("Using strategy: {}".format(strategy_cfg["target"]))
strategy_cfg["target"] = LIGHTNING_PACK_NAME + strategy_cfg["target"]
else:
strategy_cfg = {
@ -648,7 +655,6 @@ if __name__ == "__main__":
"find_unused_parameters": False
}
}
print("Using strategy: DDPStrategy")
trainer_kwargs["strategy"] = instantiate_from_config(strategy_cfg)
@ -664,7 +670,6 @@ if __name__ == "__main__":
}
}
if hasattr(model, "monitor"):
print(f"Monitoring {model.monitor} as checkpoint metric.")
default_modelckpt_cfg["params"]["monitor"] = model.monitor
default_modelckpt_cfg["params"]["save_top_k"] = 3
@ -673,7 +678,6 @@ if __name__ == "__main__":
else:
modelckpt_cfg = OmegaConf.create()
modelckpt_cfg = OmegaConf.merge(default_modelckpt_cfg, modelckpt_cfg)
print(f"Merged modelckpt-cfg: \n{modelckpt_cfg}")
if version.parse(pl.__version__) < version.parse('1.4.0'):
trainer_kwargs["checkpoint_callback"] = instantiate_from_config(modelckpt_cfg)
@ -710,8 +714,6 @@ if __name__ == "__main__":
"target": "main.CUDACallback"
},
}
if version.parse(pl.__version__) >= version.parse('1.4.0'):
default_callbacks_cfg.update({'checkpoint_callback': modelckpt_cfg})
if "callbacks" in lightning_config:
callbacks_cfg = lightning_config.callbacks
@ -737,15 +739,11 @@ if __name__ == "__main__":
default_callbacks_cfg.update(default_metrics_over_trainsteps_ckpt_dict)
callbacks_cfg = OmegaConf.merge(default_callbacks_cfg, callbacks_cfg)
if 'ignore_keys_callback' in callbacks_cfg and hasattr(trainer_opt, 'resume_from_checkpoint'):
callbacks_cfg.ignore_keys_callback.params['ckpt_path'] = trainer_opt.resume_from_checkpoint
elif 'ignore_keys_callback' in callbacks_cfg:
del callbacks_cfg['ignore_keys_callback']
trainer_kwargs["callbacks"] = [instantiate_from_config(callbacks_cfg[k]) for k in callbacks_cfg]
trainer = Trainer.from_argparse_args(trainer_opt, **trainer_kwargs)
trainer.logdir = logdir ###
trainer.logdir = logdir
# data
data = instantiate_from_config(config.data)
@ -754,9 +752,9 @@ if __name__ == "__main__":
# lightning still takes care of proper multiprocessing though
data.prepare_data()
data.setup()
print("#### Data #####")
for k in data.datasets:
print(f"{k}, {data.datasets[k].__class__.__name__}, {len(data.datasets[k])}")
rank_zero_info(f"{k}, {data.datasets[k].__class__.__name__}, {len(data.datasets[k])}")
# configure learning rate
bs, base_lr = config.data.params.batch_size, config.model.base_learning_rate
@ -768,17 +766,17 @@ if __name__ == "__main__":
accumulate_grad_batches = lightning_config.trainer.accumulate_grad_batches
else:
accumulate_grad_batches = 1
print(f"accumulate_grad_batches = {accumulate_grad_batches}")
rank_zero_info(f"accumulate_grad_batches = {accumulate_grad_batches}")
lightning_config.trainer.accumulate_grad_batches = accumulate_grad_batches
if opt.scale_lr:
model.learning_rate = accumulate_grad_batches * ngpu * bs * base_lr
print(
rank_zero_info(
"Setting learning rate to {:.2e} = {} (accumulate_grad_batches) * {} (num_gpus) * {} (batchsize) * {:.2e} (base_lr)"
.format(model.learning_rate, accumulate_grad_batches, ngpu, bs, base_lr))
else:
model.learning_rate = base_lr
print("++++ NOT USING LR SCALING ++++")
print(f"Setting learning rate to {model.learning_rate:.2e}")
rank_zero_info("++++ NOT USING LR SCALING ++++")
rank_zero_info(f"Setting learning rate to {model.learning_rate:.2e}")
# allow checkpointing via USR1
def melk(*args, **kwargs):

View File

@ -1,5 +1,5 @@
python scripts/txt2img.py --prompt "Teyvat, Name:Layla, Element: Cryo, Weapon:Sword, Region:Sumeru, Model type:Medium Female, Description:a woman in a blue outfit holding a sword" --plms \
python scripts/txt2img.py --prompt "Teyvat, Medium Female, a woman in a blue outfit holding a sword" --plms \
--outdir ./output \
--ckpt /tmp/2022-11-18T16-38-46_train_colossalai/checkpoints/last.ckpt \
--config /tmp/2022-11-18T16-38-46_train_colossalai/configs/2022-11-18T16-38-46-project.yaml \
--ckpt checkpoints/last.ckpt \
--config configs/2023-02-02T18-06-14-project.yaml \
--n_samples 4

View File

@ -2,4 +2,4 @@ HF_DATASETS_OFFLINE=1
TRANSFORMERS_OFFLINE=1
DIFFUSERS_OFFLINE=1
python main.py --logdir /tmp -t -b configs/train_colossalai.yaml
python main.py --logdir /tmp --train --base configs/Teyvat/train_colossalai_teyvat.yaml --ckpt diffuser_root_dir/512-base-ema.ckpt