summaryrefslogtreecommitdiff
path: root/media/libjxl/src/tools/benchmark/metrics/iqa.py
diff options
context:
space:
mode:
Diffstat (limited to 'media/libjxl/src/tools/benchmark/metrics/iqa.py')
-rw-r--r--media/libjxl/src/tools/benchmark/metrics/iqa.py90
1 files changed, 90 insertions, 0 deletions
diff --git a/media/libjxl/src/tools/benchmark/metrics/iqa.py b/media/libjxl/src/tools/benchmark/metrics/iqa.py
new file mode 100644
index 0000000000..1be9699926
--- /dev/null
+++ b/media/libjxl/src/tools/benchmark/metrics/iqa.py
@@ -0,0 +1,90 @@
+#!/usr/bin/env python3
+# Copyright (c) the JPEG XL Project Authors. All rights reserved.
+#
+# Use of this source code is governed by a BSD-style
+# license that can be found in the LICENSE file.
+
+import os
+import sys
+import pathlib
+import torch
+from torchvision import transforms
+import numpy as np
+
+path = pathlib.Path(__file__).parent.absolute(
+) / '..' / '..' / '..' / 'third_party' / 'IQA-optimization'
+sys.path.append(str(path))
+
+from IQA_pytorch import SSIM, MS_SSIM, CW_SSIM, GMSD, LPIPSvgg, DISTS, NLPD, FSIM, VSI, VIFs, VIF, MAD
+
+
+# only really works with the output from JXL, but we don't need more than that.
+def read_pfm(fname):
+ with open(fname, 'rb') as f:
+ header_width_height = []
+ while len(header_width_height) < 3:
+ header_width_height += f.readline().rstrip().split()
+ header, width, height = header_width_height
+ assert header == b'PF' or header == b'Pf'
+ width, height = int(width), int(height)
+ scale = float(f.readline().rstrip())
+ fmt = '<f' if scale < 0 else '>f'
+ data = np.fromfile(f, fmt)
+ if header == b'PF':
+ out = np.reshape(data, (height, width, 3))[::-1, :, :]
+ else:
+ out = np.reshape(data, (height, width))[::-1, :]
+ return out.astype(np.float)
+
+
+D_dict = {
+ 'cwssim': CW_SSIM,
+ 'dists': DISTS,
+ 'fsim': FSIM,
+ 'gmsd': GMSD,
+ 'lpips': LPIPSvgg,
+ 'mad': MAD,
+ 'msssim': MS_SSIM,
+ 'nlpd': NLPD,
+ 'ssim': SSIM,
+ 'vif': VIF,
+ 'vsi': VSI,
+}
+
+algo = os.path.basename(sys.argv[1]).split('.')[0]
+algo, color = algo.split('-')
+
+channels = 3
+
+if color == 'y':
+ channels = 1
+
+
+def Load(path):
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
+ transform = transforms.Compose([
+ transforms.ToTensor(),
+ ])
+ img = read_pfm(path)
+ if len(img.shape) == 3 and channels == 1: # rgb -> Y
+ assert img.shape[2] == 3
+ tmp = np.zeros((img.shape[0], img.shape[1], 1), dtype=float)
+ tmp[:, :, 0] = (0.2126 * img[:, :, 0] + 0.7152 * img[:, :, 1] +
+ 0.0722 * img[:, :, 2])
+ img = tmp
+ if len(img.shape) == 2 and channels == 3: # Y -> rgb
+ gray = img
+ img = np.zeros((img.shape[0], img.shape[1], 3), dtype=float)
+ img[:, :, 0] = img[:, :, 1] = img[:, :, 2] = gray
+ if len(img.shape) == 3:
+ img = np.transpose(img, axes=(2, 0, 1)).copy()
+ return torch.FloatTensor(img).unsqueeze(0).to(device)
+
+
+ref_img = Load(sys.argv[2])
+enc_img = Load(sys.argv[3])
+D = D_dict[algo](channels=channels)
+score = D(ref_img, enc_img, as_loss=False)
+
+with open(sys.argv[4], 'w') as f:
+ print(score.item(), file=f)