You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
ColossalAI/examples/images/diffusion/scripts/knn2img.py

398 lines
13 KiB

import argparse
import glob
import os
import time
from itertools import islice
from multiprocessing import cpu_count
import numpy as np
import scann
import torch
from einops import rearrange
from ldm.models.diffusion.ddim import DDIMSampler
from ldm.models.diffusion.plms import PLMSSampler
from ldm.modules.encoders.modules import FrozenClipImageEmbedder, FrozenCLIPTextEmbedder
from ldm.util import instantiate_from_config, parallel_data_prefetch
from omegaconf import OmegaConf
from PIL import Image
from torchvision.utils import make_grid
from tqdm import tqdm, trange
DATABASES = [
"openimages",
"artbench-art_nouveau",
"artbench-baroque",
"artbench-expressionism",
"artbench-impressionism",
"artbench-post_impressionism",
"artbench-realism",
"artbench-romanticism",
"artbench-renaissance",
"artbench-surrealism",
"artbench-ukiyo_e",
]
def chunk(it, size):
it = iter(it)
return iter(lambda: tuple(islice(it, size)), ())
def load_model_from_config(config, ckpt, verbose=False):
print(f"Loading model from {ckpt}")
pl_sd = torch.load(ckpt, map_location="cpu")
if "global_step" in pl_sd:
print(f"Global Step: {pl_sd['global_step']}")
sd = pl_sd["state_dict"]
model = instantiate_from_config(config.model)
m, u = model.load_state_dict(sd, strict=False)
if len(m) > 0 and verbose:
print("missing keys:")
print(m)
if len(u) > 0 and verbose:
print("unexpected keys:")
print(u)
model.cuda()
model.eval()
return model
class Searcher(object):
def __init__(self, database, retriever_version="ViT-L/14"):
assert database in DATABASES
# self.database = self.load_database(database)
self.database_name = database
self.searcher_savedir = f"data/rdm/searchers/{self.database_name}"
self.database_path = f"data/rdm/retrieval_databases/{self.database_name}"
self.retriever = self.load_retriever(version=retriever_version)
self.database = {"embedding": [], "img_id": [], "patch_coords": []}
self.load_database()
self.load_searcher()
def train_searcher(self, k, metric="dot_product", searcher_savedir=None):
print("Start training searcher")
searcher = scann.scann_ops_pybind.builder(
self.database["embedding"] / np.linalg.norm(self.database["embedding"], axis=1)[:, np.newaxis], k, metric
)
self.searcher = searcher.score_brute_force().build()
print("Finish training searcher")
if searcher_savedir is not None:
print(f'Save trained searcher under "{searcher_savedir}"')
os.makedirs(searcher_savedir, exist_ok=True)
self.searcher.serialize(searcher_savedir)
def load_single_file(self, saved_embeddings):
compressed = np.load(saved_embeddings)
self.database = {key: compressed[key] for key in compressed.files}
print("Finished loading of clip embeddings.")
def load_multi_files(self, data_archive):
out_data = {key: [] for key in self.database}
for d in tqdm(data_archive, desc=f"Loading datapool from {len(data_archive)} individual files."):
for key in d.files:
out_data[key].append(d[key])
return out_data
def load_database(self):
print(f'Load saved patch embedding from "{self.database_path}"')
file_content = glob.glob(os.path.join(self.database_path, "*.npz"))
if len(file_content) == 1:
self.load_single_file(file_content[0])
elif len(file_content) > 1:
data = [np.load(f) for f in file_content]
prefetched_data = parallel_data_prefetch(
self.load_multi_files, data, n_proc=min(len(data), cpu_count()), target_data_type="dict"
)
self.database = {
key: np.concatenate([od[key] for od in prefetched_data], axis=1)[0] for key in self.database
}
else:
raise ValueError(f'No npz-files in specified path "{self.database_path}" is this directory existing?')
print(f'Finished loading of retrieval database of length {self.database["embedding"].shape[0]}.')
def load_retriever(
self,
version="ViT-L/14",
):
model = FrozenClipImageEmbedder(model=version)
if torch.cuda.is_available():
model.cuda()
model.eval()
return model
def load_searcher(self):
print(f"load searcher for database {self.database_name} from {self.searcher_savedir}")
self.searcher = scann.scann_ops_pybind.load_searcher(self.searcher_savedir)
print("Finished loading searcher.")
def search(self, x, k):
if self.searcher is None and self.database["embedding"].shape[0] < 2e4:
self.train_searcher(k) # quickly fit searcher on the fly for small databases
assert self.searcher is not None, "Cannot search with uninitialized searcher"
if isinstance(x, torch.Tensor):
x = x.detach().cpu().numpy()
if len(x.shape) == 3:
x = x[:, 0]
query_embeddings = x / np.linalg.norm(x, axis=1)[:, np.newaxis]
start = time.time()
nns, distances = self.searcher.search_batched(query_embeddings, final_num_neighbors=k)
end = time.time()
out_embeddings = self.database["embedding"][nns]
out_img_ids = self.database["img_id"][nns]
out_pc = self.database["patch_coords"][nns]
out = {
"nn_embeddings": out_embeddings / np.linalg.norm(out_embeddings, axis=-1)[..., np.newaxis],
"img_ids": out_img_ids,
"patch_coords": out_pc,
"queries": x,
"exec_time": end - start,
"nns": nns,
"q_embeddings": query_embeddings,
}
return out
def __call__(self, x, n):
return self.search(x, n)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
# TODO: add n_neighbors and modes (text-only, text-image-retrieval, image-image retrieval etc)
# TODO: add 'image variation' mode when knn=0 but a single image is given instead of a text prompt?
parser.add_argument(
"--prompt",
type=str,
nargs="?",
default="a painting of a virus monster playing guitar",
help="the prompt to render",
)
parser.add_argument(
"--outdir", type=str, nargs="?", help="dir to write results to", default="outputs/txt2img-samples"
)
parser.add_argument(
"--skip_grid",
action="store_true",
help="do not save a grid, only individual samples. Helpful when evaluating lots of samples",
)
parser.add_argument(
"--ddim_steps",
type=int,
default=50,
help="number of ddim sampling steps",
)
parser.add_argument(
"--n_repeat",
type=int,
default=1,
help="number of repeats in CLIP latent space",
)
parser.add_argument(
"--plms",
action="store_true",
help="use plms sampling",
)
parser.add_argument(
"--ddim_eta",
type=float,
default=0.0,
help="ddim eta (eta=0.0 corresponds to deterministic sampling",
)
parser.add_argument(
"--n_iter",
type=int,
default=1,
help="sample this often",
)
parser.add_argument(
"--H",
type=int,
default=768,
help="image height, in pixel space",
)
parser.add_argument(
"--W",
type=int,
default=768,
help="image width, in pixel space",
)
parser.add_argument(
"--n_samples",
type=int,
default=3,
help="how many samples to produce for each given prompt. A.k.a batch size",
)
parser.add_argument(
"--n_rows",
type=int,
default=0,
help="rows in the grid (default: n_samples)",
)
parser.add_argument(
"--scale",
type=float,
default=5.0,
help="unconditional guidance scale: eps = eps(x, empty) + scale * (eps(x, cond) - eps(x, empty))",
)
parser.add_argument(
"--from-file",
type=str,
help="if specified, load prompts from this file",
)
parser.add_argument(
"--config",
type=str,
default="configs/retrieval-augmented-diffusion/768x768.yaml",
help="path to config which constructs model",
)
parser.add_argument(
"--ckpt",
type=str,
default="models/rdm/rdm768x768/model.ckpt",
help="path to checkpoint of model",
)
parser.add_argument(
"--clip_type",
type=str,
default="ViT-L/14",
help="which CLIP model to use for retrieval and NN encoding",
)
parser.add_argument(
"--database",
type=str,
default="artbench-surrealism",
choices=DATABASES,
help="The database used for the search, only applied when --use_neighbors=True",
)
parser.add_argument(
"--use_neighbors",
default=False,
action="store_true",
help="Include neighbors in addition to text prompt for conditioning",
)
parser.add_argument(
"--knn",
default=10,
type=int,
help="The number of included neighbors, only applied when --use_neighbors=True",
)
opt = parser.parse_args()
config = OmegaConf.load(f"{opt.config}")
model = load_model_from_config(config, f"{opt.ckpt}")
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
model = model.to(device)
clip_text_encoder = FrozenCLIPTextEmbedder(opt.clip_type).to(device)
if opt.plms:
sampler = PLMSSampler(model)
else:
sampler = DDIMSampler(model)
os.makedirs(opt.outdir, exist_ok=True)
outpath = opt.outdir
batch_size = opt.n_samples
n_rows = opt.n_rows if opt.n_rows > 0 else batch_size
if not opt.from_file:
prompt = opt.prompt
assert prompt is not None
data = [batch_size * [prompt]]
else:
print(f"reading prompts from {opt.from_file}")
with open(opt.from_file, "r") as f:
data = f.read().splitlines()
data = list(chunk(data, batch_size))
sample_path = os.path.join(outpath, "samples")
os.makedirs(sample_path, exist_ok=True)
base_count = len(os.listdir(sample_path))
grid_count = len(os.listdir(outpath)) - 1
print(f"sampling scale for cfg is {opt.scale:.2f}")
searcher = None
if opt.use_neighbors:
searcher = Searcher(opt.database)
with torch.no_grad():
with model.ema_scope():
for n in trange(opt.n_iter, desc="Sampling"):
all_samples = list()
for prompts in tqdm(data, desc="data"):
print("sampling prompts:", prompts)
if isinstance(prompts, tuple):
prompts = list(prompts)
c = clip_text_encoder.encode(prompts)
uc = None
if searcher is not None:
nn_dict = searcher(c, opt.knn)
c = torch.cat([c, torch.from_numpy(nn_dict["nn_embeddings"]).cuda()], dim=1)
if opt.scale != 1.0:
uc = torch.zeros_like(c)
if isinstance(prompts, tuple):
prompts = list(prompts)
shape = [16, opt.H // 16, opt.W // 16] # note: currently hardcoded for f16 model
samples_ddim, _ = sampler.sample(
S=opt.ddim_steps,
conditioning=c,
batch_size=c.shape[0],
shape=shape,
verbose=False,
unconditional_guidance_scale=opt.scale,
unconditional_conditioning=uc,
eta=opt.ddim_eta,
)
x_samples_ddim = model.decode_first_stage(samples_ddim)
x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
for x_sample in x_samples_ddim:
x_sample = 255.0 * rearrange(x_sample.cpu().numpy(), "c h w -> h w c")
Image.fromarray(x_sample.astype(np.uint8)).save(
os.path.join(sample_path, f"{base_count:05}.png")
)
base_count += 1
all_samples.append(x_samples_ddim)
if not opt.skip_grid:
# additionally, save as grid
grid = torch.stack(all_samples, 0)
grid = rearrange(grid, "n b c h w -> (n b) c h w")
grid = make_grid(grid, nrow=n_rows)
# to image
grid = 255.0 * rearrange(grid, "c h w -> h w c").cpu().numpy()
Image.fromarray(grid.astype(np.uint8)).save(os.path.join(outpath, f"grid-{grid_count:04}.png"))
grid_count += 1
print(f"Your samples are ready and waiting for you here: \n{outpath} \nEnjoy.")