summaryrefslogtreecommitdiff
path: root/media/libjxl/src/tools/benchmark/metrics/iqa.py
blob: 1be9699926b28ccd1c8a58cb5c704186bfdc110e (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
#!/usr/bin/env python3
# Copyright (c) the JPEG XL Project Authors. All rights reserved.
#
# Use of this source code is governed by a BSD-style
# license that can be found in the LICENSE file.

import os
import sys
import pathlib
import torch
from torchvision import transforms
import numpy as np

path = pathlib.Path(__file__).parent.absolute(
) / '..' / '..' / '..' / 'third_party' / 'IQA-optimization'
sys.path.append(str(path))

from IQA_pytorch import SSIM, MS_SSIM, CW_SSIM, GMSD, LPIPSvgg, DISTS, NLPD, FSIM, VSI, VIFs, VIF, MAD


# only really works with the output from JXL, but we don't need more than that.
def read_pfm(fname):
    with open(fname, 'rb') as f:
        header_width_height = []
        while len(header_width_height) < 3:
            header_width_height += f.readline().rstrip().split()
        header, width, height = header_width_height
        assert header == b'PF' or header == b'Pf'
        width, height = int(width), int(height)
        scale = float(f.readline().rstrip())
        fmt = '<f' if scale < 0 else '>f'
        data = np.fromfile(f, fmt)
        if header == b'PF':
            out = np.reshape(data, (height, width, 3))[::-1, :, :]
        else:
            out = np.reshape(data, (height, width))[::-1, :]
        return out.astype(np.float)


D_dict = {
    'cwssim': CW_SSIM,
    'dists': DISTS,
    'fsim': FSIM,
    'gmsd': GMSD,
    'lpips': LPIPSvgg,
    'mad': MAD,
    'msssim': MS_SSIM,
    'nlpd': NLPD,
    'ssim': SSIM,
    'vif': VIF,
    'vsi': VSI,
}

algo = os.path.basename(sys.argv[1]).split('.')[0]
algo, color = algo.split('-')

channels = 3

if color == 'y':
    channels = 1


def Load(path):
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    transform = transforms.Compose([
        transforms.ToTensor(),
    ])
    img = read_pfm(path)
    if len(img.shape) == 3 and channels == 1:  # rgb -> Y
        assert img.shape[2] == 3
        tmp = np.zeros((img.shape[0], img.shape[1], 1), dtype=float)
        tmp[:, :, 0] = (0.2126 * img[:, :, 0] + 0.7152 * img[:, :, 1] +
                        0.0722 * img[:, :, 2])
        img = tmp
    if len(img.shape) == 2 and channels == 3:  # Y -> rgb
        gray = img
        img = np.zeros((img.shape[0], img.shape[1], 3), dtype=float)
        img[:, :, 0] = img[:, :, 1] = img[:, :, 2] = gray
    if len(img.shape) == 3:
        img = np.transpose(img, axes=(2, 0, 1)).copy()
    return torch.FloatTensor(img).unsqueeze(0).to(device)


ref_img = Load(sys.argv[2])
enc_img = Load(sys.argv[3])
D = D_dict[algo](channels=channels)
score = D(ref_img, enc_img, as_loss=False)

with open(sys.argv[4], 'w') as f:
    print(score.item(), file=f)