import argparse import glob import os import time from itertools import islice from multiprocessing import cpu_count import numpy as np import scann import torch from einops import rearrange from ldm.models.diffusion.ddim import DDIMSampler from ldm.models.diffusion.plms import PLMSSampler from ldm.modules.encoders.modules import FrozenClipImageEmbedder, FrozenCLIPTextEmbedder from ldm.util import instantiate_from_config, parallel_data_prefetch from omegaconf import OmegaConf from PIL import Image from torchvision.utils import make_grid from tqdm import tqdm, trange DATABASES = [ "openimages", "artbench-art_nouveau", "artbench-baroque", "artbench-expressionism", "artbench-impressionism", "artbench-post_impressionism", "artbench-realism", "artbench-romanticism", "artbench-renaissance", "artbench-surrealism", "artbench-ukiyo_e", ] def chunk(it, size): it = iter(it) return iter(lambda: tuple(islice(it, size)), ()) def load_model_from_config(config, ckpt, verbose=False): print(f"Loading model from {ckpt}") pl_sd = torch.load(ckpt, map_location="cpu") if "global_step" in pl_sd: print(f"Global Step: {pl_sd['global_step']}") sd = pl_sd["state_dict"] model = instantiate_from_config(config.model) m, u = model.load_state_dict(sd, strict=False) if len(m) > 0 and verbose: print("missing keys:") print(m) if len(u) > 0 and verbose: print("unexpected keys:") print(u) model.cuda() model.eval() return model class Searcher(object): def __init__(self, database, retriever_version="ViT-L/14"): assert database in DATABASES # self.database = self.load_database(database) self.database_name = database self.searcher_savedir = f"data/rdm/searchers/{self.database_name}" self.database_path = f"data/rdm/retrieval_databases/{self.database_name}" self.retriever = self.load_retriever(version=retriever_version) self.database = {"embedding": [], "img_id": [], "patch_coords": []} self.load_database() self.load_searcher() def train_searcher(self, k, metric="dot_product", searcher_savedir=None): print("Start training searcher") searcher = scann.scann_ops_pybind.builder( self.database["embedding"] / np.linalg.norm(self.database["embedding"], axis=1)[:, np.newaxis], k, metric ) self.searcher = searcher.score_brute_force().build() print("Finish training searcher") if searcher_savedir is not None: print(f'Save trained searcher under "{searcher_savedir}"') os.makedirs(searcher_savedir, exist_ok=True) self.searcher.serialize(searcher_savedir) def load_single_file(self, saved_embeddings): compressed = np.load(saved_embeddings) self.database = {key: compressed[key] for key in compressed.files} print("Finished loading of clip embeddings.") def load_multi_files(self, data_archive): out_data = {key: [] for key in self.database} for d in tqdm(data_archive, desc=f"Loading datapool from {len(data_archive)} individual files."): for key in d.files: out_data[key].append(d[key]) return out_data def load_database(self): print(f'Load saved patch embedding from "{self.database_path}"') file_content = glob.glob(os.path.join(self.database_path, "*.npz")) if len(file_content) == 1: self.load_single_file(file_content[0]) elif len(file_content) > 1: data = [np.load(f) for f in file_content] prefetched_data = parallel_data_prefetch( self.load_multi_files, data, n_proc=min(len(data), cpu_count()), target_data_type="dict" ) self.database = { key: np.concatenate([od[key] for od in prefetched_data], axis=1)[0] for key in self.database } else: raise ValueError(f'No npz-files in specified path "{self.database_path}" is this directory existing?') print(f'Finished loading of retrieval database of length {self.database["embedding"].shape[0]}.') def load_retriever( self, version="ViT-L/14", ): model = FrozenClipImageEmbedder(model=version) if torch.cuda.is_available(): model.cuda() model.eval() return model def load_searcher(self): print(f"load searcher for database {self.database_name} from {self.searcher_savedir}") self.searcher = scann.scann_ops_pybind.load_searcher(self.searcher_savedir) print("Finished loading searcher.") def search(self, x, k): if self.searcher is None and self.database["embedding"].shape[0] < 2e4: self.train_searcher(k) # quickly fit searcher on the fly for small databases assert self.searcher is not None, "Cannot search with uninitialized searcher" if isinstance(x, torch.Tensor): x = x.detach().cpu().numpy() if len(x.shape) == 3: x = x[:, 0] query_embeddings = x / np.linalg.norm(x, axis=1)[:, np.newaxis] start = time.time() nns, distances = self.searcher.search_batched(query_embeddings, final_num_neighbors=k) end = time.time() out_embeddings = self.database["embedding"][nns] out_img_ids = self.database["img_id"][nns] out_pc = self.database["patch_coords"][nns] out = { "nn_embeddings": out_embeddings / np.linalg.norm(out_embeddings, axis=-1)[..., np.newaxis], "img_ids": out_img_ids, "patch_coords": out_pc, "queries": x, "exec_time": end - start, "nns": nns, "q_embeddings": query_embeddings, } return out def __call__(self, x, n): return self.search(x, n) if __name__ == "__main__": parser = argparse.ArgumentParser() # TODO: add n_neighbors and modes (text-only, text-image-retrieval, image-image retrieval etc) # TODO: add 'image variation' mode when knn=0 but a single image is given instead of a text prompt? parser.add_argument( "--prompt", type=str, nargs="?", default="a painting of a virus monster playing guitar", help="the prompt to render", ) parser.add_argument( "--outdir", type=str, nargs="?", help="dir to write results to", default="outputs/txt2img-samples" ) parser.add_argument( "--skip_grid", action="store_true", help="do not save a grid, only individual samples. Helpful when evaluating lots of samples", ) parser.add_argument( "--ddim_steps", type=int, default=50, help="number of ddim sampling steps", ) parser.add_argument( "--n_repeat", type=int, default=1, help="number of repeats in CLIP latent space", ) parser.add_argument( "--plms", action="store_true", help="use plms sampling", ) parser.add_argument( "--ddim_eta", type=float, default=0.0, help="ddim eta (eta=0.0 corresponds to deterministic sampling", ) parser.add_argument( "--n_iter", type=int, default=1, help="sample this often", ) parser.add_argument( "--H", type=int, default=768, help="image height, in pixel space", ) parser.add_argument( "--W", type=int, default=768, help="image width, in pixel space", ) parser.add_argument( "--n_samples", type=int, default=3, help="how many samples to produce for each given prompt. A.k.a batch size", ) parser.add_argument( "--n_rows", type=int, default=0, help="rows in the grid (default: n_samples)", ) parser.add_argument( "--scale", type=float, default=5.0, help="unconditional guidance scale: eps = eps(x, empty) + scale * (eps(x, cond) - eps(x, empty))", ) parser.add_argument( "--from-file", type=str, help="if specified, load prompts from this file", ) parser.add_argument( "--config", type=str, default="configs/retrieval-augmented-diffusion/768x768.yaml", help="path to config which constructs model", ) parser.add_argument( "--ckpt", type=str, default="models/rdm/rdm768x768/model.ckpt", help="path to checkpoint of model", ) parser.add_argument( "--clip_type", type=str, default="ViT-L/14", help="which CLIP model to use for retrieval and NN encoding", ) parser.add_argument( "--database", type=str, default="artbench-surrealism", choices=DATABASES, help="The database used for the search, only applied when --use_neighbors=True", ) parser.add_argument( "--use_neighbors", default=False, action="store_true", help="Include neighbors in addition to text prompt for conditioning", ) parser.add_argument( "--knn", default=10, type=int, help="The number of included neighbors, only applied when --use_neighbors=True", ) opt = parser.parse_args() config = OmegaConf.load(f"{opt.config}") model = load_model_from_config(config, f"{opt.ckpt}") device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") model = model.to(device) clip_text_encoder = FrozenCLIPTextEmbedder(opt.clip_type).to(device) if opt.plms: sampler = PLMSSampler(model) else: sampler = DDIMSampler(model) os.makedirs(opt.outdir, exist_ok=True) outpath = opt.outdir batch_size = opt.n_samples n_rows = opt.n_rows if opt.n_rows > 0 else batch_size if not opt.from_file: prompt = opt.prompt assert prompt is not None data = [batch_size * [prompt]] else: print(f"reading prompts from {opt.from_file}") with open(opt.from_file, "r") as f: data = f.read().splitlines() data = list(chunk(data, batch_size)) sample_path = os.path.join(outpath, "samples") os.makedirs(sample_path, exist_ok=True) base_count = len(os.listdir(sample_path)) grid_count = len(os.listdir(outpath)) - 1 print(f"sampling scale for cfg is {opt.scale:.2f}") searcher = None if opt.use_neighbors: searcher = Searcher(opt.database) with torch.no_grad(): with model.ema_scope(): for n in trange(opt.n_iter, desc="Sampling"): all_samples = list() for prompts in tqdm(data, desc="data"): print("sampling prompts:", prompts) if isinstance(prompts, tuple): prompts = list(prompts) c = clip_text_encoder.encode(prompts) uc = None if searcher is not None: nn_dict = searcher(c, opt.knn) c = torch.cat([c, torch.from_numpy(nn_dict["nn_embeddings"]).cuda()], dim=1) if opt.scale != 1.0: uc = torch.zeros_like(c) if isinstance(prompts, tuple): prompts = list(prompts) shape = [16, opt.H // 16, opt.W // 16] # note: currently hardcoded for f16 model samples_ddim, _ = sampler.sample( S=opt.ddim_steps, conditioning=c, batch_size=c.shape[0], shape=shape, verbose=False, unconditional_guidance_scale=opt.scale, unconditional_conditioning=uc, eta=opt.ddim_eta, ) x_samples_ddim = model.decode_first_stage(samples_ddim) x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0) for x_sample in x_samples_ddim: x_sample = 255.0 * rearrange(x_sample.cpu().numpy(), "c h w -> h w c") Image.fromarray(x_sample.astype(np.uint8)).save( os.path.join(sample_path, f"{base_count:05}.png") ) base_count += 1 all_samples.append(x_samples_ddim) if not opt.skip_grid: # additionally, save as grid grid = torch.stack(all_samples, 0) grid = rearrange(grid, "n b c h w -> (n b) c h w") grid = make_grid(grid, nrow=n_rows) # to image grid = 255.0 * rearrange(grid, "c h w -> h w c").cpu().numpy() Image.fromarray(grid.astype(np.uint8)).save(os.path.join(outpath, f"grid-{grid_count:04}.png")) grid_count += 1 print(f"Your samples are ready and waiting for you here: \n{outpath} \nEnjoy.")