summaryrefslogtreecommitdiff
path: root/media/libjxl/src/lib/jxl/optimize_test.cc
blob: c606a035c6c289bcb5daaaeaada1fdd8cff4abd7 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
// Copyright (c) the JPEG XL Project Authors. All rights reserved.
//
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.

#include "lib/jxl/optimize.h"

#include <stdio.h>

#include "gtest/gtest.h"

namespace jxl {
namespace optimize {
namespace {

// The maximum number of iterations for the test.
static const size_t kMaxTestIter = 100000;

// F(w) = (w - w_min)^2.
struct SimpleQuadraticFunction {
  typedef Array<double, 2> ArrayType;
  explicit SimpleQuadraticFunction(const ArrayType& w0) : w_min(w0) {}

  double Compute(const ArrayType& w, ArrayType* df) const {
    ArrayType dw = w - w_min;
    *df = -2.0 * dw;
    return dw * dw;
  }

  ArrayType w_min;
};

// F(alpha, beta, gamma| x,y) = \sum_i(y_i - (alpha x_i ^ gamma + beta))^2.
struct PowerFunction {
  explicit PowerFunction(const std::vector<double>& x0,
                         const std::vector<double>& y0)
      : x(x0), y(y0) {}

  typedef Array<double, 3> ArrayType;
  double Compute(const ArrayType& w, ArrayType* df) const {
    double loss_function = 0;
    (*df)[0] = 0;
    (*df)[1] = 0;
    (*df)[2] = 0;
    for (size_t ind = 0; ind < y.size(); ++ind) {
      if (x[ind] != 0) {
        double l_f = y[ind] - (w[0] * pow(x[ind], w[1]) + w[2]);
        (*df)[0] += 2.0 * l_f * pow(x[ind], w[1]);
        (*df)[1] += 2.0 * l_f * w[0] * pow(x[ind], w[1]) * log(x[ind]);
        (*df)[2] += 2.0 * l_f * 1;
        loss_function += l_f * l_f;
      }
    }
    return loss_function;
  }

  std::vector<double> x;
  std::vector<double> y;
};

TEST(OptimizeTest, SimpleQuadraticFunction) {
  SimpleQuadraticFunction::ArrayType w_min;
  w_min[0] = 1.0;
  w_min[1] = 2.0;
  SimpleQuadraticFunction f(w_min);
  SimpleQuadraticFunction::ArrayType w(0.);
  static const double kPrecision = 1e-8;
  w = optimize::OptimizeWithScaledConjugateGradientMethod(f, w, kPrecision,
                                                          kMaxTestIter);
  EXPECT_NEAR(w[0], 1.0, kPrecision);
  EXPECT_NEAR(w[1], 2.0, kPrecision);
}

TEST(OptimizeTest, PowerFunction) {
  std::vector<double> x(10);
  std::vector<double> y(10);
  for (int ind = 0; ind < 10; ++ind) {
    x[ind] = 1. * ind;
    y[ind] = 2. * pow(x[ind], 3) + 5.;
  }
  PowerFunction f(x, y);
  PowerFunction::ArrayType w(0.);

  static const double kPrecision = 0.01;
  w = optimize::OptimizeWithScaledConjugateGradientMethod(f, w, kPrecision,
                                                          kMaxTestIter);
  EXPECT_NEAR(w[0], 2.0, kPrecision);
  EXPECT_NEAR(w[1], 3.0, kPrecision);
  EXPECT_NEAR(w[2], 5.0, kPrecision);
}

TEST(OptimizeTest, SimplexOptTest) {
  auto f = [](const std::vector<double>& x) -> double {
    double t1 = x[0] - 1.0;
    double t2 = x[1] + 1.5;
    return 2.0 + t1 * t1 + t2 * t2;
  };
  auto opt = RunSimplex(2, 0.01, 100, f);
  EXPECT_EQ(opt.size(), 3u);

  static const double kPrecision = 0.01;
  EXPECT_NEAR(opt[0], 2.0, kPrecision);
  EXPECT_NEAR(opt[1], 1.0, kPrecision);
  EXPECT_NEAR(opt[2], -1.5, kPrecision);
}

}  // namespace
}  // namespace optimize
}  // namespace jxl