mirror of https://github.com/hpcaitech/ColossalAI
You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
399 lines
13 KiB
399 lines
13 KiB
import argparse, os, sys, glob
|
|
import clip
|
|
import torch
|
|
import torch.nn as nn
|
|
import numpy as np
|
|
from omegaconf import OmegaConf
|
|
from PIL import Image
|
|
from tqdm import tqdm, trange
|
|
from itertools import islice
|
|
from einops import rearrange, repeat
|
|
from torchvision.utils import make_grid
|
|
import scann
|
|
import time
|
|
from multiprocessing import cpu_count
|
|
|
|
from ldm.util import instantiate_from_config, parallel_data_prefetch
|
|
from ldm.models.diffusion.ddim import DDIMSampler
|
|
from ldm.models.diffusion.plms import PLMSSampler
|
|
from ldm.modules.encoders.modules import FrozenClipImageEmbedder, FrozenCLIPTextEmbedder
|
|
|
|
DATABASES = [
|
|
"openimages",
|
|
"artbench-art_nouveau",
|
|
"artbench-baroque",
|
|
"artbench-expressionism",
|
|
"artbench-impressionism",
|
|
"artbench-post_impressionism",
|
|
"artbench-realism",
|
|
"artbench-romanticism",
|
|
"artbench-renaissance",
|
|
"artbench-surrealism",
|
|
"artbench-ukiyo_e",
|
|
]
|
|
|
|
|
|
def chunk(it, size):
|
|
it = iter(it)
|
|
return iter(lambda: tuple(islice(it, size)), ())
|
|
|
|
|
|
def load_model_from_config(config, ckpt, verbose=False):
|
|
print(f"Loading model from {ckpt}")
|
|
pl_sd = torch.load(ckpt, map_location="cpu")
|
|
if "global_step" in pl_sd:
|
|
print(f"Global Step: {pl_sd['global_step']}")
|
|
sd = pl_sd["state_dict"]
|
|
model = instantiate_from_config(config.model)
|
|
m, u = model.load_state_dict(sd, strict=False)
|
|
if len(m) > 0 and verbose:
|
|
print("missing keys:")
|
|
print(m)
|
|
if len(u) > 0 and verbose:
|
|
print("unexpected keys:")
|
|
print(u)
|
|
|
|
model.cuda()
|
|
model.eval()
|
|
return model
|
|
|
|
|
|
class Searcher(object):
|
|
def __init__(self, database, retriever_version='ViT-L/14'):
|
|
assert database in DATABASES
|
|
# self.database = self.load_database(database)
|
|
self.database_name = database
|
|
self.searcher_savedir = f'data/rdm/searchers/{self.database_name}'
|
|
self.database_path = f'data/rdm/retrieval_databases/{self.database_name}'
|
|
self.retriever = self.load_retriever(version=retriever_version)
|
|
self.database = {'embedding': [],
|
|
'img_id': [],
|
|
'patch_coords': []}
|
|
self.load_database()
|
|
self.load_searcher()
|
|
|
|
def train_searcher(self, k,
|
|
metric='dot_product',
|
|
searcher_savedir=None):
|
|
|
|
print('Start training searcher')
|
|
searcher = scann.scann_ops_pybind.builder(self.database['embedding'] /
|
|
np.linalg.norm(self.database['embedding'], axis=1)[:, np.newaxis],
|
|
k, metric)
|
|
self.searcher = searcher.score_brute_force().build()
|
|
print('Finish training searcher')
|
|
|
|
if searcher_savedir is not None:
|
|
print(f'Save trained searcher under "{searcher_savedir}"')
|
|
os.makedirs(searcher_savedir, exist_ok=True)
|
|
self.searcher.serialize(searcher_savedir)
|
|
|
|
def load_single_file(self, saved_embeddings):
|
|
compressed = np.load(saved_embeddings)
|
|
self.database = {key: compressed[key] for key in compressed.files}
|
|
print('Finished loading of clip embeddings.')
|
|
|
|
def load_multi_files(self, data_archive):
|
|
out_data = {key: [] for key in self.database}
|
|
for d in tqdm(data_archive, desc=f'Loading datapool from {len(data_archive)} individual files.'):
|
|
for key in d.files:
|
|
out_data[key].append(d[key])
|
|
|
|
return out_data
|
|
|
|
def load_database(self):
|
|
|
|
print(f'Load saved patch embedding from "{self.database_path}"')
|
|
file_content = glob.glob(os.path.join(self.database_path, '*.npz'))
|
|
|
|
if len(file_content) == 1:
|
|
self.load_single_file(file_content[0])
|
|
elif len(file_content) > 1:
|
|
data = [np.load(f) for f in file_content]
|
|
prefetched_data = parallel_data_prefetch(self.load_multi_files, data,
|
|
n_proc=min(len(data), cpu_count()), target_data_type='dict')
|
|
|
|
self.database = {key: np.concatenate([od[key] for od in prefetched_data], axis=1)[0] for key in
|
|
self.database}
|
|
else:
|
|
raise ValueError(f'No npz-files in specified path "{self.database_path}" is this directory existing?')
|
|
|
|
print(f'Finished loading of retrieval database of length {self.database["embedding"].shape[0]}.')
|
|
|
|
def load_retriever(self, version='ViT-L/14', ):
|
|
model = FrozenClipImageEmbedder(model=version)
|
|
if torch.cuda.is_available():
|
|
model.cuda()
|
|
model.eval()
|
|
return model
|
|
|
|
def load_searcher(self):
|
|
print(f'load searcher for database {self.database_name} from {self.searcher_savedir}')
|
|
self.searcher = scann.scann_ops_pybind.load_searcher(self.searcher_savedir)
|
|
print('Finished loading searcher.')
|
|
|
|
def search(self, x, k):
|
|
if self.searcher is None and self.database['embedding'].shape[0] < 2e4:
|
|
self.train_searcher(k) # quickly fit searcher on the fly for small databases
|
|
assert self.searcher is not None, 'Cannot search with uninitialized searcher'
|
|
if isinstance(x, torch.Tensor):
|
|
x = x.detach().cpu().numpy()
|
|
if len(x.shape) == 3:
|
|
x = x[:, 0]
|
|
query_embeddings = x / np.linalg.norm(x, axis=1)[:, np.newaxis]
|
|
|
|
start = time.time()
|
|
nns, distances = self.searcher.search_batched(query_embeddings, final_num_neighbors=k)
|
|
end = time.time()
|
|
|
|
out_embeddings = self.database['embedding'][nns]
|
|
out_img_ids = self.database['img_id'][nns]
|
|
out_pc = self.database['patch_coords'][nns]
|
|
|
|
out = {'nn_embeddings': out_embeddings / np.linalg.norm(out_embeddings, axis=-1)[..., np.newaxis],
|
|
'img_ids': out_img_ids,
|
|
'patch_coords': out_pc,
|
|
'queries': x,
|
|
'exec_time': end - start,
|
|
'nns': nns,
|
|
'q_embeddings': query_embeddings}
|
|
|
|
return out
|
|
|
|
def __call__(self, x, n):
|
|
return self.search(x, n)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
parser = argparse.ArgumentParser()
|
|
# TODO: add n_neighbors and modes (text-only, text-image-retrieval, image-image retrieval etc)
|
|
# TODO: add 'image variation' mode when knn=0 but a single image is given instead of a text prompt?
|
|
parser.add_argument(
|
|
"--prompt",
|
|
type=str,
|
|
nargs="?",
|
|
default="a painting of a virus monster playing guitar",
|
|
help="the prompt to render"
|
|
)
|
|
|
|
parser.add_argument(
|
|
"--outdir",
|
|
type=str,
|
|
nargs="?",
|
|
help="dir to write results to",
|
|
default="outputs/txt2img-samples"
|
|
)
|
|
|
|
parser.add_argument(
|
|
"--skip_grid",
|
|
action='store_true',
|
|
help="do not save a grid, only individual samples. Helpful when evaluating lots of samples",
|
|
)
|
|
|
|
parser.add_argument(
|
|
"--ddim_steps",
|
|
type=int,
|
|
default=50,
|
|
help="number of ddim sampling steps",
|
|
)
|
|
|
|
parser.add_argument(
|
|
"--n_repeat",
|
|
type=int,
|
|
default=1,
|
|
help="number of repeats in CLIP latent space",
|
|
)
|
|
|
|
parser.add_argument(
|
|
"--plms",
|
|
action='store_true',
|
|
help="use plms sampling",
|
|
)
|
|
|
|
parser.add_argument(
|
|
"--ddim_eta",
|
|
type=float,
|
|
default=0.0,
|
|
help="ddim eta (eta=0.0 corresponds to deterministic sampling",
|
|
)
|
|
parser.add_argument(
|
|
"--n_iter",
|
|
type=int,
|
|
default=1,
|
|
help="sample this often",
|
|
)
|
|
|
|
parser.add_argument(
|
|
"--H",
|
|
type=int,
|
|
default=768,
|
|
help="image height, in pixel space",
|
|
)
|
|
|
|
parser.add_argument(
|
|
"--W",
|
|
type=int,
|
|
default=768,
|
|
help="image width, in pixel space",
|
|
)
|
|
|
|
parser.add_argument(
|
|
"--n_samples",
|
|
type=int,
|
|
default=3,
|
|
help="how many samples to produce for each given prompt. A.k.a batch size",
|
|
)
|
|
|
|
parser.add_argument(
|
|
"--n_rows",
|
|
type=int,
|
|
default=0,
|
|
help="rows in the grid (default: n_samples)",
|
|
)
|
|
|
|
parser.add_argument(
|
|
"--scale",
|
|
type=float,
|
|
default=5.0,
|
|
help="unconditional guidance scale: eps = eps(x, empty) + scale * (eps(x, cond) - eps(x, empty))",
|
|
)
|
|
|
|
parser.add_argument(
|
|
"--from-file",
|
|
type=str,
|
|
help="if specified, load prompts from this file",
|
|
)
|
|
|
|
parser.add_argument(
|
|
"--config",
|
|
type=str,
|
|
default="configs/retrieval-augmented-diffusion/768x768.yaml",
|
|
help="path to config which constructs model",
|
|
)
|
|
|
|
parser.add_argument(
|
|
"--ckpt",
|
|
type=str,
|
|
default="models/rdm/rdm768x768/model.ckpt",
|
|
help="path to checkpoint of model",
|
|
)
|
|
|
|
parser.add_argument(
|
|
"--clip_type",
|
|
type=str,
|
|
default="ViT-L/14",
|
|
help="which CLIP model to use for retrieval and NN encoding",
|
|
)
|
|
parser.add_argument(
|
|
"--database",
|
|
type=str,
|
|
default='artbench-surrealism',
|
|
choices=DATABASES,
|
|
help="The database used for the search, only applied when --use_neighbors=True",
|
|
)
|
|
parser.add_argument(
|
|
"--use_neighbors",
|
|
default=False,
|
|
action='store_true',
|
|
help="Include neighbors in addition to text prompt for conditioning",
|
|
)
|
|
parser.add_argument(
|
|
"--knn",
|
|
default=10,
|
|
type=int,
|
|
help="The number of included neighbors, only applied when --use_neighbors=True",
|
|
)
|
|
|
|
opt = parser.parse_args()
|
|
|
|
config = OmegaConf.load(f"{opt.config}")
|
|
model = load_model_from_config(config, f"{opt.ckpt}")
|
|
|
|
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
|
|
model = model.to(device)
|
|
|
|
clip_text_encoder = FrozenCLIPTextEmbedder(opt.clip_type).to(device)
|
|
|
|
if opt.plms:
|
|
sampler = PLMSSampler(model)
|
|
else:
|
|
sampler = DDIMSampler(model)
|
|
|
|
os.makedirs(opt.outdir, exist_ok=True)
|
|
outpath = opt.outdir
|
|
|
|
batch_size = opt.n_samples
|
|
n_rows = opt.n_rows if opt.n_rows > 0 else batch_size
|
|
if not opt.from_file:
|
|
prompt = opt.prompt
|
|
assert prompt is not None
|
|
data = [batch_size * [prompt]]
|
|
|
|
else:
|
|
print(f"reading prompts from {opt.from_file}")
|
|
with open(opt.from_file, "r") as f:
|
|
data = f.read().splitlines()
|
|
data = list(chunk(data, batch_size))
|
|
|
|
sample_path = os.path.join(outpath, "samples")
|
|
os.makedirs(sample_path, exist_ok=True)
|
|
base_count = len(os.listdir(sample_path))
|
|
grid_count = len(os.listdir(outpath)) - 1
|
|
|
|
print(f"sampling scale for cfg is {opt.scale:.2f}")
|
|
|
|
searcher = None
|
|
if opt.use_neighbors:
|
|
searcher = Searcher(opt.database)
|
|
|
|
with torch.no_grad():
|
|
with model.ema_scope():
|
|
for n in trange(opt.n_iter, desc="Sampling"):
|
|
all_samples = list()
|
|
for prompts in tqdm(data, desc="data"):
|
|
print("sampling prompts:", prompts)
|
|
if isinstance(prompts, tuple):
|
|
prompts = list(prompts)
|
|
c = clip_text_encoder.encode(prompts)
|
|
uc = None
|
|
if searcher is not None:
|
|
nn_dict = searcher(c, opt.knn)
|
|
c = torch.cat([c, torch.from_numpy(nn_dict['nn_embeddings']).cuda()], dim=1)
|
|
if opt.scale != 1.0:
|
|
uc = torch.zeros_like(c)
|
|
if isinstance(prompts, tuple):
|
|
prompts = list(prompts)
|
|
shape = [16, opt.H // 16, opt.W // 16] # note: currently hardcoded for f16 model
|
|
samples_ddim, _ = sampler.sample(S=opt.ddim_steps,
|
|
conditioning=c,
|
|
batch_size=c.shape[0],
|
|
shape=shape,
|
|
verbose=False,
|
|
unconditional_guidance_scale=opt.scale,
|
|
unconditional_conditioning=uc,
|
|
eta=opt.ddim_eta,
|
|
)
|
|
|
|
x_samples_ddim = model.decode_first_stage(samples_ddim)
|
|
x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
|
|
|
|
for x_sample in x_samples_ddim:
|
|
x_sample = 255. * rearrange(x_sample.cpu().numpy(), 'c h w -> h w c')
|
|
Image.fromarray(x_sample.astype(np.uint8)).save(
|
|
os.path.join(sample_path, f"{base_count:05}.png"))
|
|
base_count += 1
|
|
all_samples.append(x_samples_ddim)
|
|
|
|
if not opt.skip_grid:
|
|
# additionally, save as grid
|
|
grid = torch.stack(all_samples, 0)
|
|
grid = rearrange(grid, 'n b c h w -> (n b) c h w')
|
|
grid = make_grid(grid, nrow=n_rows)
|
|
|
|
# to image
|
|
grid = 255. * rearrange(grid, 'c h w -> h w c').cpu().numpy()
|
|
Image.fromarray(grid.astype(np.uint8)).save(os.path.join(outpath, f'grid-{grid_count:04}.png'))
|
|
grid_count += 1
|
|
|
|
print(f"Your samples are ready and waiting for you here: \n{outpath} \nEnjoy.")
|