LCOV - code coverage report
Current view: top level - frmts/nitf - kdtree.h (source / functions) Hit Total Coverage
Test: gdal_filtered.info Lines: 315 333 94.6 %
Date: 2026-03-05 10:33:42 Functions: 32 32 100.0 %

          Line data    Source code
       1             : /******************************************************************************
       2             :  *
       3             :  * Project:  NITF Read/Write Library
       4             :  * Purpose:  Pairwise Nearest Neighbor (PNN) clustering for Vector
       5             :  *           Quantization (VQ) using a KDTree.
       6             :  * Author:   Even Rouault, even dot rouault at spatialys dot com
       7             :  *
       8             :  **********************************************************************
       9             :  * Copyright (c) 2026, T-Kartor
      10             :  *
      11             :  * SPDX-License-Identifier: MIT
      12             :  ****************************************************************************/
      13             : 
      14             : #ifndef KDTREE_INCLUDED
      15             : #define KDTREE_INCLUDED
      16             : 
      17             : /**
      18             :  * This file implements Pairwise Nearest Neighbor (PNN) clustering for Vector
      19             :  * Quantization (VQ) using a KDTree.
      20             :  *
      21             :  * It implements paper "A New Vector Quantization Clustering Algorithm", by
      22             :  * William H. Equitz, from IEEE Transactions on Acoustics, Speech, and Signal
      23             :  * Processing, Vol. 37, Issue 10, October 1989. DOI: 10.1109/29.35395
      24             :  * https://ieeexplore.ieee.org/document/35395 (behind paywall)
      25             :  *
      26             :  * A higher level (freely accessible) and more generic paper on PNN clustering
      27             :  * is also available at
      28             :  * https://www.researchgate.net/publication/27661047_Pairwise_Nearest_Neighbor_Method_Revisited
      29             :  *
      30             :  * Papers "Analysis of Compression Techniques for Common Mapping Stdandard (CMS)
      31             :  * Raster Data" by N.J. Markuson, July 1994 (https://apps.dtic.mil/sti/tr/pdf/ADA283396.pdf)
      32             :  * and "Compression of Digitized Map Image" by D.A. Southard, March 1992
      33             :  * (https://apps.dtic.mil/sti/tr/pdf/ADA250707.pdf) analyses VQ compression and
      34             :  * contain a high-level description of the Equitz paper.
      35             :  */
      36             : 
      37             : #include <cassert>
      38             : #include <cstdio>
      39             : 
      40             : #include <algorithm>
      41             : #include <array>
      42             : #include <deque>
      43             : #include <iterator>
      44             : #include <functional>
      45             : #include <limits>
      46             : #include <map>
      47             : #include <memory>
      48             : #include <set>
      49             : #include <utility>
      50             : #include <vector>
      51             : 
      52             : #include "cpl_error.h"
      53             : 
      54             : // #define DEBUG_INVARIANTS
      55             : 
      56             : #ifdef KDTREE_DEBUG_TIMING
      57             : #include <sys/time.h>
      58             : 
      59             : static double totalTimeRebalancing = 0;
      60             : static double totalTimeCentroid = 0;
      61             : static double totalTimeStats = 0;
      62             : 
      63             : #endif
      64             : 
      65             : #if defined(__GNUC__) && !defined(__clang__)
      66             : #pragma GCC optimize("unroll-loops")
      67             : #endif
      68             : 
      69             : #if (defined(__x86_64__) || defined(_M_X64)) && !defined(KDTREE_DISABLE_SIMD)
      70             : #define KDTREE_USE_SSE2
      71             : #endif
      72             : 
      73             : #ifdef KDTREE_USE_SSE2
      74             : #include <emmintrin.h>
      75             : #endif
      76             : 
      77             : /************************************************************************/
      78             : /*                              Vector<T>                               */
      79             : /************************************************************************/
      80             : 
      81             : /** "Interface" of a "vector" of dimension DIM_COUNT to insert and cluster in
      82             :  * a PNNKDTree.
      83             :  *
      84             :  * Below functions must be implemented in classes that specialize Vector: there
      85             :  * is no default generic implementation.
      86             :  *
      87             :  * There are no constraints on the T type.
      88             :  */
      89             : template <class T> class Vector
      90             : {
      91             :   public:
      92             :     /** Returns the dimension of the vector. */
      93             :     static constexpr int DIM_COUNT = -1;
      94             : 
      95             :     /** Whether the get() method returns uint8_t values.
      96             :      * Used for speed optimizations.
      97             :      */
      98             :     static constexpr bool getReturnUInt8 = false;
      99             : 
     100             :     /** Returns the k(th) value of the vector, with k in [0, DIM_COUNT-1] range.
     101             :      *
     102             :      * The actual returned type might not be double, but must be convertible to
     103             :      * double.
     104             :      */
     105             :     double get(int k, const T &ctxt) const /* = 0 */;
     106             : 
     107             : #ifdef KDTREE_USE_SSE2
     108             :     /** Whether the computeHeightSumAndSumSquareSSE2() method is implemented.
     109             :      * Used for speed optimizations.
     110             :      */
     111             :     static constexpr bool hasComputeHeightSumAndSumSquareSSE2 = false;
     112             : 
     113             :     /** The function must do the equivalent of:
     114             :      *
     115             :      *   for (int i = 0; i < 8; ++i )
     116             :      *   {
     117             :      *      int val = item.m_vec.get(k + i, ctxt);
     118             :      *      int valMulCount = val * item.m_count;
     119             :      *      {sum0, sum1}[i] += valMulCount;
     120             :      *      {sumSquare0_lo, sumSquare0_hi,sumSquare1_lo, sumSquare1_hi}[i] += val * valMulCount;
     121             :      *   }
     122             :      *
     123             :      * k is in the [0, DIM_COUNT-8-1] range (and generally a multiple of 8).
     124             :      */
     125             :     void computeHeightSumAndSumSquareSSE2(int k, const T &ctxt, int count,
     126             :                                           __m128i &sum0, __m128i &sumSquare0_lo,
     127             :                                           __m128i &sumSquare0_hi, __m128i &sum1,
     128             :                                           __m128i &sumSquare1_lo,
     129             :                                           __m128i &sumSquare1_hi) const
     130             :         /* = 0 */;
     131             : #endif
     132             : 
     133             :     /** Returns the squared distance between this vector and other.
     134             :      * It must be symmetric, that is this->squared_distance(other, ctx) must
     135             :      * be equal to other.squared_distance(*this, ctx).
     136             :      */
     137             :     double squared_distance(const Vector &other, const T &ctxt) const /* = 0 */;
     138             : 
     139             :     /** Whether the compute_four_squared_distances() method is implemented
     140             :      * Used for speed optimizations.
     141             :      */
     142             :     static constexpr bool hasComputeFourSquaredDistances = false;
     143             : 
     144             :     /** Equivalent of
     145             :      *
     146             :      * for(int i = 0; i < 4; ++i)
     147             :      * {
     148             :      *      tabSquaredDist[i] = squared_distance(*(other[i]), ctxt);
     149             :      * }
     150             :      */
     151             :     void compute_four_squared_distances(
     152             :         const std::array<const Vector *const, 4> &others,
     153             :         std::array<int, 4> & /* out */ tabSquaredDist, const T &ctxt) const
     154             :         /* = 0 */;
     155             : 
     156             :     /** Computes a new vector that is the centroid of vector a of weight nA,
     157             :      * and vector b of weight nB.
     158             :      */
     159             :     static Vector centroid(const Vector &a, int nA, const Vector &b, int nB,
     160             :                            const T &ctxt) /* = 0 */;
     161             : };
     162             : 
     163             : /************************************************************************/
     164             : /*                            BucketItem<T>                             */
     165             : /************************************************************************/
     166             : 
     167             : /** Definition of an item placed in a bucket of a PNNKDTree.
     168             :  *
     169             :  * This class does not need to be specialized.
     170             :  */
     171     8580883 : template <class T> struct BucketItem
     172             : {
     173             :   public:
     174             :     /** Value vector */
     175             :     Vector<T> m_vec;
     176             : 
     177             :     /** Type of elements in m_origVectorIndices */
     178             :     using IdxType = int;
     179             : 
     180             :     /** Vector that points to indices in the original value space that evaluate
     181             :      * to m_vec.
     182             :      * Typically m_origVectorIndices.size() == m_count, but
     183             :      * the clustering algorithm will not enforce it. It will just concatenate
     184             :      * m_origVectorIndices from different BucketItem when merging them.
     185             :      */
     186             :     std::vector<IdxType> m_origVectorIndices;
     187             : 
     188             :     /** Number of samples that have the value of m_vec */
     189             :     int m_count;
     190             : 
     191             :     /** Constructor */
     192      525399 :     BucketItem(const Vector<T> &vec, int count,
     193             :                std::vector<IdxType> &&origVectorIndices)
     194      525399 :         : m_vec(vec), m_origVectorIndices(std::move(origVectorIndices)),
     195      525399 :           m_count(count)
     196             :     {
     197      525399 :     }
     198             : 
     199    11266998 :     BucketItem(BucketItem &&) = default;
     200             :     BucketItem &operator=(BucketItem &&) = default;
     201             : 
     202             :   private:
     203             :     BucketItem(const BucketItem &) = delete;
     204             :     BucketItem &operator=(const BucketItem &) = delete;
     205             : };
     206             : 
     207             : /************************************************************************/
     208             : /*                             PNNKDTree<T>                             */
     209             : /************************************************************************/
     210             : 
     211             : /**
     212             :  * KDTree designed for Pairwise Nearest Neighbor (PNN) clustering for Vector
     213             :  * Quantization (VQ).
     214             :  *
     215             :  * This class does not need to be specialized.
     216             :  */
     217             : template <class T> class PNNKDTree
     218             : {
     219             :   public:
     220       61712 :     PNNKDTree() = default;
     221             : 
     222             :     /* Inserts value vectors with their cardinality in the KD-Tree.
     223             :      *
     224             :      * This method must be called only once.
     225             :      *
     226             :      * Returns the initial count of buckets, that must be passed as an input
     227             :      * to cluster().
     228             :      */
     229             :     int insert(std::vector<BucketItem<T>> &&vectors, const T &ctxt);
     230             : 
     231             :     /** Iterate over leaf nodes (that contain buckets) */
     232             :     void iterateOverLeaves(const std::function<void(PNNKDTree &)> &f);
     233             : 
     234             :     /** Perform clustering to reduce the number of buckets from initialBucketCount
     235             :      * to targetCount.
     236             :      *
     237             :      * It modifies the tree structure, and returns the achieved number of
     238             :      * buckets (<= targetCount).
     239             :      */
     240             :     int cluster(int initialBucketCount, int targetCount, const T &ctxt);
     241             : 
     242             :     /** Returns the bucket items for this node. */
     243             :     inline const std::vector<BucketItem<T>> &bucketItems() const
     244             :     {
     245             :         return m_bucketItems;
     246             :     }
     247             : 
     248             :     /** Returns the bucket items for this node. */
     249        9810 :     inline std::vector<BucketItem<T>> &bucketItems()
     250             :     {
     251        9810 :         return m_bucketItems;
     252             :     }
     253             : 
     254             :   private:
     255             :     static constexpr int BUCKET_MAX_SIZE = 8;
     256             : 
     257             :     /** Left node. When non null, m_right is also non null, and m_bucketItems is empty. */
     258             :     std::unique_ptr<PNNKDTree> m_left{};
     259             : 
     260             :     /** Right node. When non null, m_left is also non null, and m_bucketItems is empty. */
     261             :     std::unique_ptr<PNNKDTree> m_right{};
     262             : 
     263             :     /** Contains items that form a bucket. The bucket is nominally at most BUCKET_MAX_SIZE
     264             :      * large (maybe transiently slightly larger during clustering operations).
     265             :      *
     266             :      * m_bucketItems is non empty only on leaf nodes.
     267             :      */
     268             :     std::vector<BucketItem<T>> m_bucketItems{};
     269             : 
     270             :     /** Data type returned by Vector<T>::get() */
     271             :     using ValType = decltype(std::declval<Vector<T>>().get(
     272             :         0, *static_cast<const T *>(nullptr)));
     273             : 
     274             :     /** Clean the current node and move it to queueNodes.
     275             :      *
     276             :      * This saves dynamic allocation and de-allocation of nodes when rebalancing.
     277             :      */
     278             :     void freeAndMoveToQueue(std::deque<std::unique_ptr<PNNKDTree>> &queueNodes);
     279             : 
     280             :     int insert(std::vector<BucketItem<T>> &&vectors, int totalCount,
     281             :                std::vector<std::pair<ValType, int>> &weightedVals,
     282             :                std::deque<std::unique_ptr<PNNKDTree>> &queueNodes,
     283             :                std::vector<BucketItem<T>> &vectLeft,
     284             :                std::vector<BucketItem<T>> &vectRight, const T &ctxt);
     285             : 
     286             :     /** Rebalance the KD-Tree. Current implementation fully rebuilds a new
     287             :      * KD-Tree using the insert() algorithm
     288             :      */
     289             :     int rebalance(const T &ctxt, std::vector<BucketItem<T>> &newLeaves,
     290             :                   std::deque<std::unique_ptr<PNNKDTree>> &queueNodes);
     291             : };
     292             : 
     293             : /************************************************************************/
     294             : /*                        PNNKDTree<T>::insert()                        */
     295             : /************************************************************************/
     296             : 
     297             : template <class T>
     298          60 : int PNNKDTree<T>::insert(std::vector<BucketItem<T>> &&vectors, const T &ctxt)
     299             : {
     300          60 :     assert(m_left == nullptr);
     301          60 :     assert(m_right == nullptr);
     302          60 :     assert(m_bucketItems.empty());
     303             : 
     304          60 :     int totalCount = 0;
     305      155796 :     for (const auto &it : vectors)
     306             :     {
     307      155736 :         totalCount += it.m_count;
     308             :     }
     309         120 :     std::vector<std::pair<ValType, int>> weightedVals;
     310         120 :     std::deque<std::unique_ptr<PNNKDTree>> queueNodes;
     311         120 :     std::vector<BucketItem<T>> vectLeft;
     312         120 :     std::vector<BucketItem<T>> vectRight;
     313          60 :     if (totalCount == 0)
     314           0 :         return 0;
     315          60 :     return insert(std::move(vectors), totalCount, weightedVals, queueNodes,
     316          60 :                   vectLeft, vectRight, ctxt);
     317             : }
     318             : 
     319             : /************************************************************************/
     320             : /*                        PNNKDTree<T>::insert()                        */
     321             : /************************************************************************/
     322             : 
     323             : template <class T>
     324      130804 : int PNNKDTree<T>::insert(std::vector<BucketItem<T>> &&vectors, int totalCount,
     325             :                          std::vector<std::pair<ValType, int>> &weightedVals,
     326             :                          std::deque<std::unique_ptr<PNNKDTree>> &queueNodes,
     327             :                          std::vector<BucketItem<T>> &vectLeft,
     328             :                          std::vector<BucketItem<T>> &vectRight, const T &ctxt)
     329             : {
     330             : #ifdef DEBUG_INVARIANTS
     331             :     std::map<Vector<T>, int> mapValuesToBucketIdx;
     332             :     for (int i = 0; i < static_cast<int>(vectors.size()); ++i)
     333             :     {
     334             :         CPLAssert(mapValuesToBucketIdx.find(vectors[i].m_vec) ==
     335             :                   mapValuesToBucketIdx.end());
     336             :         mapValuesToBucketIdx[vectors[i].m_vec] = i;
     337             :     }
     338             : #endif
     339             : 
     340      130804 :     if (vectors.size() <= BUCKET_MAX_SIZE)
     341             :     {
     342       65460 :         m_bucketItems = std::move(vectors);
     343       65460 :         return static_cast<int>(m_bucketItems.size());
     344             :     }
     345             : 
     346             : #ifdef KDTREE_DEBUG_TIMING
     347             :     struct timeval tv1, tv2;
     348             :     gettimeofday(&tv1, nullptr);
     349             : #endif
     350             : 
     351             :     // Find dimension with maximum variance
     352       65344 :     double maxM2 = 0;
     353       65344 :     int maxM2_k = 0;
     354             : 
     355      850138 :     for (int k = 0; k < Vector<T>::DIM_COUNT; ++k)
     356             :     {
     357             :         if constexpr (Vector<T>::getReturnUInt8)
     358             :         {
     359      784794 :             constexpr int MAX_BYTE_VALUE = std::numeric_limits<uint8_t>::max();
     360      784794 :             bool canUseOptimization =
     361      784794 :                 (totalCount <= std::numeric_limits<int64_t>::max() /
     362             :                                    (MAX_BYTE_VALUE * MAX_BYTE_VALUE));
     363      784794 :             if (canUseOptimization)
     364             :             {
     365      784794 :                 int maxCountPerVector = 0;
     366    59746727 :                 for (const auto &item : vectors)
     367             :                 {
     368    58961909 :                     maxCountPerVector =
     369    58961909 :                         std::max(maxCountPerVector, item.m_count);
     370             :                 }
     371      784794 :                 canUseOptimization = (maxCountPerVector <=
     372      784794 :                                       std::numeric_limits<int32_t>::max() /
     373             :                                           (MAX_BYTE_VALUE * MAX_BYTE_VALUE));
     374             :             }
     375      784794 :             if (canUseOptimization)
     376             :             {
     377             :                 // Do statistics computation in integer domain if possible.
     378             : 
     379             : #if !(defined(__i386__) || defined(_M_IX86))
     380             :                 // Below code requires more than 8 general purpose registers,
     381             :                 // so exclude i386.
     382             : 
     383      776376 :                 constexpr int VALS_AT_ONCE = 4;
     384             :                 if constexpr ((Vector<T>::DIM_COUNT % VALS_AT_ONCE) == 0)
     385             :                 {
     386             : #ifdef KDTREE_USE_SSE2
     387      776376 :                     constexpr int TWICE_VALS_AT_ONCE = 2 * VALS_AT_ONCE;
     388             :                     if constexpr ((Vector<T>::DIM_COUNT % TWICE_VALS_AT_ONCE) ==
     389             :                                       0 &&
     390             :                                   Vector<
     391             :                                       T>::hasComputeHeightSumAndSumSquareSSE2)
     392             :                     {
     393             :                         __m128i sum0 = _mm_setzero_si128();
     394             :                         __m128i sumSquare0_lo = _mm_setzero_si128();
     395             :                         __m128i sumSquare0_hi = _mm_setzero_si128();
     396             :                         __m128i sum1 = _mm_setzero_si128();
     397             :                         __m128i sumSquare1_lo = _mm_setzero_si128();
     398             :                         __m128i sumSquare1_hi = _mm_setzero_si128();
     399             : 
     400             :                         for (const auto &item : vectors)
     401             :                         {
     402             :                             item.m_vec.computeHeightSumAndSumSquareSSE2(
     403             :                                 k, ctxt, item.m_count, sum0, sumSquare0_lo,
     404             :                                 sumSquare0_hi, sum1, sumSquare1_lo,
     405             :                                 sumSquare1_hi);
     406             :                         }
     407             :                         int64_t sumSquares[TWICE_VALS_AT_ONCE];
     408             :                         _mm_storeu_si128(
     409             :                             reinterpret_cast<__m128i *>(sumSquares + 0),
     410             :                             sumSquare0_lo);
     411             :                         _mm_storeu_si128(
     412             :                             reinterpret_cast<__m128i *>(sumSquares + 2),
     413             :                             sumSquare0_hi);
     414             :                         _mm_storeu_si128(
     415             :                             reinterpret_cast<__m128i *>(sumSquares + 4),
     416             :                             sumSquare1_lo);
     417             :                         _mm_storeu_si128(
     418             :                             reinterpret_cast<__m128i *>(sumSquares + 6),
     419             :                             sumSquare1_hi);
     420             :                         int sums[TWICE_VALS_AT_ONCE];
     421             :                         _mm_storeu_si128(reinterpret_cast<__m128i *>(sums),
     422             :                                          sum0);
     423             :                         _mm_storeu_si128(
     424             :                             reinterpret_cast<__m128i *>(sums + VALS_AT_ONCE),
     425             :                             sum1);
     426             :                         for (int i = 0; i < TWICE_VALS_AT_ONCE; ++i)
     427             :                         {
     428             :                             const double M2 = static_cast<double>(
     429             :                                 sumSquares[i] * totalCount -
     430             :                                 static_cast<int64_t>(sums[i]) * sums[i]);
     431             :                             if (M2 > maxM2)
     432             :                             {
     433             :                                 maxM2 = M2;
     434             :                                 maxM2_k = k + i;
     435             :                             }
     436             :                         }
     437             :                         k += TWICE_VALS_AT_ONCE - 1;
     438             :                     }
     439             :                     else
     440             : #endif
     441             :                     {
     442      776376 :                         int sum0 = 0;
     443      776376 :                         int sum1 = 0;
     444      776376 :                         int sum2 = 0;
     445      776376 :                         int sum3 = 0;
     446      776376 :                         int64_t sumSquare0 = 0;
     447      776376 :                         int64_t sumSquare1 = 0;
     448      776376 :                         int64_t sumSquare2 = 0;
     449      776376 :                         int64_t sumSquare3 = 0;
     450    49671392 :                         for (const auto &item : vectors)
     451             :                         {
     452    48894980 :                             const int val0 = item.m_vec.get(k + 0, ctxt);
     453    48894980 :                             const int val1 = item.m_vec.get(k + 1, ctxt);
     454    48894980 :                             const int val2 = item.m_vec.get(k + 2, ctxt);
     455    48894980 :                             const int val3 = item.m_vec.get(k + 3, ctxt);
     456    48894980 :                             const int val0MulCount = val0 * item.m_count;
     457    48894980 :                             const int val1MulCount = val1 * item.m_count;
     458    48894980 :                             const int val2MulCount = val2 * item.m_count;
     459    48894980 :                             const int val3MulCount = val3 * item.m_count;
     460    48894980 :                             sum0 += val0MulCount;
     461    48894980 :                             sum1 += val1MulCount;
     462    48894980 :                             sum2 += val2MulCount;
     463    48894980 :                             sum3 += val3MulCount;
     464             :                             // It's fine to cast to int64 after multiplication
     465    48894980 :                             sumSquare0 +=
     466    48894980 :                                 cpl::fits_on<int64_t>(val0 * val0MulCount);
     467    48894980 :                             sumSquare1 +=
     468    48894980 :                                 cpl::fits_on<int64_t>(val1 * val1MulCount);
     469    48894980 :                             sumSquare2 +=
     470    48894980 :                                 cpl::fits_on<int64_t>(val2 * val2MulCount);
     471    48894980 :                             sumSquare3 +=
     472    48894980 :                                 cpl::fits_on<int64_t>(val3 * val3MulCount);
     473             :                         }
     474             : 
     475      776376 :                         const double M2[] = {
     476      776376 :                             static_cast<double>(sumSquare0 * totalCount -
     477      776376 :                                                 static_cast<int64_t>(sum0) *
     478      776376 :                                                     sum0),
     479      776376 :                             static_cast<double>(sumSquare1 * totalCount -
     480      776376 :                                                 static_cast<int64_t>(sum1) *
     481      776376 :                                                     sum1),
     482      776376 :                             static_cast<double>(sumSquare2 * totalCount -
     483      776376 :                                                 static_cast<int64_t>(sum2) *
     484      776376 :                                                     sum2),
     485      776376 :                             static_cast<double>(sumSquare3 * totalCount -
     486      776376 :                                                 static_cast<int64_t>(sum3) *
     487      776376 :                                                     sum3)};
     488     3881880 :                         for (int i = 0; i < VALS_AT_ONCE; ++i)
     489             :                         {
     490     3105508 :                             if (M2[i] > maxM2)
     491             :                             {
     492      301946 :                                 maxM2 = M2[i];
     493      301946 :                                 maxM2_k = k + i;
     494             :                             }
     495             :                         }
     496      776376 :                         k += VALS_AT_ONCE - 1;
     497             :                     }
     498             :                 }
     499             :                 else
     500             : #endif
     501             :                 {
     502           0 :                     int sum = 0;
     503           0 :                     int64_t sumSquare = 0;
     504           0 :                     for (const auto &item : vectors)
     505             :                     {
     506           0 :                         const int val = item.m_vec.get(k, ctxt);
     507           0 :                         const int valMulCount = val * item.m_count;
     508           0 :                         sum += valMulCount;
     509             :                         // It's fine to cast to int64 after multiplication
     510           0 :                         sumSquare += cpl::fits_on<int64_t>(val * valMulCount);
     511             :                     }
     512           0 :                     const double M2 =
     513           0 :                         static_cast<double>(sumSquare * totalCount -
     514           0 :                                             static_cast<int64_t>(sum) * sum);
     515           0 :                     if (M2 > maxM2)
     516             :                     {
     517           0 :                         maxM2 = M2;
     518           0 :                         maxM2_k = k;
     519             :                     }
     520             :                 }
     521      776376 :                 continue;
     522             :             }
     523             :         }
     524             : 
     525             :         // Generic code path:
     526             : 
     527             :         // First pass to compute mean value along k(th) dimension
     528        8418 :         double sum = 0;
     529    10075335 :         for (const auto &item : vectors)
     530             :         {
     531    10066929 :             sum += static_cast<double>(item.m_vec.get(k, ctxt)) * item.m_count;
     532             :         }
     533        8418 :         const double mean = sum / totalCount;
     534             :         // Second pass to compute M2 value (n * variance) along k(th) dimension
     535        8418 :         double M2 = 0;
     536    10075335 :         for (const auto &item : vectors)
     537             :         {
     538    10066929 :             const double delta =
     539    10066929 :                 static_cast<double>(item.m_vec.get(k, ctxt)) - mean;
     540    10066929 :             M2 += delta * delta * item.m_count;
     541             :         }
     542        8418 :         if (M2 > maxM2)
     543             :         {
     544        1858 :             maxM2 = M2;
     545        1858 :             maxM2_k = k;
     546             :         }
     547             :     }
     548             : 
     549             : #ifdef KDTREE_DEBUG_TIMING
     550             :     gettimeofday(&tv2, nullptr);
     551             :     totalTimeStats +=
     552             :         (tv2.tv_sec + tv2.tv_usec * 1e-6) - (tv1.tv_sec + tv1.tv_usec * 1e-6);
     553             : #endif
     554             : 
     555             :     // Find median value along that dimension
     556       65344 :     weightedVals.reserve(vectors.size());
     557       65344 :     weightedVals.clear();
     558     4363381 :     for (const auto &item : vectors)
     559             :     {
     560     4298038 :         const auto d = item.m_vec.get(maxM2_k, ctxt);
     561     4298038 :         weightedVals.emplace_back(d, item.m_count);
     562             :     }
     563             : 
     564       65344 :     std::sort(weightedVals.begin(), weightedVals.end(),
     565    33055123 :               [](const auto &a, const auto &b) { return a.first < b.first; });
     566             : 
     567       65344 :     auto median = weightedVals[0].first;
     568       65344 :     int cumulativeCount = 0;
     569       65344 :     const int targetCount = totalCount / 2;
     570     2176236 :     for (const auto &[value, count] : weightedVals)
     571             :     {
     572     2176236 :         cumulativeCount += count;
     573     2176236 :         if (cumulativeCount > targetCount)
     574             :         {
     575       65344 :             median = value;
     576       65344 :             break;
     577             :         }
     578             :     }
     579             : 
     580             :     // Split the original vectors in a "left" half that is below or equal to
     581             :     // the median and a "right" half that is above.
     582       65344 :     vectLeft.clear();
     583       65344 :     vectLeft.reserve(weightedVals.size() / 2);
     584       65344 :     vectRight.clear();
     585       65344 :     vectRight.reserve(weightedVals.size() / 2);
     586       65344 :     int countLeft = 0;
     587       65344 :     int countRight = 0;
     588     4363381 :     for (auto &item : vectors)
     589             :     {
     590     4298038 :         if (item.m_vec.get(maxM2_k, ctxt) <= median)
     591             :         {
     592     2764333 :             countLeft += item.m_count;
     593     2764333 :             vectLeft.push_back(std::move(item));
     594             :         }
     595             :         else
     596             :         {
     597     1533705 :             countRight += item.m_count;
     598     1533705 :             vectRight.push_back(std::move(item));
     599             :         }
     600             :     }
     601             : 
     602             :     // In some cases, the median can actually be the maximum value
     603             :     // Then, retry but exclusing the median itself.
     604       65344 :     if (vectLeft.empty() || vectRight.empty())
     605             :     {
     606       15498 :         if (!vectLeft.empty())
     607       15498 :             vectors = std::move(vectLeft);
     608             :         else
     609           0 :             vectors = std::move(vectRight);
     610       15498 :         vectLeft.clear();
     611       15498 :         vectRight.clear();
     612       15498 :         countLeft = 0;
     613       15498 :         countRight = 0;
     614      498442 :         for (auto &item : vectors)
     615             :         {
     616      482944 :             if (item.m_vec.get(maxM2_k, ctxt) < median)
     617             :             {
     618      239646 :                 countLeft += item.m_count;
     619      239646 :                 vectLeft.push_back(std::move(item));
     620             :             }
     621             :             else
     622             :             {
     623      243298 :                 countRight += item.m_count;
     624      243298 :                 vectRight.push_back(std::move(item));
     625             :             }
     626             :         }
     627             : 
     628             :         // Normally we shouldn't reach that point, unless the initial samples
     629             :         // where all identical, and thus clustering wasn't needed.
     630       15498 :         if (vectLeft.empty() || vectRight.empty())
     631             :         {
     632           0 :             CPLError(CE_Failure, CPLE_AppDefined,
     633             :                      "Unexpected situation in %s:%d\n", __FILE__, __LINE__);
     634           0 :             return 0;
     635             :         }
     636             :     }
     637       65344 :     vectors.clear();
     638             : 
     639             :     // Allocate (or recycle) left and right nodes
     640       65344 :     if (!queueNodes.empty())
     641             :     {
     642       34518 :         m_left = std::move(queueNodes.back());
     643       34518 :         queueNodes.pop_back();
     644             :     }
     645             :     else
     646       30826 :         m_left = std::make_unique<PNNKDTree<T>>();
     647             : 
     648       65344 :     if (!queueNodes.empty())
     649             :     {
     650       34518 :         m_right = std::move(queueNodes.back());
     651       34518 :         queueNodes.pop_back();
     652             :     }
     653             :     else
     654       30826 :         m_right = std::make_unique<PNNKDTree<T>>();
     655             : 
     656             :     // Recursively insert vectLeft in m_left and vectRight in m_right
     657       65344 :     std::vector<BucketItem<T>> vectTmp;
     658             :     // Sort for replicability of results across platforms
     659    34517420 :     const auto sortFunc = [](const BucketItem<T> &a, const BucketItem<T> &b)
     660    34517426 :     { return a.m_vec < b.m_vec; };
     661       65344 :     std::sort(vectLeft.begin(), vectLeft.end(), sortFunc);
     662       65344 :     std::sort(vectRight.begin(), vectRight.end(), sortFunc);
     663       65344 :     int retLeft = m_left->insert(std::move(vectLeft), countLeft, weightedVals,
     664             :                                  queueNodes, vectors, vectTmp, ctxt);
     665      130688 :     int retRight =
     666       65344 :         m_right->insert(std::move(vectRight), countRight, weightedVals,
     667             :                         queueNodes, vectors, vectTmp, ctxt);
     668       65344 :     vectLeft = std::vector<BucketItem<T>>();
     669       65344 :     vectRight = std::vector<BucketItem<T>>();
     670       65344 :     return (retLeft == 0 || retRight == 0) ? 0 : retLeft + retRight;
     671             : }
     672             : 
     673             : /************************************************************************/
     674             : /*                  PNNKDTree<T>::iterateOverLeaves()                   */
     675             : /************************************************************************/
     676             : 
     677             : template <class T>
     678      358846 : void PNNKDTree<T>::iterateOverLeaves(const std::function<void(PNNKDTree &)> &f)
     679             : {
     680      358846 :     if (m_left && m_right)
     681             :     {
     682      179284 :         m_left->iterateOverLeaves(f);
     683      179284 :         m_right->iterateOverLeaves(f);
     684             :     }
     685             :     else
     686             :     {
     687      179562 :         f(*this);
     688             :     }
     689      358846 : }
     690             : 
     691             : /************************************************************************/
     692             : /*                       PNNKDTree<T>::cluster()                        */
     693             : /************************************************************************/
     694             : 
     695             : template <class T>
     696          30 : int PNNKDTree<T>::cluster(int initialBucketCount, int targetCount,
     697             :                           const T &ctxt)
     698             : {
     699          30 :     int curBucketCount = initialBucketCount;
     700             : 
     701          60 :     std::vector<BucketItem<T>> newLeaves;
     702          30 :     newLeaves.reserve(initialBucketCount);
     703          60 :     std::deque<std::unique_ptr<PNNKDTree>> queueNodes;
     704             : 
     705             :     struct TupleInfo
     706             :     {
     707             :         PNNKDTree *bucket;
     708             :         int i;
     709             :         int j;
     710             :         double increasedDistortion;
     711             :     };
     712             : 
     713          30 :     std::vector<TupleInfo> distCollector;
     714          30 :     distCollector.reserve(curBucketCount);
     715             : 
     716          30 :     int iter = 0;
     717             : #ifdef DEBUG_INVARIANTS
     718             :     std::map<Vector<T>, int> mapValuesToBucketIdx;
     719             : #endif
     720         166 :     while (curBucketCount > targetCount)
     721             :     {
     722             :         /* For each bucket (leaf node), compute the increase in distortion
     723             :          * that would result in merging each (i,j) vector it contains.
     724             :          */
     725         136 :         distCollector.clear();
     726         136 :         iterateOverLeaves(
     727      114015 :             [&distCollector, &ctxt](PNNKDTree &bucket)
     728             :             {
     729      114015 :                 const int itemsCount =
     730      114015 :                     static_cast<int>(bucket.m_bucketItems.size());
     731      519267 :                 for (int i = 0; i < itemsCount - 1; ++i)
     732             :                 {
     733      405252 :                     const auto &itemI = bucket.m_bucketItems[i];
     734        4068 :                     int j = i + 1;
     735             :                     if constexpr (Vector<T>::hasComputeFourSquaredDistances)
     736             :                     {
     737      401214 :                         constexpr int CHUNK_SIZE = 4;
     738      401214 :                         if (j + CHUNK_SIZE <= itemsCount)
     739             :                         {
     740             :                             std::array<const Vector<T> *const, CHUNK_SIZE>
     741      494164 :                                 otherVectors = {
     742      123541 :                                     &(bucket.m_bucketItems[j + 0].m_vec),
     743      123541 :                                     &(bucket.m_bucketItems[j + 1].m_vec),
     744      123541 :                                     &(bucket.m_bucketItems[j + 2].m_vec),
     745      123541 :                                     &(bucket.m_bucketItems[j + 3].m_vec)};
     746             :                             std::array<int, CHUNK_SIZE> tabSquaredDist;
     747      123541 :                             itemI.m_vec.compute_four_squared_distances(
     748             :                                 otherVectors, tabSquaredDist, ctxt);
     749      617705 :                             for (int subj = 0; subj < CHUNK_SIZE; ++subj)
     750             :                             {
     751             :                                 const auto &itemJ =
     752      494164 :                                     bucket.m_bucketItems[j + subj];
     753      494164 :                                 const double increasedDistortion =
     754      494164 :                                     static_cast<double>(itemI.m_count) *
     755      494164 :                                     itemJ.m_count * tabSquaredDist[subj] /
     756      494164 :                                     (itemI.m_count + itemJ.m_count);
     757             :                                 TupleInfo ti;
     758      494164 :                                 ti.bucket = &bucket;
     759      494164 :                                 ti.i = i;
     760      494164 :                                 ti.j = j + subj;
     761      494164 :                                 ti.increasedDistortion = increasedDistortion;
     762      494164 :                                 distCollector.push_back(std::move(ti));
     763             :                             }
     764             : 
     765      123541 :                             j += CHUNK_SIZE;
     766             :                         }
     767             :                     }
     768     1049518 :                     for (; j < itemsCount; ++j)
     769             :                     {
     770      644268 :                         const auto &itemJ = bucket.m_bucketItems[j];
     771      644268 :                         const double increasedDistortion =
     772     1288536 :                             static_cast<double>(itemI.m_count) * itemJ.m_count *
     773      644268 :                             itemI.m_vec.squared_distance(itemJ.m_vec, ctxt) /
     774      644268 :                             (itemI.m_count + itemJ.m_count);
     775             :                         TupleInfo ti;
     776      644268 :                         ti.bucket = &bucket;
     777      644268 :                         ti.i = i;
     778      644268 :                         ti.j = j;
     779      644268 :                         ti.increasedDistortion = increasedDistortion;
     780      644268 :                         distCollector.push_back(std::move(ti));
     781             :                     }
     782             :                 }
     783             :             });
     784             : 
     785             :         /** Identify the median of the increased distortion */
     786         136 :         const int bucketCountToMerge =
     787         272 :             std::min(static_cast<int>(distCollector.size() / 2),
     788         136 :                      curBucketCount - targetCount);
     789     7794665 :         const auto sortFunc = [](const TupleInfo &a, const TupleInfo &b)
     790             :         {
     791    11637183 :             return a.increasedDistortion < b.increasedDistortion ||
     792     3843494 :                    (a.increasedDistortion == b.increasedDistortion &&
     793     1779205 :                     (a.bucket->m_bucketItems[0].m_vec <
     794     1779645 :                          b.bucket->m_bucketItems[0].m_vec ||
     795      683271 :                      (a.bucket->m_bucketItems[0].m_vec ==
     796      683271 :                           b.bucket->m_bucketItems[0].m_vec &&
     797     7826682 :                       (a.i < b.i || (a.i == b.i && a.j < b.j)))));
     798             :         };
     799         136 :         const auto median_iter = distCollector.begin() + bucketCountToMerge;
     800         136 :         std::nth_element(distCollector.begin(), median_iter,
     801             :                          distCollector.end(), sortFunc);
     802             :         /** Sort elements by increasing increasedDistortion, but only for
     803             :          * the first half of the array
     804             :          */
     805         136 :         std::sort(distCollector.begin(), median_iter, sortFunc);
     806             : 
     807             :         static_assert(BUCKET_MAX_SIZE <= sizeof(uint32_t) * 8);
     808         136 :         std::map<PNNKDTree *, uint32_t> invalidatedClusters;
     809             : 
     810         136 :         int expectedBucketCount = curBucketCount;
     811             : 
     812             :         // Merge all the tuple of vectors whose increasd distortion is lower
     813             :         // than the median.
     814      291486 :         for (auto oIterCollector = distCollector.begin();
     815      582836 :              oIterCollector != median_iter; ++oIterCollector)
     816             :         {
     817      291350 :             const auto &tupleInfo = *oIterCollector;
     818             :             // assert( tupleInfo.increasedDistortion <= median_iter->increasedDistortion );
     819      291350 :             auto oIter = invalidatedClusters.find(tupleInfo.bucket);
     820      291350 :             if (oIter != invalidatedClusters.end())
     821             :             {
     822             :                 // Be careful not to merge a (i,j) tuple whose at least one of
     823             :                 // the element has already been merged.
     824             :                 // (this aspect is not covered by the Equitz's paper)
     825      221820 :                 if ((oIter->second &
     826      221820 :                      ((1U << tupleInfo.i) | (1U << tupleInfo.j))) != 0)
     827             :                 {
     828      183936 :                     continue;
     829             :                 }
     830             :             }
     831             :             else
     832             :             {
     833       69530 :                 oIter =
     834           0 :                     invalidatedClusters.insert(std::pair(tupleInfo.bucket, 0))
     835       69530 :                         .first;
     836             :             }
     837             : 
     838      107414 :             auto &bucketItemI = tupleInfo.bucket->m_bucketItems[tupleInfo.i];
     839             :             const auto &bucketItemJ =
     840      107414 :                 tupleInfo.bucket->m_bucketItems[tupleInfo.j];
     841      107414 :             auto origVectorIndices = std::move(bucketItemI.m_origVectorIndices);
     842      107414 :             origVectorIndices.insert(origVectorIndices.end(),
     843             :                                      bucketItemJ.m_origVectorIndices.begin(),
     844             :                                      bucketItemJ.m_origVectorIndices.end());
     845             : 
     846             : #ifdef KDTREE_DEBUG_TIMING
     847             :             struct timeval tv1, tv2;
     848             :             gettimeofday(&tv1, nullptr);
     849             : #endif
     850      107414 :             auto newVal = Vector<T>::centroid(
     851      107414 :                 bucketItemI.m_vec, bucketItemI.m_count, bucketItemJ.m_vec,
     852      107414 :                 bucketItemJ.m_count, ctxt);
     853             : 
     854             : #ifdef KDTREE_DEBUG_TIMING
     855             :             gettimeofday(&tv2, nullptr);
     856             :             totalTimeCentroid += (tv2.tv_sec + tv2.tv_usec * 1e-6) -
     857             :                                  (tv1.tv_sec + tv1.tv_usec * 1e-6);
     858             : #endif
     859             : 
     860             :             // Look if there is an existing item in the bucket with the new
     861             :             // vector value
     862      107414 :             int bucketItemIdx = -1;
     863      481110 :             for (int i = 0;
     864      481110 :                  i < static_cast<int>(tupleInfo.bucket->m_bucketItems.size());
     865             :                  ++i)
     866             :             {
     867      820620 :                 if ((oIter->second & (1U << i)) == 0 &&
     868      376917 :                     tupleInfo.bucket->m_bucketItems[i].m_vec == newVal)
     869             :                 {
     870       70007 :                     bucketItemIdx = i;
     871       70007 :                     break;
     872             :                 }
     873             :             }
     874      107414 :             oIter->second |= ((1U << tupleInfo.i) | (1U << tupleInfo.j));
     875      107414 :             int newCount = bucketItemI.m_count + bucketItemJ.m_count;
     876      107414 :             if (bucketItemIdx >= 0 && bucketItemIdx != tupleInfo.i &&
     877       33406 :                 bucketItemIdx != tupleInfo.j)
     878             :             {
     879         204 :                 oIter->second |= ((1U << bucketItemIdx));
     880             :                 auto &existingItem =
     881         204 :                     tupleInfo.bucket->m_bucketItems[bucketItemIdx];
     882         204 :                 newCount += existingItem.m_count;
     883         204 :                 origVectorIndices.insert(
     884             :                     origVectorIndices.end(),
     885             :                     std::make_move_iterator(
     886             :                         existingItem.m_origVectorIndices.begin()),
     887             :                     std::make_move_iterator(
     888             :                         existingItem.m_origVectorIndices.end()));
     889             :             }
     890             :             // Insert the new bucket item
     891      107414 :             tupleInfo.bucket->m_bucketItems.emplace_back(
     892      107414 :                 std::move(newVal), newCount, std::move(origVectorIndices));
     893             : 
     894      107414 :             --expectedBucketCount;
     895             :         }
     896             : 
     897             :         // Remove items that have been merged
     898       69666 :         for (auto [node, indices] : invalidatedClusters)
     899             :         {
     900             :             // Inside a same bucket, be careful to remove from the end so that
     901             :             // noted indices are still valid...
     902      522421 :             for (int i = static_cast<int>(node->m_bucketItems.size()) - 1;
     903      522421 :                  i >= 0; --i)
     904             :             {
     905      452891 :                 if ((indices & (1U << i)) != 0)
     906             :                 {
     907      215032 :                     node->m_bucketItems.erase(node->m_bucketItems.begin() + i);
     908             :                 }
     909             :             }
     910             : 
     911             : #ifdef DEBUG_INVARIANTS
     912             :             mapValuesToBucketIdx.clear();
     913             :             for (int i = 0; i < static_cast<int>(node->m_bucketItems.size());
     914             :                  ++i)
     915             :             {
     916             :                 CPLAssert(
     917             :                     mapValuesToBucketIdx.find(node->m_bucketItems[i].m_vec) ==
     918             :                     mapValuesToBucketIdx.end());
     919             :                 mapValuesToBucketIdx[node->m_bucketItems[i].m_vec] = i;
     920             :             }
     921             : #endif
     922             :         }
     923             : 
     924             :         // Rebalance the tree only half of the time, to speed up things a bit
     925             :         // This is quite arbitrary. Systematic rebalancing could result in
     926             :         // slightly better results.
     927         136 :         if ((iter % 2) == 0)
     928          79 :             curBucketCount = expectedBucketCount;
     929             :         else
     930          57 :             curBucketCount = rebalance(ctxt, newLeaves, queueNodes);
     931         136 :         ++iter;
     932             :     }
     933             : 
     934          60 :     return curBucketCount;
     935             : }
     936             : 
     937             : /************************************************************************/
     938             : /*                  PNNKDTree<T>::freeAndMoveToQueue()                  */
     939             : /************************************************************************/
     940             : 
     941             : template <class T>
     942      111418 : void PNNKDTree<T>::freeAndMoveToQueue(
     943             :     std::deque<std::unique_ptr<PNNKDTree>> &queueNodes)
     944             : {
     945      111418 :     m_bucketItems.clear();
     946      111418 :     if (m_left)
     947             :     {
     948       55681 :         m_left->freeAndMoveToQueue(queueNodes);
     949       55681 :         queueNodes.push_back(std::move(m_left));
     950             :     }
     951      111418 :     if (m_right)
     952             :     {
     953       55681 :         m_right->freeAndMoveToQueue(queueNodes);
     954       55681 :         queueNodes.push_back(std::move(m_right));
     955             :     }
     956      111418 : }
     957             : 
     958             : /************************************************************************/
     959             : /*                      PNNKDTree<T>::rebalance()                       */
     960             : /************************************************************************/
     961             : 
     962             : template <class T>
     963          57 : int PNNKDTree<T>::rebalance(const T &ctxt,
     964             :                             std::vector<BucketItem<T>> &newLeaves,
     965             :                             std::deque<std::unique_ptr<PNNKDTree>> &queueNodes)
     966             : {
     967          57 :     if (m_left && m_right)
     968             :     {
     969             : #ifdef KDTREE_DEBUG_TIMING
     970             :         struct timeval tv1, tv2;
     971             :         gettimeofday(&tv1, nullptr);
     972             : #endif
     973             :         std::map<Vector<T>,
     974             :                  std::pair<int, std::vector<typename BucketItem<T>::IdxType>>>
     975         112 :             mapVectors;
     976          56 :         int totalCount = 0;
     977             :         // Rebuild a new map of vector values -> (count, indices)
     978             :         // This needs to be a map as we cannot guarantee the uniqueness
     979             :         // of vector values after the clustering pass
     980          56 :         iterateOverLeaves(
     981       55737 :             [&mapVectors, &totalCount](PNNKDTree &bucket)
     982             :             {
     983      230664 :                 for (auto &item : bucket.m_bucketItems)
     984             :                 {
     985      174927 :                     totalCount += item.m_count;
     986      174927 :                     auto oIter = mapVectors.find(item.m_vec);
     987      174927 :                     if (oIter == mapVectors.end())
     988             :                     {
     989      174872 :                         mapVectors[item.m_vec] = std::make_pair(
     990      174872 :                             item.m_count, std::move(item.m_origVectorIndices));
     991             :                     }
     992             :                     else
     993             :                     {
     994          55 :                         oIter->second.first += item.m_count;
     995         110 :                         oIter->second.second.insert(
     996          55 :                             oIter->second.second.end(),
     997             :                             std::make_move_iterator(
     998             :                                 item.m_origVectorIndices.begin()),
     999             :                             std::make_move_iterator(
    1000             :                                 item.m_origVectorIndices.end()));
    1001             :                     }
    1002             :                 }
    1003             :             });
    1004             : 
    1005          56 :         freeAndMoveToQueue(queueNodes);
    1006             : 
    1007             :         // Convert the map to an array
    1008          56 :         newLeaves.clear();
    1009      174928 :         for (auto &[key, value] : mapVectors)
    1010             :         {
    1011      174872 :             newLeaves.emplace_back(std::move(key), value.first,
    1012      174872 :                                    std::move(value.second));
    1013             :         }
    1014             : 
    1015         112 :         std::vector<std::pair<ValType, int>> weightedVals;
    1016         112 :         std::vector<BucketItem<T>> vectLeft;
    1017          56 :         std::vector<BucketItem<T>> vectRight;
    1018          56 :         const int ret = insert(std::move(newLeaves), totalCount, weightedVals,
    1019             :                                queueNodes, vectLeft, vectRight, ctxt);
    1020             : #ifdef KDTREE_DEBUG_TIMING
    1021             :         gettimeofday(&tv2, nullptr);
    1022             :         totalTimeRebalancing += (tv2.tv_sec + tv2.tv_usec * 1e-6) -
    1023             :                                 (tv1.tv_sec + tv1.tv_usec * 1e-6);
    1024             : #endif
    1025          56 :         newLeaves = std::vector<BucketItem<T>>();
    1026          56 :         return ret;
    1027             :     }
    1028             :     else
    1029             :     {
    1030           1 :         return static_cast<int>(m_bucketItems.size());
    1031             :     }
    1032             : }
    1033             : 
    1034             : #endif  // KDTREE_INCLUDED

Generated by: LCOV version 1.14