mirror of https://github.com/hpcaitech/ColossalAI
You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
81 lines
2.9 KiB
81 lines
2.9 KiB
4 months ago
|
# Code from https://github.com/mit-han-lab/distrifuser/blob/main/scripts/compute_metrics.py
|
||
|
import argparse
|
||
|
import os
|
||
|
|
||
|
import numpy as np
|
||
|
import torch
|
||
|
from cleanfid import fid
|
||
|
from PIL import Image
|
||
|
from torch.utils.data import DataLoader, Dataset
|
||
|
from torchmetrics.image import LearnedPerceptualImagePatchSimilarity, PeakSignalNoiseRatio
|
||
|
from torchvision.transforms import Resize
|
||
|
from tqdm import tqdm
|
||
|
|
||
|
|
||
|
def read_image(path: str):
|
||
|
"""
|
||
|
input: path
|
||
|
output: tensor (C, H, W)
|
||
|
"""
|
||
|
img = np.asarray(Image.open(path))
|
||
|
if len(img.shape) == 2:
|
||
|
img = np.repeat(img[:, :, None], 3, axis=2)
|
||
|
img = torch.from_numpy(img).permute(2, 0, 1)
|
||
|
return img
|
||
|
|
||
|
|
||
|
class MultiImageDataset(Dataset):
|
||
|
def __init__(self, root0, root1, is_gt=False):
|
||
|
super().__init__()
|
||
|
self.root0 = root0
|
||
|
self.root1 = root1
|
||
|
file_names0 = os.listdir(root0)
|
||
|
file_names1 = os.listdir(root1)
|
||
|
|
||
|
self.image_names0 = sorted([name for name in file_names0 if name.endswith(".png") or name.endswith(".jpg")])
|
||
|
self.image_names1 = sorted([name for name in file_names1 if name.endswith(".png") or name.endswith(".jpg")])
|
||
|
self.is_gt = is_gt
|
||
|
assert len(self.image_names0) == len(self.image_names1)
|
||
|
|
||
|
def __len__(self):
|
||
|
return len(self.image_names0)
|
||
|
|
||
|
def __getitem__(self, idx):
|
||
|
img0 = read_image(os.path.join(self.root0, self.image_names0[idx]))
|
||
|
if self.is_gt:
|
||
|
# resize to 1024 x 1024
|
||
|
img0 = Resize((1024, 1024))(img0)
|
||
|
img1 = read_image(os.path.join(self.root1, self.image_names1[idx]))
|
||
|
|
||
|
batch_list = [img0, img1]
|
||
|
return batch_list
|
||
|
|
||
|
|
||
|
if __name__ == "__main__":
|
||
|
parser = argparse.ArgumentParser()
|
||
|
parser.add_argument("--batch_size", type=int, default=64)
|
||
|
parser.add_argument("--num_workers", type=int, default=8)
|
||
|
parser.add_argument("--is_gt", action="store_true")
|
||
|
parser.add_argument("--input_root0", type=str, required=True)
|
||
|
parser.add_argument("--input_root1", type=str, required=True)
|
||
|
args = parser.parse_args()
|
||
|
|
||
|
psnr = PeakSignalNoiseRatio(data_range=(0, 1), reduction="elementwise_mean", dim=(1, 2, 3)).to("cuda")
|
||
|
lpips = LearnedPerceptualImagePatchSimilarity(normalize=True).to("cuda")
|
||
|
|
||
|
dataset = MultiImageDataset(args.input_root0, args.input_root1, is_gt=args.is_gt)
|
||
|
dataloader = DataLoader(dataset, batch_size=args.batch_size, num_workers=args.num_workers)
|
||
|
|
||
|
progress_bar = tqdm(dataloader)
|
||
|
with torch.inference_mode():
|
||
|
for i, batch in enumerate(progress_bar):
|
||
|
batch = [img.to("cuda") / 255 for img in batch]
|
||
|
batch_size = batch[0].shape[0]
|
||
|
psnr.update(batch[0], batch[1])
|
||
|
lpips.update(batch[0], batch[1])
|
||
|
fid_score = fid.compute_fid(args.input_root0, args.input_root1)
|
||
|
|
||
|
print("PSNR:", psnr.compute().item())
|
||
|
print("LPIPS:", lpips.compute().item())
|
||
|
print("FID:", fid_score)
|