Line data Source code
1 : /******************************************************************************
2 : *
3 : * Project: NITF Read/Write Library
4 : * Purpose: Pairwise Nearest Neighbor (PNN) clustering for Vector
5 : * Quantization (VQ) using a KDTree.
6 : * Author: Even Rouault, even dot rouault at spatialys dot com
7 : *
8 : **********************************************************************
9 : * Copyright (c) 2026, T-Kartor
10 : *
11 : * SPDX-License-Identifier: MIT
12 : ****************************************************************************/
13 :
14 : #ifndef KDTREE_INCLUDED
15 : #define KDTREE_INCLUDED
16 :
17 : /**
18 : * This file implements Pairwise Nearest Neighbor (PNN) clustering for Vector
19 : * Quantization (VQ) using a KDTree.
20 : *
21 : * It implements paper "A New Vector Quantization Clustering Algorithm", by
22 : * William H. Equitz, from IEEE Transactions on Acoustics, Speech, and Signal
23 : * Processing, Vol. 37, Issue 10, October 1989. DOI: 10.1109/29.35395
24 : * https://ieeexplore.ieee.org/document/35395 (behind paywall)
25 : *
26 : * A higher level (freely accessible) and more generic paper on PNN clustering
27 : * is also available at
28 : * https://www.researchgate.net/publication/27661047_Pairwise_Nearest_Neighbor_Method_Revisited
29 : *
30 : * Papers "Analysis of Compression Techniques for Common Mapping Stdandard (CMS)
31 : * Raster Data" by N.J. Markuson, July 1994 (https://apps.dtic.mil/sti/tr/pdf/ADA283396.pdf)
32 : * and "Compression of Digitized Map Image" by D.A. Southard, March 1992
33 : * (https://apps.dtic.mil/sti/tr/pdf/ADA250707.pdf) analyses VQ compression and
34 : * contain a high-level description of the Equitz paper.
35 : */
36 :
37 : #include <cassert>
38 : #include <cstdio>
39 :
40 : #include <algorithm>
41 : #include <array>
42 : #include <deque>
43 : #include <iterator>
44 : #include <functional>
45 : #include <limits>
46 : #include <map>
47 : #include <memory>
48 : #include <set>
49 : #include <utility>
50 : #include <vector>
51 :
52 : #include "cpl_error.h"
53 :
54 : // #define DEBUG_INVARIANTS
55 :
56 : #ifdef KDTREE_DEBUG_TIMING
57 : #include <sys/time.h>
58 :
59 : static double totalTimeRebalancing = 0;
60 : static double totalTimeCentroid = 0;
61 : static double totalTimeStats = 0;
62 :
63 : #endif
64 :
65 : #if defined(__GNUC__) && !defined(__clang__)
66 : #pragma GCC optimize("unroll-loops")
67 : #endif
68 :
69 : #if (defined(__x86_64__) || defined(_M_X64)) && !defined(KDTREE_DISABLE_SIMD)
70 : #define KDTREE_USE_SSE2
71 : #endif
72 :
73 : #ifdef KDTREE_USE_SSE2
74 : #include <emmintrin.h>
75 : #endif
76 :
77 : /************************************************************************/
78 : /* Vector<T> */
79 : /************************************************************************/
80 :
81 : /** "Interface" of a "vector" of dimension DIM_COUNT to insert and cluster in
82 : * a PNNKDTree.
83 : *
84 : * Below functions must be implemented in classes that specialize Vector: there
85 : * is no default generic implementation.
86 : *
87 : * There are no constraints on the T type.
88 : */
89 : template <class T> class Vector
90 : {
91 : public:
92 : /** Returns the dimension of the vector. */
93 : static constexpr int DIM_COUNT = -1;
94 :
95 : /** Whether the get() method returns uint8_t values.
96 : * Used for speed optimizations.
97 : */
98 : static constexpr bool getReturnUInt8 = false;
99 :
100 : /** Returns the k(th) value of the vector, with k in [0, DIM_COUNT-1] range.
101 : *
102 : * The actual returned type might not be double, but must be convertible to
103 : * double.
104 : */
105 : double get(int k, const T &ctxt) const /* = 0 */;
106 :
107 : #ifdef KDTREE_USE_SSE2
108 : /** Whether the computeHeightSumAndSumSquareSSE2() method is implemented.
109 : * Used for speed optimizations.
110 : */
111 : static constexpr bool hasComputeHeightSumAndSumSquareSSE2 = false;
112 :
113 : /** The function must do the equivalent of:
114 : *
115 : * for (int i = 0; i < 8; ++i )
116 : * {
117 : * int val = item.m_vec.get(k + i, ctxt);
118 : * int valMulCount = val * item.m_count;
119 : * {sum0, sum1}[i] += valMulCount;
120 : * {sumSquare0_lo, sumSquare0_hi,sumSquare1_lo, sumSquare1_hi}[i] += val * valMulCount;
121 : * }
122 : *
123 : * k is in the [0, DIM_COUNT-8-1] range (and generally a multiple of 8).
124 : */
125 : void computeHeightSumAndSumSquareSSE2(int k, const T &ctxt, int count,
126 : __m128i &sum0, __m128i &sumSquare0_lo,
127 : __m128i &sumSquare0_hi, __m128i &sum1,
128 : __m128i &sumSquare1_lo,
129 : __m128i &sumSquare1_hi) const
130 : /* = 0 */;
131 : #endif
132 :
133 : /** Returns the squared distance between this vector and other.
134 : * It must be symmetric, that is this->squared_distance(other, ctx) must
135 : * be equal to other.squared_distance(*this, ctx).
136 : */
137 : double squared_distance(const Vector &other, const T &ctxt) const /* = 0 */;
138 :
139 : /** Whether the compute_four_squared_distances() method is implemented
140 : * Used for speed optimizations.
141 : */
142 : static constexpr bool hasComputeFourSquaredDistances = false;
143 :
144 : /** Equivalent of
145 : *
146 : * for(int i = 0; i < 4; ++i)
147 : * {
148 : * tabSquaredDist[i] = squared_distance(*(other[i]), ctxt);
149 : * }
150 : */
151 : void compute_four_squared_distances(
152 : const std::array<const Vector *const, 4> &others,
153 : std::array<int, 4> & /* out */ tabSquaredDist, const T &ctxt) const
154 : /* = 0 */;
155 :
156 : /** Computes a new vector that is the centroid of vector a of weight nA,
157 : * and vector b of weight nB.
158 : */
159 : static Vector centroid(const Vector &a, int nA, const Vector &b, int nB,
160 : const T &ctxt) /* = 0 */;
161 : };
162 :
163 : /************************************************************************/
164 : /* BucketItem<T> */
165 : /************************************************************************/
166 :
167 : /** Definition of an item placed in a bucket of a PNNKDTree.
168 : *
169 : * This class does not need to be specialized.
170 : */
171 8580883 : template <class T> struct BucketItem
172 : {
173 : public:
174 : /** Value vector */
175 : Vector<T> m_vec;
176 :
177 : /** Type of elements in m_origVectorIndices */
178 : using IdxType = int;
179 :
180 : /** Vector that points to indices in the original value space that evaluate
181 : * to m_vec.
182 : * Typically m_origVectorIndices.size() == m_count, but
183 : * the clustering algorithm will not enforce it. It will just concatenate
184 : * m_origVectorIndices from different BucketItem when merging them.
185 : */
186 : std::vector<IdxType> m_origVectorIndices;
187 :
188 : /** Number of samples that have the value of m_vec */
189 : int m_count;
190 :
191 : /** Constructor */
192 525399 : BucketItem(const Vector<T> &vec, int count,
193 : std::vector<IdxType> &&origVectorIndices)
194 525399 : : m_vec(vec), m_origVectorIndices(std::move(origVectorIndices)),
195 525399 : m_count(count)
196 : {
197 525399 : }
198 :
199 11266998 : BucketItem(BucketItem &&) = default;
200 : BucketItem &operator=(BucketItem &&) = default;
201 :
202 : private:
203 : BucketItem(const BucketItem &) = delete;
204 : BucketItem &operator=(const BucketItem &) = delete;
205 : };
206 :
207 : /************************************************************************/
208 : /* PNNKDTree<T> */
209 : /************************************************************************/
210 :
211 : /**
212 : * KDTree designed for Pairwise Nearest Neighbor (PNN) clustering for Vector
213 : * Quantization (VQ).
214 : *
215 : * This class does not need to be specialized.
216 : */
217 : template <class T> class PNNKDTree
218 : {
219 : public:
220 61712 : PNNKDTree() = default;
221 :
222 : /* Inserts value vectors with their cardinality in the KD-Tree.
223 : *
224 : * This method must be called only once.
225 : *
226 : * Returns the initial count of buckets, that must be passed as an input
227 : * to cluster().
228 : */
229 : int insert(std::vector<BucketItem<T>> &&vectors, const T &ctxt);
230 :
231 : /** Iterate over leaf nodes (that contain buckets) */
232 : void iterateOverLeaves(const std::function<void(PNNKDTree &)> &f);
233 :
234 : /** Perform clustering to reduce the number of buckets from initialBucketCount
235 : * to targetCount.
236 : *
237 : * It modifies the tree structure, and returns the achieved number of
238 : * buckets (<= targetCount).
239 : */
240 : int cluster(int initialBucketCount, int targetCount, const T &ctxt);
241 :
242 : /** Returns the bucket items for this node. */
243 : inline const std::vector<BucketItem<T>> &bucketItems() const
244 : {
245 : return m_bucketItems;
246 : }
247 :
248 : /** Returns the bucket items for this node. */
249 9810 : inline std::vector<BucketItem<T>> &bucketItems()
250 : {
251 9810 : return m_bucketItems;
252 : }
253 :
254 : private:
255 : static constexpr int BUCKET_MAX_SIZE = 8;
256 :
257 : /** Left node. When non null, m_right is also non null, and m_bucketItems is empty. */
258 : std::unique_ptr<PNNKDTree> m_left{};
259 :
260 : /** Right node. When non null, m_left is also non null, and m_bucketItems is empty. */
261 : std::unique_ptr<PNNKDTree> m_right{};
262 :
263 : /** Contains items that form a bucket. The bucket is nominally at most BUCKET_MAX_SIZE
264 : * large (maybe transiently slightly larger during clustering operations).
265 : *
266 : * m_bucketItems is non empty only on leaf nodes.
267 : */
268 : std::vector<BucketItem<T>> m_bucketItems{};
269 :
270 : /** Data type returned by Vector<T>::get() */
271 : using ValType = decltype(std::declval<Vector<T>>().get(
272 : 0, *static_cast<const T *>(nullptr)));
273 :
274 : /** Clean the current node and move it to queueNodes.
275 : *
276 : * This saves dynamic allocation and de-allocation of nodes when rebalancing.
277 : */
278 : void freeAndMoveToQueue(std::deque<std::unique_ptr<PNNKDTree>> &queueNodes);
279 :
280 : int insert(std::vector<BucketItem<T>> &&vectors, int totalCount,
281 : std::vector<std::pair<ValType, int>> &weightedVals,
282 : std::deque<std::unique_ptr<PNNKDTree>> &queueNodes,
283 : std::vector<BucketItem<T>> &vectLeft,
284 : std::vector<BucketItem<T>> &vectRight, const T &ctxt);
285 :
286 : /** Rebalance the KD-Tree. Current implementation fully rebuilds a new
287 : * KD-Tree using the insert() algorithm
288 : */
289 : int rebalance(const T &ctxt, std::vector<BucketItem<T>> &newLeaves,
290 : std::deque<std::unique_ptr<PNNKDTree>> &queueNodes);
291 : };
292 :
293 : /************************************************************************/
294 : /* PNNKDTree<T>::insert() */
295 : /************************************************************************/
296 :
297 : template <class T>
298 60 : int PNNKDTree<T>::insert(std::vector<BucketItem<T>> &&vectors, const T &ctxt)
299 : {
300 60 : assert(m_left == nullptr);
301 60 : assert(m_right == nullptr);
302 60 : assert(m_bucketItems.empty());
303 :
304 60 : int totalCount = 0;
305 155796 : for (const auto &it : vectors)
306 : {
307 155736 : totalCount += it.m_count;
308 : }
309 120 : std::vector<std::pair<ValType, int>> weightedVals;
310 120 : std::deque<std::unique_ptr<PNNKDTree>> queueNodes;
311 120 : std::vector<BucketItem<T>> vectLeft;
312 120 : std::vector<BucketItem<T>> vectRight;
313 60 : if (totalCount == 0)
314 0 : return 0;
315 60 : return insert(std::move(vectors), totalCount, weightedVals, queueNodes,
316 60 : vectLeft, vectRight, ctxt);
317 : }
318 :
319 : /************************************************************************/
320 : /* PNNKDTree<T>::insert() */
321 : /************************************************************************/
322 :
323 : template <class T>
324 130804 : int PNNKDTree<T>::insert(std::vector<BucketItem<T>> &&vectors, int totalCount,
325 : std::vector<std::pair<ValType, int>> &weightedVals,
326 : std::deque<std::unique_ptr<PNNKDTree>> &queueNodes,
327 : std::vector<BucketItem<T>> &vectLeft,
328 : std::vector<BucketItem<T>> &vectRight, const T &ctxt)
329 : {
330 : #ifdef DEBUG_INVARIANTS
331 : std::map<Vector<T>, int> mapValuesToBucketIdx;
332 : for (int i = 0; i < static_cast<int>(vectors.size()); ++i)
333 : {
334 : CPLAssert(mapValuesToBucketIdx.find(vectors[i].m_vec) ==
335 : mapValuesToBucketIdx.end());
336 : mapValuesToBucketIdx[vectors[i].m_vec] = i;
337 : }
338 : #endif
339 :
340 130804 : if (vectors.size() <= BUCKET_MAX_SIZE)
341 : {
342 65460 : m_bucketItems = std::move(vectors);
343 65460 : return static_cast<int>(m_bucketItems.size());
344 : }
345 :
346 : #ifdef KDTREE_DEBUG_TIMING
347 : struct timeval tv1, tv2;
348 : gettimeofday(&tv1, nullptr);
349 : #endif
350 :
351 : // Find dimension with maximum variance
352 65344 : double maxM2 = 0;
353 65344 : int maxM2_k = 0;
354 :
355 850138 : for (int k = 0; k < Vector<T>::DIM_COUNT; ++k)
356 : {
357 : if constexpr (Vector<T>::getReturnUInt8)
358 : {
359 784794 : constexpr int MAX_BYTE_VALUE = std::numeric_limits<uint8_t>::max();
360 784794 : bool canUseOptimization =
361 784794 : (totalCount <= std::numeric_limits<int64_t>::max() /
362 : (MAX_BYTE_VALUE * MAX_BYTE_VALUE));
363 784794 : if (canUseOptimization)
364 : {
365 784794 : int maxCountPerVector = 0;
366 59746727 : for (const auto &item : vectors)
367 : {
368 58961909 : maxCountPerVector =
369 58961909 : std::max(maxCountPerVector, item.m_count);
370 : }
371 784794 : canUseOptimization = (maxCountPerVector <=
372 784794 : std::numeric_limits<int32_t>::max() /
373 : (MAX_BYTE_VALUE * MAX_BYTE_VALUE));
374 : }
375 784794 : if (canUseOptimization)
376 : {
377 : // Do statistics computation in integer domain if possible.
378 :
379 : #if !(defined(__i386__) || defined(_M_IX86))
380 : // Below code requires more than 8 general purpose registers,
381 : // so exclude i386.
382 :
383 776376 : constexpr int VALS_AT_ONCE = 4;
384 : if constexpr ((Vector<T>::DIM_COUNT % VALS_AT_ONCE) == 0)
385 : {
386 : #ifdef KDTREE_USE_SSE2
387 776376 : constexpr int TWICE_VALS_AT_ONCE = 2 * VALS_AT_ONCE;
388 : if constexpr ((Vector<T>::DIM_COUNT % TWICE_VALS_AT_ONCE) ==
389 : 0 &&
390 : Vector<
391 : T>::hasComputeHeightSumAndSumSquareSSE2)
392 : {
393 : __m128i sum0 = _mm_setzero_si128();
394 : __m128i sumSquare0_lo = _mm_setzero_si128();
395 : __m128i sumSquare0_hi = _mm_setzero_si128();
396 : __m128i sum1 = _mm_setzero_si128();
397 : __m128i sumSquare1_lo = _mm_setzero_si128();
398 : __m128i sumSquare1_hi = _mm_setzero_si128();
399 :
400 : for (const auto &item : vectors)
401 : {
402 : item.m_vec.computeHeightSumAndSumSquareSSE2(
403 : k, ctxt, item.m_count, sum0, sumSquare0_lo,
404 : sumSquare0_hi, sum1, sumSquare1_lo,
405 : sumSquare1_hi);
406 : }
407 : int64_t sumSquares[TWICE_VALS_AT_ONCE];
408 : _mm_storeu_si128(
409 : reinterpret_cast<__m128i *>(sumSquares + 0),
410 : sumSquare0_lo);
411 : _mm_storeu_si128(
412 : reinterpret_cast<__m128i *>(sumSquares + 2),
413 : sumSquare0_hi);
414 : _mm_storeu_si128(
415 : reinterpret_cast<__m128i *>(sumSquares + 4),
416 : sumSquare1_lo);
417 : _mm_storeu_si128(
418 : reinterpret_cast<__m128i *>(sumSquares + 6),
419 : sumSquare1_hi);
420 : int sums[TWICE_VALS_AT_ONCE];
421 : _mm_storeu_si128(reinterpret_cast<__m128i *>(sums),
422 : sum0);
423 : _mm_storeu_si128(
424 : reinterpret_cast<__m128i *>(sums + VALS_AT_ONCE),
425 : sum1);
426 : for (int i = 0; i < TWICE_VALS_AT_ONCE; ++i)
427 : {
428 : const double M2 = static_cast<double>(
429 : sumSquares[i] * totalCount -
430 : static_cast<int64_t>(sums[i]) * sums[i]);
431 : if (M2 > maxM2)
432 : {
433 : maxM2 = M2;
434 : maxM2_k = k + i;
435 : }
436 : }
437 : k += TWICE_VALS_AT_ONCE - 1;
438 : }
439 : else
440 : #endif
441 : {
442 776376 : int sum0 = 0;
443 776376 : int sum1 = 0;
444 776376 : int sum2 = 0;
445 776376 : int sum3 = 0;
446 776376 : int64_t sumSquare0 = 0;
447 776376 : int64_t sumSquare1 = 0;
448 776376 : int64_t sumSquare2 = 0;
449 776376 : int64_t sumSquare3 = 0;
450 49671392 : for (const auto &item : vectors)
451 : {
452 48894980 : const int val0 = item.m_vec.get(k + 0, ctxt);
453 48894980 : const int val1 = item.m_vec.get(k + 1, ctxt);
454 48894980 : const int val2 = item.m_vec.get(k + 2, ctxt);
455 48894980 : const int val3 = item.m_vec.get(k + 3, ctxt);
456 48894980 : const int val0MulCount = val0 * item.m_count;
457 48894980 : const int val1MulCount = val1 * item.m_count;
458 48894980 : const int val2MulCount = val2 * item.m_count;
459 48894980 : const int val3MulCount = val3 * item.m_count;
460 48894980 : sum0 += val0MulCount;
461 48894980 : sum1 += val1MulCount;
462 48894980 : sum2 += val2MulCount;
463 48894980 : sum3 += val3MulCount;
464 : // It's fine to cast to int64 after multiplication
465 48894980 : sumSquare0 +=
466 48894980 : cpl::fits_on<int64_t>(val0 * val0MulCount);
467 48894980 : sumSquare1 +=
468 48894980 : cpl::fits_on<int64_t>(val1 * val1MulCount);
469 48894980 : sumSquare2 +=
470 48894980 : cpl::fits_on<int64_t>(val2 * val2MulCount);
471 48894980 : sumSquare3 +=
472 48894980 : cpl::fits_on<int64_t>(val3 * val3MulCount);
473 : }
474 :
475 776376 : const double M2[] = {
476 776376 : static_cast<double>(sumSquare0 * totalCount -
477 776376 : static_cast<int64_t>(sum0) *
478 776376 : sum0),
479 776376 : static_cast<double>(sumSquare1 * totalCount -
480 776376 : static_cast<int64_t>(sum1) *
481 776376 : sum1),
482 776376 : static_cast<double>(sumSquare2 * totalCount -
483 776376 : static_cast<int64_t>(sum2) *
484 776376 : sum2),
485 776376 : static_cast<double>(sumSquare3 * totalCount -
486 776376 : static_cast<int64_t>(sum3) *
487 776376 : sum3)};
488 3881880 : for (int i = 0; i < VALS_AT_ONCE; ++i)
489 : {
490 3105508 : if (M2[i] > maxM2)
491 : {
492 301946 : maxM2 = M2[i];
493 301946 : maxM2_k = k + i;
494 : }
495 : }
496 776376 : k += VALS_AT_ONCE - 1;
497 : }
498 : }
499 : else
500 : #endif
501 : {
502 0 : int sum = 0;
503 0 : int64_t sumSquare = 0;
504 0 : for (const auto &item : vectors)
505 : {
506 0 : const int val = item.m_vec.get(k, ctxt);
507 0 : const int valMulCount = val * item.m_count;
508 0 : sum += valMulCount;
509 : // It's fine to cast to int64 after multiplication
510 0 : sumSquare += cpl::fits_on<int64_t>(val * valMulCount);
511 : }
512 0 : const double M2 =
513 0 : static_cast<double>(sumSquare * totalCount -
514 0 : static_cast<int64_t>(sum) * sum);
515 0 : if (M2 > maxM2)
516 : {
517 0 : maxM2 = M2;
518 0 : maxM2_k = k;
519 : }
520 : }
521 776376 : continue;
522 : }
523 : }
524 :
525 : // Generic code path:
526 :
527 : // First pass to compute mean value along k(th) dimension
528 8418 : double sum = 0;
529 10075335 : for (const auto &item : vectors)
530 : {
531 10066929 : sum += static_cast<double>(item.m_vec.get(k, ctxt)) * item.m_count;
532 : }
533 8418 : const double mean = sum / totalCount;
534 : // Second pass to compute M2 value (n * variance) along k(th) dimension
535 8418 : double M2 = 0;
536 10075335 : for (const auto &item : vectors)
537 : {
538 10066929 : const double delta =
539 10066929 : static_cast<double>(item.m_vec.get(k, ctxt)) - mean;
540 10066929 : M2 += delta * delta * item.m_count;
541 : }
542 8418 : if (M2 > maxM2)
543 : {
544 1858 : maxM2 = M2;
545 1858 : maxM2_k = k;
546 : }
547 : }
548 :
549 : #ifdef KDTREE_DEBUG_TIMING
550 : gettimeofday(&tv2, nullptr);
551 : totalTimeStats +=
552 : (tv2.tv_sec + tv2.tv_usec * 1e-6) - (tv1.tv_sec + tv1.tv_usec * 1e-6);
553 : #endif
554 :
555 : // Find median value along that dimension
556 65344 : weightedVals.reserve(vectors.size());
557 65344 : weightedVals.clear();
558 4363381 : for (const auto &item : vectors)
559 : {
560 4298038 : const auto d = item.m_vec.get(maxM2_k, ctxt);
561 4298038 : weightedVals.emplace_back(d, item.m_count);
562 : }
563 :
564 65344 : std::sort(weightedVals.begin(), weightedVals.end(),
565 33055123 : [](const auto &a, const auto &b) { return a.first < b.first; });
566 :
567 65344 : auto median = weightedVals[0].first;
568 65344 : int cumulativeCount = 0;
569 65344 : const int targetCount = totalCount / 2;
570 2176236 : for (const auto &[value, count] : weightedVals)
571 : {
572 2176236 : cumulativeCount += count;
573 2176236 : if (cumulativeCount > targetCount)
574 : {
575 65344 : median = value;
576 65344 : break;
577 : }
578 : }
579 :
580 : // Split the original vectors in a "left" half that is below or equal to
581 : // the median and a "right" half that is above.
582 65344 : vectLeft.clear();
583 65344 : vectLeft.reserve(weightedVals.size() / 2);
584 65344 : vectRight.clear();
585 65344 : vectRight.reserve(weightedVals.size() / 2);
586 65344 : int countLeft = 0;
587 65344 : int countRight = 0;
588 4363381 : for (auto &item : vectors)
589 : {
590 4298038 : if (item.m_vec.get(maxM2_k, ctxt) <= median)
591 : {
592 2764333 : countLeft += item.m_count;
593 2764333 : vectLeft.push_back(std::move(item));
594 : }
595 : else
596 : {
597 1533705 : countRight += item.m_count;
598 1533705 : vectRight.push_back(std::move(item));
599 : }
600 : }
601 :
602 : // In some cases, the median can actually be the maximum value
603 : // Then, retry but exclusing the median itself.
604 65344 : if (vectLeft.empty() || vectRight.empty())
605 : {
606 15498 : if (!vectLeft.empty())
607 15498 : vectors = std::move(vectLeft);
608 : else
609 0 : vectors = std::move(vectRight);
610 15498 : vectLeft.clear();
611 15498 : vectRight.clear();
612 15498 : countLeft = 0;
613 15498 : countRight = 0;
614 498442 : for (auto &item : vectors)
615 : {
616 482944 : if (item.m_vec.get(maxM2_k, ctxt) < median)
617 : {
618 239646 : countLeft += item.m_count;
619 239646 : vectLeft.push_back(std::move(item));
620 : }
621 : else
622 : {
623 243298 : countRight += item.m_count;
624 243298 : vectRight.push_back(std::move(item));
625 : }
626 : }
627 :
628 : // Normally we shouldn't reach that point, unless the initial samples
629 : // where all identical, and thus clustering wasn't needed.
630 15498 : if (vectLeft.empty() || vectRight.empty())
631 : {
632 0 : CPLError(CE_Failure, CPLE_AppDefined,
633 : "Unexpected situation in %s:%d\n", __FILE__, __LINE__);
634 0 : return 0;
635 : }
636 : }
637 65344 : vectors.clear();
638 :
639 : // Allocate (or recycle) left and right nodes
640 65344 : if (!queueNodes.empty())
641 : {
642 34518 : m_left = std::move(queueNodes.back());
643 34518 : queueNodes.pop_back();
644 : }
645 : else
646 30826 : m_left = std::make_unique<PNNKDTree<T>>();
647 :
648 65344 : if (!queueNodes.empty())
649 : {
650 34518 : m_right = std::move(queueNodes.back());
651 34518 : queueNodes.pop_back();
652 : }
653 : else
654 30826 : m_right = std::make_unique<PNNKDTree<T>>();
655 :
656 : // Recursively insert vectLeft in m_left and vectRight in m_right
657 65344 : std::vector<BucketItem<T>> vectTmp;
658 : // Sort for replicability of results across platforms
659 34517420 : const auto sortFunc = [](const BucketItem<T> &a, const BucketItem<T> &b)
660 34517426 : { return a.m_vec < b.m_vec; };
661 65344 : std::sort(vectLeft.begin(), vectLeft.end(), sortFunc);
662 65344 : std::sort(vectRight.begin(), vectRight.end(), sortFunc);
663 65344 : int retLeft = m_left->insert(std::move(vectLeft), countLeft, weightedVals,
664 : queueNodes, vectors, vectTmp, ctxt);
665 130688 : int retRight =
666 65344 : m_right->insert(std::move(vectRight), countRight, weightedVals,
667 : queueNodes, vectors, vectTmp, ctxt);
668 65344 : vectLeft = std::vector<BucketItem<T>>();
669 65344 : vectRight = std::vector<BucketItem<T>>();
670 65344 : return (retLeft == 0 || retRight == 0) ? 0 : retLeft + retRight;
671 : }
672 :
673 : /************************************************************************/
674 : /* PNNKDTree<T>::iterateOverLeaves() */
675 : /************************************************************************/
676 :
677 : template <class T>
678 358846 : void PNNKDTree<T>::iterateOverLeaves(const std::function<void(PNNKDTree &)> &f)
679 : {
680 358846 : if (m_left && m_right)
681 : {
682 179284 : m_left->iterateOverLeaves(f);
683 179284 : m_right->iterateOverLeaves(f);
684 : }
685 : else
686 : {
687 179562 : f(*this);
688 : }
689 358846 : }
690 :
691 : /************************************************************************/
692 : /* PNNKDTree<T>::cluster() */
693 : /************************************************************************/
694 :
695 : template <class T>
696 30 : int PNNKDTree<T>::cluster(int initialBucketCount, int targetCount,
697 : const T &ctxt)
698 : {
699 30 : int curBucketCount = initialBucketCount;
700 :
701 60 : std::vector<BucketItem<T>> newLeaves;
702 30 : newLeaves.reserve(initialBucketCount);
703 60 : std::deque<std::unique_ptr<PNNKDTree>> queueNodes;
704 :
705 : struct TupleInfo
706 : {
707 : PNNKDTree *bucket;
708 : int i;
709 : int j;
710 : double increasedDistortion;
711 : };
712 :
713 30 : std::vector<TupleInfo> distCollector;
714 30 : distCollector.reserve(curBucketCount);
715 :
716 30 : int iter = 0;
717 : #ifdef DEBUG_INVARIANTS
718 : std::map<Vector<T>, int> mapValuesToBucketIdx;
719 : #endif
720 166 : while (curBucketCount > targetCount)
721 : {
722 : /* For each bucket (leaf node), compute the increase in distortion
723 : * that would result in merging each (i,j) vector it contains.
724 : */
725 136 : distCollector.clear();
726 136 : iterateOverLeaves(
727 114015 : [&distCollector, &ctxt](PNNKDTree &bucket)
728 : {
729 114015 : const int itemsCount =
730 114015 : static_cast<int>(bucket.m_bucketItems.size());
731 519267 : for (int i = 0; i < itemsCount - 1; ++i)
732 : {
733 405252 : const auto &itemI = bucket.m_bucketItems[i];
734 4068 : int j = i + 1;
735 : if constexpr (Vector<T>::hasComputeFourSquaredDistances)
736 : {
737 401214 : constexpr int CHUNK_SIZE = 4;
738 401214 : if (j + CHUNK_SIZE <= itemsCount)
739 : {
740 : std::array<const Vector<T> *const, CHUNK_SIZE>
741 494164 : otherVectors = {
742 123541 : &(bucket.m_bucketItems[j + 0].m_vec),
743 123541 : &(bucket.m_bucketItems[j + 1].m_vec),
744 123541 : &(bucket.m_bucketItems[j + 2].m_vec),
745 123541 : &(bucket.m_bucketItems[j + 3].m_vec)};
746 : std::array<int, CHUNK_SIZE> tabSquaredDist;
747 123541 : itemI.m_vec.compute_four_squared_distances(
748 : otherVectors, tabSquaredDist, ctxt);
749 617705 : for (int subj = 0; subj < CHUNK_SIZE; ++subj)
750 : {
751 : const auto &itemJ =
752 494164 : bucket.m_bucketItems[j + subj];
753 494164 : const double increasedDistortion =
754 494164 : static_cast<double>(itemI.m_count) *
755 494164 : itemJ.m_count * tabSquaredDist[subj] /
756 494164 : (itemI.m_count + itemJ.m_count);
757 : TupleInfo ti;
758 494164 : ti.bucket = &bucket;
759 494164 : ti.i = i;
760 494164 : ti.j = j + subj;
761 494164 : ti.increasedDistortion = increasedDistortion;
762 494164 : distCollector.push_back(std::move(ti));
763 : }
764 :
765 123541 : j += CHUNK_SIZE;
766 : }
767 : }
768 1049518 : for (; j < itemsCount; ++j)
769 : {
770 644268 : const auto &itemJ = bucket.m_bucketItems[j];
771 644268 : const double increasedDistortion =
772 1288536 : static_cast<double>(itemI.m_count) * itemJ.m_count *
773 644268 : itemI.m_vec.squared_distance(itemJ.m_vec, ctxt) /
774 644268 : (itemI.m_count + itemJ.m_count);
775 : TupleInfo ti;
776 644268 : ti.bucket = &bucket;
777 644268 : ti.i = i;
778 644268 : ti.j = j;
779 644268 : ti.increasedDistortion = increasedDistortion;
780 644268 : distCollector.push_back(std::move(ti));
781 : }
782 : }
783 : });
784 :
785 : /** Identify the median of the increased distortion */
786 136 : const int bucketCountToMerge =
787 272 : std::min(static_cast<int>(distCollector.size() / 2),
788 136 : curBucketCount - targetCount);
789 7794665 : const auto sortFunc = [](const TupleInfo &a, const TupleInfo &b)
790 : {
791 11637183 : return a.increasedDistortion < b.increasedDistortion ||
792 3843494 : (a.increasedDistortion == b.increasedDistortion &&
793 1779205 : (a.bucket->m_bucketItems[0].m_vec <
794 1779645 : b.bucket->m_bucketItems[0].m_vec ||
795 683271 : (a.bucket->m_bucketItems[0].m_vec ==
796 683271 : b.bucket->m_bucketItems[0].m_vec &&
797 7826682 : (a.i < b.i || (a.i == b.i && a.j < b.j)))));
798 : };
799 136 : const auto median_iter = distCollector.begin() + bucketCountToMerge;
800 136 : std::nth_element(distCollector.begin(), median_iter,
801 : distCollector.end(), sortFunc);
802 : /** Sort elements by increasing increasedDistortion, but only for
803 : * the first half of the array
804 : */
805 136 : std::sort(distCollector.begin(), median_iter, sortFunc);
806 :
807 : static_assert(BUCKET_MAX_SIZE <= sizeof(uint32_t) * 8);
808 136 : std::map<PNNKDTree *, uint32_t> invalidatedClusters;
809 :
810 136 : int expectedBucketCount = curBucketCount;
811 :
812 : // Merge all the tuple of vectors whose increasd distortion is lower
813 : // than the median.
814 291486 : for (auto oIterCollector = distCollector.begin();
815 582836 : oIterCollector != median_iter; ++oIterCollector)
816 : {
817 291350 : const auto &tupleInfo = *oIterCollector;
818 : // assert( tupleInfo.increasedDistortion <= median_iter->increasedDistortion );
819 291350 : auto oIter = invalidatedClusters.find(tupleInfo.bucket);
820 291350 : if (oIter != invalidatedClusters.end())
821 : {
822 : // Be careful not to merge a (i,j) tuple whose at least one of
823 : // the element has already been merged.
824 : // (this aspect is not covered by the Equitz's paper)
825 221820 : if ((oIter->second &
826 221820 : ((1U << tupleInfo.i) | (1U << tupleInfo.j))) != 0)
827 : {
828 183936 : continue;
829 : }
830 : }
831 : else
832 : {
833 69530 : oIter =
834 0 : invalidatedClusters.insert(std::pair(tupleInfo.bucket, 0))
835 69530 : .first;
836 : }
837 :
838 107414 : auto &bucketItemI = tupleInfo.bucket->m_bucketItems[tupleInfo.i];
839 : const auto &bucketItemJ =
840 107414 : tupleInfo.bucket->m_bucketItems[tupleInfo.j];
841 107414 : auto origVectorIndices = std::move(bucketItemI.m_origVectorIndices);
842 107414 : origVectorIndices.insert(origVectorIndices.end(),
843 : bucketItemJ.m_origVectorIndices.begin(),
844 : bucketItemJ.m_origVectorIndices.end());
845 :
846 : #ifdef KDTREE_DEBUG_TIMING
847 : struct timeval tv1, tv2;
848 : gettimeofday(&tv1, nullptr);
849 : #endif
850 107414 : auto newVal = Vector<T>::centroid(
851 107414 : bucketItemI.m_vec, bucketItemI.m_count, bucketItemJ.m_vec,
852 107414 : bucketItemJ.m_count, ctxt);
853 :
854 : #ifdef KDTREE_DEBUG_TIMING
855 : gettimeofday(&tv2, nullptr);
856 : totalTimeCentroid += (tv2.tv_sec + tv2.tv_usec * 1e-6) -
857 : (tv1.tv_sec + tv1.tv_usec * 1e-6);
858 : #endif
859 :
860 : // Look if there is an existing item in the bucket with the new
861 : // vector value
862 107414 : int bucketItemIdx = -1;
863 481110 : for (int i = 0;
864 481110 : i < static_cast<int>(tupleInfo.bucket->m_bucketItems.size());
865 : ++i)
866 : {
867 820620 : if ((oIter->second & (1U << i)) == 0 &&
868 376917 : tupleInfo.bucket->m_bucketItems[i].m_vec == newVal)
869 : {
870 70007 : bucketItemIdx = i;
871 70007 : break;
872 : }
873 : }
874 107414 : oIter->second |= ((1U << tupleInfo.i) | (1U << tupleInfo.j));
875 107414 : int newCount = bucketItemI.m_count + bucketItemJ.m_count;
876 107414 : if (bucketItemIdx >= 0 && bucketItemIdx != tupleInfo.i &&
877 33406 : bucketItemIdx != tupleInfo.j)
878 : {
879 204 : oIter->second |= ((1U << bucketItemIdx));
880 : auto &existingItem =
881 204 : tupleInfo.bucket->m_bucketItems[bucketItemIdx];
882 204 : newCount += existingItem.m_count;
883 204 : origVectorIndices.insert(
884 : origVectorIndices.end(),
885 : std::make_move_iterator(
886 : existingItem.m_origVectorIndices.begin()),
887 : std::make_move_iterator(
888 : existingItem.m_origVectorIndices.end()));
889 : }
890 : // Insert the new bucket item
891 107414 : tupleInfo.bucket->m_bucketItems.emplace_back(
892 107414 : std::move(newVal), newCount, std::move(origVectorIndices));
893 :
894 107414 : --expectedBucketCount;
895 : }
896 :
897 : // Remove items that have been merged
898 69666 : for (auto [node, indices] : invalidatedClusters)
899 : {
900 : // Inside a same bucket, be careful to remove from the end so that
901 : // noted indices are still valid...
902 522421 : for (int i = static_cast<int>(node->m_bucketItems.size()) - 1;
903 522421 : i >= 0; --i)
904 : {
905 452891 : if ((indices & (1U << i)) != 0)
906 : {
907 215032 : node->m_bucketItems.erase(node->m_bucketItems.begin() + i);
908 : }
909 : }
910 :
911 : #ifdef DEBUG_INVARIANTS
912 : mapValuesToBucketIdx.clear();
913 : for (int i = 0; i < static_cast<int>(node->m_bucketItems.size());
914 : ++i)
915 : {
916 : CPLAssert(
917 : mapValuesToBucketIdx.find(node->m_bucketItems[i].m_vec) ==
918 : mapValuesToBucketIdx.end());
919 : mapValuesToBucketIdx[node->m_bucketItems[i].m_vec] = i;
920 : }
921 : #endif
922 : }
923 :
924 : // Rebalance the tree only half of the time, to speed up things a bit
925 : // This is quite arbitrary. Systematic rebalancing could result in
926 : // slightly better results.
927 136 : if ((iter % 2) == 0)
928 79 : curBucketCount = expectedBucketCount;
929 : else
930 57 : curBucketCount = rebalance(ctxt, newLeaves, queueNodes);
931 136 : ++iter;
932 : }
933 :
934 60 : return curBucketCount;
935 : }
936 :
937 : /************************************************************************/
938 : /* PNNKDTree<T>::freeAndMoveToQueue() */
939 : /************************************************************************/
940 :
941 : template <class T>
942 111418 : void PNNKDTree<T>::freeAndMoveToQueue(
943 : std::deque<std::unique_ptr<PNNKDTree>> &queueNodes)
944 : {
945 111418 : m_bucketItems.clear();
946 111418 : if (m_left)
947 : {
948 55681 : m_left->freeAndMoveToQueue(queueNodes);
949 55681 : queueNodes.push_back(std::move(m_left));
950 : }
951 111418 : if (m_right)
952 : {
953 55681 : m_right->freeAndMoveToQueue(queueNodes);
954 55681 : queueNodes.push_back(std::move(m_right));
955 : }
956 111418 : }
957 :
958 : /************************************************************************/
959 : /* PNNKDTree<T>::rebalance() */
960 : /************************************************************************/
961 :
962 : template <class T>
963 57 : int PNNKDTree<T>::rebalance(const T &ctxt,
964 : std::vector<BucketItem<T>> &newLeaves,
965 : std::deque<std::unique_ptr<PNNKDTree>> &queueNodes)
966 : {
967 57 : if (m_left && m_right)
968 : {
969 : #ifdef KDTREE_DEBUG_TIMING
970 : struct timeval tv1, tv2;
971 : gettimeofday(&tv1, nullptr);
972 : #endif
973 : std::map<Vector<T>,
974 : std::pair<int, std::vector<typename BucketItem<T>::IdxType>>>
975 112 : mapVectors;
976 56 : int totalCount = 0;
977 : // Rebuild a new map of vector values -> (count, indices)
978 : // This needs to be a map as we cannot guarantee the uniqueness
979 : // of vector values after the clustering pass
980 56 : iterateOverLeaves(
981 55737 : [&mapVectors, &totalCount](PNNKDTree &bucket)
982 : {
983 230664 : for (auto &item : bucket.m_bucketItems)
984 : {
985 174927 : totalCount += item.m_count;
986 174927 : auto oIter = mapVectors.find(item.m_vec);
987 174927 : if (oIter == mapVectors.end())
988 : {
989 174872 : mapVectors[item.m_vec] = std::make_pair(
990 174872 : item.m_count, std::move(item.m_origVectorIndices));
991 : }
992 : else
993 : {
994 55 : oIter->second.first += item.m_count;
995 110 : oIter->second.second.insert(
996 55 : oIter->second.second.end(),
997 : std::make_move_iterator(
998 : item.m_origVectorIndices.begin()),
999 : std::make_move_iterator(
1000 : item.m_origVectorIndices.end()));
1001 : }
1002 : }
1003 : });
1004 :
1005 56 : freeAndMoveToQueue(queueNodes);
1006 :
1007 : // Convert the map to an array
1008 56 : newLeaves.clear();
1009 174928 : for (auto &[key, value] : mapVectors)
1010 : {
1011 174872 : newLeaves.emplace_back(std::move(key), value.first,
1012 174872 : std::move(value.second));
1013 : }
1014 :
1015 112 : std::vector<std::pair<ValType, int>> weightedVals;
1016 112 : std::vector<BucketItem<T>> vectLeft;
1017 56 : std::vector<BucketItem<T>> vectRight;
1018 56 : const int ret = insert(std::move(newLeaves), totalCount, weightedVals,
1019 : queueNodes, vectLeft, vectRight, ctxt);
1020 : #ifdef KDTREE_DEBUG_TIMING
1021 : gettimeofday(&tv2, nullptr);
1022 : totalTimeRebalancing += (tv2.tv_sec + tv2.tv_usec * 1e-6) -
1023 : (tv1.tv_sec + tv1.tv_usec * 1e-6);
1024 : #endif
1025 56 : newLeaves = std::vector<BucketItem<T>>();
1026 56 : return ret;
1027 : }
1028 : else
1029 : {
1030 1 : return static_cast<int>(m_bucketItems.size());
1031 : }
1032 : }
1033 :
1034 : #endif // KDTREE_INCLUDED
|