LCOV - code coverage report
Current view: top level - apps - gdalalg_raster_neighbors.cpp (source / functions) Hit Total Coverage
Test: gdal_filtered.info Lines: 217 223 97.3 %
Date: 2025-11-11 01:53:33 Functions: 9 9 100.0 %

          Line data    Source code
       1             : /******************************************************************************
       2             :  *
       3             :  * Project:  GDAL
       4             :  * Purpose:  "neighbors" step of "raster pipeline"
       5             :  * Author:   Even Rouault <even dot rouault at spatialys.com>
       6             :  *
       7             :  ******************************************************************************
       8             :  * Copyright (c) 2025, Even Rouault <even dot rouault at spatialys.com>
       9             :  *
      10             :  * SPDX-License-Identifier: MIT
      11             :  ****************************************************************************/
      12             : 
      13             : #include "gdalalg_raster_neighbors.h"
      14             : 
      15             : #include "gdal_priv.h"
      16             : #include "gdal_priv_templates.hpp"
      17             : #include "vrtdataset.h"
      18             : 
      19             : #include <algorithm>
      20             : #include <limits>
      21             : #include <map>
      22             : #include <optional>
      23             : #include <set>
      24             : #include <utility>
      25             : #include <vector>
      26             : 
      27             : //! @cond Doxygen_Suppress
      28             : 
      29             : #ifndef _
      30             : #define _(x) (x)
      31             : #endif
      32             : 
      33             : static const std::set<std::string> oSetKernelNames = {
      34             :     "u",     "v",       "equal",    "edge1",
      35             :     "edge2", "sharpen", "gaussian", "unsharp-masking"};
      36             : 
      37             : namespace
      38             : {
      39             : struct KernelDef
      40             : {
      41             :     int size = 0;
      42             :     std::vector<double> adfCoefficients{};
      43             : };
      44             : }  // namespace
      45             : 
      46             : // clang-format off
      47             : // Cf https://en.wikipedia.org/wiki/Kernel_(image_processing)
      48             : static const std::map<std::string, std::pair<int, std::vector<int>>> oMapKernelNameToMatrix = {
      49             :     { "u", { 3, {  0,  0,  0,
      50             :                   -1,  0,  1,
      51             :                    0,  0,  0 } } },
      52             :     { "v", { 3, {  0, -1,  0,
      53             :                    0,  0,  0,
      54             :                    0,  1,  0 } } },
      55             :     { "edge1", { 3, {  0, -1,  0,
      56             :                       -1,  4, -1,
      57             :                        0, -1,  0 } } },
      58             :     { "edge2", { 3, { -1, -1, -1,
      59             :                       -1,  8, -1,
      60             :                       -1, -1, -1 } } },
      61             :     { "sharpen",  { 3, { 0, -1,  0,
      62             :                         -1,  5, -1,
      63             :                          0, -1,  0 } } },
      64             :     { "gaussian-3x3", { 3, { 1, 2, 1,
      65             :                                   2, 4, 2,
      66             :                                   1, 2, 1 } } },
      67             :     { "gaussian-5x5", { 5, { 1, 4,   6,  4, 1,
      68             :                                   4, 16, 24, 16, 4,
      69             :                                   6, 24, 36, 24, 6,
      70             :                                   4, 16, 24, 16, 4,
      71             :                                   1, 4,   6,  4, 1, } } },
      72             :     { "unsharp-masking-5x5", { 5, { 1, 4,     6,  4, 1,
      73             :                                     4, 16,   24, 16, 4,
      74             :                                     6, 24, -476, 24, 6,
      75             :                                     4, 16,   24, 16, 4,
      76             :                                     1, 4,     6,  4, 1, } } },
      77             : };
      78             : 
      79             : // clang-format on
      80             : 
      81             : /************************************************************************/
      82             : /*                        CreateDerivedBandXML()                        */
      83             : /************************************************************************/
      84             : 
      85          29 : static bool CreateDerivedBandXML(VRTDataset *poVRTDS, GDALRasterBand *poSrcBand,
      86             :                                  GDALDataType eType, const std::string &noData,
      87             :                                  const std::string &method,
      88             :                                  const KernelDef &kernelDef)
      89             : {
      90          29 :     poVRTDS->AddBand(eType, nullptr);
      91             : 
      92          29 :     std::optional<double> dstNoData;
      93          29 :     bool autoSelectNoDataValue = false;
      94          29 :     if (noData.empty())
      95             :     {
      96          26 :         autoSelectNoDataValue = true;
      97             :     }
      98           3 :     else if (noData != "none")
      99             :     {
     100             :         // Already validated to be numeric by the validation action
     101           2 :         dstNoData = CPLAtof(noData.c_str());
     102             :     }
     103             : 
     104          29 :     auto poVRTBand = cpl::down_cast<VRTSourcedRasterBand *>(
     105             :         poVRTDS->GetRasterBand(poVRTDS->GetRasterCount()));
     106             : 
     107          58 :     auto poSource = std::make_unique<VRTKernelFilteredSource>();
     108          29 :     poSrcBand->GetDataset()->Reference();
     109          29 :     poSource->SetSrcBand(poSrcBand);
     110          29 :     poSource->SetKernel(kernelDef.size, /* separable = */ false,
     111          29 :                         kernelDef.adfCoefficients);
     112          29 :     poSource->SetNormalized(method != "sum");
     113          29 :     if (method != "sum" && method != "mean")
     114          10 :         poSource->SetFunction(method.c_str());
     115             : 
     116          29 :     int bSrcHasNoData = false;
     117          29 :     const double dfNoDataValue = poSrcBand->GetNoDataValue(&bSrcHasNoData);
     118          29 :     if (bSrcHasNoData)
     119             :     {
     120           4 :         poSource->SetNoDataValue(dfNoDataValue);
     121           4 :         if (autoSelectNoDataValue && !dstNoData.has_value())
     122             :         {
     123           2 :             dstNoData = dfNoDataValue;
     124             :         }
     125             :     }
     126             : 
     127          29 :     if (dstNoData.has_value())
     128             :     {
     129           4 :         if (!GDALIsValueExactAs(dstNoData.value(), eType))
     130             :         {
     131           1 :             CPLError(CE_Failure, CPLE_AppDefined,
     132             :                      "Band output type %s cannot represent NoData value %g",
     133           1 :                      GDALGetDataTypeName(eType), dstNoData.value());
     134           1 :             return false;
     135             :         }
     136             : 
     137           3 :         poVRTBand->SetNoDataValue(dstNoData.value());
     138             :     }
     139             : 
     140          28 :     poVRTBand->AddSource(std::move(poSource));
     141             : 
     142          28 :     return true;
     143             : }
     144             : 
     145             : /************************************************************************/
     146             : /*                      GDALNeighborsCreateVRTDerived()                 */
     147             : /************************************************************************/
     148             : 
     149             : static std::unique_ptr<GDALDataset>
     150          25 : GDALNeighborsCreateVRTDerived(GDALDataset *poSrcDS, int nBand,
     151             :                               GDALDataType eType, const std::string &noData,
     152             :                               const std::vector<std::string> &methods,
     153             :                               const std::vector<KernelDef> &aKernelDefs)
     154             : {
     155          25 :     CPLAssert(methods.size() == aKernelDefs.size());
     156             : 
     157           0 :     auto ds = std::make_unique<VRTDataset>(poSrcDS->GetRasterXSize(),
     158          50 :                                            poSrcDS->GetRasterYSize());
     159          25 :     GDALGeoTransform gt;
     160          25 :     if (poSrcDS->GetGeoTransform(gt) == CE_None)
     161          12 :         ds->SetGeoTransform(gt);
     162          25 :     if (const OGRSpatialReference *poSRS = poSrcDS->GetSpatialRef())
     163             :     {
     164          12 :         ds->SetSpatialRef(poSRS);
     165             :     }
     166             : 
     167          25 :     bool ret = true;
     168          25 :     if (nBand != 0)
     169             :     {
     170           0 :         for (size_t i = 0; i < aKernelDefs.size() && ret; ++i)
     171             :         {
     172             :             ret =
     173           0 :                 CreateDerivedBandXML(ds.get(), poSrcDS->GetRasterBand(nBand),
     174           0 :                                      eType, noData, methods[i], aKernelDefs[i]);
     175             :         }
     176             :     }
     177             :     else
     178             :     {
     179          52 :         for (int iBand = 1; iBand <= poSrcDS->GetRasterCount(); ++iBand)
     180             :         {
     181          56 :             for (size_t i = 0; i < aKernelDefs.size() && ret; ++i)
     182             :             {
     183          29 :                 ret = CreateDerivedBandXML(ds.get(),
     184             :                                            poSrcDS->GetRasterBand(iBand), eType,
     185          29 :                                            noData, methods[i], aKernelDefs[i]);
     186             :             }
     187             :         }
     188             :     }
     189          25 :     if (!ret)
     190           1 :         ds.reset();
     191          50 :     return ds;
     192             : }
     193             : 
     194             : /************************************************************************/
     195             : /*       GDALRasterNeighborsAlgorithm::GDALRasterNeighborsAlgorithm()   */
     196             : /************************************************************************/
     197             : 
     198          81 : GDALRasterNeighborsAlgorithm::GDALRasterNeighborsAlgorithm(
     199          81 :     bool standaloneStep) noexcept
     200             :     : GDALRasterPipelineStepAlgorithm(
     201             :           NAME, DESCRIPTION, HELP_URL,
     202          81 :           ConstructorOptions().SetStandaloneStep(standaloneStep))
     203             : {
     204          81 :     AddBandArg(&m_band);
     205             : 
     206         162 :     AddArg("method", 0, _("Method to combine weighed source pixels"), &m_method)
     207          81 :         .SetChoices("mean", "sum", "min", "max", "stddev", "median", "mode");
     208             : 
     209         162 :     AddArg("size", 0, _("Neighborhood size"), &m_size)
     210          81 :         .SetMinValueIncluded(3)
     211          81 :         .SetMaxValueIncluded(99)
     212             :         .AddValidationAction(
     213          12 :             [this]()
     214             :             {
     215          11 :                 if ((m_size % 2) != 1)
     216             :                 {
     217           1 :                     ReportError(CE_Failure, CPLE_IllegalArg,
     218             :                                 "The value of 'size' must be an odd number.");
     219           1 :                     return false;
     220             :                 }
     221          10 :                 return true;
     222          81 :             });
     223             : 
     224         162 :     AddArg("kernel", 0, _("Convolution kernel(s) to apply"), &m_kernel)
     225          81 :         .SetPackedValuesAllowed(false)
     226          81 :         .SetMinCount(1)
     227          81 :         .SetMinCharCount(1)
     228          81 :         .SetRequired()
     229             :         .SetAutoCompleteFunction(
     230           2 :             [](const std::string &v)
     231             :             {
     232           2 :                 std::vector<std::string> ret;
     233           2 :                 if (v.empty() || v.front() != '[')
     234             :                 {
     235           1 :                     ret.insert(ret.end(), oSetKernelNames.begin(),
     236           2 :                                oSetKernelNames.end());
     237           1 :                     ret.push_back(
     238             :                         "[[val00,val10,...,valN0],...,[val0N,val1N,...valNN]]");
     239             :                 }
     240           2 :                 return ret;
     241         162 :             })
     242             :         .AddValidationAction(
     243          70 :             [this]()
     244             :             {
     245         132 :                 for (const std::string &kernel : m_kernel)
     246             :                 {
     247          70 :                     if (kernel.front() == '[' && kernel.back() == ']')
     248             :                     {
     249             :                         const CPLStringList aosValues(CSLTokenizeString2(
     250             :                             kernel.c_str(), "[] ,",
     251          15 :                             CSLT_STRIPLEADSPACES | CSLT_STRIPENDSPACES));
     252             :                         const double dfSize =
     253          15 :                             static_cast<double>(aosValues.size());
     254          15 :                         const int nSqrt = static_cast<int>(sqrt(dfSize) + 0.5);
     255          28 :                         if (!((aosValues.size() % 2) == 1 &&
     256          13 :                               nSqrt * nSqrt == aosValues.size()))
     257             :                         {
     258           2 :                             ReportError(
     259             :                                 CE_Failure, CPLE_IllegalArg,
     260             :                                 "The number of values in the 'kernel' "
     261             :                                 "argument must be an odd square number.");
     262           2 :                             return false;
     263             :                         }
     264         125 :                         for (int i = 0; i < aosValues.size(); ++i)
     265             :                         {
     266         113 :                             if (CPLGetValueType(aosValues[i]) ==
     267             :                                 CPL_VALUE_STRING)
     268             :                             {
     269           1 :                                 ReportError(CE_Failure, CPLE_IllegalArg,
     270             :                                             "Non-numeric value found in the "
     271             :                                             "'kernel' argument");
     272           1 :                                 return false;
     273             :                             }
     274             :                         }
     275             :                     }
     276          55 :                     else if (!cpl::contains(oSetKernelNames, kernel))
     277             :                     {
     278             :                         std::string osMsg =
     279           1 :                             "Valid values for 'kernel' argument are: ";
     280           1 :                         bool bFirst = true;
     281           9 :                         for (const auto &name : oSetKernelNames)
     282             :                         {
     283           8 :                             if (!bFirst)
     284           7 :                                 osMsg += ", ";
     285           8 :                             bFirst = false;
     286           8 :                             osMsg += '\'';
     287           8 :                             osMsg += name;
     288           8 :                             osMsg += '\'';
     289             :                         }
     290             :                         osMsg += " or "
     291             :                                  "[[val00,val10,...,valN0],...,[val0N,val1N,..."
     292           1 :                                  "valNN]]";
     293           1 :                         ReportError(CE_Failure, CPLE_IllegalArg, "%s",
     294             :                                     osMsg.c_str());
     295           1 :                         return false;
     296             :                     }
     297             :                 }
     298          62 :                 return true;
     299          81 :             });
     300             : 
     301          81 :     AddOutputDataTypeArg(&m_type).SetDefault("Float64");
     302             : 
     303          81 :     AddNodataArg(&m_nodata, true);
     304             : 
     305          81 :     AddValidationAction(
     306          80 :         [this]()
     307             :         {
     308          31 :             if (m_method.size() > 1 && m_method.size() != m_kernel.size())
     309             :             {
     310           1 :                 ReportError(
     311             :                     CE_Failure, CPLE_AppDefined,
     312             :                     "The number of values for the 'method' argument should "
     313             :                     "be one or exactly the number of values of 'kernel'");
     314           1 :                 return false;
     315             :             }
     316             : 
     317          30 :             if (m_size > 0)
     318             :             {
     319           6 :                 for (const std::string &kernel : m_kernel)
     320             :                 {
     321           5 :                     if (kernel == "gaussian")
     322             :                     {
     323           2 :                         if (m_size != 3 && m_size != 5)
     324             :                         {
     325           1 :                             ReportError(CE_Failure, CPLE_AppDefined,
     326             :                                         "Currently only size = 3 or 5 is "
     327             :                                         "supported for kernel '%s'",
     328             :                                         kernel.c_str());
     329           4 :                             return false;
     330             :                         }
     331             :                     }
     332           3 :                     else if (kernel == "unsharp-masking")
     333             :                     {
     334           1 :                         if (m_size != 5)
     335             :                         {
     336           1 :                             ReportError(CE_Failure, CPLE_AppDefined,
     337             :                                         "Currently only size = 5 is supported "
     338             :                                         "for kernel '%s'",
     339             :                                         kernel.c_str());
     340           1 :                             return false;
     341             :                         }
     342             :                     }
     343           2 :                     else if (kernel[0] == '[')
     344             :                     {
     345             :                         const CPLStringList aosValues(CSLTokenizeString2(
     346             :                             kernel.c_str(), "[] ,",
     347           1 :                             CSLT_STRIPLEADSPACES | CSLT_STRIPENDSPACES));
     348             :                         const double dfSize =
     349           1 :                             static_cast<double>(aosValues.size());
     350             :                         const int size =
     351           1 :                             static_cast<int>(std::floor(sqrt(dfSize) + 0.5));
     352           1 :                         if (m_size != size)
     353             :                         {
     354           1 :                             ReportError(CE_Failure, CPLE_AppDefined,
     355             :                                         "Value of 'size' argument (%d) "
     356             :                                         "inconsistent with the one deduced "
     357             :                                         "from the kernel matrix (%d)",
     358             :                                         m_size, size);
     359           1 :                             return false;
     360             :                         }
     361             :                     }
     362           2 :                     else if (m_size != 3 && kernel != "equal" &&
     363           1 :                              kernel[0] != '[')
     364             :                     {
     365           1 :                         ReportError(CE_Failure, CPLE_AppDefined,
     366             :                                     "Currently only size = 3 is supported for "
     367             :                                     "kernel '%s'",
     368             :                                     kernel.c_str());
     369           1 :                         return false;
     370             :                     }
     371             :                 }
     372             :             }
     373             : 
     374          26 :             return true;
     375             :         });
     376          81 : }
     377             : 
     378             : /************************************************************************/
     379             : /*                               GetKernelDef()                         */
     380             : /************************************************************************/
     381             : 
     382          11 : static KernelDef GetKernelDef(const std::string &name, bool normalizeCoefs,
     383             :                               double weightIfNotNormalized)
     384             : {
     385          11 :     auto it = oMapKernelNameToMatrix.find(name);
     386          11 :     CPLAssert(it != oMapKernelNameToMatrix.end());
     387          11 :     KernelDef def;
     388          11 :     def.size = it->second.first;
     389          11 :     int nSum = 0;
     390         142 :     for (const int nVal : it->second.second)
     391         131 :         nSum += nVal;
     392             :     const double dfWeight = normalizeCoefs
     393          14 :                                 ? 1.0 / (static_cast<double>(nSum) +
     394           3 :                                          std::numeric_limits<double>::min())
     395          11 :                                 : weightIfNotNormalized;
     396         142 :     for (const int nVal : it->second.second)
     397             :     {
     398         131 :         def.adfCoefficients.push_back(nVal * dfWeight);
     399             :     }
     400          22 :     return def;
     401             : }
     402             : 
     403             : /************************************************************************/
     404             : /*                GDALRasterNeighborsAlgorithm::RunStep()               */
     405             : /************************************************************************/
     406             : 
     407          26 : bool GDALRasterNeighborsAlgorithm::RunStep(GDALPipelineStepRunContext &)
     408             : {
     409          26 :     auto poSrcDS = m_inputDataset[0].GetDatasetRef();
     410          26 :     CPLAssert(!m_outputDataset.GetDatasetRef());
     411             : 
     412          26 :     CPLAssert(m_band <= poSrcDS->GetRasterCount());
     413             : 
     414          26 :     auto eType = GDALGetDataTypeByName(m_type.c_str());
     415          26 :     if (eType == GDT_Unknown)
     416             :     {
     417           0 :         eType = GDT_Float64;
     418             :     }
     419          52 :     std::vector<KernelDef> aKernelDefs(m_kernel.size());
     420          52 :     std::vector<bool> abNullCoefficientSum(m_kernel.size());
     421          53 :     for (size_t i = 0; i < m_kernel.size(); ++i)
     422             :     {
     423          28 :         const std::string &kernel = m_kernel[i];
     424          28 :         if (!kernel.empty() && kernel[0] == '[')
     425             :         {
     426             :             const CPLStringList aosValues(
     427             :                 CSLTokenizeString2(kernel.c_str(), "[] ,",
     428           5 :                                    CSLT_STRIPLEADSPACES | CSLT_STRIPENDSPACES));
     429           5 :             const double dfSize = static_cast<double>(aosValues.size());
     430           5 :             KernelDef def;
     431           5 :             def.size = static_cast<int>(std::floor(sqrt(dfSize) + 0.5));
     432           5 :             double dfSum = 0;
     433          50 :             for (const char *pszC : cpl::Iterate(aosValues))
     434             :             {
     435          45 :                 const double dfV = CPLAtof(pszC);
     436          45 :                 dfSum += dfV;
     437             :                 // Already validated to be numeric by the validation action
     438          45 :                 def.adfCoefficients.push_back(dfV);
     439             :             }
     440           5 :             abNullCoefficientSum[i] = std::fabs(dfSum) < 1e-10;
     441           6 :             if (abNullCoefficientSum[i] && m_method.size() == m_kernel.size() &&
     442           1 :                 m_method[i] == "mean")
     443             :             {
     444           1 :                 ReportError(
     445             :                     CE_Failure, CPLE_AppDefined,
     446             :                     "Specifying method = 'mean' for a kernel whose sum of "
     447             :                     "coefficients is zero is not allowed. Use 'sum' instead");
     448           1 :                 return false;
     449             :             }
     450           4 :             aKernelDefs[i] = std::move(def);
     451             :         }
     452             :     }
     453             : 
     454          25 :     if (m_method.empty())
     455             :     {
     456          28 :         for (size_t i = 0; i < m_kernel.size(); ++i)
     457             :         {
     458             :             const bool bIsZeroSumKernel =
     459          39 :                 m_kernel[i] == "u" || m_kernel[i] == "v" ||
     460          48 :                 m_kernel[i] == "edge1" || m_kernel[i] == "edge2" ||
     461           9 :                 abNullCoefficientSum[i];
     462          15 :             m_method.push_back(bIsZeroSumKernel ? "sum" : "mean");
     463             :         }
     464             :     }
     465          12 :     else if (m_method.size() == 1)
     466             :     {
     467          24 :         const std::string lastValue(m_method.back());
     468          12 :         m_method.resize(m_kernel.size(), lastValue);
     469             :     }
     470             : 
     471          25 :     if (m_size == 0 && m_kernel[0][0] != '[')
     472          21 :         m_size = m_kernel[0] == "unsharp-masking" ? 5 : 3;
     473             : 
     474          52 :     for (size_t i = 0; i < m_kernel.size(); ++i)
     475             :     {
     476          27 :         const std::string &kernel = m_kernel[i];
     477          27 :         if (aKernelDefs[i].adfCoefficients.empty())
     478             :         {
     479          46 :             KernelDef def;
     480          23 :             if (kernel == "edge1" || kernel == "edge2" || kernel == "sharpen")
     481             :             {
     482           4 :                 CPLAssert(m_size == 3);
     483           4 :                 def = GetKernelDef(kernel, false, 1.0);
     484             :             }
     485          19 :             else if (kernel == "u" || kernel == "v")
     486             :             {
     487           4 :                 CPLAssert(m_size == 3);
     488           4 :                 def = GetKernelDef(kernel, false, 0.5);
     489             :             }
     490          15 :             else if (kernel == "equal")
     491             :             {
     492          12 :                 def.size = m_size;
     493             :                 const double dfWeight =
     494          12 :                     m_method[i] == "mean"
     495          13 :                         ? 1.0 / (static_cast<double>(m_size) * m_size +
     496           1 :                                  std::numeric_limits<double>::min())
     497          12 :                         : 1.0;
     498          12 :                 def.adfCoefficients.resize(static_cast<size_t>(m_size) * m_size,
     499             :                                            dfWeight);
     500             :             }
     501           3 :             else if (kernel == "gaussian")
     502             :             {
     503           2 :                 CPLAssert(m_size == 3 || m_size == 5);
     504           6 :                 def = GetKernelDef(
     505           4 :                     m_size == 3 ? "gaussian-3x3" : "gaussian-5x5", true, 0.0);
     506             :             }
     507           1 :             else if (kernel == "unsharp-masking")
     508             :             {
     509           1 :                 CPLAssert(m_size == 5);
     510           1 :                 def = GetKernelDef("unsharp-masking-5x5", true, 0.0);
     511             :             }
     512             :             else
     513             :             {
     514           0 :                 CPLAssert(false);
     515             :             }
     516          23 :             aKernelDefs[i] = std::move(def);
     517             :         }
     518             :     }
     519             : 
     520          25 :     auto vrt = GDALNeighborsCreateVRTDerived(poSrcDS, m_band, eType, m_nodata,
     521          25 :                                              m_method, aKernelDefs);
     522          25 :     const bool ret = vrt != nullptr;
     523          25 :     if (vrt)
     524             :     {
     525          24 :         m_outputDataset.Set(std::move(vrt));
     526             :     }
     527          25 :     return ret;
     528             : }
     529             : 
     530             : GDALRasterNeighborsAlgorithmStandalone::
     531             :     ~GDALRasterNeighborsAlgorithmStandalone() = default;
     532             : 
     533             : //! @endcond

Generated by: LCOV version 1.14