import argparse import datetime import glob import os import sys import time import numpy as np import torch import yaml from ldm.models.diffusion.ddim import DDIMSampler from ldm.util import instantiate_from_config from omegaconf import OmegaConf from PIL import Image from tqdm import trange rescale = lambda x: (x + 1.0) / 2.0 def custom_to_pil(x): x = x.detach().cpu() x = torch.clamp(x, -1.0, 1.0) x = (x + 1.0) / 2.0 x = x.permute(1, 2, 0).numpy() x = (255 * x).astype(np.uint8) x = Image.fromarray(x) if not x.mode == "RGB": x = x.convert("RGB") return x def custom_to_np(x): # saves the batch in adm style as in https://github.com/openai/guided-diffusion/blob/main/scripts/image_sample.py sample = x.detach().cpu() sample = ((sample + 1) * 127.5).clamp(0, 255).to(torch.uint8) sample = sample.permute(0, 2, 3, 1) sample = sample.contiguous() return sample def logs2pil(logs, keys=["sample"]): imgs = dict() for k in logs: try: if len(logs[k].shape) == 4: img = custom_to_pil(logs[k][0, ...]) elif len(logs[k].shape) == 3: img = custom_to_pil(logs[k]) else: print(f"Unknown format for key {k}. ") img = None except: img = None imgs[k] = img return imgs @torch.no_grad() def convsample(model, shape, return_intermediates=True, verbose=True, make_prog_row=False): if not make_prog_row: return model.p_sample_loop(None, shape, return_intermediates=return_intermediates, verbose=verbose) else: return model.progressive_denoising(None, shape, verbose=True) @torch.no_grad() def convsample_ddim(model, steps, shape, eta=1.0): ddim = DDIMSampler(model) bs = shape[0] shape = shape[1:] samples, intermediates = ddim.sample( steps, batch_size=bs, shape=shape, eta=eta, verbose=False, ) return samples, intermediates @torch.no_grad() def make_convolutional_sample( model, batch_size, vanilla=False, custom_steps=None, eta=1.0, ): log = dict() shape = [ batch_size, model.model.diffusion_model.in_channels, model.model.diffusion_model.image_size, model.model.diffusion_model.image_size, ] with model.ema_scope("Plotting"): t0 = time.time() if vanilla: sample, progrow = convsample(model, shape, make_prog_row=True) else: sample, intermediates = convsample_ddim(model, steps=custom_steps, shape=shape, eta=eta) t1 = time.time() x_sample = model.decode_first_stage(sample) log["sample"] = x_sample log["time"] = t1 - t0 log["throughput"] = sample.shape[0] / (t1 - t0) print(f'Throughput for this batch: {log["throughput"]}') return log def run(model, logdir, batch_size=50, vanilla=False, custom_steps=None, eta=None, n_samples=50000, nplog=None): if vanilla: print(f"Using Vanilla DDPM sampling with {model.num_timesteps} sampling steps.") else: print(f"Using DDIM sampling with {custom_steps} sampling steps and eta={eta}") tstart = time.time() n_saved = len(glob.glob(os.path.join(logdir, "*.png"))) - 1 # path = logdir if model.cond_stage_model is None: all_images = [] print(f"Running unconditional sampling for {n_samples} samples") for _ in trange(n_samples // batch_size, desc="Sampling Batches (unconditional)"): logs = make_convolutional_sample( model, batch_size=batch_size, vanilla=vanilla, custom_steps=custom_steps, eta=eta ) n_saved = save_logs(logs, logdir, n_saved=n_saved, key="sample") all_images.extend([custom_to_np(logs["sample"])]) if n_saved >= n_samples: print(f"Finish after generating {n_saved} samples") break all_img = np.concatenate(all_images, axis=0) all_img = all_img[:n_samples] shape_str = "x".join([str(x) for x in all_img.shape]) nppath = os.path.join(nplog, f"{shape_str}-samples.npz") np.savez(nppath, all_img) else: raise NotImplementedError("Currently only sampling for unconditional models supported.") print(f"sampling of {n_saved} images finished in {(time.time() - tstart) / 60.:.2f} minutes.") def save_logs(logs, path, n_saved=0, key="sample", np_path=None): for k in logs: if k == key: batch = logs[key] if np_path is None: for x in batch: img = custom_to_pil(x) imgpath = os.path.join(path, f"{key}_{n_saved:06}.png") img.save(imgpath) n_saved += 1 else: npbatch = custom_to_np(batch) shape_str = "x".join([str(x) for x in npbatch.shape]) nppath = os.path.join(np_path, f"{n_saved}-{shape_str}-samples.npz") np.savez(nppath, npbatch) n_saved += npbatch.shape[0] return n_saved def get_parser(): parser = argparse.ArgumentParser() parser.add_argument( "-r", "--resume", type=str, nargs="?", help="load from logdir or checkpoint in logdir", ) parser.add_argument("-n", "--n_samples", type=int, nargs="?", help="number of samples to draw", default=50000) parser.add_argument( "-e", "--eta", type=float, nargs="?", help="eta for ddim sampling (0.0 yields deterministic sampling)", default=1.0, ) parser.add_argument( "-v", "--vanilla_sample", default=False, action="store_true", help="vanilla sampling (default option is DDIM sampling)?", ) parser.add_argument("-l", "--logdir", type=str, nargs="?", help="extra logdir", default="none") parser.add_argument( "-c", "--custom_steps", type=int, nargs="?", help="number of steps for ddim and fastdpm sampling", default=50 ) parser.add_argument("--batch_size", type=int, nargs="?", help="the bs", default=10) return parser def load_model_from_config(config, sd): model = instantiate_from_config(config) model.load_state_dict(sd, strict=False) model.cuda() model.eval() return model def load_model(config, ckpt, gpu, eval_mode): if ckpt: print(f"Loading model from {ckpt}") pl_sd = torch.load(ckpt, map_location="cpu") global_step = pl_sd["global_step"] else: pl_sd = {"state_dict": None} global_step = None model = load_model_from_config(config.model, pl_sd["state_dict"]) return model, global_step if __name__ == "__main__": now = datetime.datetime.now().strftime("%Y-%m-%d-%H-%M-%S") sys.path.append(os.getcwd()) command = " ".join(sys.argv) parser = get_parser() opt, unknown = parser.parse_known_args() ckpt = None if not os.path.exists(opt.resume): raise ValueError("Cannot find {}".format(opt.resume)) if os.path.isfile(opt.resume): # paths = opt.resume.split("/") try: logdir = "/".join(opt.resume.split("/")[:-1]) # idx = len(paths)-paths[::-1].index("logs")+1 print(f"Logdir is {logdir}") except ValueError: paths = opt.resume.split("/") idx = -2 # take a guess: path/to/logdir/checkpoints/model.ckpt logdir = "/".join(paths[:idx]) ckpt = opt.resume else: assert os.path.isdir(opt.resume), f"{opt.resume} is not a directory" logdir = opt.resume.rstrip("/") ckpt = os.path.join(logdir, "model.ckpt") base_configs = sorted(glob.glob(os.path.join(logdir, "config.yaml"))) opt.base = base_configs configs = [OmegaConf.load(cfg) for cfg in opt.base] cli = OmegaConf.from_dotlist(unknown) config = OmegaConf.merge(*configs, cli) gpu = True eval_mode = True if opt.logdir != "none": locallog = logdir.split(os.sep)[-1] if locallog == "": locallog = logdir.split(os.sep)[-2] print(f"Switching logdir from '{logdir}' to '{os.path.join(opt.logdir, locallog)}'") logdir = os.path.join(opt.logdir, locallog) print(config) model, global_step = load_model(config, ckpt, gpu, eval_mode) print(f"global step: {global_step}") print(75 * "=") print("logging to:") logdir = os.path.join(logdir, "samples", f"{global_step:08}", now) imglogdir = os.path.join(logdir, "img") numpylogdir = os.path.join(logdir, "numpy") os.makedirs(imglogdir) os.makedirs(numpylogdir) print(logdir) print(75 * "=") # write config out sampling_file = os.path.join(logdir, "sampling_config.yaml") sampling_conf = vars(opt) with open(sampling_file, "w") as f: yaml.dump(sampling_conf, f, default_flow_style=False) print(sampling_conf) run( model, imglogdir, eta=opt.eta, vanilla=opt.vanilla_sample, n_samples=opt.n_samples, custom_steps=opt.custom_steps, batch_size=opt.batch_size, nplog=numpylogdir, ) print("done.")