diff options
Diffstat (limited to 'media/libjxl/src/lib/jxl/dec_ans.cc')
-rw-r--r-- | media/libjxl/src/lib/jxl/dec_ans.cc | 374 |
1 files changed, 374 insertions, 0 deletions
diff --git a/media/libjxl/src/lib/jxl/dec_ans.cc b/media/libjxl/src/lib/jxl/dec_ans.cc new file mode 100644 index 0000000000..a64493237e --- /dev/null +++ b/media/libjxl/src/lib/jxl/dec_ans.cc @@ -0,0 +1,374 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#include "lib/jxl/dec_ans.h" + +#include <stdint.h> + +#include <vector> + +#include "lib/jxl/ans_common.h" +#include "lib/jxl/ans_params.h" +#include "lib/jxl/base/bits.h" +#include "lib/jxl/base/printf_macros.h" +#include "lib/jxl/base/profiler.h" +#include "lib/jxl/base/status.h" +#include "lib/jxl/common.h" +#include "lib/jxl/dec_context_map.h" +#include "lib/jxl/fields.h" + +namespace jxl { +namespace { + +// Decodes a number in the range [0..255], by reading 1 - 11 bits. +inline int DecodeVarLenUint8(BitReader* input) { + if (input->ReadFixedBits<1>()) { + int nbits = static_cast<int>(input->ReadFixedBits<3>()); + if (nbits == 0) { + return 1; + } else { + return static_cast<int>(input->ReadBits(nbits)) + (1 << nbits); + } + } + return 0; +} + +// Decodes a number in the range [0..65535], by reading 1 - 21 bits. +inline int DecodeVarLenUint16(BitReader* input) { + if (input->ReadFixedBits<1>()) { + int nbits = static_cast<int>(input->ReadFixedBits<4>()); + if (nbits == 0) { + return 1; + } else { + return static_cast<int>(input->ReadBits(nbits)) + (1 << nbits); + } + } + return 0; +} + +Status ReadHistogram(int precision_bits, std::vector<int>* counts, + BitReader* input) { + int simple_code = input->ReadBits(1); + if (simple_code == 1) { + int i; + int symbols[2] = {0}; + int max_symbol = 0; + const int num_symbols = input->ReadBits(1) + 1; + for (i = 0; i < num_symbols; ++i) { + symbols[i] = DecodeVarLenUint8(input); + if (symbols[i] > max_symbol) max_symbol = symbols[i]; + } + counts->resize(max_symbol + 1); + if (num_symbols == 1) { + (*counts)[symbols[0]] = 1 << precision_bits; + } else { + if (symbols[0] == symbols[1]) { // corrupt data + return false; + } + (*counts)[symbols[0]] = input->ReadBits(precision_bits); + (*counts)[symbols[1]] = (1 << precision_bits) - (*counts)[symbols[0]]; + } + } else { + int is_flat = input->ReadBits(1); + if (is_flat == 1) { + int alphabet_size = DecodeVarLenUint8(input) + 1; + *counts = CreateFlatHistogram(alphabet_size, 1 << precision_bits); + return true; + } + + uint32_t shift; + { + // TODO(veluca): speed up reading with table lookups. + int upper_bound_log = FloorLog2Nonzero(ANS_LOG_TAB_SIZE + 1); + int log = 0; + for (; log < upper_bound_log; log++) { + if (input->ReadFixedBits<1>() == 0) break; + } + shift = (input->ReadBits(log) | (1 << log)) - 1; + if (shift > ANS_LOG_TAB_SIZE + 1) { + return JXL_FAILURE("Invalid shift value"); + } + } + + int length = DecodeVarLenUint8(input) + 3; + counts->resize(length); + int total_count = 0; + + static const uint8_t huff[128][2] = { + {3, 10}, {7, 12}, {3, 7}, {4, 3}, {3, 6}, {3, 8}, {3, 9}, {4, 5}, + {3, 10}, {4, 4}, {3, 7}, {4, 1}, {3, 6}, {3, 8}, {3, 9}, {4, 2}, + {3, 10}, {5, 0}, {3, 7}, {4, 3}, {3, 6}, {3, 8}, {3, 9}, {4, 5}, + {3, 10}, {4, 4}, {3, 7}, {4, 1}, {3, 6}, {3, 8}, {3, 9}, {4, 2}, + {3, 10}, {6, 11}, {3, 7}, {4, 3}, {3, 6}, {3, 8}, {3, 9}, {4, 5}, + {3, 10}, {4, 4}, {3, 7}, {4, 1}, {3, 6}, {3, 8}, {3, 9}, {4, 2}, + {3, 10}, {5, 0}, {3, 7}, {4, 3}, {3, 6}, {3, 8}, {3, 9}, {4, 5}, + {3, 10}, {4, 4}, {3, 7}, {4, 1}, {3, 6}, {3, 8}, {3, 9}, {4, 2}, + {3, 10}, {7, 13}, {3, 7}, {4, 3}, {3, 6}, {3, 8}, {3, 9}, {4, 5}, + {3, 10}, {4, 4}, {3, 7}, {4, 1}, {3, 6}, {3, 8}, {3, 9}, {4, 2}, + {3, 10}, {5, 0}, {3, 7}, {4, 3}, {3, 6}, {3, 8}, {3, 9}, {4, 5}, + {3, 10}, {4, 4}, {3, 7}, {4, 1}, {3, 6}, {3, 8}, {3, 9}, {4, 2}, + {3, 10}, {6, 11}, {3, 7}, {4, 3}, {3, 6}, {3, 8}, {3, 9}, {4, 5}, + {3, 10}, {4, 4}, {3, 7}, {4, 1}, {3, 6}, {3, 8}, {3, 9}, {4, 2}, + {3, 10}, {5, 0}, {3, 7}, {4, 3}, {3, 6}, {3, 8}, {3, 9}, {4, 5}, + {3, 10}, {4, 4}, {3, 7}, {4, 1}, {3, 6}, {3, 8}, {3, 9}, {4, 2}, + }; + + std::vector<int> logcounts(counts->size()); + int omit_log = -1; + int omit_pos = -1; + // This array remembers which symbols have an RLE length. + std::vector<int> same(counts->size(), 0); + for (size_t i = 0; i < logcounts.size(); ++i) { + input->Refill(); // for PeekFixedBits + Advance + int idx = input->PeekFixedBits<7>(); + input->Consume(huff[idx][0]); + logcounts[i] = huff[idx][1]; + // The RLE symbol. + if (logcounts[i] == ANS_LOG_TAB_SIZE + 1) { + int rle_length = DecodeVarLenUint8(input); + same[i] = rle_length + 5; + i += rle_length + 3; + continue; + } + if (logcounts[i] > omit_log) { + omit_log = logcounts[i]; + omit_pos = i; + } + } + // Invalid input, e.g. due to invalid usage of RLE. + if (omit_pos < 0) return JXL_FAILURE("Invalid histogram."); + if (static_cast<size_t>(omit_pos) + 1 < logcounts.size() && + logcounts[omit_pos + 1] == ANS_TAB_SIZE + 1) { + return JXL_FAILURE("Invalid histogram."); + } + int prev = 0; + int numsame = 0; + for (size_t i = 0; i < logcounts.size(); ++i) { + if (same[i]) { + // RLE sequence, let this loop output the same count for the next + // iterations. + numsame = same[i] - 1; + prev = i > 0 ? (*counts)[i - 1] : 0; + } + if (numsame > 0) { + (*counts)[i] = prev; + numsame--; + } else { + int code = logcounts[i]; + // omit_pos may not be negative at this point (checked before). + if (i == static_cast<size_t>(omit_pos)) { + continue; + } else if (code == 0) { + continue; + } else if (code == 1) { + (*counts)[i] = 1; + } else { + int bitcount = GetPopulationCountPrecision(code - 1, shift); + (*counts)[i] = (1 << (code - 1)) + + (input->ReadBits(bitcount) << (code - 1 - bitcount)); + } + } + total_count += (*counts)[i]; + } + (*counts)[omit_pos] = (1 << precision_bits) - total_count; + if ((*counts)[omit_pos] <= 0) { + // The histogram we've read sums to more than total_count (including at + // least 1 for the omitted value). + return JXL_FAILURE("Invalid histogram count."); + } + } + return true; +} + +} // namespace + +Status DecodeANSCodes(const size_t num_histograms, + const size_t max_alphabet_size, BitReader* in, + ANSCode* result) { + result->degenerate_symbols.resize(num_histograms, -1); + if (result->use_prefix_code) { + JXL_ASSERT(max_alphabet_size <= 1 << PREFIX_MAX_BITS); + result->huffman_data.resize(num_histograms); + std::vector<uint16_t> alphabet_sizes(num_histograms); + for (size_t c = 0; c < num_histograms; c++) { + alphabet_sizes[c] = DecodeVarLenUint16(in) + 1; + if (alphabet_sizes[c] > max_alphabet_size) { + return JXL_FAILURE("Alphabet size is too long: %u", alphabet_sizes[c]); + } + } + for (size_t c = 0; c < num_histograms; c++) { + if (alphabet_sizes[c] > 1) { + if (!result->huffman_data[c].ReadFromBitStream(alphabet_sizes[c], in)) { + if (!in->AllReadsWithinBounds()) { + return JXL_STATUS(StatusCode::kNotEnoughBytes, + "Not enough bytes for huffman code"); + } + return JXL_FAILURE("Invalid huffman tree number %" PRIuS + ", alphabet size %u", + c, alphabet_sizes[c]); + } + } else { + // 0-bit codes does not require extension tables. + result->huffman_data[c].table_.clear(); + result->huffman_data[c].table_.resize(1u << kHuffmanTableBits); + } + for (const auto& h : result->huffman_data[c].table_) { + if (h.bits <= kHuffmanTableBits) { + result->UpdateMaxNumBits(c, h.value); + } + } + } + } else { + JXL_ASSERT(max_alphabet_size <= ANS_MAX_ALPHABET_SIZE); + result->alias_tables = + AllocateArray(num_histograms * (1 << result->log_alpha_size) * + sizeof(AliasTable::Entry)); + AliasTable::Entry* alias_tables = + reinterpret_cast<AliasTable::Entry*>(result->alias_tables.get()); + for (size_t c = 0; c < num_histograms; ++c) { + std::vector<int> counts; + if (!ReadHistogram(ANS_LOG_TAB_SIZE, &counts, in)) { + return JXL_FAILURE("Invalid histogram bitstream."); + } + if (counts.size() > max_alphabet_size) { + return JXL_FAILURE("Alphabet size is too long: %" PRIuS, counts.size()); + } + while (!counts.empty() && counts.back() == 0) { + counts.pop_back(); + } + for (size_t s = 0; s < counts.size(); s++) { + if (counts[s] != 0) { + result->UpdateMaxNumBits(c, s); + } + } + // InitAliasTable "fixes" empty counts to contain degenerate "0" symbol. + int degenerate_symbol = counts.empty() ? 0 : (counts.size() - 1); + for (int s = 0; s < degenerate_symbol; ++s) { + if (counts[s] != 0) { + degenerate_symbol = -1; + break; + } + } + result->degenerate_symbols[c] = degenerate_symbol; + InitAliasTable(counts, ANS_TAB_SIZE, result->log_alpha_size, + alias_tables + c * (1 << result->log_alpha_size)); + } + } + return true; +} +Status DecodeUintConfig(size_t log_alpha_size, HybridUintConfig* uint_config, + BitReader* br) { + br->Refill(); + size_t split_exponent = br->ReadBits(CeilLog2Nonzero(log_alpha_size + 1)); + size_t msb_in_token = 0, lsb_in_token = 0; + if (split_exponent != log_alpha_size) { + // otherwise, msb/lsb don't matter. + size_t nbits = CeilLog2Nonzero(split_exponent + 1); + msb_in_token = br->ReadBits(nbits); + if (msb_in_token > split_exponent) { + // This could be invalid here already and we need to check this before + // we use its value to read more bits. + return JXL_FAILURE("Invalid HybridUintConfig"); + } + nbits = CeilLog2Nonzero(split_exponent - msb_in_token + 1); + lsb_in_token = br->ReadBits(nbits); + } + if (lsb_in_token + msb_in_token > split_exponent) { + return JXL_FAILURE("Invalid HybridUintConfig"); + } + *uint_config = HybridUintConfig(split_exponent, msb_in_token, lsb_in_token); + return true; +} + +Status DecodeUintConfigs(size_t log_alpha_size, + std::vector<HybridUintConfig>* uint_config, + BitReader* br) { + // TODO(veluca): RLE? + for (size_t i = 0; i < uint_config->size(); i++) { + JXL_RETURN_IF_ERROR( + DecodeUintConfig(log_alpha_size, &(*uint_config)[i], br)); + } + return true; +} + +LZ77Params::LZ77Params() { Bundle::Init(this); } +Status LZ77Params::VisitFields(Visitor* JXL_RESTRICT visitor) { + JXL_QUIET_RETURN_IF_ERROR(visitor->Bool(false, &enabled)); + if (!visitor->Conditional(enabled)) return true; + JXL_QUIET_RETURN_IF_ERROR(visitor->U32(Val(224), Val(512), Val(4096), + BitsOffset(15, 8), 224, &min_symbol)); + JXL_QUIET_RETURN_IF_ERROR(visitor->U32(Val(3), Val(4), BitsOffset(2, 5), + BitsOffset(8, 9), 3, &min_length)); + return true; +} + +void ANSCode::UpdateMaxNumBits(size_t ctx, size_t symbol) { + HybridUintConfig* cfg = &uint_config[ctx]; + // LZ77 symbols use a different uint config. + if (lz77.enabled && lz77.nonserialized_distance_context != ctx && + symbol >= lz77.min_symbol) { + symbol -= lz77.min_symbol; + cfg = &lz77.length_uint_config; + } + size_t split_token = cfg->split_token; + size_t msb_in_token = cfg->msb_in_token; + size_t lsb_in_token = cfg->lsb_in_token; + size_t split_exponent = cfg->split_exponent; + if (symbol < split_token) { + max_num_bits = std::max(max_num_bits, split_exponent); + return; + } + uint32_t n_extra_bits = + split_exponent - (msb_in_token + lsb_in_token) + + ((symbol - split_token) >> (msb_in_token + lsb_in_token)); + size_t total_bits = msb_in_token + lsb_in_token + n_extra_bits + 1; + max_num_bits = std::max(max_num_bits, total_bits); +} + +Status DecodeHistograms(BitReader* br, size_t num_contexts, ANSCode* code, + std::vector<uint8_t>* context_map, bool disallow_lz77) { + PROFILER_FUNC; + JXL_RETURN_IF_ERROR(Bundle::Read(br, &code->lz77)); + if (code->lz77.enabled) { + num_contexts++; + JXL_RETURN_IF_ERROR(DecodeUintConfig(/*log_alpha_size=*/8, + &code->lz77.length_uint_config, br)); + } + if (code->lz77.enabled && disallow_lz77) { + return JXL_FAILURE("Using LZ77 when explicitly disallowed"); + } + size_t num_histograms = 1; + context_map->resize(num_contexts); + if (num_contexts > 1) { + JXL_RETURN_IF_ERROR(DecodeContextMap(context_map, &num_histograms, br)); + } + code->lz77.nonserialized_distance_context = context_map->back(); + code->use_prefix_code = br->ReadFixedBits<1>(); + if (code->use_prefix_code) { + code->log_alpha_size = PREFIX_MAX_BITS; + } else { + code->log_alpha_size = br->ReadFixedBits<2>() + 5; + } + code->uint_config.resize(num_histograms); + JXL_RETURN_IF_ERROR( + DecodeUintConfigs(code->log_alpha_size, &code->uint_config, br)); + const size_t max_alphabet_size = 1 << code->log_alpha_size; + JXL_RETURN_IF_ERROR( + DecodeANSCodes(num_histograms, max_alphabet_size, br, code)); + // When using LZ77, flat codes might result in valid codestreams with + // histograms that potentially allow very large bit counts. + // TODO(veluca): in principle, a valid codestream might contain a histogram + // that could allow very large numbers of bits that is never used during ANS + // decoding. There's no benefit to doing that, though. + if (!code->lz77.enabled && code->max_num_bits > 32) { + // Just emit a warning as there are many opportunities for false positives. + JXL_WARNING("Histogram can represent numbers that are too large: %" PRIuS + "\n", + code->max_num_bits); + } + return true; +} + +} // namespace jxl |