mirror of https://github.com/hpcaitech/ColossalAI
commit
79079a9d0c
|
@ -53,27 +53,33 @@ You can also update an existing [latent diffusion](https://github.com/CompVis/la
|
|||
```
|
||||
conda install pytorch==1.12.1 torchvision==0.13.1 torchaudio==0.12.1 cudatoolkit=11.3 -c pytorch
|
||||
pip install transformers==4.19.2 diffusers invisible-watermark
|
||||
pip install -e .
|
||||
```
|
||||
|
||||
#### Step 2: install lightning
|
||||
|
||||
Install Lightning version later than 2022.01.04. We suggest you install lightning from source.
|
||||
|
||||
##### From Source
|
||||
```
|
||||
git clone https://github.com/Lightning-AI/lightning.git
|
||||
pip install -r requirements.txt
|
||||
python setup.py install
|
||||
```
|
||||
|
||||
##### From pip
|
||||
|
||||
```
|
||||
pip install pytorch-lightning
|
||||
```
|
||||
|
||||
#### Step 3:Install [Colossal-AI](https://colossalai.org/download/) From Our Official Website
|
||||
|
||||
##### From pip
|
||||
|
||||
For example, you can install v0.1.12 from our official website.
|
||||
For example, you can install v0.2.0 from our official website.
|
||||
|
||||
```
|
||||
pip install colossalai==0.1.12+torch1.12cu11.3 -f https://release.colossalai.org
|
||||
pip install colossalai==0.2.0+torch1.12cu11.3 -f https://release.colossalai.org
|
||||
```
|
||||
|
||||
##### From source
|
||||
|
@ -133,10 +139,9 @@ It is important for you to configure your volume mapping in order to get the bes
|
|||
3. **Optional**, if you encounter any problem stating that shared memory is insufficient inside container, please add `-v /dev/shm:/dev/shm` to your `docker run` command.
|
||||
|
||||
|
||||
|
||||
## Download the model checkpoint from pretrained
|
||||
|
||||
### stable-diffusion-v2-base
|
||||
### stable-diffusion-v2-base(Recommand)
|
||||
|
||||
```
|
||||
wget https://huggingface.co/stabilityai/stable-diffusion-2-base/resolve/main/512-base-ema.ckpt
|
||||
|
@ -144,8 +149,6 @@ wget https://huggingface.co/stabilityai/stable-diffusion-2-base/resolve/main/512
|
|||
|
||||
### stable-diffusion-v1-4
|
||||
|
||||
Our default model config use the weight from [CompVis/stable-diffusion-v1-4](https://huggingface.co/CompVis/stable-diffusion-v1-4?text=A+mecha+robot+in+a+favela+in+expressionist+style)
|
||||
|
||||
```
|
||||
git lfs install
|
||||
git clone https://huggingface.co/CompVis/stable-diffusion-v1-4
|
||||
|
@ -153,8 +156,6 @@ git clone https://huggingface.co/CompVis/stable-diffusion-v1-4
|
|||
|
||||
### stable-diffusion-v1-5 from runway
|
||||
|
||||
If you want to useed the Last [stable-diffusion-v1-5](https://huggingface.co/runwayml/stable-diffusion-v1-5) weight from runwayml
|
||||
|
||||
```
|
||||
git lfs install
|
||||
git clone https://huggingface.co/runwayml/stable-diffusion-v1-5
|
||||
|
@ -171,11 +172,16 @@ We provide the script `train_colossalai.sh` to run the training task with coloss
|
|||
and can also use `train_ddp.sh` to run the training task with ddp to compare.
|
||||
|
||||
In `train_colossalai.sh` the main command is:
|
||||
|
||||
```
|
||||
python main.py --logdir /tmp/ -t -b configs/train_colossalai.yaml
|
||||
python main.py --logdir /tmp/ --train --base configs/train_colossalai.yaml --ckpt 512-base-ema.ckpt
|
||||
```
|
||||
|
||||
- you can change the `--logdir` to decide where to save the log information and the last checkpoint.
|
||||
- You can change the `--logdir` to decide where to save the log information and the last checkpoint.
|
||||
- You will find your ckpt in `logdir/checkpoints` or `logdir/diff_tb/version_0/checkpoints`
|
||||
- You will find your train config yaml in `logdir/configs`
|
||||
- You can add the `--ckpt` if you want to load the pretrained model, for example `512-base-ema.ckpt`
|
||||
- You can change the `--base` to specify the path of config yaml
|
||||
|
||||
### Training config
|
||||
|
||||
|
@ -186,7 +192,8 @@ You can change the trainging config in the yaml file
|
|||
- precision: the precision type used in training, default 16 (fp16), you must use fp16 if you want to apply colossalai
|
||||
- more information about the configuration of ColossalAIStrategy can be found [here](https://pytorch-lightning.readthedocs.io/en/latest/advanced/model_parallel.html#colossal-ai)
|
||||
|
||||
## Finetune Example (Work In Progress)
|
||||
|
||||
## Finetune Example
|
||||
### Training on Teyvat Datasets
|
||||
|
||||
We provide the finetuning example on [Teyvat](https://huggingface.co/datasets/Fazzie/Teyvat) dataset, which is create by BLIP generated captions.
|
||||
|
@ -201,8 +208,8 @@ you can get yout training last.ckpt and train config.yaml in your `--logdir`, an
|
|||
```
|
||||
python scripts/txt2img.py --prompt "a photograph of an astronaut riding a horse" --plms
|
||||
--outdir ./output \
|
||||
--config path/to/logdir/checkpoints/last.ckpt \
|
||||
--ckpt /path/to/logdir/configs/project.yaml \
|
||||
--ckpt path/to/logdir/checkpoints/last.ckpt \
|
||||
--config /path/to/logdir/configs/project.yaml \
|
||||
```
|
||||
|
||||
```commandline
|
||||
|
|
|
@ -6,6 +6,7 @@ model:
|
|||
linear_start: 0.00085
|
||||
linear_end: 0.0120
|
||||
num_timesteps_cond: 1
|
||||
ckpt: None # use ckpt path
|
||||
log_every_t: 200
|
||||
timesteps: 1000
|
||||
first_stage_key: image
|
||||
|
@ -16,7 +17,7 @@ model:
|
|||
conditioning_key: crossattn
|
||||
monitor: val/loss_simple_ema
|
||||
scale_factor: 0.18215
|
||||
use_ema: False # we set this to false because this is an inference only config
|
||||
use_ema: False
|
||||
|
||||
scheduler_config: # 10000 warmup steps
|
||||
target: ldm.lr_scheduler.LambdaLinearScheduler
|
||||
|
|
File diff suppressed because it is too large
Load Diff
|
@ -1,10 +1,11 @@
|
|||
# pytorch_diffusion + derived encoder decoder
|
||||
import math
|
||||
from typing import Any, Optional
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import numpy as np
|
||||
from einops import rearrange
|
||||
from typing import Optional, Any
|
||||
|
||||
try:
|
||||
from lightning.pytorch.utilities import rank_zero_info
|
||||
|
@ -53,15 +54,12 @@ def Normalize(in_channels, num_groups=32):
|
|||
|
||||
|
||||
class Upsample(nn.Module):
|
||||
|
||||
def __init__(self, in_channels, with_conv):
|
||||
super().__init__()
|
||||
self.with_conv = with_conv
|
||||
if self.with_conv:
|
||||
self.conv = torch.nn.Conv2d(in_channels,
|
||||
in_channels,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1)
|
||||
self.conv = torch.nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1)
|
||||
|
||||
def forward(self, x):
|
||||
x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")
|
||||
|
@ -71,16 +69,13 @@ class Upsample(nn.Module):
|
|||
|
||||
|
||||
class Downsample(nn.Module):
|
||||
|
||||
def __init__(self, in_channels, with_conv):
|
||||
super().__init__()
|
||||
self.with_conv = with_conv
|
||||
if self.with_conv:
|
||||
# no asymmetric padding in torch conv, must do it ourselves
|
||||
self.conv = torch.nn.Conv2d(in_channels,
|
||||
in_channels,
|
||||
kernel_size=3,
|
||||
stride=2,
|
||||
padding=0)
|
||||
self.conv = torch.nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=2, padding=0)
|
||||
|
||||
def forward(self, x):
|
||||
if self.with_conv:
|
||||
|
@ -93,8 +88,8 @@ class Downsample(nn.Module):
|
|||
|
||||
|
||||
class ResnetBlock(nn.Module):
|
||||
def __init__(self, *, in_channels, out_channels=None, conv_shortcut=False,
|
||||
dropout, temb_channels=512):
|
||||
|
||||
def __init__(self, *, in_channels, out_channels=None, conv_shortcut=False, dropout, temb_channels=512):
|
||||
super().__init__()
|
||||
self.in_channels = in_channels
|
||||
out_channels = in_channels if out_channels is None else out_channels
|
||||
|
@ -102,34 +97,17 @@ class ResnetBlock(nn.Module):
|
|||
self.use_conv_shortcut = conv_shortcut
|
||||
|
||||
self.norm1 = Normalize(in_channels)
|
||||
self.conv1 = torch.nn.Conv2d(in_channels,
|
||||
out_channels,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1)
|
||||
self.conv1 = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
|
||||
if temb_channels > 0:
|
||||
self.temb_proj = torch.nn.Linear(temb_channels,
|
||||
out_channels)
|
||||
self.temb_proj = torch.nn.Linear(temb_channels, out_channels)
|
||||
self.norm2 = Normalize(out_channels)
|
||||
self.dropout = torch.nn.Dropout(dropout)
|
||||
self.conv2 = torch.nn.Conv2d(out_channels,
|
||||
out_channels,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1)
|
||||
self.conv2 = torch.nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
|
||||
if self.in_channels != self.out_channels:
|
||||
if self.use_conv_shortcut:
|
||||
self.conv_shortcut = torch.nn.Conv2d(in_channels,
|
||||
out_channels,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1)
|
||||
self.conv_shortcut = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
|
||||
else:
|
||||
self.nin_shortcut = torch.nn.Conv2d(in_channels,
|
||||
out_channels,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0)
|
||||
self.nin_shortcut = torch.nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
|
||||
|
||||
def forward(self, x, temb):
|
||||
h = x
|
||||
|
@ -155,31 +133,16 @@ class ResnetBlock(nn.Module):
|
|||
|
||||
|
||||
class AttnBlock(nn.Module):
|
||||
|
||||
def __init__(self, in_channels):
|
||||
super().__init__()
|
||||
self.in_channels = in_channels
|
||||
|
||||
self.norm = Normalize(in_channels)
|
||||
self.q = torch.nn.Conv2d(in_channels,
|
||||
in_channels,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0)
|
||||
self.k = torch.nn.Conv2d(in_channels,
|
||||
in_channels,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0)
|
||||
self.v = torch.nn.Conv2d(in_channels,
|
||||
in_channels,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0)
|
||||
self.proj_out = torch.nn.Conv2d(in_channels,
|
||||
in_channels,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0)
|
||||
self.q = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
|
||||
self.k = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
|
||||
self.v = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
|
||||
self.proj_out = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
|
||||
|
||||
def forward(self, x):
|
||||
h_ = x
|
||||
|
@ -207,38 +170,24 @@ class AttnBlock(nn.Module):
|
|||
|
||||
return x + h_
|
||||
|
||||
|
||||
class MemoryEfficientAttnBlock(nn.Module):
|
||||
"""
|
||||
Uses xformers efficient implementation,
|
||||
see https://github.com/MatthieuTPHR/diffusers/blob/d80b531ff8060ec1ea982b65a1b8df70f73aa67c/src/diffusers/models/attention.py#L223
|
||||
Note: this is a single-head self-attention operation
|
||||
"""
|
||||
|
||||
#
|
||||
def __init__(self, in_channels):
|
||||
super().__init__()
|
||||
self.in_channels = in_channels
|
||||
|
||||
self.norm = Normalize(in_channels)
|
||||
self.q = torch.nn.Conv2d(in_channels,
|
||||
in_channels,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0)
|
||||
self.k = torch.nn.Conv2d(in_channels,
|
||||
in_channels,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0)
|
||||
self.v = torch.nn.Conv2d(in_channels,
|
||||
in_channels,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0)
|
||||
self.proj_out = torch.nn.Conv2d(in_channels,
|
||||
in_channels,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0)
|
||||
self.q = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
|
||||
self.k = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
|
||||
self.v = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
|
||||
self.proj_out = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
|
||||
self.attention_op: Optional[Any] = None
|
||||
|
||||
def forward(self, x):
|
||||
|
@ -253,27 +202,20 @@ class MemoryEfficientAttnBlock(nn.Module):
|
|||
q, k, v = map(lambda x: rearrange(x, 'b c h w -> b (h w) c'), (q, k, v))
|
||||
|
||||
q, k, v = map(
|
||||
lambda t: t.unsqueeze(3)
|
||||
.reshape(B, t.shape[1], 1, C)
|
||||
.permute(0, 2, 1, 3)
|
||||
.reshape(B * 1, t.shape[1], C)
|
||||
.contiguous(),
|
||||
lambda t: t.unsqueeze(3).reshape(B, t.shape[1], 1, C).permute(0, 2, 1, 3).reshape(B * 1, t.shape[1], C).
|
||||
contiguous(),
|
||||
(q, k, v),
|
||||
)
|
||||
out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None, op=self.attention_op)
|
||||
|
||||
out = (
|
||||
out.unsqueeze(0)
|
||||
.reshape(B, 1, out.shape[1], C)
|
||||
.permute(0, 2, 1, 3)
|
||||
.reshape(B, out.shape[1], C)
|
||||
)
|
||||
out = (out.unsqueeze(0).reshape(B, 1, out.shape[1], C).permute(0, 2, 1, 3).reshape(B, out.shape[1], C))
|
||||
out = rearrange(out, 'b (h w) c -> b c h w', b=B, h=H, w=W, c=C)
|
||||
out = self.proj_out(out)
|
||||
return x + out
|
||||
|
||||
|
||||
class MemoryEfficientCrossAttentionWrapper(MemoryEfficientCrossAttention):
|
||||
|
||||
def forward(self, x, context=None, mask=None):
|
||||
b, c, h, w = x.shape
|
||||
x = rearrange(x, 'b c h w -> b (h w) c')
|
||||
|
@ -283,10 +225,10 @@ class MemoryEfficientCrossAttentionWrapper(MemoryEfficientCrossAttention):
|
|||
|
||||
|
||||
def make_attn(in_channels, attn_type="vanilla", attn_kwargs=None):
|
||||
assert attn_type in ["vanilla", "vanilla-xformers", "memory-efficient-cross-attn", "linear", "none"], f'attn_type {attn_type} unknown'
|
||||
assert attn_type in ["vanilla", "vanilla-xformers", "memory-efficient-cross-attn", "linear",
|
||||
"none"], f'attn_type {attn_type} unknown'
|
||||
if XFORMERS_IS_AVAILBLE and attn_type == "vanilla":
|
||||
attn_type = "vanilla-xformers"
|
||||
rank_zero_info(f"making attention of type '{attn_type}' with {in_channels} in_channels")
|
||||
if attn_type == "vanilla":
|
||||
assert attn_kwargs is None
|
||||
return AttnBlock(in_channels)
|
||||
|
@ -303,11 +245,24 @@ def make_attn(in_channels, attn_type="vanilla", attn_kwargs=None):
|
|||
|
||||
|
||||
class Model(nn.Module):
|
||||
def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks,
|
||||
attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels,
|
||||
resolution, use_timestep=True, use_linear_attn=False, attn_type="vanilla"):
|
||||
|
||||
def __init__(self,
|
||||
*,
|
||||
ch,
|
||||
out_ch,
|
||||
ch_mult=(1, 2, 4, 8),
|
||||
num_res_blocks,
|
||||
attn_resolutions,
|
||||
dropout=0.0,
|
||||
resamp_with_conv=True,
|
||||
in_channels,
|
||||
resolution,
|
||||
use_timestep=True,
|
||||
use_linear_attn=False,
|
||||
attn_type="vanilla"):
|
||||
super().__init__()
|
||||
if use_linear_attn: attn_type = "linear"
|
||||
if use_linear_attn:
|
||||
attn_type = "linear"
|
||||
self.ch = ch
|
||||
self.temb_ch = self.ch * 4
|
||||
self.num_resolutions = len(ch_mult)
|
||||
|
@ -320,18 +275,12 @@ class Model(nn.Module):
|
|||
# timestep embedding
|
||||
self.temb = nn.Module()
|
||||
self.temb.dense = nn.ModuleList([
|
||||
torch.nn.Linear(self.ch,
|
||||
self.temb_ch),
|
||||
torch.nn.Linear(self.temb_ch,
|
||||
self.temb_ch),
|
||||
torch.nn.Linear(self.ch, self.temb_ch),
|
||||
torch.nn.Linear(self.temb_ch, self.temb_ch),
|
||||
])
|
||||
|
||||
# downsampling
|
||||
self.conv_in = torch.nn.Conv2d(in_channels,
|
||||
self.ch,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1)
|
||||
self.conv_in = torch.nn.Conv2d(in_channels, self.ch, kernel_size=3, stride=1, padding=1)
|
||||
|
||||
curr_res = resolution
|
||||
in_ch_mult = (1,) + tuple(ch_mult)
|
||||
|
@ -342,7 +291,8 @@ class Model(nn.Module):
|
|||
block_in = ch * in_ch_mult[i_level]
|
||||
block_out = ch * ch_mult[i_level]
|
||||
for i_block in range(self.num_res_blocks):
|
||||
block.append(ResnetBlock(in_channels=block_in,
|
||||
block.append(
|
||||
ResnetBlock(in_channels=block_in,
|
||||
out_channels=block_out,
|
||||
temb_channels=self.temb_ch,
|
||||
dropout=dropout))
|
||||
|
@ -379,7 +329,8 @@ class Model(nn.Module):
|
|||
for i_block in range(self.num_res_blocks + 1):
|
||||
if i_block == self.num_res_blocks:
|
||||
skip_in = ch * in_ch_mult[i_level]
|
||||
block.append(ResnetBlock(in_channels=block_in+skip_in,
|
||||
block.append(
|
||||
ResnetBlock(in_channels=block_in + skip_in,
|
||||
out_channels=block_out,
|
||||
temb_channels=self.temb_ch,
|
||||
dropout=dropout))
|
||||
|
@ -396,11 +347,7 @@ class Model(nn.Module):
|
|||
|
||||
# end
|
||||
self.norm_out = Normalize(block_in)
|
||||
self.conv_out = torch.nn.Conv2d(block_in,
|
||||
out_ch,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1)
|
||||
self.conv_out = torch.nn.Conv2d(block_in, out_ch, kernel_size=3, stride=1, padding=1)
|
||||
|
||||
def forward(self, x, t=None, context=None):
|
||||
#assert x.shape[2] == x.shape[3] == self.resolution
|
||||
|
@ -437,8 +384,7 @@ class Model(nn.Module):
|
|||
# upsampling
|
||||
for i_level in reversed(range(self.num_resolutions)):
|
||||
for i_block in range(self.num_res_blocks + 1):
|
||||
h = self.up[i_level].block[i_block](
|
||||
torch.cat([h, hs.pop()], dim=1), temb)
|
||||
h = self.up[i_level].block[i_block](torch.cat([h, hs.pop()], dim=1), temb)
|
||||
if len(self.up[i_level].attn) > 0:
|
||||
h = self.up[i_level].attn[i_block](h)
|
||||
if i_level != 0:
|
||||
|
@ -455,12 +401,26 @@ class Model(nn.Module):
|
|||
|
||||
|
||||
class Encoder(nn.Module):
|
||||
def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks,
|
||||
attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels,
|
||||
resolution, z_channels, double_z=True, use_linear_attn=False, attn_type="vanilla",
|
||||
|
||||
def __init__(self,
|
||||
*,
|
||||
ch,
|
||||
out_ch,
|
||||
ch_mult=(1, 2, 4, 8),
|
||||
num_res_blocks,
|
||||
attn_resolutions,
|
||||
dropout=0.0,
|
||||
resamp_with_conv=True,
|
||||
in_channels,
|
||||
resolution,
|
||||
z_channels,
|
||||
double_z=True,
|
||||
use_linear_attn=False,
|
||||
attn_type="vanilla",
|
||||
**ignore_kwargs):
|
||||
super().__init__()
|
||||
if use_linear_attn: attn_type = "linear"
|
||||
if use_linear_attn:
|
||||
attn_type = "linear"
|
||||
self.ch = ch
|
||||
self.temb_ch = 0
|
||||
self.num_resolutions = len(ch_mult)
|
||||
|
@ -469,11 +429,7 @@ class Encoder(nn.Module):
|
|||
self.in_channels = in_channels
|
||||
|
||||
# downsampling
|
||||
self.conv_in = torch.nn.Conv2d(in_channels,
|
||||
self.ch,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1)
|
||||
self.conv_in = torch.nn.Conv2d(in_channels, self.ch, kernel_size=3, stride=1, padding=1)
|
||||
|
||||
curr_res = resolution
|
||||
in_ch_mult = (1,) + tuple(ch_mult)
|
||||
|
@ -485,7 +441,8 @@ class Encoder(nn.Module):
|
|||
block_in = ch * in_ch_mult[i_level]
|
||||
block_out = ch * ch_mult[i_level]
|
||||
for i_block in range(self.num_res_blocks):
|
||||
block.append(ResnetBlock(in_channels=block_in,
|
||||
block.append(
|
||||
ResnetBlock(in_channels=block_in,
|
||||
out_channels=block_out,
|
||||
temb_channels=self.temb_ch,
|
||||
dropout=dropout))
|
||||
|
@ -549,12 +506,27 @@ class Encoder(nn.Module):
|
|||
|
||||
|
||||
class Decoder(nn.Module):
|
||||
def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks,
|
||||
attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels,
|
||||
resolution, z_channels, give_pre_end=False, tanh_out=False, use_linear_attn=False,
|
||||
attn_type="vanilla", **ignorekwargs):
|
||||
|
||||
def __init__(self,
|
||||
*,
|
||||
ch,
|
||||
out_ch,
|
||||
ch_mult=(1, 2, 4, 8),
|
||||
num_res_blocks,
|
||||
attn_resolutions,
|
||||
dropout=0.0,
|
||||
resamp_with_conv=True,
|
||||
in_channels,
|
||||
resolution,
|
||||
z_channels,
|
||||
give_pre_end=False,
|
||||
tanh_out=False,
|
||||
use_linear_attn=False,
|
||||
attn_type="vanilla",
|
||||
**ignorekwargs):
|
||||
super().__init__()
|
||||
if use_linear_attn: attn_type = "linear"
|
||||
if use_linear_attn:
|
||||
attn_type = "linear"
|
||||
self.ch = ch
|
||||
self.temb_ch = 0
|
||||
self.num_resolutions = len(ch_mult)
|
||||
|
@ -569,15 +541,10 @@ class Decoder(nn.Module):
|
|||
block_in = ch * ch_mult[self.num_resolutions - 1]
|
||||
curr_res = resolution // 2**(self.num_resolutions - 1)
|
||||
self.z_shape = (1, z_channels, curr_res, curr_res)
|
||||
rank_zero_info("Working with z of shape {} = {} dimensions.".format(
|
||||
self.z_shape, np.prod(self.z_shape)))
|
||||
rank_zero_info("Working with z of shape {} = {} dimensions.".format(self.z_shape, np.prod(self.z_shape)))
|
||||
|
||||
# z to block_in
|
||||
self.conv_in = torch.nn.Conv2d(z_channels,
|
||||
block_in,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1)
|
||||
self.conv_in = torch.nn.Conv2d(z_channels, block_in, kernel_size=3, stride=1, padding=1)
|
||||
|
||||
# middle
|
||||
self.mid = nn.Module()
|
||||
|
@ -598,7 +565,8 @@ class Decoder(nn.Module):
|
|||
attn = nn.ModuleList()
|
||||
block_out = ch * ch_mult[i_level]
|
||||
for i_block in range(self.num_res_blocks + 1):
|
||||
block.append(ResnetBlock(in_channels=block_in,
|
||||
block.append(
|
||||
ResnetBlock(in_channels=block_in,
|
||||
out_channels=block_out,
|
||||
temb_channels=self.temb_ch,
|
||||
dropout=dropout))
|
||||
|
@ -615,11 +583,7 @@ class Decoder(nn.Module):
|
|||
|
||||
# end
|
||||
self.norm_out = Normalize(block_in)
|
||||
self.conv_out = torch.nn.Conv2d(block_in,
|
||||
out_ch,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1)
|
||||
self.conv_out = torch.nn.Conv2d(block_in, out_ch, kernel_size=3, stride=1, padding=1)
|
||||
|
||||
def forward(self, z):
|
||||
#assert z.shape[1:] == self.z_shape[1:]
|
||||
|
@ -658,27 +622,20 @@ class Decoder(nn.Module):
|
|||
|
||||
|
||||
class SimpleDecoder(nn.Module):
|
||||
|
||||
def __init__(self, in_channels, out_channels, *args, **kwargs):
|
||||
super().__init__()
|
||||
self.model = nn.ModuleList([nn.Conv2d(in_channels, in_channels, 1),
|
||||
ResnetBlock(in_channels=in_channels,
|
||||
out_channels=2 * in_channels,
|
||||
temb_channels=0, dropout=0.0),
|
||||
ResnetBlock(in_channels=2 * in_channels,
|
||||
out_channels=4 * in_channels,
|
||||
temb_channels=0, dropout=0.0),
|
||||
ResnetBlock(in_channels=4 * in_channels,
|
||||
out_channels=2 * in_channels,
|
||||
temb_channels=0, dropout=0.0),
|
||||
self.model = nn.ModuleList([
|
||||
nn.Conv2d(in_channels, in_channels, 1),
|
||||
ResnetBlock(in_channels=in_channels, out_channels=2 * in_channels, temb_channels=0, dropout=0.0),
|
||||
ResnetBlock(in_channels=2 * in_channels, out_channels=4 * in_channels, temb_channels=0, dropout=0.0),
|
||||
ResnetBlock(in_channels=4 * in_channels, out_channels=2 * in_channels, temb_channels=0, dropout=0.0),
|
||||
nn.Conv2d(2 * in_channels, in_channels, 1),
|
||||
Upsample(in_channels, with_conv=True)])
|
||||
Upsample(in_channels, with_conv=True)
|
||||
])
|
||||
# end
|
||||
self.norm_out = Normalize(in_channels)
|
||||
self.conv_out = torch.nn.Conv2d(in_channels,
|
||||
out_channels,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1)
|
||||
self.conv_out = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
|
||||
|
||||
def forward(self, x):
|
||||
for i, layer in enumerate(self.model):
|
||||
|
@ -694,8 +651,8 @@ class SimpleDecoder(nn.Module):
|
|||
|
||||
|
||||
class UpsampleDecoder(nn.Module):
|
||||
def __init__(self, in_channels, out_channels, ch, num_res_blocks, resolution,
|
||||
ch_mult=(2,2), dropout=0.0):
|
||||
|
||||
def __init__(self, in_channels, out_channels, ch, num_res_blocks, resolution, ch_mult=(2, 2), dropout=0.0):
|
||||
super().__init__()
|
||||
# upsampling
|
||||
self.temb_ch = 0
|
||||
|
@ -709,7 +666,8 @@ class UpsampleDecoder(nn.Module):
|
|||
res_block = []
|
||||
block_out = ch * ch_mult[i_level]
|
||||
for i_block in range(self.num_res_blocks + 1):
|
||||
res_block.append(ResnetBlock(in_channels=block_in,
|
||||
res_block.append(
|
||||
ResnetBlock(in_channels=block_in,
|
||||
out_channels=block_out,
|
||||
temb_channels=self.temb_ch,
|
||||
dropout=dropout))
|
||||
|
@ -721,11 +679,7 @@ class UpsampleDecoder(nn.Module):
|
|||
|
||||
# end
|
||||
self.norm_out = Normalize(block_in)
|
||||
self.conv_out = torch.nn.Conv2d(block_in,
|
||||
out_channels,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1)
|
||||
self.conv_out = torch.nn.Conv2d(block_in, out_channels, kernel_size=3, stride=1, padding=1)
|
||||
|
||||
def forward(self, x):
|
||||
# upsampling
|
||||
|
@ -742,26 +696,24 @@ class UpsampleDecoder(nn.Module):
|
|||
|
||||
|
||||
class LatentRescaler(nn.Module):
|
||||
|
||||
def __init__(self, factor, in_channels, mid_channels, out_channels, depth=2):
|
||||
super().__init__()
|
||||
# residual block, interpolate, residual block
|
||||
self.factor = factor
|
||||
self.conv_in = nn.Conv2d(in_channels,
|
||||
mid_channels,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1)
|
||||
self.res_block1 = nn.ModuleList([ResnetBlock(in_channels=mid_channels,
|
||||
out_channels=mid_channels,
|
||||
temb_channels=0,
|
||||
dropout=0.0) for _ in range(depth)])
|
||||
self.conv_in = nn.Conv2d(in_channels, mid_channels, kernel_size=3, stride=1, padding=1)
|
||||
self.res_block1 = nn.ModuleList([
|
||||
ResnetBlock(in_channels=mid_channels, out_channels=mid_channels, temb_channels=0, dropout=0.0)
|
||||
for _ in range(depth)
|
||||
])
|
||||
self.attn = AttnBlock(mid_channels)
|
||||
self.res_block2 = nn.ModuleList([ResnetBlock(in_channels=mid_channels,
|
||||
out_channels=mid_channels,
|
||||
temb_channels=0,
|
||||
dropout=0.0) for _ in range(depth)])
|
||||
self.res_block2 = nn.ModuleList([
|
||||
ResnetBlock(in_channels=mid_channels, out_channels=mid_channels, temb_channels=0, dropout=0.0)
|
||||
for _ in range(depth)
|
||||
])
|
||||
|
||||
self.conv_out = nn.Conv2d(mid_channels,
|
||||
self.conv_out = nn.Conv2d(
|
||||
mid_channels,
|
||||
out_channels,
|
||||
kernel_size=1,
|
||||
)
|
||||
|
@ -770,7 +722,9 @@ class LatentRescaler(nn.Module):
|
|||
x = self.conv_in(x)
|
||||
for block in self.res_block1:
|
||||
x = block(x, None)
|
||||
x = torch.nn.functional.interpolate(x, size=(int(round(x.shape[2]*self.factor)), int(round(x.shape[3]*self.factor))))
|
||||
x = torch.nn.functional.interpolate(x,
|
||||
size=(int(round(x.shape[2] * self.factor)),
|
||||
int(round(x.shape[3] * self.factor))))
|
||||
x = self.attn(x)
|
||||
for block in self.res_block2:
|
||||
x = block(x, None)
|
||||
|
@ -779,17 +733,37 @@ class LatentRescaler(nn.Module):
|
|||
|
||||
|
||||
class MergedRescaleEncoder(nn.Module):
|
||||
def __init__(self, in_channels, ch, resolution, out_ch, num_res_blocks,
|
||||
attn_resolutions, dropout=0.0, resamp_with_conv=True,
|
||||
ch_mult=(1,2,4,8), rescale_factor=1.0, rescale_module_depth=1):
|
||||
|
||||
def __init__(self,
|
||||
in_channels,
|
||||
ch,
|
||||
resolution,
|
||||
out_ch,
|
||||
num_res_blocks,
|
||||
attn_resolutions,
|
||||
dropout=0.0,
|
||||
resamp_with_conv=True,
|
||||
ch_mult=(1, 2, 4, 8),
|
||||
rescale_factor=1.0,
|
||||
rescale_module_depth=1):
|
||||
super().__init__()
|
||||
intermediate_chn = ch * ch_mult[-1]
|
||||
self.encoder = Encoder(in_channels=in_channels, num_res_blocks=num_res_blocks, ch=ch, ch_mult=ch_mult,
|
||||
z_channels=intermediate_chn, double_z=False, resolution=resolution,
|
||||
attn_resolutions=attn_resolutions, dropout=dropout, resamp_with_conv=resamp_with_conv,
|
||||
self.encoder = Encoder(in_channels=in_channels,
|
||||
num_res_blocks=num_res_blocks,
|
||||
ch=ch,
|
||||
ch_mult=ch_mult,
|
||||
z_channels=intermediate_chn,
|
||||
double_z=False,
|
||||
resolution=resolution,
|
||||
attn_resolutions=attn_resolutions,
|
||||
dropout=dropout,
|
||||
resamp_with_conv=resamp_with_conv,
|
||||
out_ch=None)
|
||||
self.rescaler = LatentRescaler(factor=rescale_factor, in_channels=intermediate_chn,
|
||||
mid_channels=intermediate_chn, out_channels=out_ch, depth=rescale_module_depth)
|
||||
self.rescaler = LatentRescaler(factor=rescale_factor,
|
||||
in_channels=intermediate_chn,
|
||||
mid_channels=intermediate_chn,
|
||||
out_channels=out_ch,
|
||||
depth=rescale_module_depth)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.encoder(x)
|
||||
|
@ -798,15 +772,36 @@ class MergedRescaleEncoder(nn.Module):
|
|||
|
||||
|
||||
class MergedRescaleDecoder(nn.Module):
|
||||
def __init__(self, z_channels, out_ch, resolution, num_res_blocks, attn_resolutions, ch, ch_mult=(1,2,4,8),
|
||||
dropout=0.0, resamp_with_conv=True, rescale_factor=1.0, rescale_module_depth=1):
|
||||
|
||||
def __init__(self,
|
||||
z_channels,
|
||||
out_ch,
|
||||
resolution,
|
||||
num_res_blocks,
|
||||
attn_resolutions,
|
||||
ch,
|
||||
ch_mult=(1, 2, 4, 8),
|
||||
dropout=0.0,
|
||||
resamp_with_conv=True,
|
||||
rescale_factor=1.0,
|
||||
rescale_module_depth=1):
|
||||
super().__init__()
|
||||
tmp_chn = z_channels * ch_mult[-1]
|
||||
self.decoder = Decoder(out_ch=out_ch, z_channels=tmp_chn, attn_resolutions=attn_resolutions, dropout=dropout,
|
||||
resamp_with_conv=resamp_with_conv, in_channels=None, num_res_blocks=num_res_blocks,
|
||||
ch_mult=ch_mult, resolution=resolution, ch=ch)
|
||||
self.rescaler = LatentRescaler(factor=rescale_factor, in_channels=z_channels, mid_channels=tmp_chn,
|
||||
out_channels=tmp_chn, depth=rescale_module_depth)
|
||||
self.decoder = Decoder(out_ch=out_ch,
|
||||
z_channels=tmp_chn,
|
||||
attn_resolutions=attn_resolutions,
|
||||
dropout=dropout,
|
||||
resamp_with_conv=resamp_with_conv,
|
||||
in_channels=None,
|
||||
num_res_blocks=num_res_blocks,
|
||||
ch_mult=ch_mult,
|
||||
resolution=resolution,
|
||||
ch=ch)
|
||||
self.rescaler = LatentRescaler(factor=rescale_factor,
|
||||
in_channels=z_channels,
|
||||
mid_channels=tmp_chn,
|
||||
out_channels=tmp_chn,
|
||||
depth=rescale_module_depth)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.rescaler(x)
|
||||
|
@ -815,16 +810,26 @@ class MergedRescaleDecoder(nn.Module):
|
|||
|
||||
|
||||
class Upsampler(nn.Module):
|
||||
|
||||
def __init__(self, in_size, out_size, in_channels, out_channels, ch_mult=2):
|
||||
super().__init__()
|
||||
assert out_size >= in_size
|
||||
num_blocks = int(np.log2(out_size // in_size)) + 1
|
||||
factor_up = 1. + (out_size % in_size)
|
||||
rank_zero_info(f"Building {self.__class__.__name__} with in_size: {in_size} --> out_size {out_size} and factor {factor_up}")
|
||||
self.rescaler = LatentRescaler(factor=factor_up, in_channels=in_channels, mid_channels=2*in_channels,
|
||||
rank_zero_info(
|
||||
f"Building {self.__class__.__name__} with in_size: {in_size} --> out_size {out_size} and factor {factor_up}"
|
||||
)
|
||||
self.rescaler = LatentRescaler(factor=factor_up,
|
||||
in_channels=in_channels,
|
||||
mid_channels=2 * in_channels,
|
||||
out_channels=in_channels)
|
||||
self.decoder = Decoder(out_ch=out_channels, resolution=out_size, z_channels=in_channels, num_res_blocks=2,
|
||||
attn_resolutions=[], in_channels=None, ch=in_channels,
|
||||
self.decoder = Decoder(out_ch=out_channels,
|
||||
resolution=out_size,
|
||||
z_channels=in_channels,
|
||||
num_res_blocks=2,
|
||||
attn_resolutions=[],
|
||||
in_channels=None,
|
||||
ch=in_channels,
|
||||
ch_mult=[ch_mult for _ in range(num_blocks)])
|
||||
|
||||
def forward(self, x):
|
||||
|
@ -834,20 +839,18 @@ class Upsampler(nn.Module):
|
|||
|
||||
|
||||
class Resize(nn.Module):
|
||||
|
||||
def __init__(self, in_channels=None, learned=False, mode="bilinear"):
|
||||
super().__init__()
|
||||
self.with_conv = learned
|
||||
self.mode = mode
|
||||
if self.with_conv:
|
||||
rank_zero_info(f"Note: {self.__class__.__name} uses learned downsampling and will ignore the fixed {mode} mode")
|
||||
rank_zero_info(
|
||||
f"Note: {self.__class__.__name} uses learned downsampling and will ignore the fixed {mode} mode")
|
||||
raise NotImplementedError()
|
||||
assert in_channels is not None
|
||||
# no asymmetric padding in torch conv, must do it ourselves
|
||||
self.conv = torch.nn.Conv2d(in_channels,
|
||||
in_channels,
|
||||
kernel_size=4,
|
||||
stride=2,
|
||||
padding=1)
|
||||
self.conv = torch.nn.Conv2d(in_channels, in_channels, kernel_size=4, stride=2, padding=1)
|
||||
|
||||
def forward(self, x, scale_factor=1.0):
|
||||
if scale_factor == 1.0:
|
||||
|
|
|
@ -106,7 +106,20 @@ def get_parser(**parser_kwargs):
|
|||
nargs="?",
|
||||
help="disable test",
|
||||
)
|
||||
parser.add_argument("-p", "--project", help="name of new or path to existing project")
|
||||
parser.add_argument(
|
||||
"-p",
|
||||
"--project",
|
||||
help="name of new or path to existing project",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-c",
|
||||
"--ckpt",
|
||||
type=str,
|
||||
const=True,
|
||||
default="",
|
||||
nargs="?",
|
||||
help="load pretrained checkpoint from stable AI",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-d",
|
||||
"--debug",
|
||||
|
@ -145,22 +158,7 @@ def get_parser(**parser_kwargs):
|
|||
default=True,
|
||||
help="scale base-lr by ngpu * batch_size * n_accumulate",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--use_fp16",
|
||||
type=str2bool,
|
||||
nargs="?",
|
||||
const=True,
|
||||
default=True,
|
||||
help="whether to use fp16",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--flash",
|
||||
type=str2bool,
|
||||
const=True,
|
||||
default=False,
|
||||
nargs="?",
|
||||
help="whether to use flash attention",
|
||||
)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
|
@ -341,6 +339,12 @@ class SetupCallback(Callback):
|
|||
except FileNotFoundError:
|
||||
pass
|
||||
|
||||
# def on_fit_end(self, trainer, pl_module):
|
||||
# if trainer.global_rank == 0:
|
||||
# ckpt_path = os.path.join(self.ckptdir, "last.ckpt")
|
||||
# rank_zero_info(f"Saving final checkpoint in {ckpt_path}.")
|
||||
# trainer.save_checkpoint(ckpt_path)
|
||||
|
||||
|
||||
class ImageLogger(Callback):
|
||||
|
||||
|
@ -536,6 +540,7 @@ if __name__ == "__main__":
|
|||
"If you want to resume training in a new log folder, "
|
||||
"use -n/--name in combination with --resume_from_checkpoint")
|
||||
if opt.resume:
|
||||
rank_zero_info("Resuming from {}".format(opt.resume))
|
||||
if not os.path.exists(opt.resume):
|
||||
raise ValueError("Cannot find {}".format(opt.resume))
|
||||
if os.path.isfile(opt.resume):
|
||||
|
@ -543,13 +548,13 @@ if __name__ == "__main__":
|
|||
# idx = len(paths)-paths[::-1].index("logs")+1
|
||||
# logdir = "/".join(paths[:idx])
|
||||
logdir = "/".join(paths[:-2])
|
||||
rank_zero_info("logdir: {}".format(logdir))
|
||||
ckpt = opt.resume
|
||||
else:
|
||||
assert os.path.isdir(opt.resume), opt.resume
|
||||
logdir = opt.resume.rstrip("/")
|
||||
ckpt = os.path.join(logdir, "checkpoints", "last.ckpt")
|
||||
|
||||
opt.resume_from_checkpoint = ckpt
|
||||
base_configs = sorted(glob.glob(os.path.join(logdir, "configs/*.yaml")))
|
||||
opt.base = base_configs + opt.base
|
||||
_tmp = logdir.split("/")
|
||||
|
@ -558,6 +563,7 @@ if __name__ == "__main__":
|
|||
if opt.name:
|
||||
name = "_" + opt.name
|
||||
elif opt.base:
|
||||
rank_zero_info("Using base config {}".format(opt.base))
|
||||
cfg_fname = os.path.split(opt.base[0])[-1]
|
||||
cfg_name = os.path.splitext(cfg_fname)[0]
|
||||
name = "_" + cfg_name
|
||||
|
@ -566,6 +572,9 @@ if __name__ == "__main__":
|
|||
nowname = now + name + opt.postfix
|
||||
logdir = os.path.join(opt.logdir, nowname)
|
||||
|
||||
if opt.ckpt:
|
||||
ckpt = opt.ckpt
|
||||
|
||||
ckptdir = os.path.join(logdir, "checkpoints")
|
||||
cfgdir = os.path.join(logdir, "configs")
|
||||
seed_everything(opt.seed)
|
||||
|
@ -582,14 +591,11 @@ if __name__ == "__main__":
|
|||
for k in nondefault_trainer_args(opt):
|
||||
trainer_config[k] = getattr(opt, k)
|
||||
|
||||
print(trainer_config)
|
||||
if not trainer_config["accelerator"] == "gpu":
|
||||
del trainer_config["accelerator"]
|
||||
cpu = True
|
||||
print("Running on CPU")
|
||||
else:
|
||||
cpu = False
|
||||
print("Running on GPU")
|
||||
trainer_opt = argparse.Namespace(**trainer_config)
|
||||
lightning_config.trainer = trainer_config
|
||||
|
||||
|
@ -597,10 +603,12 @@ if __name__ == "__main__":
|
|||
use_fp16 = trainer_config.get("precision", 32) == 16
|
||||
if use_fp16:
|
||||
config.model["params"].update({"use_fp16": True})
|
||||
print("Using FP16 = {}".format(config.model["params"]["use_fp16"]))
|
||||
else:
|
||||
config.model["params"].update({"use_fp16": False})
|
||||
print("Using FP16 = {}".format(config.model["params"]["use_fp16"]))
|
||||
|
||||
if ckpt is not None:
|
||||
config.model["params"].update({"ckpt": ckpt})
|
||||
rank_zero_info("Using ckpt_path = {}".format(config.model["params"]["ckpt"]))
|
||||
|
||||
model = instantiate_from_config(config.model)
|
||||
# trainer and callbacks
|
||||
|
@ -639,7 +647,6 @@ if __name__ == "__main__":
|
|||
# config the strategy, defualt is ddp
|
||||
if "strategy" in trainer_config:
|
||||
strategy_cfg = trainer_config["strategy"]
|
||||
print("Using strategy: {}".format(strategy_cfg["target"]))
|
||||
strategy_cfg["target"] = LIGHTNING_PACK_NAME + strategy_cfg["target"]
|
||||
else:
|
||||
strategy_cfg = {
|
||||
|
@ -648,7 +655,6 @@ if __name__ == "__main__":
|
|||
"find_unused_parameters": False
|
||||
}
|
||||
}
|
||||
print("Using strategy: DDPStrategy")
|
||||
|
||||
trainer_kwargs["strategy"] = instantiate_from_config(strategy_cfg)
|
||||
|
||||
|
@ -664,7 +670,6 @@ if __name__ == "__main__":
|
|||
}
|
||||
}
|
||||
if hasattr(model, "monitor"):
|
||||
print(f"Monitoring {model.monitor} as checkpoint metric.")
|
||||
default_modelckpt_cfg["params"]["monitor"] = model.monitor
|
||||
default_modelckpt_cfg["params"]["save_top_k"] = 3
|
||||
|
||||
|
@ -673,7 +678,6 @@ if __name__ == "__main__":
|
|||
else:
|
||||
modelckpt_cfg = OmegaConf.create()
|
||||
modelckpt_cfg = OmegaConf.merge(default_modelckpt_cfg, modelckpt_cfg)
|
||||
print(f"Merged modelckpt-cfg: \n{modelckpt_cfg}")
|
||||
if version.parse(pl.__version__) < version.parse('1.4.0'):
|
||||
trainer_kwargs["checkpoint_callback"] = instantiate_from_config(modelckpt_cfg)
|
||||
|
||||
|
@ -710,8 +714,6 @@ if __name__ == "__main__":
|
|||
"target": "main.CUDACallback"
|
||||
},
|
||||
}
|
||||
if version.parse(pl.__version__) >= version.parse('1.4.0'):
|
||||
default_callbacks_cfg.update({'checkpoint_callback': modelckpt_cfg})
|
||||
|
||||
if "callbacks" in lightning_config:
|
||||
callbacks_cfg = lightning_config.callbacks
|
||||
|
@ -737,15 +739,11 @@ if __name__ == "__main__":
|
|||
default_callbacks_cfg.update(default_metrics_over_trainsteps_ckpt_dict)
|
||||
|
||||
callbacks_cfg = OmegaConf.merge(default_callbacks_cfg, callbacks_cfg)
|
||||
if 'ignore_keys_callback' in callbacks_cfg and hasattr(trainer_opt, 'resume_from_checkpoint'):
|
||||
callbacks_cfg.ignore_keys_callback.params['ckpt_path'] = trainer_opt.resume_from_checkpoint
|
||||
elif 'ignore_keys_callback' in callbacks_cfg:
|
||||
del callbacks_cfg['ignore_keys_callback']
|
||||
|
||||
trainer_kwargs["callbacks"] = [instantiate_from_config(callbacks_cfg[k]) for k in callbacks_cfg]
|
||||
|
||||
trainer = Trainer.from_argparse_args(trainer_opt, **trainer_kwargs)
|
||||
trainer.logdir = logdir ###
|
||||
trainer.logdir = logdir
|
||||
|
||||
# data
|
||||
data = instantiate_from_config(config.data)
|
||||
|
@ -754,9 +752,9 @@ if __name__ == "__main__":
|
|||
# lightning still takes care of proper multiprocessing though
|
||||
data.prepare_data()
|
||||
data.setup()
|
||||
print("#### Data #####")
|
||||
|
||||
for k in data.datasets:
|
||||
print(f"{k}, {data.datasets[k].__class__.__name__}, {len(data.datasets[k])}")
|
||||
rank_zero_info(f"{k}, {data.datasets[k].__class__.__name__}, {len(data.datasets[k])}")
|
||||
|
||||
# configure learning rate
|
||||
bs, base_lr = config.data.params.batch_size, config.model.base_learning_rate
|
||||
|
@ -768,17 +766,17 @@ if __name__ == "__main__":
|
|||
accumulate_grad_batches = lightning_config.trainer.accumulate_grad_batches
|
||||
else:
|
||||
accumulate_grad_batches = 1
|
||||
print(f"accumulate_grad_batches = {accumulate_grad_batches}")
|
||||
rank_zero_info(f"accumulate_grad_batches = {accumulate_grad_batches}")
|
||||
lightning_config.trainer.accumulate_grad_batches = accumulate_grad_batches
|
||||
if opt.scale_lr:
|
||||
model.learning_rate = accumulate_grad_batches * ngpu * bs * base_lr
|
||||
print(
|
||||
rank_zero_info(
|
||||
"Setting learning rate to {:.2e} = {} (accumulate_grad_batches) * {} (num_gpus) * {} (batchsize) * {:.2e} (base_lr)"
|
||||
.format(model.learning_rate, accumulate_grad_batches, ngpu, bs, base_lr))
|
||||
else:
|
||||
model.learning_rate = base_lr
|
||||
print("++++ NOT USING LR SCALING ++++")
|
||||
print(f"Setting learning rate to {model.learning_rate:.2e}")
|
||||
rank_zero_info("++++ NOT USING LR SCALING ++++")
|
||||
rank_zero_info(f"Setting learning rate to {model.learning_rate:.2e}")
|
||||
|
||||
# allow checkpointing via USR1
|
||||
def melk(*args, **kwargs):
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
python scripts/txt2img.py --prompt "Teyvat, Name:Layla, Element: Cryo, Weapon:Sword, Region:Sumeru, Model type:Medium Female, Description:a woman in a blue outfit holding a sword" --plms \
|
||||
python scripts/txt2img.py --prompt "Teyvat, Medium Female, a woman in a blue outfit holding a sword" --plms \
|
||||
--outdir ./output \
|
||||
--ckpt /tmp/2022-11-18T16-38-46_train_colossalai/checkpoints/last.ckpt \
|
||||
--config /tmp/2022-11-18T16-38-46_train_colossalai/configs/2022-11-18T16-38-46-project.yaml \
|
||||
--ckpt checkpoints/last.ckpt \
|
||||
--config configs/2023-02-02T18-06-14-project.yaml \
|
||||
--n_samples 4
|
||||
|
|
|
@ -2,4 +2,4 @@ HF_DATASETS_OFFLINE=1
|
|||
TRANSFORMERS_OFFLINE=1
|
||||
DIFFUSERS_OFFLINE=1
|
||||
|
||||
python main.py --logdir /tmp -t -b configs/train_colossalai.yaml
|
||||
python main.py --logdir /tmp --train --base configs/Teyvat/train_colossalai_teyvat.yaml --ckpt diffuser_root_dir/512-base-ema.ckpt
|
||||
|
|
Loading…
Reference in New Issue