LCOV - code coverage report
Current view: top level - frmts/nitf - kdtree_vq_cadrg.h (source / functions) Hit Total Coverage
Test: gdal_filtered.info Lines: 200 200 100.0 %
Date: 2026-03-05 10:33:42 Functions: 19 19 100.0 %

          Line data    Source code
       1             : /******************************************************************************
       2             :  *
       3             :  * Project:  NITF Read/Write Library
       4             :  * Purpose:  Specialization of KDTree for CADRG VQ compression
       5             :  * Author:   Even Rouault, even dot rouault at spatialys dot com
       6             :  *
       7             :  **********************************************************************
       8             :  * Copyright (c) 2026, T-Kartor
       9             :  *
      10             :  * SPDX-License-Identifier: MIT
      11             :  ****************************************************************************/
      12             : 
      13             : #ifndef KDTREE_VQ_CADRG_INCLUDED
      14             : #define KDTREE_VQ_CADRG_INCLUDED
      15             : 
      16             : #include "kdtree.h"
      17             : 
      18             : #include <array>
      19             : #include <limits>
      20             : 
      21             : #ifdef __GNUC__
      22             : #pragma GCC diagnostic push
      23             : #pragma GCC diagnostic ignored "-Wold-style-cast"
      24             : #pragma GCC diagnostic ignored "-Weffc++"
      25             : #endif
      26             : #include "../../third_party/libdivide/libdivide.h"
      27             : #ifdef __GNUC__
      28             : #pragma GCC diagnostic pop
      29             : #endif
      30             : 
      31             : #ifdef KDTREE_USE_SSE2
      32             : 
      33             : #include <emmintrin.h>
      34             : #ifdef __SSE4_1__
      35             : #include <smmintrin.h>
      36             : #endif
      37             : #ifdef __AVX2__
      38             : #include <immintrin.h>
      39             : #endif
      40             : 
      41             : namespace
      42             : {
      43    66205824 : inline __m128i blendv_epi8(__m128i a, __m128i b, __m128i mask)
      44             : {
      45             : #ifdef __SSE4_1__
      46             :     return _mm_blendv_epi8(a, b, mask);
      47             : #else
      48   198616872 :     return _mm_or_si128(_mm_andnot_si128(mask, a), _mm_and_si128(mask, b));
      49             : #endif
      50             : }
      51             : }  // namespace
      52             : #endif
      53             : 
      54             : namespace
      55             : {
      56    40775900 : template <class T> T square(T x)
      57             : {
      58    40775900 :     return x * x;
      59             : }
      60             : 
      61             : /************************************************************************/
      62             : /*                            filled_array()                            */
      63             : /************************************************************************/
      64             : 
      65             : template <typename T, std::size_t N>
      66      193532 : constexpr std::array<T, N> filled_array(const T &value)
      67             : {
      68             : #ifdef __COVERITY__
      69             :     std::array<T, N> a{};
      70             : #else
      71             :     std::array<T, N> a;
      72             : #endif
      73     3290046 :     for (auto &x : a)
      74     3096518 :         x = value;
      75      193532 :     return a;
      76             : }
      77             : }  // namespace
      78             : 
      79             : /************************************************************************/
      80             : /*                       ColorTableBased4x4Pixels                       */
      81             : /************************************************************************/
      82             : 
      83             : struct ColorTableBased4x4Pixels
      84             : {
      85             :     static constexpr int COMP_COUNT = 3;
      86          35 :     explicit ColorTableBased4x4Pixels(const std::vector<GByte> &R,
      87             :                                       const std::vector<GByte> &G,
      88             :                                       const std::vector<GByte> &B)
      89          35 :         : m_R(R), m_G(G), m_B(B), m_RGB({&m_R, &m_G, &m_B})
      90             : #if defined(KDTREE_USE_SSE2) && defined(__AVX2__)
      91             :           ,
      92             :           m_RGB32({&m_R32, &m_G32, &m_B32})
      93             : #endif
      94             :     {
      95             : #if defined(KDTREE_USE_SSE2) && defined(__AVX2__)
      96             :         for (size_t i = 0; i < R.size(); ++i)
      97             :         {
      98             :             m_R32.push_back(R[i]);
      99             :             m_G32.push_back(G[i]);
     100             :             m_B32.push_back(B[i]);
     101             :         }
     102             : #endif
     103          35 :     }
     104             : 
     105             :     const std::vector<GByte> &m_R;
     106             :     const std::vector<GByte> &m_G;
     107             :     const std::vector<GByte> &m_B;
     108             :     std::array<const std::vector<GByte> *const, COMP_COUNT> m_RGB;
     109             : #if defined(KDTREE_USE_SSE2) && defined(__AVX2__)
     110             :     std::vector<int32_t> m_R32{}, m_G32{}, m_B32{};
     111             :     std::array<const std::vector<int32_t> *const, COMP_COUNT> m_RGB32;
     112             : #endif
     113             : };
     114             : 
     115             : /************************************************************************/
     116             : /*                   Vector<ColorTableBased4x4Pixels>                   */
     117             : /************************************************************************/
     118             : 
     119             : template <> class Vector<ColorTableBased4x4Pixels>
     120             : {
     121             :   public:
     122             :     static constexpr int PIX_COUNT = 4 * 4;
     123             : 
     124             :   private:
     125             :     static constexpr int COMP_COUNT = ColorTableBased4x4Pixels::COMP_COUNT;
     126             : 
     127             :     std::array<GByte, PIX_COUNT> m_vals;
     128             : 
     129             :     // cppcheck-suppress uninitMemberVarPrivate
     130             :     Vector() = default;
     131             : 
     132             :   public:
     133     1409109 :     explicit Vector(const std::array<GByte, PIX_COUNT> &vals) : m_vals(vals)
     134             :     {
     135     1409109 :     }
     136             : 
     137             :     static constexpr int DIM_COUNT /* specialize */ = COMP_COUNT * PIX_COUNT;
     138             : 
     139     2097152 :     inline GByte val(int i) const
     140             :     {
     141     2097152 :         return m_vals[i];
     142             :     }
     143             : 
     144           1 :     inline GByte *vals()
     145             :     {
     146           1 :         return m_vals.data();
     147             :     }
     148             : 
     149      135140 :     inline const std::array<GByte, PIX_COUNT> &vals() const
     150             :     {
     151      135140 :         return m_vals;
     152             :     }
     153             : 
     154             :     static constexpr bool getReturnUInt8 /* specialize */ = true;
     155             : 
     156   224675756 :     inline int get(int i,
     157             :                    const ColorTableBased4x4Pixels &ctxt) const /* specialize */
     158             :     {
     159   224675756 :         return (*ctxt.m_RGB[i / PIX_COUNT])[m_vals[i % PIX_COUNT]];
     160             :     }
     161             : 
     162             : #if defined(KDTREE_USE_SSE2)
     163             :     static constexpr bool hasComputeFourSquaredDistances /* specialize */ =
     164             :         true;
     165             : 
     166             : #if defined(__SSE4_1__) && defined(__GNUC__)
     167             :     static constexpr bool hasComputeHeightSumAndSumSquareSSE2 /* specialize */ =
     168             :         true;
     169             : 
     170             :     /************************************************************************/
     171             :     /*                    computeHeightSumAndSumSquareSSE2()                  */
     172             :     /************************************************************************/
     173             : 
     174             :     inline void computeHeightSumAndSumSquareSSE2(
     175             :         int k, const ColorTableBased4x4Pixels &ctxt, int count, __m128i &sum0,
     176             :         __m128i &sumSquare0_lo, __m128i &sumSquare0_hi, __m128i &sum1,
     177             :         __m128i &sumSquare1_lo, __m128i &sumSquare1_hi) const
     178             :     {
     179             : #if defined(__AVX2__)
     180             :         const int32_t *comp_data = ctxt.m_RGB32[k / PIX_COUNT]->data();
     181             :         const GByte *pindices = m_vals.data() + (k % PIX_COUNT);
     182             :         const auto idx = _mm256_cvtepu8_epi32(
     183             :             _mm_loadl_epi64(reinterpret_cast<const __m128i *>(pindices)));
     184             :         constexpr int SCALE = static_cast<int>(sizeof(*comp_data));
     185             :         const auto vals = _mm256_i32gather_epi32(comp_data, idx, SCALE);
     186             :         const auto vcount = _mm256_set1_epi32(count);
     187             :         const auto vals_mul_count = _mm256_mullo_epi32(vals, vcount);
     188             :         sum0 = _mm_add_epi32(sum0, _mm256_castsi256_si128(vals_mul_count));
     189             :         sum1 = _mm_add_epi32(sum1, _mm256_extracti128_si256(vals_mul_count, 1));
     190             :         const auto vals_sq_mul_count = _mm256_mullo_epi32(vals, vals_mul_count);
     191             :         const auto vals0_sq_mul_count =
     192             :             _mm256_castsi256_si128(vals_sq_mul_count);
     193             :         const auto vals1_sq_mul_count =
     194             :             _mm256_extracti128_si256(vals_sq_mul_count, 1);
     195             : #else
     196             :         const GByte *comp_data = ctxt.m_RGB[k / PIX_COUNT]->data();
     197             :         const GByte *pindices = m_vals.data() + (k % PIX_COUNT);
     198             :         const auto i32_from_epu8_gather_epu8 =
     199             :             [](const GByte *base_addr, const GByte *pindices)
     200             :         {
     201             :             return _mm_setr_epi32(
     202             :                 base_addr[pindices[0]], base_addr[pindices[1]],
     203             :                 base_addr[pindices[2]], base_addr[pindices[3]]);
     204             :         };
     205             :         const auto vals0 = i32_from_epu8_gather_epu8(comp_data, pindices + 0);
     206             :         const auto vals1 = i32_from_epu8_gather_epu8(comp_data, pindices + 4);
     207             :         const auto vcount = _mm_set1_epi32(count);
     208             :         const auto vals0_mul_count = _mm_mullo_epi32(vals0, vcount);
     209             :         const auto vals1_mul_count = _mm_mullo_epi32(vals1, vcount);
     210             :         sum0 = _mm_add_epi32(sum0, vals0_mul_count);
     211             :         sum1 = _mm_add_epi32(sum1, vals1_mul_count);
     212             :         const auto vals0_sq_mul_count = _mm_mullo_epi32(vals0, vals0_mul_count);
     213             :         const auto vals1_sq_mul_count = _mm_mullo_epi32(vals1, vals1_mul_count);
     214             : #endif
     215             :         sumSquare0_lo = _mm_add_epi64(
     216             :             sumSquare0_lo,
     217             :             _mm_unpacklo_epi32(vals0_sq_mul_count, _mm_setzero_si128()));
     218             :         sumSquare0_hi = _mm_add_epi64(
     219             :             sumSquare0_hi,
     220             :             _mm_unpackhi_epi32(vals0_sq_mul_count, _mm_setzero_si128()));
     221             :         sumSquare1_lo = _mm_add_epi64(
     222             :             sumSquare1_lo,
     223             :             _mm_unpacklo_epi32(vals1_sq_mul_count, _mm_setzero_si128()));
     224             :         sumSquare1_hi = _mm_add_epi64(
     225             :             sumSquare1_hi,
     226             :             _mm_unpackhi_epi32(vals1_sq_mul_count, _mm_setzero_si128()));
     227             :     }
     228             : 
     229             : #else
     230             :     static constexpr bool hasComputeHeightSumAndSumSquareSSE2 /* specialize */ =
     231             :         false;
     232             : #endif
     233             : 
     234             :   private:
     235             :     /************************************************************************/
     236             :     /*                           gatherRGB_epi16()                          */
     237             :     /************************************************************************/
     238             : 
     239     3766864 :     static inline void gatherRGB_epi16(const GByte *indices,
     240             :                                        const ColorTableBased4x4Pixels &ctxt,
     241             :                                        __m128i &r, __m128i &g, __m128i &b)
     242             :     {
     243     3766864 :         const uint8_t i0 = indices[0];
     244     3766864 :         const uint8_t i1 = indices[1];
     245     3766864 :         const uint8_t i2 = indices[2];
     246     3766864 :         const uint8_t i3 = indices[3];
     247     3766864 :         const uint8_t i4 = indices[4];
     248     3766864 :         const uint8_t i5 = indices[5];
     249     3766864 :         const uint8_t i6 = indices[6];
     250     3766864 :         const uint8_t i7 = indices[7];
     251             : 
     252     3766864 :         r = _mm_setr_epi16(ctxt.m_R[i0], ctxt.m_R[i1], ctxt.m_R[i2],
     253     3766864 :                            ctxt.m_R[i3], ctxt.m_R[i4], ctxt.m_R[i5],
     254     3766864 :                            ctxt.m_R[i6], ctxt.m_R[i7]);
     255             : 
     256     3766864 :         g = _mm_setr_epi16(ctxt.m_G[i0], ctxt.m_G[i1], ctxt.m_G[i2],
     257     3766864 :                            ctxt.m_G[i3], ctxt.m_G[i4], ctxt.m_G[i5],
     258     3766864 :                            ctxt.m_G[i6], ctxt.m_G[i7]);
     259             : 
     260     3766864 :         b = _mm_setr_epi16(ctxt.m_B[i0], ctxt.m_B[i1], ctxt.m_B[i2],
     261     3766864 :                            ctxt.m_B[i3], ctxt.m_B[i4], ctxt.m_B[i5],
     262     3766864 :                            ctxt.m_B[i6], ctxt.m_B[i7]);
     263     3766864 :     }
     264             : 
     265             :     /************************************************************************/
     266             :     /*                            updateSums()                              */
     267             :     /************************************************************************/
     268             : 
     269      988328 :     static inline void updateSums(const Vector *other, int i,
     270             :                                   const ColorTableBased4x4Pixels &ctxt,
     271             :                                   __m128i rA, __m128i gA, __m128i bA,
     272             :                                   __m128i &acc)
     273             :     {
     274             :         __m128i rB, gB, bB;
     275      988328 :         gatherRGB_epi16(other->m_vals.data() + i, ctxt, rB, gB, bB);
     276             : 
     277             :         // Compute signed differences
     278      988328 :         const auto diffR = _mm_sub_epi16(rA, rB);
     279      988328 :         const auto diffG = _mm_sub_epi16(gA, gB);
     280     1976652 :         const auto diffB = _mm_sub_epi16(bA, bB);
     281             : 
     282             :         // Square differences
     283      988328 :         const auto sqR = _mm_mullo_epi16(diffR, diffR);
     284      988328 :         const auto sqG = _mm_mullo_epi16(diffG, diffG);
     285      988328 :         const auto sqB = _mm_mullo_epi16(diffB, diffB);
     286             : 
     287             :         // Extend to 32 bit before summing R,G,B
     288     1976652 :         const auto sqR_lo = _mm_unpacklo_epi16(sqR, _mm_setzero_si128());
     289     1976652 :         const auto sqR_hi = _mm_unpackhi_epi16(sqR, _mm_setzero_si128());
     290     1976652 :         const auto sqG_lo = _mm_unpacklo_epi16(sqG, _mm_setzero_si128());
     291     1976652 :         const auto sqG_hi = _mm_unpackhi_epi16(sqG, _mm_setzero_si128());
     292     1976652 :         const auto sqB_lo = _mm_unpacklo_epi16(sqB, _mm_setzero_si128());
     293      988328 :         const auto sqB_hi = _mm_unpackhi_epi16(sqB, _mm_setzero_si128());
     294             : 
     295             :         // Sum RGB
     296      988328 :         acc = _mm_add_epi32(acc, sqR_lo);
     297      988328 :         acc = _mm_add_epi32(acc, sqR_hi);
     298      988328 :         acc = _mm_add_epi32(acc, sqG_lo);
     299      988328 :         acc = _mm_add_epi32(acc, sqG_hi);
     300      988328 :         acc = _mm_add_epi32(acc, sqB_lo);
     301      988328 :         acc = _mm_add_epi32(acc, sqB_hi);
     302      988328 :     }
     303             : 
     304             :   public:
     305             :     /************************************************************************/
     306             :     /*                   compute_four_squared_distances()                   */
     307             :     /************************************************************************/
     308             : 
     309      123541 :     void compute_four_squared_distances(
     310             :         const std::array<const Vector *const, 4> &others,
     311             :         std::array<int, 4> & /* out */ tabSquaredDist,
     312             :         const ColorTableBased4x4Pixels &ctxt) const
     313             :     {
     314      123541 :         __m128i acc_0 = _mm_setzero_si128();
     315      123541 :         __m128i acc_1 = _mm_setzero_si128();
     316      123541 :         __m128i acc_2 = _mm_setzero_si128();
     317      123541 :         __m128i acc_3 = _mm_setzero_si128();
     318             : 
     319      370623 :         for (int i = 0; i < 16; i += 8)
     320             :         {
     321             :             __m128i rA, gA, bA;
     322      247082 :             gatherRGB_epi16(m_vals.data() + i, ctxt, rA, gA, bA);
     323             : 
     324      247082 :             updateSums(others[0], i, ctxt, rA, gA, bA, acc_0);
     325      247082 :             updateSums(others[1], i, ctxt, rA, gA, bA, acc_1);
     326      247082 :             updateSums(others[2], i, ctxt, rA, gA, bA, acc_2);
     327      247082 :             updateSums(others[3], i, ctxt, rA, gA, bA, acc_3);
     328             :         }
     329             : 
     330      494164 :         const auto horizontalSum = [](__m128i acc)
     331             :         {
     332             :             // Horizontal reduction 4 => 1
     333      494164 :             auto tmp = _mm_shuffle_epi32(acc, _MM_SHUFFLE(1, 0, 3, 2));
     334      494164 :             acc = _mm_add_epi32(acc, tmp);
     335      494164 :             tmp = _mm_shuffle_epi32(acc, _MM_SHUFFLE(2, 3, 0, 1));
     336      494164 :             acc = _mm_add_epi32(acc, tmp);
     337      494164 :             return _mm_cvtsi128_si32(acc);
     338             :         };
     339             : 
     340      123541 :         tabSquaredDist[0] = horizontalSum(acc_0);
     341      123541 :         tabSquaredDist[1] = horizontalSum(acc_1);
     342      123541 :         tabSquaredDist[2] = horizontalSum(acc_2);
     343      123541 :         tabSquaredDist[3] = horizontalSum(acc_3);
     344      123541 :     }
     345             : 
     346             : #else
     347             :     static constexpr bool hasComputeFourSquaredDistances /* specialize */ =
     348             :         false;
     349             : #endif
     350             : 
     351             :     /************************************************************************/
     352             :     /*                          squared_distance()                          */
     353             :     /************************************************************************/
     354             : 
     355      632864 :     int squared_distance(
     356             :         const Vector &other,
     357             :         const ColorTableBased4x4Pixels &ctxt) const /* specialize */
     358             :     {
     359             : #if defined(KDTREE_USE_SSE2) && !defined(__AVX2__)
     360      632864 :         __m128i acc0 = _mm_setzero_si128();
     361      632864 :         __m128i acc1 = _mm_setzero_si128();
     362             : 
     363     1898593 :         for (int i = 0; i < 2; ++i)
     364             :         {
     365             :             __m128i rA, gA, bA;
     366     1265732 :             gatherRGB_epi16(m_vals.data() + i * 8, ctxt, rA, gA, bA);
     367             : 
     368             :             __m128i rB, gB, bB;
     369     1265732 :             gatherRGB_epi16(other.m_vals.data() + i * 8, ctxt, rB, gB, bB);
     370             : 
     371             :             // Compute signed differences
     372     1265732 :             const auto diffR = _mm_sub_epi16(rA, rB);
     373     1265732 :             const auto diffG = _mm_sub_epi16(gA, gB);
     374     2531454 :             const auto diffB = _mm_sub_epi16(bA, bB);
     375             : 
     376             :             // Square differences
     377     1265732 :             const auto sqR = _mm_mullo_epi16(diffR, diffR);
     378     1265732 :             const auto sqG = _mm_mullo_epi16(diffG, diffG);
     379     1265732 :             const auto sqB = _mm_mullo_epi16(diffB, diffB);
     380             : 
     381             :             // Extend to 32 bit before summing R,G,B
     382     2531454 :             const auto sqR_lo = _mm_unpacklo_epi16(sqR, _mm_setzero_si128());
     383     2531454 :             const auto sqR_hi = _mm_unpackhi_epi16(sqR, _mm_setzero_si128());
     384     2531454 :             const auto sqG_lo = _mm_unpacklo_epi16(sqG, _mm_setzero_si128());
     385     2531454 :             const auto sqG_hi = _mm_unpackhi_epi16(sqG, _mm_setzero_si128());
     386     2531454 :             const auto sqB_lo = _mm_unpacklo_epi16(sqB, _mm_setzero_si128());
     387     2531454 :             const auto sqB_hi = _mm_unpackhi_epi16(sqB, _mm_setzero_si128());
     388             : 
     389             :             // Sum RGB
     390     1265732 :             acc0 = _mm_add_epi32(acc0, sqR_lo);
     391     1265732 :             acc1 = _mm_add_epi32(acc1, sqR_hi);
     392     1265732 :             acc0 = _mm_add_epi32(acc0, sqG_lo);
     393     1265732 :             acc1 = _mm_add_epi32(acc1, sqG_hi);
     394     1265732 :             acc0 = _mm_add_epi32(acc0, sqB_lo);
     395     1265732 :             acc1 = _mm_add_epi32(acc1, sqB_hi);
     396             :         }
     397             : 
     398             :         // Horizontal reduction 4 => 1
     399      632864 :         auto acc = _mm_add_epi32(acc0, acc1);
     400      632864 :         auto tmp = _mm_shuffle_epi32(acc, _MM_SHUFFLE(1, 0, 3, 2));
     401      632864 :         acc = _mm_add_epi32(acc, tmp);
     402      632864 :         tmp = _mm_shuffle_epi32(acc, _MM_SHUFFLE(2, 3, 0, 1));
     403      632864 :         acc = _mm_add_epi32(acc, tmp);
     404      632864 :         return _mm_cvtsi128_si32(acc);
     405             : 
     406             : #elif defined(KDTREE_USE_SSE2) && defined(__AVX2__)
     407             :         const auto idxA =
     408             :             _mm_loadu_si128(reinterpret_cast<const __m128i *>(m_vals.data()));
     409             :         const auto idxB = _mm_loadu_si128(
     410             :             reinterpret_cast<const __m128i *>(other.m_vals.data()));
     411             : 
     412             :         // Convert from 16 uint8_t values into to 2 vectors of 8 int32_t
     413             :         const auto idxA_lo = _mm256_cvtepu8_epi32(idxA);
     414             :         const auto idxB_lo = _mm256_cvtepu8_epi32(idxB);
     415             :         const auto idxA_hi = _mm256_cvtepu8_epi32(_mm_srli_si128(idxA, 8));
     416             :         const auto idxB_hi = _mm256_cvtepu8_epi32(_mm_srli_si128(idxB, 8));
     417             : 
     418             :         // Gather R, G, B for A and B (8 at a time)
     419             :         const auto gather_epi32 = [](const std::vector<int> &v, __m256i idx)
     420             :         {
     421             :             constexpr int SCALE = static_cast<int>(sizeof(int32_t));
     422             :             return _mm256_i32gather_epi32(v.data(), idx, SCALE);
     423             :         };
     424             :         const auto rA_lo = gather_epi32(ctxt.m_R32, idxA_lo);
     425             :         const auto rB_lo = gather_epi32(ctxt.m_R32, idxB_lo);
     426             :         const auto gA_lo = gather_epi32(ctxt.m_G32, idxA_lo);
     427             :         const auto gB_lo = gather_epi32(ctxt.m_G32, idxB_lo);
     428             :         const auto bA_lo = gather_epi32(ctxt.m_B32, idxA_lo);
     429             :         const auto bB_lo = gather_epi32(ctxt.m_B32, idxB_lo);
     430             : 
     431             :         const auto rA_hi = gather_epi32(ctxt.m_R32, idxA_hi);
     432             :         const auto rB_hi = gather_epi32(ctxt.m_R32, idxB_hi);
     433             :         const auto gA_hi = gather_epi32(ctxt.m_G32, idxA_hi);
     434             :         const auto gB_hi = gather_epi32(ctxt.m_G32, idxB_hi);
     435             :         const auto bA_hi = gather_epi32(ctxt.m_B32, idxA_hi);
     436             :         const auto bB_hi = gather_epi32(ctxt.m_B32, idxB_hi);
     437             : 
     438             :         // Compute square of differences
     439             :         const auto square_epi32 = [](__m256i x)
     440             :         { return _mm256_mullo_epi32(x, x); };
     441             : 
     442             :         const auto dr_lo = square_epi32(_mm256_sub_epi32(rA_lo, rB_lo));
     443             :         const auto dg_lo = square_epi32(_mm256_sub_epi32(gA_lo, gB_lo));
     444             :         const auto db_lo = square_epi32(_mm256_sub_epi32(bA_lo, bB_lo));
     445             : 
     446             :         const auto dr_hi = square_epi32(_mm256_sub_epi32(rA_hi, rB_hi));
     447             :         const auto dg_hi = square_epi32(_mm256_sub_epi32(gA_hi, gB_hi));
     448             :         const auto db_hi = square_epi32(_mm256_sub_epi32(bA_hi, bB_hi));
     449             : 
     450             :         // Sum RGB
     451             :         const auto sum_lo =
     452             :             _mm256_add_epi32(_mm256_add_epi32(dr_lo, dg_lo), db_lo);
     453             :         const auto sum_hi =
     454             :             _mm256_add_epi32(_mm256_add_epi32(dr_hi, dg_hi), db_hi);
     455             : 
     456             :         // Horizontal reduction 16 => 8
     457             :         const auto sum8 = _mm256_add_epi32(sum_lo, sum_hi);
     458             : 
     459             :         // Horizontal reduction 8 => 4
     460             :         const auto sum8_lo = _mm256_castsi256_si128(sum8);
     461             :         const auto sum8_hi = _mm256_extracti128_si256(sum8, 1);
     462             :         auto sum = _mm_add_epi32(sum8_lo, sum8_hi);
     463             : 
     464             :         // Horizontal reduction 4 => 1
     465             :         auto tmp = _mm_shuffle_epi32(sum, _MM_SHUFFLE(1, 0, 3, 2));
     466             :         sum = _mm_add_epi32(sum, tmp);
     467             :         tmp = _mm_shuffle_epi32(sum, _MM_SHUFFLE(2, 3, 0, 1));
     468             :         sum = _mm_add_epi32(sum, tmp);
     469             : 
     470             :         return _mm_cvtsi128_si32(sum);
     471             : 
     472             : #else
     473             :         int nSqDist1 = 0;
     474             :         int nSqDist2 = 0;
     475             :         int nSqDist3 = 0;
     476             :         for (int i = 0; i < PIX_COUNT; ++i)
     477             :         {
     478             :             const int aEntry = m_vals[i];
     479             :             const int bEntry = other.m_vals[i];
     480             :             nSqDist1 += square(ctxt.m_R[aEntry] - ctxt.m_R[bEntry]);
     481             :             nSqDist2 += square(ctxt.m_G[aEntry] - ctxt.m_G[bEntry]);
     482             :             nSqDist3 += square(ctxt.m_B[aEntry] - ctxt.m_B[bEntry]);
     483             :         }
     484             :         return nSqDist1 + nSqDist2 + nSqDist3;
     485             : #endif
     486             :     }
     487             : 
     488             :     /************************************************************************/
     489             :     /*                              centroid()                              */
     490             :     /************************************************************************/
     491             : 
     492             :     static Vector
     493      106111 :     centroid(const Vector &a, int nA, const Vector &b, int nB,
     494             :              const ColorTableBased4x4Pixels &ctxt) /* specialize */
     495             :     {
     496      106111 :         auto minSqDist = filled_array<int, PIX_COUNT>(256 * 256 * COMP_COUNT);
     497             :         Vector res;
     498      106111 :         libdivide::divider<uint32_t> divisor(static_cast<uint32_t>(nA + nB));
     499     1803891 :         for (int k = 0; k < PIX_COUNT; ++k)
     500             :         {
     501     1697778 :             const int aEntry = a.m_vals[k];
     502     1697778 :             const int bEntry = b.m_vals[k];
     503             :             const int meanR =
     504     1697778 :                 static_cast<uint32_t>(ctxt.m_R[aEntry] * nA +
     505     1697778 :                                       ctxt.m_R[bEntry] * nB + (nA + nB) / 2) /
     506     1697778 :                 divisor;
     507             :             const int meanG =
     508     1697778 :                 static_cast<uint32_t>(ctxt.m_G[aEntry] * nA +
     509     1697778 :                                       ctxt.m_G[bEntry] * nB + (nA + nB) / 2) /
     510     1697778 :                 divisor;
     511             :             const int meanB =
     512     1697778 :                 static_cast<uint32_t>(ctxt.m_B[aEntry] * nA +
     513     1697778 :                                       ctxt.m_B[bEntry] * nB + (nA + nB) / 2) /
     514     1697778 :                 divisor;
     515             : 
     516     1697778 :             assert(meanR <= 255);
     517     1697778 :             assert(meanG <= 255);
     518     1697778 :             assert(meanB <= 255);
     519             : 
     520             : #ifdef PRECISE_DISTANCE_COMPUTATION
     521             :             constexpr int BIT_SHIFT = 0;
     522             : #else
     523             :             // Minimum value to avoid int16 overflow when adding 3 squares of
     524             :             // uint8, because 3 * ((255 * 255) >> 3) = 24384 < INT16_MAX
     525     1697778 :             constexpr int BIT_SHIFT = 3;
     526             : #endif
     527             : 
     528     1697778 :             int i = 0;
     529             : #if defined(KDTREE_USE_SSE2)
     530     1697778 :             const auto targetR = _mm_set1_epi16(static_cast<short>(meanR));
     531     1697778 :             const auto targetG = _mm_set1_epi16(static_cast<short>(meanG));
     532     1697778 :             const auto targetB = _mm_set1_epi16(static_cast<short>(meanB));
     533             : 
     534             :             // Initialize min distance vector with max int32 values
     535             : #ifdef PRECISE_DISTANCE_COMPUTATION
     536             :             auto minDistVec0 = _mm_set1_epi32(std::numeric_limits<int>::max());
     537             :             auto minDistVec1 = _mm_set1_epi32(std::numeric_limits<int>::max());
     538             :             auto minDistVec2 = _mm_set1_epi32(std::numeric_limits<int>::max());
     539             :             auto minDistVec3 = _mm_set1_epi32(std::numeric_limits<int>::max());
     540             : #else
     541             :             auto minDistVec0 =
     542     1697778 :                 _mm_set1_epi16(std::numeric_limits<short>::max());
     543             :             auto minDistVec1 =
     544     3395556 :                 _mm_set1_epi16(std::numeric_limits<short>::max());
     545             : #endif
     546             : 
     547             :             // Initialize index vectors for tracking best index per lane
     548     1697778 :             auto idx = _mm_setr_epi8(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12,
     549             :                                      13, 14, 15);
     550     1697778 :             const auto zero = _mm_setzero_si128();
     551     1697778 :             auto idxMin = zero;
     552             : 
     553   132411248 :             const auto square_lo_epi16 = [](__m128i x)
     554   132411248 :             { return _mm_mullo_epi16(x, x); };
     555             : 
     556     1697778 :             constexpr int VALS_AT_ONCE = 16;
     557    23766416 :             for (; i + VALS_AT_ONCE <= static_cast<int>(ctxt.m_R.size());
     558    22068608 :                  i += VALS_AT_ONCE)
     559             :             {
     560             :                 // Load 16 color components
     561    22068608 :                 const auto valsR = _mm_loadu_si128(
     562    22068608 :                     reinterpret_cast<const __m128i *>(ctxt.m_R.data() + i));
     563    22068608 :                 const auto valsG = _mm_loadu_si128(
     564    22068608 :                     reinterpret_cast<const __m128i *>(ctxt.m_G.data() + i));
     565    22068608 :                 const auto valsB = _mm_loadu_si128(
     566    22068608 :                     reinterpret_cast<const __m128i *>(ctxt.m_B.data() + i));
     567             : 
     568    22068608 :                 const auto valsR_lo = _mm_unpacklo_epi8(valsR, zero);
     569    22068608 :                 const auto valsR_hi = _mm_unpackhi_epi8(valsR, zero);
     570    22068608 :                 const auto valsG_lo = _mm_unpacklo_epi8(valsG, zero);
     571    22068608 :                 const auto valsG_hi = _mm_unpackhi_epi8(valsG, zero);
     572    22068608 :                 const auto valsB_lo = _mm_unpacklo_epi8(valsB, zero);
     573    22068608 :                 const auto valsB_hi = _mm_unpackhi_epi8(valsB, zero);
     574             : 
     575             :                 // Compute signed differences
     576    22068608 :                 const auto diffR_lo = _mm_sub_epi16(valsR_lo, targetR);
     577    22068608 :                 const auto diffR_hi = _mm_sub_epi16(valsR_hi, targetR);
     578    22068608 :                 const auto diffG_lo = _mm_sub_epi16(valsG_lo, targetG);
     579    22068608 :                 const auto diffG_hi = _mm_sub_epi16(valsG_hi, targetG);
     580    22068608 :                 const auto diffB_lo = _mm_sub_epi16(valsB_lo, targetB);
     581    22068608 :                 const auto diffB_hi = _mm_sub_epi16(valsB_hi, targetB);
     582             : 
     583             :                 // Square differences
     584    22068608 :                 const auto sqR_lo = square_lo_epi16(diffR_lo);
     585    22068608 :                 const auto sqR_hi = square_lo_epi16(diffR_hi);
     586    22068608 :                 const auto sqG_lo = square_lo_epi16(diffG_lo);
     587    22068608 :                 const auto sqG_hi = square_lo_epi16(diffG_hi);
     588    22068608 :                 const auto sqB_lo = square_lo_epi16(diffB_lo);
     589    22068608 :                 const auto sqB_hi = square_lo_epi16(diffB_hi);
     590             : 
     591             : #ifdef PRECISE_DISTANCE_COMPUTATION
     592             :                 // Convert squares from 16-bit to 32-bit integers
     593             :                 const auto sqR0 = _mm_unpacklo_epi16(sqR_lo, zero);
     594             :                 const auto sqR1 = _mm_unpackhi_epi16(sqR_lo, zero);
     595             :                 const auto sqR2 = _mm_unpacklo_epi16(sqR_hi, zero);
     596             :                 const auto sqR3 = _mm_unpackhi_epi16(sqR_hi, zero);
     597             : 
     598             :                 const auto sqG0 = _mm_unpacklo_epi16(sqG_lo, zero);
     599             :                 const auto sqG1 = _mm_unpackhi_epi16(sqG_lo, zero);
     600             :                 const auto sqG2 = _mm_unpacklo_epi16(sqG_hi, zero);
     601             :                 const auto sqG3 = _mm_unpackhi_epi16(sqG_hi, zero);
     602             : 
     603             :                 const auto sqB0 = _mm_unpacklo_epi16(sqB_lo, zero);
     604             :                 const auto sqB1 = _mm_unpackhi_epi16(sqB_lo, zero);
     605             :                 const auto sqB2 = _mm_unpacklo_epi16(sqB_hi, zero);
     606             :                 const auto sqB3 = _mm_unpackhi_epi16(sqB_hi, zero);
     607             : 
     608             :                 // Sum squared differences for each 32-bit lane: (R + G + B)
     609             :                 const auto dist0 =
     610             :                     _mm_add_epi32(_mm_add_epi32(sqR0, sqG0), sqB0);
     611             :                 const auto dist1 =
     612             :                     _mm_add_epi32(_mm_add_epi32(sqR1, sqG1), sqB1);
     613             :                 const auto dist2 =
     614             :                     _mm_add_epi32(_mm_add_epi32(sqR2, sqG2), sqB2);
     615             :                 const auto dist3 =
     616             :                     _mm_add_epi32(_mm_add_epi32(sqR3, sqG3), sqB3);
     617             : 
     618             :                 // Compare with current minimum distances
     619             :                 auto mask0 = _mm_cmplt_epi32(dist0, minDistVec0);
     620             :                 auto mask1 = _mm_cmplt_epi32(dist1, minDistVec1);
     621             :                 auto mask2 = _mm_cmplt_epi32(dist2, minDistVec2);
     622             :                 auto mask3 = _mm_cmplt_epi32(dist3, minDistVec3);
     623             : 
     624             :                 // Update minimum distances
     625             :                 minDistVec0 = blendv_epi8(minDistVec0, dist0, mask0);
     626             :                 minDistVec1 = blendv_epi8(minDistVec1, dist1, mask1);
     627             :                 minDistVec2 = blendv_epi8(minDistVec2, dist2, mask2);
     628             :                 minDistVec3 = blendv_epi8(minDistVec3, dist3, mask3);
     629             : 
     630             :                 // Merge the 4 masks of 4 x uint32_t into
     631             :                 // a single mask 16 x 1 uint8_t mask
     632             :                 mask0 = _mm_srli_epi32(mask0, 24);
     633             :                 mask1 = _mm_srli_epi32(mask1, 24);
     634             :                 mask2 = _mm_srli_epi32(mask2, 24);
     635             :                 mask3 = _mm_srli_epi32(mask3, 24);
     636             :                 const auto mask_merged =
     637             :                     _mm_packus_epi16(_mm_packs_epi32(mask0, mask1),
     638             :                                      _mm_packs_epi32(mask2, mask3));
     639             : 
     640             :                 // Update indices
     641             :                 idxMin = blendv_epi8(idxMin, idx, mask_merged);
     642             : 
     643             : #else
     644             :                 // Sum squared differences, by removing a few LSB bits to avoid
     645             :                 // overflows.
     646   110343040 :                 const auto dist0 = _mm_add_epi16(
     647             :                     _mm_add_epi16(_mm_srli_epi16(sqR_lo, BIT_SHIFT),
     648             :                                   _mm_srli_epi16(sqG_lo, BIT_SHIFT)),
     649             :                     _mm_srli_epi16(sqB_lo, BIT_SHIFT));
     650   110343040 :                 const auto dist1 = _mm_add_epi16(
     651             :                     _mm_add_epi16(_mm_srli_epi16(sqR_hi, BIT_SHIFT),
     652             :                                   _mm_srli_epi16(sqG_hi, BIT_SHIFT)),
     653             :                     _mm_srli_epi16(sqB_hi, BIT_SHIFT));
     654             : 
     655             :                 // Compare with current minimum distances
     656    22068608 :                 auto mask0 = _mm_cmplt_epi16(dist0, minDistVec0);
     657    22068608 :                 auto mask1 = _mm_cmplt_epi16(dist1, minDistVec1);
     658             : 
     659             :                 // Update minimum distances
     660    22068608 :                 minDistVec0 = blendv_epi8(minDistVec0, dist0, mask0);
     661    22068608 :                 minDistVec1 = blendv_epi8(minDistVec1, dist1, mask1);
     662             : 
     663             :                 // Merge the 2 masks of 8 x uint16_t into
     664             :                 // a single mask 16 x 1 uint8_t mask
     665    22068608 :                 mask0 = _mm_srli_epi16(mask0, 8);
     666    22068608 :                 mask1 = _mm_srli_epi16(mask1, 8);
     667    22068608 :                 const auto mask_merged = _mm_packus_epi16(mask0, mask1);
     668             : 
     669             :                 // Update indices
     670    22068608 :                 idxMin = blendv_epi8(idxMin, idx, mask_merged);
     671             : #endif
     672             : 
     673    44137216 :                 idx = _mm_add_epi8(idx, _mm_set1_epi8(VALS_AT_ONCE));
     674             :             }
     675             : 
     676             :             // Horizontal update
     677             : #ifdef PRECISE_DISTANCE_COMPUTATION
     678             :             int minDistVals[VALS_AT_ONCE];
     679             :             _mm_storeu_si128(reinterpret_cast<__m128i *>(minDistVals + 0),
     680             :                              minDistVec0);
     681             :             _mm_storeu_si128(reinterpret_cast<__m128i *>(minDistVals + 4),
     682             :                              minDistVec1);
     683             :             _mm_storeu_si128(reinterpret_cast<__m128i *>(minDistVals + 8),
     684             :                              minDistVec2);
     685             :             _mm_storeu_si128(reinterpret_cast<__m128i *>(minDistVals + 12),
     686             :                              minDistVec3);
     687             : #else
     688             :             short minDistVals[VALS_AT_ONCE];
     689             :             _mm_storeu_si128(reinterpret_cast<__m128i *>(minDistVals + 0),
     690             :                              minDistVec0);
     691     1697778 :             _mm_storeu_si128(reinterpret_cast<__m128i *>(minDistVals + 8),
     692             :                              minDistVec1);
     693             : #endif
     694             : 
     695             :             GByte minIdxVals[VALS_AT_ONCE];
     696             :             _mm_storeu_si128(reinterpret_cast<__m128i *>(minIdxVals), idxMin);
     697             : 
     698    28862236 :             for (int j = 0; j < VALS_AT_ONCE; ++j)
     699             :             {
     700    48786996 :                 if (minDistVals[j] < minSqDist[k] ||
     701    21622568 :                     (minDistVals[j] == minSqDist[k] &&
     702      740012 :                      minIdxVals[j] < res.m_vals[k]))
     703             :                 {
     704     5872790 :                     minSqDist[k] = minDistVals[j];
     705     5872790 :                     res.m_vals[k] = minIdxVals[j];
     706             :                 }
     707             :             }
     708             : #endif
     709             : 
     710             :             // Generic/scalar code
     711    15278308 :             for (; i < static_cast<int>(ctxt.m_R.size()); ++i)
     712             :             {
     713    13580500 :                 const int sqDist = (square(meanR - ctxt.m_R[i]) >> BIT_SHIFT) +
     714    13580500 :                                    (square(meanG - ctxt.m_G[i]) >> BIT_SHIFT) +
     715    13580500 :                                    (square(meanB - ctxt.m_B[i]) >> BIT_SHIFT);
     716    13580500 :                 if (sqDist < minSqDist[k])
     717             :                 {
     718       60127 :                     minSqDist[k] = sqDist;
     719       60127 :                     res.m_vals[k] = static_cast<GByte>(i);
     720             :                 }
     721             :             }
     722             :         }
     723      106111 :         return res;
     724             :     }
     725             : 
     726             :     /************************************************************************/
     727             :     /*                           operator == ()                             */
     728             :     /************************************************************************/
     729             : 
     730     1053204 :     inline bool operator==(const Vector &other) const
     731             :     {
     732     1053204 :         return m_vals == other.m_vals;
     733             :     }
     734             : 
     735             :     /************************************************************************/
     736             :     /*                           operator < ()                              */
     737             :     /************************************************************************/
     738             : 
     739             :     // Purely arbitrary for the purpose of distinguishing a vector from
     740             :     // another one
     741    56072279 :     inline bool operator<(const Vector &other) const
     742             :     {
     743    56072279 :         return m_vals < other.m_vals;
     744             :     }
     745             : };
     746             : 
     747             : #endif  // KDTREE_VQ_CADRG_INCLUDED

Generated by: LCOV version 1.14