mirror of https://github.com/hpcaitech/ColossalAI
302 lines
9.1 KiB
Python
302 lines
9.1 KiB
Python
import argparse
|
|
import datetime
|
|
import glob
|
|
import os
|
|
import sys
|
|
import time
|
|
|
|
import numpy as np
|
|
import torch
|
|
import yaml
|
|
from ldm.models.diffusion.ddim import DDIMSampler
|
|
from ldm.util import instantiate_from_config
|
|
from omegaconf import OmegaConf
|
|
from PIL import Image
|
|
from tqdm import trange
|
|
|
|
rescale = lambda x: (x + 1.0) / 2.0
|
|
|
|
|
|
def custom_to_pil(x):
|
|
x = x.detach().cpu()
|
|
x = torch.clamp(x, -1.0, 1.0)
|
|
x = (x + 1.0) / 2.0
|
|
x = x.permute(1, 2, 0).numpy()
|
|
x = (255 * x).astype(np.uint8)
|
|
x = Image.fromarray(x)
|
|
if not x.mode == "RGB":
|
|
x = x.convert("RGB")
|
|
return x
|
|
|
|
|
|
def custom_to_np(x):
|
|
# saves the batch in adm style as in https://github.com/openai/guided-diffusion/blob/main/scripts/image_sample.py
|
|
sample = x.detach().cpu()
|
|
sample = ((sample + 1) * 127.5).clamp(0, 255).to(torch.uint8)
|
|
sample = sample.permute(0, 2, 3, 1)
|
|
sample = sample.contiguous()
|
|
return sample
|
|
|
|
|
|
def logs2pil(logs, keys=["sample"]):
|
|
imgs = dict()
|
|
for k in logs:
|
|
try:
|
|
if len(logs[k].shape) == 4:
|
|
img = custom_to_pil(logs[k][0, ...])
|
|
elif len(logs[k].shape) == 3:
|
|
img = custom_to_pil(logs[k])
|
|
else:
|
|
print(f"Unknown format for key {k}. ")
|
|
img = None
|
|
except:
|
|
img = None
|
|
imgs[k] = img
|
|
return imgs
|
|
|
|
|
|
@torch.no_grad()
|
|
def convsample(model, shape, return_intermediates=True, verbose=True, make_prog_row=False):
|
|
if not make_prog_row:
|
|
return model.p_sample_loop(None, shape, return_intermediates=return_intermediates, verbose=verbose)
|
|
else:
|
|
return model.progressive_denoising(None, shape, verbose=True)
|
|
|
|
|
|
@torch.no_grad()
|
|
def convsample_ddim(model, steps, shape, eta=1.0):
|
|
ddim = DDIMSampler(model)
|
|
bs = shape[0]
|
|
shape = shape[1:]
|
|
samples, intermediates = ddim.sample(
|
|
steps,
|
|
batch_size=bs,
|
|
shape=shape,
|
|
eta=eta,
|
|
verbose=False,
|
|
)
|
|
return samples, intermediates
|
|
|
|
|
|
@torch.no_grad()
|
|
def make_convolutional_sample(
|
|
model,
|
|
batch_size,
|
|
vanilla=False,
|
|
custom_steps=None,
|
|
eta=1.0,
|
|
):
|
|
log = dict()
|
|
|
|
shape = [
|
|
batch_size,
|
|
model.model.diffusion_model.in_channels,
|
|
model.model.diffusion_model.image_size,
|
|
model.model.diffusion_model.image_size,
|
|
]
|
|
|
|
with model.ema_scope("Plotting"):
|
|
t0 = time.time()
|
|
if vanilla:
|
|
sample, progrow = convsample(model, shape, make_prog_row=True)
|
|
else:
|
|
sample, intermediates = convsample_ddim(model, steps=custom_steps, shape=shape, eta=eta)
|
|
|
|
t1 = time.time()
|
|
|
|
x_sample = model.decode_first_stage(sample)
|
|
|
|
log["sample"] = x_sample
|
|
log["time"] = t1 - t0
|
|
log["throughput"] = sample.shape[0] / (t1 - t0)
|
|
print(f'Throughput for this batch: {log["throughput"]}')
|
|
return log
|
|
|
|
|
|
def run(model, logdir, batch_size=50, vanilla=False, custom_steps=None, eta=None, n_samples=50000, nplog=None):
|
|
if vanilla:
|
|
print(f"Using Vanilla DDPM sampling with {model.num_timesteps} sampling steps.")
|
|
else:
|
|
print(f"Using DDIM sampling with {custom_steps} sampling steps and eta={eta}")
|
|
|
|
tstart = time.time()
|
|
n_saved = len(glob.glob(os.path.join(logdir, "*.png"))) - 1
|
|
# path = logdir
|
|
if model.cond_stage_model is None:
|
|
all_images = []
|
|
|
|
print(f"Running unconditional sampling for {n_samples} samples")
|
|
for _ in trange(n_samples // batch_size, desc="Sampling Batches (unconditional)"):
|
|
logs = make_convolutional_sample(
|
|
model, batch_size=batch_size, vanilla=vanilla, custom_steps=custom_steps, eta=eta
|
|
)
|
|
n_saved = save_logs(logs, logdir, n_saved=n_saved, key="sample")
|
|
all_images.extend([custom_to_np(logs["sample"])])
|
|
if n_saved >= n_samples:
|
|
print(f"Finish after generating {n_saved} samples")
|
|
break
|
|
all_img = np.concatenate(all_images, axis=0)
|
|
all_img = all_img[:n_samples]
|
|
shape_str = "x".join([str(x) for x in all_img.shape])
|
|
nppath = os.path.join(nplog, f"{shape_str}-samples.npz")
|
|
np.savez(nppath, all_img)
|
|
|
|
else:
|
|
raise NotImplementedError("Currently only sampling for unconditional models supported.")
|
|
|
|
print(f"sampling of {n_saved} images finished in {(time.time() - tstart) / 60.:.2f} minutes.")
|
|
|
|
|
|
def save_logs(logs, path, n_saved=0, key="sample", np_path=None):
|
|
for k in logs:
|
|
if k == key:
|
|
batch = logs[key]
|
|
if np_path is None:
|
|
for x in batch:
|
|
img = custom_to_pil(x)
|
|
imgpath = os.path.join(path, f"{key}_{n_saved:06}.png")
|
|
img.save(imgpath)
|
|
n_saved += 1
|
|
else:
|
|
npbatch = custom_to_np(batch)
|
|
shape_str = "x".join([str(x) for x in npbatch.shape])
|
|
nppath = os.path.join(np_path, f"{n_saved}-{shape_str}-samples.npz")
|
|
np.savez(nppath, npbatch)
|
|
n_saved += npbatch.shape[0]
|
|
return n_saved
|
|
|
|
|
|
def get_parser():
|
|
parser = argparse.ArgumentParser()
|
|
parser.add_argument(
|
|
"-r",
|
|
"--resume",
|
|
type=str,
|
|
nargs="?",
|
|
help="load from logdir or checkpoint in logdir",
|
|
)
|
|
parser.add_argument("-n", "--n_samples", type=int, nargs="?", help="number of samples to draw", default=50000)
|
|
parser.add_argument(
|
|
"-e",
|
|
"--eta",
|
|
type=float,
|
|
nargs="?",
|
|
help="eta for ddim sampling (0.0 yields deterministic sampling)",
|
|
default=1.0,
|
|
)
|
|
parser.add_argument(
|
|
"-v",
|
|
"--vanilla_sample",
|
|
default=False,
|
|
action="store_true",
|
|
help="vanilla sampling (default option is DDIM sampling)?",
|
|
)
|
|
parser.add_argument("-l", "--logdir", type=str, nargs="?", help="extra logdir", default="none")
|
|
parser.add_argument(
|
|
"-c", "--custom_steps", type=int, nargs="?", help="number of steps for ddim and fastdpm sampling", default=50
|
|
)
|
|
parser.add_argument("--batch_size", type=int, nargs="?", help="the bs", default=10)
|
|
return parser
|
|
|
|
|
|
def load_model_from_config(config, sd):
|
|
model = instantiate_from_config(config)
|
|
model.load_state_dict(sd, strict=False)
|
|
model.cuda()
|
|
model.eval()
|
|
return model
|
|
|
|
|
|
def load_model(config, ckpt, gpu, eval_mode):
|
|
if ckpt:
|
|
print(f"Loading model from {ckpt}")
|
|
pl_sd = torch.load(ckpt, map_location="cpu")
|
|
global_step = pl_sd["global_step"]
|
|
else:
|
|
pl_sd = {"state_dict": None}
|
|
global_step = None
|
|
model = load_model_from_config(config.model, pl_sd["state_dict"])
|
|
|
|
return model, global_step
|
|
|
|
|
|
if __name__ == "__main__":
|
|
now = datetime.datetime.now().strftime("%Y-%m-%d-%H-%M-%S")
|
|
sys.path.append(os.getcwd())
|
|
command = " ".join(sys.argv)
|
|
|
|
parser = get_parser()
|
|
opt, unknown = parser.parse_known_args()
|
|
ckpt = None
|
|
|
|
if not os.path.exists(opt.resume):
|
|
raise ValueError("Cannot find {}".format(opt.resume))
|
|
if os.path.isfile(opt.resume):
|
|
# paths = opt.resume.split("/")
|
|
try:
|
|
logdir = "/".join(opt.resume.split("/")[:-1])
|
|
# idx = len(paths)-paths[::-1].index("logs")+1
|
|
print(f"Logdir is {logdir}")
|
|
except ValueError:
|
|
paths = opt.resume.split("/")
|
|
idx = -2 # take a guess: path/to/logdir/checkpoints/model.ckpt
|
|
logdir = "/".join(paths[:idx])
|
|
ckpt = opt.resume
|
|
else:
|
|
assert os.path.isdir(opt.resume), f"{opt.resume} is not a directory"
|
|
logdir = opt.resume.rstrip("/")
|
|
ckpt = os.path.join(logdir, "model.ckpt")
|
|
|
|
base_configs = sorted(glob.glob(os.path.join(logdir, "config.yaml")))
|
|
opt.base = base_configs
|
|
|
|
configs = [OmegaConf.load(cfg) for cfg in opt.base]
|
|
cli = OmegaConf.from_dotlist(unknown)
|
|
config = OmegaConf.merge(*configs, cli)
|
|
|
|
gpu = True
|
|
eval_mode = True
|
|
|
|
if opt.logdir != "none":
|
|
locallog = logdir.split(os.sep)[-1]
|
|
if locallog == "":
|
|
locallog = logdir.split(os.sep)[-2]
|
|
print(f"Switching logdir from '{logdir}' to '{os.path.join(opt.logdir, locallog)}'")
|
|
logdir = os.path.join(opt.logdir, locallog)
|
|
|
|
print(config)
|
|
|
|
model, global_step = load_model(config, ckpt, gpu, eval_mode)
|
|
print(f"global step: {global_step}")
|
|
print(75 * "=")
|
|
print("logging to:")
|
|
logdir = os.path.join(logdir, "samples", f"{global_step:08}", now)
|
|
imglogdir = os.path.join(logdir, "img")
|
|
numpylogdir = os.path.join(logdir, "numpy")
|
|
|
|
os.makedirs(imglogdir)
|
|
os.makedirs(numpylogdir)
|
|
print(logdir)
|
|
print(75 * "=")
|
|
|
|
# write config out
|
|
sampling_file = os.path.join(logdir, "sampling_config.yaml")
|
|
sampling_conf = vars(opt)
|
|
|
|
with open(sampling_file, "w") as f:
|
|
yaml.dump(sampling_conf, f, default_flow_style=False)
|
|
print(sampling_conf)
|
|
|
|
run(
|
|
model,
|
|
imglogdir,
|
|
eta=opt.eta,
|
|
vanilla=opt.vanilla_sample,
|
|
n_samples=opt.n_samples,
|
|
custom_steps=opt.custom_steps,
|
|
batch_size=opt.batch_size,
|
|
nplog=numpylogdir,
|
|
)
|
|
|
|
print("done.")
|