[example] add diffusion inference (#1986)

pull/1999/head
Fazzie-Maqianli 2022-11-20 18:35:29 +08:00 committed by GitHub
parent a01278e810
commit b5dbb46172
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 343 additions and 45 deletions

View File

@ -96,9 +96,53 @@ We provide the finetuning example on CIFAR10 dataset
You can run by config `train_colossalai_cifar10.yaml`
```
python main.py --logdir /tmp -t --postfix test -b configs/train_colossalai_cifar10.yaml
python main.py --logdir /tmp -t --postfix test -b configs/train_colossalai_cifar10.yaml
```
## Inference
you can get yout training last.ckpt and train config.yaml in your `--logdir`, and run by
```
python scripts/txt2img.py --prompt "a photograph of an astronaut riding a horse" --plms
--outdir ./output \
--config path/to/logdir/checkpoints/last.ckpt \
--ckpt /path/to/logdir/configs/project.yaml \
```
```commandline
usage: txt2img.py [-h] [--prompt [PROMPT]] [--outdir [OUTDIR]] [--skip_grid] [--skip_save] [--ddim_steps DDIM_STEPS] [--plms] [--laion400m] [--fixed_code] [--ddim_eta DDIM_ETA]
[--n_iter N_ITER] [--H H] [--W W] [--C C] [--f F] [--n_samples N_SAMPLES] [--n_rows N_ROWS] [--scale SCALE] [--from-file FROM_FILE] [--config CONFIG] [--ckpt CKPT]
[--seed SEED] [--precision {full,autocast}]
optional arguments:
-h, --help show this help message and exit
--prompt [PROMPT] the prompt to render
--outdir [OUTDIR] dir to write results to
--skip_grid do not save a grid, only individual samples. Helpful when evaluating lots of samples
--skip_save do not save individual samples. For speed measurements.
--ddim_steps DDIM_STEPS
number of ddim sampling steps
--plms use plms sampling
--laion400m uses the LAION400M model
--fixed_code if enabled, uses the same starting code across samples
--ddim_eta DDIM_ETA ddim eta (eta=0.0 corresponds to deterministic sampling
--n_iter N_ITER sample this often
--H H image height, in pixel space
--W W image width, in pixel space
--C C latent channels
--f F downsampling factor
--n_samples N_SAMPLES
how many samples to produce for each given prompt. A.k.a. batch size
--n_rows N_ROWS rows in the grid (default: n_samples)
--scale SCALE unconditional guidance scale: eps = eps(x, empty) + scale * (eps(x, cond) - eps(x, empty))
--from-file FROM_FILE
if specified, load prompts from this file
--config CONFIG path to config which constructs model
--ckpt CKPT path to checkpoint of model
--seed SEED the seed (for reproducible sampling)
--precision {full,autocast}
evaluate at this precision
```
## Comments

View File

@ -0,0 +1,122 @@
model:
base_learning_rate: 1.0e-04
target: ldm.models.diffusion.ddpm.LatentDiffusion
params:
linear_start: 0.00085
linear_end: 0.0120
num_timesteps_cond: 1
log_every_t: 200
timesteps: 1000
first_stage_key: image
cond_stage_key: txt
image_size: 64
channels: 4
cond_stage_trainable: false # Note: different from the one we trained before
conditioning_key: crossattn
monitor: val/loss_simple_ema
scale_factor: 0.18215
use_ema: False
scheduler_config: # 10000 warmup steps
target: ldm.lr_scheduler.LambdaLinearScheduler
params:
warm_up_steps: [ 1 ] # NOTE for resuming. use 10000 if starting from scratch
cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases
f_start: [ 1.e-6 ]
f_max: [ 1.e-4 ]
f_min: [ 1.e-10 ]
unet_config:
target: ldm.modules.diffusionmodules.openaimodel.UNetModel
params:
image_size: 32 # unused
from_pretrained: '/data/scratch/diffuser/stable-diffusion-v1-4/unet/diffusion_pytorch_model.bin'
in_channels: 4
out_channels: 4
model_channels: 320
attention_resolutions: [ 4, 2, 1 ]
num_res_blocks: 2
channel_mult: [ 1, 2, 4, 4 ]
num_heads: 8
use_spatial_transformer: True
transformer_depth: 1
context_dim: 768
use_checkpoint: False
legacy: False
first_stage_config:
target: ldm.models.autoencoder.AutoencoderKL
params:
embed_dim: 4
from_pretrained: '/data/scratch/diffuser/stable-diffusion-v1-4/vae/diffusion_pytorch_model.bin'
monitor: val/rec_loss
ddconfig:
double_z: true
z_channels: 4
resolution: 256
in_channels: 3
out_ch: 3
ch: 128
ch_mult:
- 1
- 2
- 4
- 4
num_res_blocks: 2
attn_resolutions: []
dropout: 0.0
lossconfig:
target: torch.nn.Identity
cond_stage_config:
target: ldm.modules.encoders.modules.FrozenCLIPEmbedder
params:
use_fp16: True
data:
target: main.DataModuleFromConfig
params:
batch_size: 16
num_workers: 4
train:
target: ldm.data.teyvat.hf_dataset
params:
path: Fazzie/Teyvat
image_transforms:
- target: torchvision.transforms.Resize
params:
size: 512
# - target: torchvision.transforms.RandomCrop
# params:
# size: 256
# - target: torchvision.transforms.RandomHorizontalFlip
lightning:
trainer:
accelerator: 'gpu'
devices: 2
log_gpu_memory: all
max_epochs: 10
precision: 16
auto_select_gpus: False
strategy:
target: lightning.pytorch.strategies.ColossalAIStrategy
params:
use_chunk: False
enable_distributed_storage: True,
placement_policy: cuda
force_outputs_fp32: False
log_every_n_steps: 2
logger: True
default_root_dir: "/tmp/diff_log/"
profiler: pytorch
logger_config:
wandb:
target: lightning.pytorch.loggers.WandbLogger
params:
name: nowname
save_dir: "/tmp/diff_log/"
offline: opt.debug
id: nowname

View File

@ -0,0 +1,152 @@
from typing import Dict
import numpy as np
from omegaconf import DictConfig, ListConfig
import torch
from torch.utils.data import Dataset
from pathlib import Path
import json
from PIL import Image
from torchvision import transforms
from einops import rearrange
from ldm.util import instantiate_from_config
from datasets import load_dataset
def make_multi_folder_data(paths, caption_files=None, **kwargs):
"""Make a concat dataset from multiple folders
Don't suport captions yet
If paths is a list, that's ok, if it's a Dict interpret it as:
k=folder v=n_times to repeat that
"""
list_of_paths = []
if isinstance(paths, (Dict, DictConfig)):
assert caption_files is None, \
"Caption files not yet supported for repeats"
for folder_path, repeats in paths.items():
list_of_paths.extend([folder_path]*repeats)
paths = list_of_paths
if caption_files is not None:
datasets = [FolderData(p, caption_file=c, **kwargs) for (p, c) in zip(paths, caption_files)]
else:
datasets = [FolderData(p, **kwargs) for p in paths]
return torch.utils.data.ConcatDataset(datasets)
class FolderData(Dataset):
def __init__(self,
root_dir,
caption_file=None,
image_transforms=[],
ext="jpg",
default_caption="",
postprocess=None,
return_paths=False,
) -> None:
"""Create a dataset from a folder of images.
If you pass in a root directory it will be searched for images
ending in ext (ext can be a list)
"""
self.root_dir = Path(root_dir)
self.default_caption = default_caption
self.return_paths = return_paths
if isinstance(postprocess, DictConfig):
postprocess = instantiate_from_config(postprocess)
self.postprocess = postprocess
if caption_file is not None:
with open(caption_file, "rt") as f:
ext = Path(caption_file).suffix.lower()
if ext == ".json":
captions = json.load(f)
elif ext == ".jsonl":
lines = f.readlines()
lines = [json.loads(x) for x in lines]
captions = {x["file_name"]: x["text"].strip("\n") for x in lines}
else:
raise ValueError(f"Unrecognised format: {ext}")
self.captions = captions
else:
self.captions = None
if not isinstance(ext, (tuple, list, ListConfig)):
ext = [ext]
# Only used if there is no caption file
self.paths = []
for e in ext:
self.paths.extend(list(self.root_dir.rglob(f"*.{e}")))
if isinstance(image_transforms, ListConfig):
image_transforms = [instantiate_from_config(tt) for tt in image_transforms]
image_transforms.extend([transforms.ToTensor(),
transforms.Lambda(lambda x: rearrange(x * 2. - 1., 'c h w -> h w c'))])
image_transforms = transforms.Compose(image_transforms)
self.tform = image_transforms
def __len__(self):
if self.captions is not None:
return len(self.captions.keys())
else:
return len(self.paths)
def __getitem__(self, index):
data = {}
if self.captions is not None:
chosen = list(self.captions.keys())[index]
caption = self.captions.get(chosen, None)
if caption is None:
caption = self.default_caption
filename = self.root_dir/chosen
else:
filename = self.paths[index]
if self.return_paths:
data["path"] = str(filename)
im = Image.open(filename)
im = self.process_im(im)
data["image"] = im
if self.captions is not None:
data["txt"] = caption
else:
data["txt"] = self.default_caption
if self.postprocess is not None:
data = self.postprocess(data)
return data
def process_im(self, im):
im = im.convert("RGB")
return self.tform(im)
def hf_dataset(
path = "Fazzie/Teyvat",
image_transforms=[],
image_column="image",
text_column="text",
image_key='image',
caption_key='txt',
):
"""Make huggingface dataset with appropriate list of transforms applied
"""
ds = load_dataset(path, name="train")
ds = ds["train"]
image_transforms = [instantiate_from_config(tt) for tt in image_transforms]
image_transforms.extend([transforms.Resize((256, 256)),
transforms.ToTensor(),
transforms.Lambda(lambda x: rearrange(x * 2. - 1., 'c h w -> h w c'))]
)
tform = transforms.Compose(image_transforms)
assert image_column in ds.column_names, f"Didn't find column {image_column} in {ds.column_names}"
assert text_column in ds.column_names, f"Didn't find column {text_column} in {ds.column_names}"
def pre_process(examples):
processed = {}
processed[image_key] = [tform(im) for im in examples[image_column]]
processed[caption_key] = examples[text_column]
return processed
ds.set_transform(pre_process)
return ds

View File

@ -99,12 +99,12 @@ class DDPM(pl.LightningModule):
self.use_positional_encodings = use_positional_encodings
self.unet_config = unet_config
self.conditioning_key = conditioning_key
# self.model = DiffusionWrapper(unet_config, conditioning_key)
# count_params(self.model, verbose=True)
self.model = DiffusionWrapper(unet_config, conditioning_key)
count_params(self.model, verbose=True)
self.use_ema = use_ema
# if self.use_ema:
# self.model_ema = LitEma(self.model)
# print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.")
if self.use_ema:
self.model_ema = LitEma(self.model)
print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.")
self.use_scheduler = scheduler_config is not None
if self.use_scheduler:
@ -125,20 +125,20 @@ class DDPM(pl.LightningModule):
self.linear_start = linear_start
self.linear_end = linear_end
self.cosine_s = cosine_s
# if ckpt_path is not None:
# self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys, only_model=load_only_unet)
#
# self.register_schedule(given_betas=given_betas, beta_schedule=beta_schedule, timesteps=timesteps,
# linear_start=linear_start, linear_end=linear_end, cosine_s=cosine_s)
if ckpt_path is not None:
self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys, only_model=load_only_unet)
self.register_schedule(given_betas=given_betas, beta_schedule=beta_schedule, timesteps=timesteps,
linear_start=linear_start, linear_end=linear_end, cosine_s=cosine_s)
self.loss_type = loss_type
self.learn_logvar = learn_logvar
self.logvar_init = logvar_init
# self.logvar = torch.full(fill_value=logvar_init, size=(self.num_timesteps,))
# if self.learn_logvar:
# self.logvar = nn.Parameter(self.logvar, requires_grad=True)
# self.logvar = nn.Parameter(self.logvar, requires_grad=True)
self.logvar = torch.full(fill_value=logvar_init, size=(self.num_timesteps,))
if self.learn_logvar:
self.logvar = nn.Parameter(self.logvar, requires_grad=True)
self.logvar = nn.Parameter(self.logvar, requires_grad=True)
self.use_fp16 = use_fp16
if use_fp16:
@ -312,14 +312,6 @@ class DDPM(pl.LightningModule):
def get_loss(self, pred, target, mean=True):
if pred.isnan().any():
print("Warning: Prediction has nan values")
lr = self.optimizers().param_groups[0]['lr']
# self.log('lr_abs', lr, prog_bar=True, logger=True, on_step=True, on_epoch=False)
print(f"lr: {lr}")
if pred.isinf().any():
print("Warning: Prediction has inf values")
if self.use_fp16:
target = target.half()
@ -334,15 +326,6 @@ class DDPM(pl.LightningModule):
loss = torch.nn.functional.mse_loss(target, pred, reduction='none')
else:
raise NotImplementedError("unknown loss type '{loss_type}'")
if loss.isnan().any():
print("Warning: loss has nan values")
print("loss: ", loss[0][0][0])
raise ValueError("loss has nan values")
if loss.isinf().any():
print("Warning: loss has inf values")
print("loss: ", loss)
raise ValueError("loss has inf values")
return loss
@ -382,11 +365,7 @@ class DDPM(pl.LightningModule):
return self.p_losses(x, t, *args, **kwargs)
def get_input(self, batch, k):
# print("+" * 30)
# print(batch['jpg'].shape)
# print(len(batch['txt']))
# print(k)
# print("=" * 30)
if not isinstance(batch, torch.Tensor):
x = batch[k]
else:
@ -534,8 +513,8 @@ class LatentDiffusion(DDPM):
else:
self.cond_stage_config["params"].update({"use_fp16": False})
rank_zero_info("Using fp16 for conditioning stage = {}".format(self.cond_stage_config["params"]["use_fp16"]))
# self.instantiate_first_stage(first_stage_config)
# self.instantiate_cond_stage(cond_stage_config)
self.instantiate_first_stage(first_stage_config)
self.instantiate_cond_stage(cond_stage_config)
self.cond_stage_forward = cond_stage_forward
self.clip_denoised = False
self.bbox_tokenizer = None
@ -561,16 +540,11 @@ class LatentDiffusion(DDPM):
self.logvar = torch.full(fill_value=self.logvar_init, size=(self.num_timesteps,))
if self.learn_logvar:
self.logvar = nn.Parameter(self.logvar, requires_grad=True)
# self.logvar = nn.Parameter(self.logvar, requires_grad=True)
self.logvar = nn.Parameter(self.logvar, requires_grad=True)
if self.ckpt_path is not None:
self.init_from_ckpt(self.ckpt_path, self.ignore_keys)
self.restarted_from_ckpt = True
# TODO()
# for p in self.model.modules():
# if not p.parameters().data.is_contiguous:
# p.data = p.data.contiguous()
self.instantiate_first_stage(self.first_stage_config)
self.instantiate_cond_stage(self.cond_stage_config)

View File

0
examples/images/diffusion/scripts/download_models.sh Normal file → Executable file
View File

View File

@ -0,0 +1,6 @@
python scripts/txt2img.py --prompt "Teyvat, Name:Layla, Element: Cryo, Weapon:Sword, Region:Sumeru, Model type:Medium Female, Description:a woman in a blue outfit holding a sword" --plms \
--outdir ./output \
--config /home/lcmql/data2/Genshin/2022-11-18T16-38-46_train_colossalai_teyvattest/checkpoints/last.ckpt \
--ckpt /home/lcmql/data2/Genshin/2022-11-18T16-38-46_train_colossalai_teyvattest/configs/2022-11-18T16-38-46-project.yaml \
--n_samples 4