Line data Source code
1 : /******************************************************************************
2 : *
3 : * Project: NITF Read/Write Library
4 : * Purpose: Specialization of KDTree for CADRG VQ compression
5 : * Author: Even Rouault, even dot rouault at spatialys dot com
6 : *
7 : **********************************************************************
8 : * Copyright (c) 2026, T-Kartor
9 : *
10 : * SPDX-License-Identifier: MIT
11 : ****************************************************************************/
12 :
13 : #ifndef KDTREE_VQ_CADRG_INCLUDED
14 : #define KDTREE_VQ_CADRG_INCLUDED
15 :
16 : #include "kdtree.h"
17 :
18 : #include <array>
19 : #include <limits>
20 :
21 : #ifdef __GNUC__
22 : #pragma GCC diagnostic push
23 : #pragma GCC diagnostic ignored "-Wold-style-cast"
24 : #pragma GCC diagnostic ignored "-Weffc++"
25 : #endif
26 : #include "../../third_party/libdivide/libdivide.h"
27 : #ifdef __GNUC__
28 : #pragma GCC diagnostic pop
29 : #endif
30 :
31 : #ifdef KDTREE_USE_SSE2
32 :
33 : #include <emmintrin.h>
34 : #ifdef __SSE4_1__
35 : #include <smmintrin.h>
36 : #endif
37 : #ifdef __AVX2__
38 : #include <immintrin.h>
39 : #endif
40 :
41 : namespace
42 : {
43 66205824 : inline __m128i blendv_epi8(__m128i a, __m128i b, __m128i mask)
44 : {
45 : #ifdef __SSE4_1__
46 : return _mm_blendv_epi8(a, b, mask);
47 : #else
48 198616872 : return _mm_or_si128(_mm_andnot_si128(mask, a), _mm_and_si128(mask, b));
49 : #endif
50 : }
51 : } // namespace
52 : #endif
53 :
54 : namespace
55 : {
56 40775900 : template <class T> T square(T x)
57 : {
58 40775900 : return x * x;
59 : }
60 :
61 : /************************************************************************/
62 : /* filled_array() */
63 : /************************************************************************/
64 :
65 : template <typename T, std::size_t N>
66 193532 : constexpr std::array<T, N> filled_array(const T &value)
67 : {
68 : #ifdef __COVERITY__
69 : std::array<T, N> a{};
70 : #else
71 : std::array<T, N> a;
72 : #endif
73 3290046 : for (auto &x : a)
74 3096518 : x = value;
75 193532 : return a;
76 : }
77 : } // namespace
78 :
79 : /************************************************************************/
80 : /* ColorTableBased4x4Pixels */
81 : /************************************************************************/
82 :
83 : struct ColorTableBased4x4Pixels
84 : {
85 : static constexpr int COMP_COUNT = 3;
86 35 : explicit ColorTableBased4x4Pixels(const std::vector<GByte> &R,
87 : const std::vector<GByte> &G,
88 : const std::vector<GByte> &B)
89 35 : : m_R(R), m_G(G), m_B(B), m_RGB({&m_R, &m_G, &m_B})
90 : #if defined(KDTREE_USE_SSE2) && defined(__AVX2__)
91 : ,
92 : m_RGB32({&m_R32, &m_G32, &m_B32})
93 : #endif
94 : {
95 : #if defined(KDTREE_USE_SSE2) && defined(__AVX2__)
96 : for (size_t i = 0; i < R.size(); ++i)
97 : {
98 : m_R32.push_back(R[i]);
99 : m_G32.push_back(G[i]);
100 : m_B32.push_back(B[i]);
101 : }
102 : #endif
103 35 : }
104 :
105 : const std::vector<GByte> &m_R;
106 : const std::vector<GByte> &m_G;
107 : const std::vector<GByte> &m_B;
108 : std::array<const std::vector<GByte> *const, COMP_COUNT> m_RGB;
109 : #if defined(KDTREE_USE_SSE2) && defined(__AVX2__)
110 : std::vector<int32_t> m_R32{}, m_G32{}, m_B32{};
111 : std::array<const std::vector<int32_t> *const, COMP_COUNT> m_RGB32;
112 : #endif
113 : };
114 :
115 : /************************************************************************/
116 : /* Vector<ColorTableBased4x4Pixels> */
117 : /************************************************************************/
118 :
119 : template <> class Vector<ColorTableBased4x4Pixels>
120 : {
121 : public:
122 : static constexpr int PIX_COUNT = 4 * 4;
123 :
124 : private:
125 : static constexpr int COMP_COUNT = ColorTableBased4x4Pixels::COMP_COUNT;
126 :
127 : std::array<GByte, PIX_COUNT> m_vals;
128 :
129 : // cppcheck-suppress uninitMemberVarPrivate
130 : Vector() = default;
131 :
132 : public:
133 1409109 : explicit Vector(const std::array<GByte, PIX_COUNT> &vals) : m_vals(vals)
134 : {
135 1409109 : }
136 :
137 : static constexpr int DIM_COUNT /* specialize */ = COMP_COUNT * PIX_COUNT;
138 :
139 2097152 : inline GByte val(int i) const
140 : {
141 2097152 : return m_vals[i];
142 : }
143 :
144 1 : inline GByte *vals()
145 : {
146 1 : return m_vals.data();
147 : }
148 :
149 135140 : inline const std::array<GByte, PIX_COUNT> &vals() const
150 : {
151 135140 : return m_vals;
152 : }
153 :
154 : static constexpr bool getReturnUInt8 /* specialize */ = true;
155 :
156 224675756 : inline int get(int i,
157 : const ColorTableBased4x4Pixels &ctxt) const /* specialize */
158 : {
159 224675756 : return (*ctxt.m_RGB[i / PIX_COUNT])[m_vals[i % PIX_COUNT]];
160 : }
161 :
162 : #if defined(KDTREE_USE_SSE2)
163 : static constexpr bool hasComputeFourSquaredDistances /* specialize */ =
164 : true;
165 :
166 : #if defined(__SSE4_1__) && defined(__GNUC__)
167 : static constexpr bool hasComputeHeightSumAndSumSquareSSE2 /* specialize */ =
168 : true;
169 :
170 : /************************************************************************/
171 : /* computeHeightSumAndSumSquareSSE2() */
172 : /************************************************************************/
173 :
174 : inline void computeHeightSumAndSumSquareSSE2(
175 : int k, const ColorTableBased4x4Pixels &ctxt, int count, __m128i &sum0,
176 : __m128i &sumSquare0_lo, __m128i &sumSquare0_hi, __m128i &sum1,
177 : __m128i &sumSquare1_lo, __m128i &sumSquare1_hi) const
178 : {
179 : #if defined(__AVX2__)
180 : const int32_t *comp_data = ctxt.m_RGB32[k / PIX_COUNT]->data();
181 : const GByte *pindices = m_vals.data() + (k % PIX_COUNT);
182 : const auto idx = _mm256_cvtepu8_epi32(
183 : _mm_loadl_epi64(reinterpret_cast<const __m128i *>(pindices)));
184 : constexpr int SCALE = static_cast<int>(sizeof(*comp_data));
185 : const auto vals = _mm256_i32gather_epi32(comp_data, idx, SCALE);
186 : const auto vcount = _mm256_set1_epi32(count);
187 : const auto vals_mul_count = _mm256_mullo_epi32(vals, vcount);
188 : sum0 = _mm_add_epi32(sum0, _mm256_castsi256_si128(vals_mul_count));
189 : sum1 = _mm_add_epi32(sum1, _mm256_extracti128_si256(vals_mul_count, 1));
190 : const auto vals_sq_mul_count = _mm256_mullo_epi32(vals, vals_mul_count);
191 : const auto vals0_sq_mul_count =
192 : _mm256_castsi256_si128(vals_sq_mul_count);
193 : const auto vals1_sq_mul_count =
194 : _mm256_extracti128_si256(vals_sq_mul_count, 1);
195 : #else
196 : const GByte *comp_data = ctxt.m_RGB[k / PIX_COUNT]->data();
197 : const GByte *pindices = m_vals.data() + (k % PIX_COUNT);
198 : const auto i32_from_epu8_gather_epu8 =
199 : [](const GByte *base_addr, const GByte *pindices)
200 : {
201 : return _mm_setr_epi32(
202 : base_addr[pindices[0]], base_addr[pindices[1]],
203 : base_addr[pindices[2]], base_addr[pindices[3]]);
204 : };
205 : const auto vals0 = i32_from_epu8_gather_epu8(comp_data, pindices + 0);
206 : const auto vals1 = i32_from_epu8_gather_epu8(comp_data, pindices + 4);
207 : const auto vcount = _mm_set1_epi32(count);
208 : const auto vals0_mul_count = _mm_mullo_epi32(vals0, vcount);
209 : const auto vals1_mul_count = _mm_mullo_epi32(vals1, vcount);
210 : sum0 = _mm_add_epi32(sum0, vals0_mul_count);
211 : sum1 = _mm_add_epi32(sum1, vals1_mul_count);
212 : const auto vals0_sq_mul_count = _mm_mullo_epi32(vals0, vals0_mul_count);
213 : const auto vals1_sq_mul_count = _mm_mullo_epi32(vals1, vals1_mul_count);
214 : #endif
215 : sumSquare0_lo = _mm_add_epi64(
216 : sumSquare0_lo,
217 : _mm_unpacklo_epi32(vals0_sq_mul_count, _mm_setzero_si128()));
218 : sumSquare0_hi = _mm_add_epi64(
219 : sumSquare0_hi,
220 : _mm_unpackhi_epi32(vals0_sq_mul_count, _mm_setzero_si128()));
221 : sumSquare1_lo = _mm_add_epi64(
222 : sumSquare1_lo,
223 : _mm_unpacklo_epi32(vals1_sq_mul_count, _mm_setzero_si128()));
224 : sumSquare1_hi = _mm_add_epi64(
225 : sumSquare1_hi,
226 : _mm_unpackhi_epi32(vals1_sq_mul_count, _mm_setzero_si128()));
227 : }
228 :
229 : #else
230 : static constexpr bool hasComputeHeightSumAndSumSquareSSE2 /* specialize */ =
231 : false;
232 : #endif
233 :
234 : private:
235 : /************************************************************************/
236 : /* gatherRGB_epi16() */
237 : /************************************************************************/
238 :
239 3766864 : static inline void gatherRGB_epi16(const GByte *indices,
240 : const ColorTableBased4x4Pixels &ctxt,
241 : __m128i &r, __m128i &g, __m128i &b)
242 : {
243 3766864 : const uint8_t i0 = indices[0];
244 3766864 : const uint8_t i1 = indices[1];
245 3766864 : const uint8_t i2 = indices[2];
246 3766864 : const uint8_t i3 = indices[3];
247 3766864 : const uint8_t i4 = indices[4];
248 3766864 : const uint8_t i5 = indices[5];
249 3766864 : const uint8_t i6 = indices[6];
250 3766864 : const uint8_t i7 = indices[7];
251 :
252 3766864 : r = _mm_setr_epi16(ctxt.m_R[i0], ctxt.m_R[i1], ctxt.m_R[i2],
253 3766864 : ctxt.m_R[i3], ctxt.m_R[i4], ctxt.m_R[i5],
254 3766864 : ctxt.m_R[i6], ctxt.m_R[i7]);
255 :
256 3766864 : g = _mm_setr_epi16(ctxt.m_G[i0], ctxt.m_G[i1], ctxt.m_G[i2],
257 3766864 : ctxt.m_G[i3], ctxt.m_G[i4], ctxt.m_G[i5],
258 3766864 : ctxt.m_G[i6], ctxt.m_G[i7]);
259 :
260 3766864 : b = _mm_setr_epi16(ctxt.m_B[i0], ctxt.m_B[i1], ctxt.m_B[i2],
261 3766864 : ctxt.m_B[i3], ctxt.m_B[i4], ctxt.m_B[i5],
262 3766864 : ctxt.m_B[i6], ctxt.m_B[i7]);
263 3766864 : }
264 :
265 : /************************************************************************/
266 : /* updateSums() */
267 : /************************************************************************/
268 :
269 988328 : static inline void updateSums(const Vector *other, int i,
270 : const ColorTableBased4x4Pixels &ctxt,
271 : __m128i rA, __m128i gA, __m128i bA,
272 : __m128i &acc)
273 : {
274 : __m128i rB, gB, bB;
275 988328 : gatherRGB_epi16(other->m_vals.data() + i, ctxt, rB, gB, bB);
276 :
277 : // Compute signed differences
278 988328 : const auto diffR = _mm_sub_epi16(rA, rB);
279 988328 : const auto diffG = _mm_sub_epi16(gA, gB);
280 1976652 : const auto diffB = _mm_sub_epi16(bA, bB);
281 :
282 : // Square differences
283 988328 : const auto sqR = _mm_mullo_epi16(diffR, diffR);
284 988328 : const auto sqG = _mm_mullo_epi16(diffG, diffG);
285 988328 : const auto sqB = _mm_mullo_epi16(diffB, diffB);
286 :
287 : // Extend to 32 bit before summing R,G,B
288 1976652 : const auto sqR_lo = _mm_unpacklo_epi16(sqR, _mm_setzero_si128());
289 1976652 : const auto sqR_hi = _mm_unpackhi_epi16(sqR, _mm_setzero_si128());
290 1976652 : const auto sqG_lo = _mm_unpacklo_epi16(sqG, _mm_setzero_si128());
291 1976652 : const auto sqG_hi = _mm_unpackhi_epi16(sqG, _mm_setzero_si128());
292 1976652 : const auto sqB_lo = _mm_unpacklo_epi16(sqB, _mm_setzero_si128());
293 988328 : const auto sqB_hi = _mm_unpackhi_epi16(sqB, _mm_setzero_si128());
294 :
295 : // Sum RGB
296 988328 : acc = _mm_add_epi32(acc, sqR_lo);
297 988328 : acc = _mm_add_epi32(acc, sqR_hi);
298 988328 : acc = _mm_add_epi32(acc, sqG_lo);
299 988328 : acc = _mm_add_epi32(acc, sqG_hi);
300 988328 : acc = _mm_add_epi32(acc, sqB_lo);
301 988328 : acc = _mm_add_epi32(acc, sqB_hi);
302 988328 : }
303 :
304 : public:
305 : /************************************************************************/
306 : /* compute_four_squared_distances() */
307 : /************************************************************************/
308 :
309 123541 : void compute_four_squared_distances(
310 : const std::array<const Vector *const, 4> &others,
311 : std::array<int, 4> & /* out */ tabSquaredDist,
312 : const ColorTableBased4x4Pixels &ctxt) const
313 : {
314 123541 : __m128i acc_0 = _mm_setzero_si128();
315 123541 : __m128i acc_1 = _mm_setzero_si128();
316 123541 : __m128i acc_2 = _mm_setzero_si128();
317 123541 : __m128i acc_3 = _mm_setzero_si128();
318 :
319 370623 : for (int i = 0; i < 16; i += 8)
320 : {
321 : __m128i rA, gA, bA;
322 247082 : gatherRGB_epi16(m_vals.data() + i, ctxt, rA, gA, bA);
323 :
324 247082 : updateSums(others[0], i, ctxt, rA, gA, bA, acc_0);
325 247082 : updateSums(others[1], i, ctxt, rA, gA, bA, acc_1);
326 247082 : updateSums(others[2], i, ctxt, rA, gA, bA, acc_2);
327 247082 : updateSums(others[3], i, ctxt, rA, gA, bA, acc_3);
328 : }
329 :
330 494164 : const auto horizontalSum = [](__m128i acc)
331 : {
332 : // Horizontal reduction 4 => 1
333 494164 : auto tmp = _mm_shuffle_epi32(acc, _MM_SHUFFLE(1, 0, 3, 2));
334 494164 : acc = _mm_add_epi32(acc, tmp);
335 494164 : tmp = _mm_shuffle_epi32(acc, _MM_SHUFFLE(2, 3, 0, 1));
336 494164 : acc = _mm_add_epi32(acc, tmp);
337 494164 : return _mm_cvtsi128_si32(acc);
338 : };
339 :
340 123541 : tabSquaredDist[0] = horizontalSum(acc_0);
341 123541 : tabSquaredDist[1] = horizontalSum(acc_1);
342 123541 : tabSquaredDist[2] = horizontalSum(acc_2);
343 123541 : tabSquaredDist[3] = horizontalSum(acc_3);
344 123541 : }
345 :
346 : #else
347 : static constexpr bool hasComputeFourSquaredDistances /* specialize */ =
348 : false;
349 : #endif
350 :
351 : /************************************************************************/
352 : /* squared_distance() */
353 : /************************************************************************/
354 :
355 632864 : int squared_distance(
356 : const Vector &other,
357 : const ColorTableBased4x4Pixels &ctxt) const /* specialize */
358 : {
359 : #if defined(KDTREE_USE_SSE2) && !defined(__AVX2__)
360 632864 : __m128i acc0 = _mm_setzero_si128();
361 632864 : __m128i acc1 = _mm_setzero_si128();
362 :
363 1898593 : for (int i = 0; i < 2; ++i)
364 : {
365 : __m128i rA, gA, bA;
366 1265732 : gatherRGB_epi16(m_vals.data() + i * 8, ctxt, rA, gA, bA);
367 :
368 : __m128i rB, gB, bB;
369 1265732 : gatherRGB_epi16(other.m_vals.data() + i * 8, ctxt, rB, gB, bB);
370 :
371 : // Compute signed differences
372 1265732 : const auto diffR = _mm_sub_epi16(rA, rB);
373 1265732 : const auto diffG = _mm_sub_epi16(gA, gB);
374 2531454 : const auto diffB = _mm_sub_epi16(bA, bB);
375 :
376 : // Square differences
377 1265732 : const auto sqR = _mm_mullo_epi16(diffR, diffR);
378 1265732 : const auto sqG = _mm_mullo_epi16(diffG, diffG);
379 1265732 : const auto sqB = _mm_mullo_epi16(diffB, diffB);
380 :
381 : // Extend to 32 bit before summing R,G,B
382 2531454 : const auto sqR_lo = _mm_unpacklo_epi16(sqR, _mm_setzero_si128());
383 2531454 : const auto sqR_hi = _mm_unpackhi_epi16(sqR, _mm_setzero_si128());
384 2531454 : const auto sqG_lo = _mm_unpacklo_epi16(sqG, _mm_setzero_si128());
385 2531454 : const auto sqG_hi = _mm_unpackhi_epi16(sqG, _mm_setzero_si128());
386 2531454 : const auto sqB_lo = _mm_unpacklo_epi16(sqB, _mm_setzero_si128());
387 2531454 : const auto sqB_hi = _mm_unpackhi_epi16(sqB, _mm_setzero_si128());
388 :
389 : // Sum RGB
390 1265732 : acc0 = _mm_add_epi32(acc0, sqR_lo);
391 1265732 : acc1 = _mm_add_epi32(acc1, sqR_hi);
392 1265732 : acc0 = _mm_add_epi32(acc0, sqG_lo);
393 1265732 : acc1 = _mm_add_epi32(acc1, sqG_hi);
394 1265732 : acc0 = _mm_add_epi32(acc0, sqB_lo);
395 1265732 : acc1 = _mm_add_epi32(acc1, sqB_hi);
396 : }
397 :
398 : // Horizontal reduction 4 => 1
399 632864 : auto acc = _mm_add_epi32(acc0, acc1);
400 632864 : auto tmp = _mm_shuffle_epi32(acc, _MM_SHUFFLE(1, 0, 3, 2));
401 632864 : acc = _mm_add_epi32(acc, tmp);
402 632864 : tmp = _mm_shuffle_epi32(acc, _MM_SHUFFLE(2, 3, 0, 1));
403 632864 : acc = _mm_add_epi32(acc, tmp);
404 632864 : return _mm_cvtsi128_si32(acc);
405 :
406 : #elif defined(KDTREE_USE_SSE2) && defined(__AVX2__)
407 : const auto idxA =
408 : _mm_loadu_si128(reinterpret_cast<const __m128i *>(m_vals.data()));
409 : const auto idxB = _mm_loadu_si128(
410 : reinterpret_cast<const __m128i *>(other.m_vals.data()));
411 :
412 : // Convert from 16 uint8_t values into to 2 vectors of 8 int32_t
413 : const auto idxA_lo = _mm256_cvtepu8_epi32(idxA);
414 : const auto idxB_lo = _mm256_cvtepu8_epi32(idxB);
415 : const auto idxA_hi = _mm256_cvtepu8_epi32(_mm_srli_si128(idxA, 8));
416 : const auto idxB_hi = _mm256_cvtepu8_epi32(_mm_srli_si128(idxB, 8));
417 :
418 : // Gather R, G, B for A and B (8 at a time)
419 : const auto gather_epi32 = [](const std::vector<int> &v, __m256i idx)
420 : {
421 : constexpr int SCALE = static_cast<int>(sizeof(int32_t));
422 : return _mm256_i32gather_epi32(v.data(), idx, SCALE);
423 : };
424 : const auto rA_lo = gather_epi32(ctxt.m_R32, idxA_lo);
425 : const auto rB_lo = gather_epi32(ctxt.m_R32, idxB_lo);
426 : const auto gA_lo = gather_epi32(ctxt.m_G32, idxA_lo);
427 : const auto gB_lo = gather_epi32(ctxt.m_G32, idxB_lo);
428 : const auto bA_lo = gather_epi32(ctxt.m_B32, idxA_lo);
429 : const auto bB_lo = gather_epi32(ctxt.m_B32, idxB_lo);
430 :
431 : const auto rA_hi = gather_epi32(ctxt.m_R32, idxA_hi);
432 : const auto rB_hi = gather_epi32(ctxt.m_R32, idxB_hi);
433 : const auto gA_hi = gather_epi32(ctxt.m_G32, idxA_hi);
434 : const auto gB_hi = gather_epi32(ctxt.m_G32, idxB_hi);
435 : const auto bA_hi = gather_epi32(ctxt.m_B32, idxA_hi);
436 : const auto bB_hi = gather_epi32(ctxt.m_B32, idxB_hi);
437 :
438 : // Compute square of differences
439 : const auto square_epi32 = [](__m256i x)
440 : { return _mm256_mullo_epi32(x, x); };
441 :
442 : const auto dr_lo = square_epi32(_mm256_sub_epi32(rA_lo, rB_lo));
443 : const auto dg_lo = square_epi32(_mm256_sub_epi32(gA_lo, gB_lo));
444 : const auto db_lo = square_epi32(_mm256_sub_epi32(bA_lo, bB_lo));
445 :
446 : const auto dr_hi = square_epi32(_mm256_sub_epi32(rA_hi, rB_hi));
447 : const auto dg_hi = square_epi32(_mm256_sub_epi32(gA_hi, gB_hi));
448 : const auto db_hi = square_epi32(_mm256_sub_epi32(bA_hi, bB_hi));
449 :
450 : // Sum RGB
451 : const auto sum_lo =
452 : _mm256_add_epi32(_mm256_add_epi32(dr_lo, dg_lo), db_lo);
453 : const auto sum_hi =
454 : _mm256_add_epi32(_mm256_add_epi32(dr_hi, dg_hi), db_hi);
455 :
456 : // Horizontal reduction 16 => 8
457 : const auto sum8 = _mm256_add_epi32(sum_lo, sum_hi);
458 :
459 : // Horizontal reduction 8 => 4
460 : const auto sum8_lo = _mm256_castsi256_si128(sum8);
461 : const auto sum8_hi = _mm256_extracti128_si256(sum8, 1);
462 : auto sum = _mm_add_epi32(sum8_lo, sum8_hi);
463 :
464 : // Horizontal reduction 4 => 1
465 : auto tmp = _mm_shuffle_epi32(sum, _MM_SHUFFLE(1, 0, 3, 2));
466 : sum = _mm_add_epi32(sum, tmp);
467 : tmp = _mm_shuffle_epi32(sum, _MM_SHUFFLE(2, 3, 0, 1));
468 : sum = _mm_add_epi32(sum, tmp);
469 :
470 : return _mm_cvtsi128_si32(sum);
471 :
472 : #else
473 : int nSqDist1 = 0;
474 : int nSqDist2 = 0;
475 : int nSqDist3 = 0;
476 : for (int i = 0; i < PIX_COUNT; ++i)
477 : {
478 : const int aEntry = m_vals[i];
479 : const int bEntry = other.m_vals[i];
480 : nSqDist1 += square(ctxt.m_R[aEntry] - ctxt.m_R[bEntry]);
481 : nSqDist2 += square(ctxt.m_G[aEntry] - ctxt.m_G[bEntry]);
482 : nSqDist3 += square(ctxt.m_B[aEntry] - ctxt.m_B[bEntry]);
483 : }
484 : return nSqDist1 + nSqDist2 + nSqDist3;
485 : #endif
486 : }
487 :
488 : /************************************************************************/
489 : /* centroid() */
490 : /************************************************************************/
491 :
492 : static Vector
493 106111 : centroid(const Vector &a, int nA, const Vector &b, int nB,
494 : const ColorTableBased4x4Pixels &ctxt) /* specialize */
495 : {
496 106111 : auto minSqDist = filled_array<int, PIX_COUNT>(256 * 256 * COMP_COUNT);
497 : Vector res;
498 106111 : libdivide::divider<uint32_t> divisor(static_cast<uint32_t>(nA + nB));
499 1803891 : for (int k = 0; k < PIX_COUNT; ++k)
500 : {
501 1697778 : const int aEntry = a.m_vals[k];
502 1697778 : const int bEntry = b.m_vals[k];
503 : const int meanR =
504 1697778 : static_cast<uint32_t>(ctxt.m_R[aEntry] * nA +
505 1697778 : ctxt.m_R[bEntry] * nB + (nA + nB) / 2) /
506 1697778 : divisor;
507 : const int meanG =
508 1697778 : static_cast<uint32_t>(ctxt.m_G[aEntry] * nA +
509 1697778 : ctxt.m_G[bEntry] * nB + (nA + nB) / 2) /
510 1697778 : divisor;
511 : const int meanB =
512 1697778 : static_cast<uint32_t>(ctxt.m_B[aEntry] * nA +
513 1697778 : ctxt.m_B[bEntry] * nB + (nA + nB) / 2) /
514 1697778 : divisor;
515 :
516 1697778 : assert(meanR <= 255);
517 1697778 : assert(meanG <= 255);
518 1697778 : assert(meanB <= 255);
519 :
520 : #ifdef PRECISE_DISTANCE_COMPUTATION
521 : constexpr int BIT_SHIFT = 0;
522 : #else
523 : // Minimum value to avoid int16 overflow when adding 3 squares of
524 : // uint8, because 3 * ((255 * 255) >> 3) = 24384 < INT16_MAX
525 1697778 : constexpr int BIT_SHIFT = 3;
526 : #endif
527 :
528 1697778 : int i = 0;
529 : #if defined(KDTREE_USE_SSE2)
530 1697778 : const auto targetR = _mm_set1_epi16(static_cast<short>(meanR));
531 1697778 : const auto targetG = _mm_set1_epi16(static_cast<short>(meanG));
532 1697778 : const auto targetB = _mm_set1_epi16(static_cast<short>(meanB));
533 :
534 : // Initialize min distance vector with max int32 values
535 : #ifdef PRECISE_DISTANCE_COMPUTATION
536 : auto minDistVec0 = _mm_set1_epi32(std::numeric_limits<int>::max());
537 : auto minDistVec1 = _mm_set1_epi32(std::numeric_limits<int>::max());
538 : auto minDistVec2 = _mm_set1_epi32(std::numeric_limits<int>::max());
539 : auto minDistVec3 = _mm_set1_epi32(std::numeric_limits<int>::max());
540 : #else
541 : auto minDistVec0 =
542 1697778 : _mm_set1_epi16(std::numeric_limits<short>::max());
543 : auto minDistVec1 =
544 3395556 : _mm_set1_epi16(std::numeric_limits<short>::max());
545 : #endif
546 :
547 : // Initialize index vectors for tracking best index per lane
548 1697778 : auto idx = _mm_setr_epi8(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12,
549 : 13, 14, 15);
550 1697778 : const auto zero = _mm_setzero_si128();
551 1697778 : auto idxMin = zero;
552 :
553 132411248 : const auto square_lo_epi16 = [](__m128i x)
554 132411248 : { return _mm_mullo_epi16(x, x); };
555 :
556 1697778 : constexpr int VALS_AT_ONCE = 16;
557 23766416 : for (; i + VALS_AT_ONCE <= static_cast<int>(ctxt.m_R.size());
558 22068608 : i += VALS_AT_ONCE)
559 : {
560 : // Load 16 color components
561 22068608 : const auto valsR = _mm_loadu_si128(
562 22068608 : reinterpret_cast<const __m128i *>(ctxt.m_R.data() + i));
563 22068608 : const auto valsG = _mm_loadu_si128(
564 22068608 : reinterpret_cast<const __m128i *>(ctxt.m_G.data() + i));
565 22068608 : const auto valsB = _mm_loadu_si128(
566 22068608 : reinterpret_cast<const __m128i *>(ctxt.m_B.data() + i));
567 :
568 22068608 : const auto valsR_lo = _mm_unpacklo_epi8(valsR, zero);
569 22068608 : const auto valsR_hi = _mm_unpackhi_epi8(valsR, zero);
570 22068608 : const auto valsG_lo = _mm_unpacklo_epi8(valsG, zero);
571 22068608 : const auto valsG_hi = _mm_unpackhi_epi8(valsG, zero);
572 22068608 : const auto valsB_lo = _mm_unpacklo_epi8(valsB, zero);
573 22068608 : const auto valsB_hi = _mm_unpackhi_epi8(valsB, zero);
574 :
575 : // Compute signed differences
576 22068608 : const auto diffR_lo = _mm_sub_epi16(valsR_lo, targetR);
577 22068608 : const auto diffR_hi = _mm_sub_epi16(valsR_hi, targetR);
578 22068608 : const auto diffG_lo = _mm_sub_epi16(valsG_lo, targetG);
579 22068608 : const auto diffG_hi = _mm_sub_epi16(valsG_hi, targetG);
580 22068608 : const auto diffB_lo = _mm_sub_epi16(valsB_lo, targetB);
581 22068608 : const auto diffB_hi = _mm_sub_epi16(valsB_hi, targetB);
582 :
583 : // Square differences
584 22068608 : const auto sqR_lo = square_lo_epi16(diffR_lo);
585 22068608 : const auto sqR_hi = square_lo_epi16(diffR_hi);
586 22068608 : const auto sqG_lo = square_lo_epi16(diffG_lo);
587 22068608 : const auto sqG_hi = square_lo_epi16(diffG_hi);
588 22068608 : const auto sqB_lo = square_lo_epi16(diffB_lo);
589 22068608 : const auto sqB_hi = square_lo_epi16(diffB_hi);
590 :
591 : #ifdef PRECISE_DISTANCE_COMPUTATION
592 : // Convert squares from 16-bit to 32-bit integers
593 : const auto sqR0 = _mm_unpacklo_epi16(sqR_lo, zero);
594 : const auto sqR1 = _mm_unpackhi_epi16(sqR_lo, zero);
595 : const auto sqR2 = _mm_unpacklo_epi16(sqR_hi, zero);
596 : const auto sqR3 = _mm_unpackhi_epi16(sqR_hi, zero);
597 :
598 : const auto sqG0 = _mm_unpacklo_epi16(sqG_lo, zero);
599 : const auto sqG1 = _mm_unpackhi_epi16(sqG_lo, zero);
600 : const auto sqG2 = _mm_unpacklo_epi16(sqG_hi, zero);
601 : const auto sqG3 = _mm_unpackhi_epi16(sqG_hi, zero);
602 :
603 : const auto sqB0 = _mm_unpacklo_epi16(sqB_lo, zero);
604 : const auto sqB1 = _mm_unpackhi_epi16(sqB_lo, zero);
605 : const auto sqB2 = _mm_unpacklo_epi16(sqB_hi, zero);
606 : const auto sqB3 = _mm_unpackhi_epi16(sqB_hi, zero);
607 :
608 : // Sum squared differences for each 32-bit lane: (R + G + B)
609 : const auto dist0 =
610 : _mm_add_epi32(_mm_add_epi32(sqR0, sqG0), sqB0);
611 : const auto dist1 =
612 : _mm_add_epi32(_mm_add_epi32(sqR1, sqG1), sqB1);
613 : const auto dist2 =
614 : _mm_add_epi32(_mm_add_epi32(sqR2, sqG2), sqB2);
615 : const auto dist3 =
616 : _mm_add_epi32(_mm_add_epi32(sqR3, sqG3), sqB3);
617 :
618 : // Compare with current minimum distances
619 : auto mask0 = _mm_cmplt_epi32(dist0, minDistVec0);
620 : auto mask1 = _mm_cmplt_epi32(dist1, minDistVec1);
621 : auto mask2 = _mm_cmplt_epi32(dist2, minDistVec2);
622 : auto mask3 = _mm_cmplt_epi32(dist3, minDistVec3);
623 :
624 : // Update minimum distances
625 : minDistVec0 = blendv_epi8(minDistVec0, dist0, mask0);
626 : minDistVec1 = blendv_epi8(minDistVec1, dist1, mask1);
627 : minDistVec2 = blendv_epi8(minDistVec2, dist2, mask2);
628 : minDistVec3 = blendv_epi8(minDistVec3, dist3, mask3);
629 :
630 : // Merge the 4 masks of 4 x uint32_t into
631 : // a single mask 16 x 1 uint8_t mask
632 : mask0 = _mm_srli_epi32(mask0, 24);
633 : mask1 = _mm_srli_epi32(mask1, 24);
634 : mask2 = _mm_srli_epi32(mask2, 24);
635 : mask3 = _mm_srli_epi32(mask3, 24);
636 : const auto mask_merged =
637 : _mm_packus_epi16(_mm_packs_epi32(mask0, mask1),
638 : _mm_packs_epi32(mask2, mask3));
639 :
640 : // Update indices
641 : idxMin = blendv_epi8(idxMin, idx, mask_merged);
642 :
643 : #else
644 : // Sum squared differences, by removing a few LSB bits to avoid
645 : // overflows.
646 110343040 : const auto dist0 = _mm_add_epi16(
647 : _mm_add_epi16(_mm_srli_epi16(sqR_lo, BIT_SHIFT),
648 : _mm_srli_epi16(sqG_lo, BIT_SHIFT)),
649 : _mm_srli_epi16(sqB_lo, BIT_SHIFT));
650 110343040 : const auto dist1 = _mm_add_epi16(
651 : _mm_add_epi16(_mm_srli_epi16(sqR_hi, BIT_SHIFT),
652 : _mm_srli_epi16(sqG_hi, BIT_SHIFT)),
653 : _mm_srli_epi16(sqB_hi, BIT_SHIFT));
654 :
655 : // Compare with current minimum distances
656 22068608 : auto mask0 = _mm_cmplt_epi16(dist0, minDistVec0);
657 22068608 : auto mask1 = _mm_cmplt_epi16(dist1, minDistVec1);
658 :
659 : // Update minimum distances
660 22068608 : minDistVec0 = blendv_epi8(minDistVec0, dist0, mask0);
661 22068608 : minDistVec1 = blendv_epi8(minDistVec1, dist1, mask1);
662 :
663 : // Merge the 2 masks of 8 x uint16_t into
664 : // a single mask 16 x 1 uint8_t mask
665 22068608 : mask0 = _mm_srli_epi16(mask0, 8);
666 22068608 : mask1 = _mm_srli_epi16(mask1, 8);
667 22068608 : const auto mask_merged = _mm_packus_epi16(mask0, mask1);
668 :
669 : // Update indices
670 22068608 : idxMin = blendv_epi8(idxMin, idx, mask_merged);
671 : #endif
672 :
673 44137216 : idx = _mm_add_epi8(idx, _mm_set1_epi8(VALS_AT_ONCE));
674 : }
675 :
676 : // Horizontal update
677 : #ifdef PRECISE_DISTANCE_COMPUTATION
678 : int minDistVals[VALS_AT_ONCE];
679 : _mm_storeu_si128(reinterpret_cast<__m128i *>(minDistVals + 0),
680 : minDistVec0);
681 : _mm_storeu_si128(reinterpret_cast<__m128i *>(minDistVals + 4),
682 : minDistVec1);
683 : _mm_storeu_si128(reinterpret_cast<__m128i *>(minDistVals + 8),
684 : minDistVec2);
685 : _mm_storeu_si128(reinterpret_cast<__m128i *>(minDistVals + 12),
686 : minDistVec3);
687 : #else
688 : short minDistVals[VALS_AT_ONCE];
689 : _mm_storeu_si128(reinterpret_cast<__m128i *>(minDistVals + 0),
690 : minDistVec0);
691 1697778 : _mm_storeu_si128(reinterpret_cast<__m128i *>(minDistVals + 8),
692 : minDistVec1);
693 : #endif
694 :
695 : GByte minIdxVals[VALS_AT_ONCE];
696 : _mm_storeu_si128(reinterpret_cast<__m128i *>(minIdxVals), idxMin);
697 :
698 28862236 : for (int j = 0; j < VALS_AT_ONCE; ++j)
699 : {
700 48786996 : if (minDistVals[j] < minSqDist[k] ||
701 21622568 : (minDistVals[j] == minSqDist[k] &&
702 740012 : minIdxVals[j] < res.m_vals[k]))
703 : {
704 5872790 : minSqDist[k] = minDistVals[j];
705 5872790 : res.m_vals[k] = minIdxVals[j];
706 : }
707 : }
708 : #endif
709 :
710 : // Generic/scalar code
711 15278308 : for (; i < static_cast<int>(ctxt.m_R.size()); ++i)
712 : {
713 13580500 : const int sqDist = (square(meanR - ctxt.m_R[i]) >> BIT_SHIFT) +
714 13580500 : (square(meanG - ctxt.m_G[i]) >> BIT_SHIFT) +
715 13580500 : (square(meanB - ctxt.m_B[i]) >> BIT_SHIFT);
716 13580500 : if (sqDist < minSqDist[k])
717 : {
718 60127 : minSqDist[k] = sqDist;
719 60127 : res.m_vals[k] = static_cast<GByte>(i);
720 : }
721 : }
722 : }
723 106111 : return res;
724 : }
725 :
726 : /************************************************************************/
727 : /* operator == () */
728 : /************************************************************************/
729 :
730 1053204 : inline bool operator==(const Vector &other) const
731 : {
732 1053204 : return m_vals == other.m_vals;
733 : }
734 :
735 : /************************************************************************/
736 : /* operator < () */
737 : /************************************************************************/
738 :
739 : // Purely arbitrary for the purpose of distinguishing a vector from
740 : // another one
741 56072279 : inline bool operator<(const Vector &other) const
742 : {
743 56072279 : return m_vals < other.m_vals;
744 : }
745 : };
746 :
747 : #endif // KDTREE_VQ_CADRG_INCLUDED
|