You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
ColossalAI/examples/inference/stable_diffusion/compute_metric.py

81 lines
2.9 KiB

# Code from https://github.com/mit-han-lab/distrifuser/blob/main/scripts/compute_metrics.py
import argparse
import os
import numpy as np
import torch
from cleanfid import fid
from PIL import Image
from torch.utils.data import DataLoader, Dataset
from torchmetrics.image import LearnedPerceptualImagePatchSimilarity, PeakSignalNoiseRatio
from torchvision.transforms import Resize
from tqdm import tqdm
def read_image(path: str):
"""
input: path
output: tensor (C, H, W)
"""
img = np.asarray(Image.open(path))
if len(img.shape) == 2:
img = np.repeat(img[:, :, None], 3, axis=2)
img = torch.from_numpy(img).permute(2, 0, 1)
return img
class MultiImageDataset(Dataset):
def __init__(self, root0, root1, is_gt=False):
super().__init__()
self.root0 = root0
self.root1 = root1
file_names0 = os.listdir(root0)
file_names1 = os.listdir(root1)
self.image_names0 = sorted([name for name in file_names0 if name.endswith(".png") or name.endswith(".jpg")])
self.image_names1 = sorted([name for name in file_names1 if name.endswith(".png") or name.endswith(".jpg")])
self.is_gt = is_gt
assert len(self.image_names0) == len(self.image_names1)
def __len__(self):
return len(self.image_names0)
def __getitem__(self, idx):
img0 = read_image(os.path.join(self.root0, self.image_names0[idx]))
if self.is_gt:
# resize to 1024 x 1024
img0 = Resize((1024, 1024))(img0)
img1 = read_image(os.path.join(self.root1, self.image_names1[idx]))
batch_list = [img0, img1]
return batch_list
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--batch_size", type=int, default=64)
parser.add_argument("--num_workers", type=int, default=8)
parser.add_argument("--is_gt", action="store_true")
parser.add_argument("--input_root0", type=str, required=True)
parser.add_argument("--input_root1", type=str, required=True)
args = parser.parse_args()
psnr = PeakSignalNoiseRatio(data_range=(0, 1), reduction="elementwise_mean", dim=(1, 2, 3)).to("cuda")
lpips = LearnedPerceptualImagePatchSimilarity(normalize=True).to("cuda")
dataset = MultiImageDataset(args.input_root0, args.input_root1, is_gt=args.is_gt)
dataloader = DataLoader(dataset, batch_size=args.batch_size, num_workers=args.num_workers)
progress_bar = tqdm(dataloader)
with torch.inference_mode():
for i, batch in enumerate(progress_bar):
batch = [img.to("cuda") / 255 for img in batch]
batch_size = batch[0].shape[0]
psnr.update(batch[0], batch[1])
lpips.update(batch[0], batch[1])
fid_score = fid.compute_fid(args.input_root0, args.input_root1)
print("PSNR:", psnr.compute().item())
print("LPIPS:", lpips.compute().item())
print("FID:", fid_score)