Line data Source code
1 : /******************************************************************************
2 : *
3 : * Project: GDAL
4 : * Purpose: gdal "mdim compare" subcommand
5 : * Author: Even Rouault <even dot rouault at spatialys.com>
6 : *
7 : ******************************************************************************
8 : * Copyright (c) 2026, Even Rouault <even dot rouault at spatialys.com>
9 : *
10 : * SPDX-License-Identifier: MIT
11 : ****************************************************************************/
12 :
13 : #include "gdalalg_mdim_compare.h"
14 :
15 : #include "gdal_dataset.h"
16 : #include "gdal_multidim.h"
17 :
18 : #include <algorithm>
19 : #include <cinttypes>
20 : #include <cmath>
21 : #include <iterator>
22 : #include <set>
23 : #include <utility>
24 :
25 : //! @cond Doxygen_Suppress
26 :
27 : #ifndef _
28 : #define _(x) (x)
29 : #endif
30 :
31 : /************************************************************************/
32 : /* GDALMdimCompareAlgorithm::GDALMdimCompareAlgorithm() */
33 : /************************************************************************/
34 :
35 53 : GDALMdimCompareAlgorithm::GDALMdimCompareAlgorithm(bool standaloneStep)
36 : : GDALMdimPipelineStepAlgorithm(NAME, DESCRIPTION, HELP_URL,
37 0 : ConstructorOptions()
38 53 : .SetStandaloneStep(standaloneStep)
39 53 : .SetInputDatasetMaxCount(1)
40 106 : .SetAddDefaultArguments(false))
41 : {
42 53 : if (standaloneStep)
43 : {
44 51 : AddProgressArg();
45 : }
46 : else
47 : {
48 2 : AddMdimHiddenInputDatasetArg();
49 : }
50 :
51 : auto &referenceDatasetArg =
52 : AddArg("reference", 0, _("Reference dataset"), &m_referenceDataset,
53 106 : GDAL_OF_MULTIDIM_RASTER)
54 53 : .SetPositional()
55 53 : .SetRequired();
56 :
57 53 : SetAutoCompleteFunctionForFilename(referenceDatasetArg,
58 : GDAL_OF_MULTIDIM_RASTER);
59 :
60 53 : if (standaloneStep)
61 : {
62 51 : AddMdimInputArgs(/* openForMixedMdimVector = */ false,
63 : /* hiddenForCLI = */ false,
64 : /* acceptRaster = */ false);
65 : }
66 :
67 106 : AddArg("metric", 0, _("Comparison metric(s)"), &m_metrics)
68 : .SetChoices(METRIC_ALL, METRIC_NONE, METRIC_DIFF, METRIC_RMSD,
69 53 : METRIC_PSNR)
70 53 : .SetDefault(METRIC_DEFAULT);
71 :
72 53 : AddArrayNameArg(&m_array, _("Name of array(s) to compare"));
73 :
74 53 : AddOutputStringArg(&m_output);
75 :
76 53 : AddArg("skip-binary", 0, _("Skip binary file comparison"), &m_skipBinary);
77 :
78 106 : AddArg("return-code", 0, _("Return code"), &m_retCode)
79 53 : .SetHiddenForCLI()
80 53 : .SetIsInput(false)
81 53 : .SetIsOutput(true);
82 53 : }
83 :
84 : /************************************************************************/
85 : /* GetPixelCount() */
86 : /************************************************************************/
87 :
88 26 : static uint64_t GetPixelCount(const std::shared_ptr<GDALMDArray> &array)
89 : {
90 26 : uint64_t nPixels = 1;
91 62 : for (const auto &poDim : array->GetDimensions())
92 : {
93 36 : nPixels *= poDim->GetSize();
94 : }
95 26 : return nPixels;
96 : }
97 :
98 : /************************************************************************/
99 : /* GDALMdimCompareAlgorithm::RunStep() */
100 : /************************************************************************/
101 :
102 10 : bool GDALMdimCompareAlgorithm::RunStep(GDALPipelineStepRunContext &ctxt)
103 : {
104 10 : auto poRefDS = m_referenceDataset.GetDatasetRef();
105 10 : CPLAssert(poRefDS);
106 :
107 10 : CPLAssert(m_inputDataset.size() == 1);
108 10 : auto poInputDS = m_inputDataset[0].GetDatasetRef();
109 10 : CPLAssert(poInputDS);
110 :
111 20 : std::vector<std::string> aosReport;
112 :
113 10 : if (!m_skipBinary)
114 : {
115 1 : if (BinaryComparison(this, aosReport, poRefDS, poInputDS))
116 : {
117 1 : return true;
118 : }
119 : }
120 :
121 18 : auto poRefRootGroup = poRefDS->GetRootGroup();
122 9 : if (!poRefRootGroup)
123 : {
124 0 : ReportError(CE_Failure, CPLE_AppDefined,
125 : "Cannot get root group on reference dataset");
126 0 : return false;
127 : }
128 :
129 18 : auto poInputRootGroup = poInputDS->GetRootGroup();
130 9 : if (!poInputRootGroup)
131 : {
132 0 : ReportError(CE_Failure, CPLE_AppDefined,
133 : "Cannot get root group on input dataset");
134 0 : return false;
135 : }
136 :
137 18 : auto refArrays = poRefRootGroup->GetMDArrayFullNamesRecursive();
138 18 : auto inputArrays = poInputRootGroup->GetMDArrayFullNamesRecursive();
139 9 : if (m_array.empty())
140 : {
141 6 : std::sort(refArrays.begin(), refArrays.end());
142 6 : std::sort(inputArrays.begin(), inputArrays.end());
143 :
144 : {
145 12 : std::vector<std::string> missing;
146 : std::set_difference(refArrays.begin(), refArrays.end(),
147 : inputArrays.begin(), inputArrays.end(),
148 6 : std::back_inserter(missing));
149 6 : if (!missing.empty())
150 : {
151 : std::string line =
152 : "The following arrays are found in the reference dataset, "
153 2 : "but missing in the input one: ";
154 1 : bool first = true;
155 4 : for (const auto &name : missing)
156 : {
157 3 : if (!first)
158 2 : line += ", ";
159 3 : first = false;
160 3 : line += name;
161 : }
162 1 : aosReport.push_back(std::move(line));
163 : }
164 : }
165 :
166 : {
167 12 : std::vector<std::string> missing;
168 : std::set_difference(inputArrays.begin(), inputArrays.end(),
169 : refArrays.begin(), refArrays.end(),
170 6 : std::back_inserter(missing));
171 6 : if (!missing.empty())
172 : {
173 : std::string line =
174 : "The following arrays are found in the input dataset, but "
175 2 : "missing in the reference one: ";
176 1 : bool first = true;
177 4 : for (const auto &name : missing)
178 : {
179 3 : if (!first)
180 2 : line += ", ";
181 3 : first = false;
182 3 : line += name;
183 : }
184 1 : aosReport.push_back(std::move(line));
185 : }
186 : }
187 :
188 : std::set_intersection(refArrays.begin(), refArrays.end(),
189 : inputArrays.begin(), inputArrays.end(),
190 6 : std::back_inserter(m_array));
191 : }
192 : else
193 : {
194 3 : std::set<std::string> newArrays;
195 4 : for (const auto &name : m_array)
196 : {
197 : const auto ExistsIn =
198 5 : [&name,
199 : &newArrays](const std::vector<std::string> &aosExitingArrays,
200 49 : bool insert)
201 : {
202 5 : bool ret = false;
203 17 : for (const auto &exitingArray : aosExitingArrays)
204 : {
205 24 : if (!name.empty() &&
206 12 : ((name[0] == '/' && name == exitingArray) ||
207 17 : (name[0] != '/' &&
208 6 : name == CPLGetFilename(exitingArray.c_str()))))
209 : {
210 3 : if (insert)
211 2 : newArrays.insert(exitingArray);
212 3 : ret = true;
213 : }
214 : }
215 5 : return ret;
216 3 : };
217 :
218 3 : if (!ExistsIn(refArrays, true))
219 : {
220 1 : ReportError(CE_Failure, CPLE_AppDefined,
221 : "Array '%s' does not exist in reference dataset",
222 : name.c_str());
223 2 : return false;
224 : }
225 2 : if (!ExistsIn(inputArrays, false))
226 : {
227 1 : ReportError(CE_Failure, CPLE_AppDefined,
228 : "Array '%s' does not exist in input dataset",
229 : name.c_str());
230 1 : return false;
231 : }
232 : }
233 1 : m_array.clear();
234 1 : m_array.insert(m_array.end(), newArrays.begin(), newArrays.end());
235 : }
236 :
237 : std::vector<
238 : std::pair<std::shared_ptr<GDALMDArray>, std::shared_ptr<GDALMDArray>>>
239 14 : arrayPairs;
240 7 : uint64_t nTotalPixels = 0;
241 20 : for (const auto &array : m_array)
242 : {
243 13 : auto poRefArray = poRefRootGroup->OpenMDArrayFromFullname(array);
244 13 : if (!poRefArray)
245 : {
246 0 : ReportError(CE_Failure, CPLE_AppDefined,
247 : "Array '%s' cannot be opened in reference dataset",
248 : array.c_str());
249 0 : return false;
250 : }
251 :
252 13 : auto poInputArray = poInputRootGroup->OpenMDArrayFromFullname(array);
253 13 : if (!poInputArray)
254 : {
255 0 : ReportError(CE_Failure, CPLE_AppDefined,
256 : "Array '%s' cannot be opened in input dataset",
257 : array.c_str());
258 0 : return false;
259 : }
260 :
261 13 : nTotalPixels += GetPixelCount(poRefArray);
262 13 : arrayPairs.emplace_back(std::move(poRefArray), std::move(poInputArray));
263 : }
264 :
265 7 : uint64_t nCurPixels = 0;
266 20 : for (const auto &[poRefArray, poInputArray] : arrayPairs)
267 : {
268 13 : const uint64_t nThisArrayPixels = GetPixelCount(poRefArray);
269 13 : const double dfMinPct =
270 13 : static_cast<double>(nCurPixels) / static_cast<double>(nTotalPixels);
271 13 : const double dfMaxPct =
272 13 : static_cast<double>(nCurPixels + nThisArrayPixels) /
273 13 : static_cast<double>(nTotalPixels);
274 : std::unique_ptr<void, decltype(&GDALDestroyScaledProgress)>
275 : pScaledProgress(GDALCreateScaledProgress(dfMinPct, dfMaxPct,
276 : ctxt.m_pfnProgress,
277 : ctxt.m_pProgressData),
278 13 : GDALDestroyScaledProgress);
279 26 : CompareArray(aosReport, poRefArray, poInputArray,
280 13 : pScaledProgress ? GDALScaledProgress : nullptr,
281 : pScaledProgress.get());
282 13 : nCurPixels += nThisArrayPixels;
283 : }
284 :
285 7 : if (ctxt.m_pfnProgress)
286 1 : ctxt.m_pfnProgress(1.0, "", ctxt.m_pProgressData);
287 :
288 19 : for (const auto &s : aosReport)
289 : {
290 12 : m_output += s;
291 12 : m_output += '\n';
292 : }
293 :
294 7 : m_retCode = static_cast<int>(aosReport.size());
295 :
296 7 : return true;
297 : }
298 :
299 : /************************************************************************/
300 : /* NonZeroValueIterator */
301 : /************************************************************************/
302 :
303 : namespace
304 : {
305 : struct NonZeroValueIterator
306 : {
307 : std::vector<double> adfValues{};
308 : uint64_t nNonZero = 0;
309 :
310 3 : static bool func(GDALAbstractMDArray *array,
311 : const GUInt64 *chunkArrayStartIdx,
312 : const size_t *chunkCount, GUInt64 /* iCurChunk */,
313 : GUInt64 /* nChunkCount */, void *pUserData)
314 : {
315 3 : auto self = static_cast<NonZeroValueIterator *>(pUserData);
316 3 : const size_t nDims = array->GetDimensionCount();
317 3 : size_t nElts = 1;
318 9 : for (size_t i = 0; i < nDims; ++i)
319 6 : nElts *= chunkCount[i];
320 3 : if (self->adfValues.size() < nElts)
321 : {
322 : try
323 : {
324 3 : self->adfValues.resize(nElts);
325 : }
326 0 : catch (const std::exception &)
327 : {
328 0 : CPLError(CE_Failure, CPLE_OutOfMemory, "Out of memory");
329 0 : return false;
330 : }
331 : }
332 3 : if (!array->Read(chunkArrayStartIdx, chunkCount,
333 : /* step = */ nullptr, /* stride = */ nullptr,
334 6 : GDALExtendedDataType::Create(GDT_Float64),
335 3 : self->adfValues.data()))
336 : {
337 0 : return false;
338 : }
339 1203 : for (size_t i = 0; i < nElts; ++i)
340 : {
341 1200 : if (self->adfValues[i] != 0)
342 1197 : self->nNonZero++;
343 : }
344 :
345 3 : return true;
346 : }
347 : };
348 : } // namespace
349 :
350 : /************************************************************************/
351 : /* GetDataTypeName() */
352 : /************************************************************************/
353 :
354 2 : static std::string GetDataTypeName(const GDALExtendedDataType &dt)
355 : {
356 2 : switch (dt.GetClass())
357 : {
358 2 : case GEDTC_NUMERIC:
359 2 : return GDALGetDataTypeName(dt.GetNumericDataType());
360 0 : case GEDTC_STRING:
361 0 : return "String";
362 0 : case GEDTC_COMPOUND:
363 0 : break;
364 : }
365 0 : return "Compound";
366 : }
367 :
368 : /************************************************************************/
369 : /* GDALMdimCompareAlgorithm::CompareArray() */
370 : /************************************************************************/
371 :
372 13 : void GDALMdimCompareAlgorithm::CompareArray(
373 : std::vector<std::string> &aosReport,
374 : const std::shared_ptr<GDALMDArray> &poRefArray,
375 : const std::shared_ptr<GDALMDArray> &poInputArray,
376 : GDALProgressFunc pfnProgress, void *pProgressData)
377 : {
378 13 : const auto &osName = poRefArray->GetFullName();
379 13 : const auto nDims = poRefArray->GetDimensionCount();
380 13 : if (nDims != poInputArray->GetDimensionCount())
381 : {
382 0 : aosReport.push_back(
383 : CPLSPrintf("Array %s: dimension count in reference is %d, whereas "
384 : "it is %d in input",
385 : osName.c_str(), static_cast<int>(nDims),
386 0 : static_cast<int>(poInputArray->GetDimensionCount())));
387 0 : return;
388 : }
389 :
390 13 : if (!poRefArray->HasSameShapeAs(*poInputArray))
391 : {
392 0 : const auto ShapeToString = [](const std::shared_ptr<GDALMDArray> &array)
393 : {
394 0 : std::string s("[");
395 0 : for (const auto &poDim : array->GetDimensions())
396 : {
397 0 : if (s.size() > 1)
398 0 : s += ',';
399 0 : s += std::to_string(poDim->GetSize());
400 : }
401 0 : s += ']';
402 0 : return s;
403 : };
404 :
405 0 : aosReport.push_back(CPLSPrintf(
406 : "Array %s: shape in reference is %s, whereas it is %s in input",
407 0 : osName.c_str(), ShapeToString(poRefArray).c_str(),
408 0 : ShapeToString(poInputArray).c_str()));
409 0 : return;
410 : }
411 :
412 13 : if (poRefArray->GetDataType() != poInputArray->GetDataType())
413 : {
414 4 : aosReport.push_back(CPLSPrintf(
415 : "Array %s: data type in reference is %s, whereas it is %s in input",
416 2 : osName.c_str(), GetDataTypeName(poRefArray->GetDataType()).c_str(),
417 2 : GetDataTypeName(poInputArray->GetDataType()).c_str()));
418 : }
419 26 : if (poRefArray->GetDataType().GetClass() != GEDTC_NUMERIC ||
420 13 : poInputArray->GetDataType().GetClass() != GEDTC_NUMERIC)
421 : {
422 0 : CPLError(CE_Warning, CPLE_AppDefined,
423 : "Array %s: not compared, as comparison of non-numeric data "
424 : "types is not currently supported",
425 : osName.c_str());
426 0 : return;
427 : }
428 :
429 13 : auto diffArray = (*poRefArray) - poInputArray;
430 13 : if (!diffArray)
431 : {
432 : // Given above checks, this shouldn't happen.
433 0 : aosReport.push_back(CPLSPrintf(
434 : "Array %s: cannot compute array of differences", osName.c_str()));
435 0 : return;
436 : }
437 :
438 13 : int nCountMetrics = 0;
439 13 : if (HasMetric(METRIC_DIFF))
440 7 : ++nCountMetrics;
441 13 : if (HasMetric(METRIC_RMSD) || HasMetric(METRIC_PSNR))
442 6 : ++nCountMetrics;
443 :
444 13 : double dfLastPct = 0;
445 13 : if (HasMetric(METRIC_DIFF))
446 : {
447 7 : double dfMin = 0, dfMax = 0;
448 :
449 7 : const double dfNewLastPct = 1.0 / nCountMetrics;
450 : std::unique_ptr<void, decltype(&GDALDestroyScaledProgress)>
451 : pScaledProgress(GDALCreateScaledProgress(
452 : 0.0, dfNewLastPct, pfnProgress, pProgressData),
453 7 : GDALDestroyScaledProgress);
454 7 : dfLastPct = dfNewLastPct;
455 :
456 14 : if (!diffArray->ComputeStatistics(
457 : /* bApproxOK =*/false, &dfMin, &dfMax, nullptr, nullptr,
458 7 : nullptr, pScaledProgress ? GDALScaledProgress : nullptr,
459 7 : pScaledProgress.get(), nullptr))
460 : {
461 0 : aosReport.push_back(CPLSPrintf(
462 : "Array %s: cannot compute statistics", osName.c_str()));
463 0 : return;
464 : }
465 7 : if (dfMin != 0 || dfMax != 0)
466 : {
467 : const double dfMaxDiff =
468 3 : std::max(std::fabs(dfMin), std::fabs(dfMax));
469 3 : aosReport.push_back(
470 : CPLSPrintf("Array %s: maximum pixel value difference: %g",
471 : osName.c_str(), dfMaxDiff));
472 :
473 3 : NonZeroValueIterator it;
474 3 : std::vector<GUInt64> arrayStartIdx(nDims, 0);
475 3 : std::vector<GUInt64> count(nDims, 0);
476 9 : for (size_t i = 0; i < nDims; ++i)
477 6 : count[i] = diffArray->GetDimensions()[i]->GetSize();
478 :
479 3 : size_t nMaxSize = 100 * 1024 * 1024;
480 3 : const GIntBig nUsableRAM = CPLGetUsablePhysicalRAM() / 10;
481 3 : if (nUsableRAM > 0)
482 3 : nMaxSize = static_cast<size_t>(nUsableRAM);
483 :
484 6 : if (!diffArray->ProcessPerChunk(
485 3 : arrayStartIdx.data(), count.data(),
486 6 : diffArray->GetProcessingChunkSize(nMaxSize).data(),
487 3 : NonZeroValueIterator::func, &it))
488 : {
489 0 : aosReport.push_back(
490 : CPLSPrintf("Array %s: diffArray->ProcessPerChunk() failed",
491 : osName.c_str()));
492 0 : return;
493 : }
494 :
495 3 : aosReport.push_back(
496 : CPLSPrintf("Array %s: pixels differing: %" PRIu64,
497 : osName.c_str(), it.nNonZero));
498 : }
499 : }
500 :
501 13 : if (HasMetric(METRIC_RMSD) || HasMetric(METRIC_PSNR))
502 : {
503 : // For PSNR on floating point image, we need to compute min and max of
504 : // reference band
505 6 : const bool bIsInteger = CPL_TO_BOOL(GDALDataTypeIsInteger(
506 6 : poRefArray->GetDataType().GetNumericDataType()));
507 : const double dfScalingProgress =
508 6 : HasMetric(METRIC_PSNR) && !bIsInteger ? 0.5 : 1;
509 : const double dfNewLastPct =
510 6 : std::min(1.0, dfLastPct + dfScalingProgress * (1.0 - dfLastPct));
511 : std::unique_ptr<void, decltype(&GDALDestroyScaledProgress)>
512 : pScaledProgress(GDALCreateScaledProgress(dfLastPct, dfNewLastPct,
513 : pfnProgress,
514 : pProgressData),
515 12 : GDALDestroyScaledProgress);
516 6 : dfLastPct = dfNewLastPct;
517 :
518 12 : auto squaredDiffArray = (*diffArray) * diffArray;
519 6 : double dfMeanSquareError = 0;
520 12 : if (squaredDiffArray->ComputeStatistics(
521 : /* bApproxOK = */ false,
522 : /* pdfMin = */ nullptr,
523 : /* pdfMax = */ nullptr, &dfMeanSquareError,
524 : /* pdfStdDev = */ nullptr, nullptr,
525 6 : pScaledProgress ? GDALScaledProgress : nullptr,
526 6 : pScaledProgress.get(), nullptr))
527 : {
528 6 : const double dfRMSD = std::sqrt(dfMeanSquareError);
529 6 : if (dfRMSD > 0)
530 : {
531 2 : if (HasMetric(METRIC_RMSD))
532 : {
533 1 : aosReport.push_back(CPLSPrintf("Array %s: RMSD: %g",
534 : osName.c_str(), dfRMSD));
535 : }
536 :
537 2 : if (HasMetric(METRIC_PSNR))
538 : {
539 2 : if (bIsInteger)
540 : {
541 : double dfMaxAmplitude =
542 1 : std::pow(2.0, GDALGetDataTypeSizeBits(
543 1 : poRefArray->GetDataType()
544 : .GetNumericDataType())) -
545 1 : 1;
546 :
547 : const double dfPSNR_dB =
548 1 : 20 * std::log10(dfMaxAmplitude / dfRMSD);
549 1 : aosReport.push_back(
550 : CPLSPrintf("Array %s: PSNR (dB): %g",
551 : osName.c_str(), dfPSNR_dB));
552 : }
553 : else
554 : {
555 1 : double dfMin = 0;
556 1 : double dfMax = 0;
557 1 : const char *const apszOptions[] = {
558 : "SET_STATISTICS=FALSE", nullptr};
559 :
560 1 : pScaledProgress.reset(GDALCreateScaledProgress(
561 : dfLastPct, 1.0, pfnProgress, pProgressData));
562 :
563 2 : if (poRefArray->ComputeStatistics(
564 : /* bApproxOK = */ false, &dfMin, &dfMax,
565 : nullptr, nullptr, nullptr,
566 1 : pScaledProgress ? GDALScaledProgress : nullptr,
567 1 : pScaledProgress.get(), apszOptions))
568 : {
569 : const double dfPSNR_dB =
570 1 : 20 * std::log10((dfMax - dfMin) / dfRMSD);
571 1 : aosReport.push_back(
572 : CPLSPrintf("Array %s: PSNR (dB): %g",
573 : osName.c_str(), dfPSNR_dB));
574 : }
575 : else
576 : {
577 0 : aosReport.push_back(
578 0 : std::string("Error during PSNR computation: ")
579 0 : .append(CPLGetLastErrorMsg()));
580 : }
581 : }
582 : }
583 : }
584 : }
585 : else
586 : {
587 0 : aosReport.push_back(
588 0 : std::string("Error during RMSD/PSNR computation: ")
589 0 : .append(CPLGetLastErrorMsg()));
590 : }
591 : }
592 : }
593 :
594 : /************************************************************************/
595 : /* ~GDALMdimCompareAlgorithmStandalone() */
596 : /************************************************************************/
597 :
598 : GDALMdimCompareAlgorithmStandalone::~GDALMdimCompareAlgorithmStandalone() =
599 : default;
600 :
601 : //! @endcond
|