LCOV - code coverage report
Current view: top level - apps - gdalalg_raster_neighbors.cpp (source / functions) Hit Total Coverage
Test: gdal_filtered.info Lines: 202 207 97.6 %
Date: 2025-10-21 22:35:35 Functions: 9 9 100.0 %

          Line data    Source code
       1             : /******************************************************************************
       2             :  *
       3             :  * Project:  GDAL
       4             :  * Purpose:  "neighbors" step of "raster pipeline"
       5             :  * Author:   Even Rouault <even dot rouault at spatialys.com>
       6             :  *
       7             :  ******************************************************************************
       8             :  * Copyright (c) 2025, Even Rouault <even dot rouault at spatialys.com>
       9             :  *
      10             :  * SPDX-License-Identifier: MIT
      11             :  ****************************************************************************/
      12             : 
      13             : #include "gdalalg_raster_neighbors.h"
      14             : 
      15             : #include "gdal_priv.h"
      16             : #include "gdal_priv_templates.hpp"
      17             : #include "vrtdataset.h"
      18             : 
      19             : #include <algorithm>
      20             : #include <limits>
      21             : #include <map>
      22             : #include <optional>
      23             : #include <set>
      24             : #include <utility>
      25             : #include <vector>
      26             : 
      27             : //! @cond Doxygen_Suppress
      28             : 
      29             : #ifndef _
      30             : #define _(x) (x)
      31             : #endif
      32             : 
      33             : static const std::set<std::string> oSetKernelNames = {
      34             :     "u",     "v",       "equal",    "edge1",
      35             :     "edge2", "sharpen", "gaussian", "unsharp-masking"};
      36             : 
      37             : namespace
      38             : {
      39             : struct KernelDef
      40             : {
      41             :     int size = 0;
      42             :     std::vector<double> adfCoefficients{};
      43             : };
      44             : }  // namespace
      45             : 
      46             : // clang-format off
      47             : // Cf https://en.wikipedia.org/wiki/Kernel_(image_processing)
      48             : static const std::map<std::string, std::pair<int, std::vector<int>>> oMapKernelNameToMatrix = {
      49             :     { "u", { 3, {  0,  0,  0,
      50             :                   -1,  0,  1,
      51             :                    0,  0,  0 } } },
      52             :     { "v", { 3, {  0, -1,  0,
      53             :                    0,  0,  0,
      54             :                    0,  1,  0 } } },
      55             :     { "edge1", { 3, {  0, -1,  0,
      56             :                       -1,  4, -1,
      57             :                        0, -1,  0 } } },
      58             :     { "edge2", { 3, { -1, -1, -1,
      59             :                       -1,  8, -1,
      60             :                       -1, -1, -1 } } },
      61             :     { "sharpen",  { 3, { 0, -1,  0,
      62             :                         -1,  5, -1,
      63             :                          0, -1,  0 } } },
      64             :     { "gaussian-3x3", { 3, { 1, 2, 1,
      65             :                                   2, 4, 2,
      66             :                                   1, 2, 1 } } },
      67             :     { "gaussian-5x5", { 5, { 1, 4,   6,  4, 1,
      68             :                                   4, 16, 24, 16, 4,
      69             :                                   6, 24, 36, 24, 6,
      70             :                                   4, 16, 24, 16, 4,
      71             :                                   1, 4,   6,  4, 1, } } },
      72             :     { "unsharp-masking-5x5", { 5, { 1, 4,     6,  4, 1,
      73             :                                     4, 16,   24, 16, 4,
      74             :                                     6, 24, -476, 24, 6,
      75             :                                     4, 16,   24, 16, 4,
      76             :                                     1, 4,     6,  4, 1, } } },
      77             : };
      78             : 
      79             : // clang-format on
      80             : 
      81             : /************************************************************************/
      82             : /*                        CreateDerivedBandXML()                        */
      83             : /************************************************************************/
      84             : 
      85          28 : static bool CreateDerivedBandXML(VRTDataset *poVRTDS, GDALRasterBand *poSrcBand,
      86             :                                  GDALDataType eType, const std::string &noData,
      87             :                                  const std::string &method,
      88             :                                  const KernelDef &kernelDef)
      89             : {
      90          28 :     poVRTDS->AddBand(eType, nullptr);
      91             : 
      92          28 :     std::optional<double> dstNoData;
      93          28 :     bool autoSelectNoDataValue = false;
      94          28 :     if (noData.empty())
      95             :     {
      96          25 :         autoSelectNoDataValue = true;
      97             :     }
      98           3 :     else if (noData != "none")
      99             :     {
     100             :         // Already validated to be numeric by the validation action
     101           2 :         dstNoData = CPLAtof(noData.c_str());
     102             :     }
     103             : 
     104          28 :     auto poVRTBand = cpl::down_cast<VRTSourcedRasterBand *>(
     105             :         poVRTDS->GetRasterBand(poVRTDS->GetRasterCount()));
     106             : 
     107          56 :     auto poSource = std::make_unique<VRTKernelFilteredSource>();
     108          28 :     poSrcBand->GetDataset()->Reference();
     109          28 :     poSource->SetSrcBand(poSrcBand);
     110          28 :     poSource->SetKernel(kernelDef.size, /* separable = */ false,
     111          28 :                         kernelDef.adfCoefficients);
     112          28 :     poSource->SetNormalized(method != "sum");
     113          28 :     if (method != "sum" && method != "mean")
     114          10 :         poSource->SetFunction(method.c_str());
     115             : 
     116          28 :     int bSrcHasNoData = false;
     117          28 :     const double dfNoDataValue = poSrcBand->GetNoDataValue(&bSrcHasNoData);
     118          28 :     if (bSrcHasNoData)
     119             :     {
     120           4 :         poSource->SetNoDataValue(dfNoDataValue);
     121           4 :         if (autoSelectNoDataValue && !dstNoData.has_value())
     122             :         {
     123           2 :             dstNoData = dfNoDataValue;
     124             :         }
     125             :     }
     126             : 
     127          28 :     if (dstNoData.has_value())
     128             :     {
     129           4 :         if (!GDALIsValueExactAs(dstNoData.value(), eType))
     130             :         {
     131           1 :             CPLError(CE_Failure, CPLE_AppDefined,
     132             :                      "Band output type %s cannot represent NoData value %g",
     133           1 :                      GDALGetDataTypeName(eType), dstNoData.value());
     134           1 :             return false;
     135             :         }
     136             : 
     137           3 :         poVRTBand->SetNoDataValue(dstNoData.value());
     138             :     }
     139             : 
     140          27 :     poVRTBand->AddSource(std::move(poSource));
     141             : 
     142          27 :     return true;
     143             : }
     144             : 
     145             : /************************************************************************/
     146             : /*                      GDALNeighborsCreateVRTDerived()                 */
     147             : /************************************************************************/
     148             : 
     149             : static std::unique_ptr<GDALDataset>
     150          24 : GDALNeighborsCreateVRTDerived(GDALDataset *poSrcDS, int nBand,
     151             :                               GDALDataType eType, const std::string &noData,
     152             :                               const std::vector<std::string> &methods,
     153             :                               const std::vector<KernelDef> &aKernelDefs)
     154             : {
     155          24 :     CPLAssert(methods.size() == aKernelDefs.size());
     156             : 
     157           0 :     auto ds = std::make_unique<VRTDataset>(poSrcDS->GetRasterXSize(),
     158          48 :                                            poSrcDS->GetRasterYSize());
     159          24 :     GDALGeoTransform gt;
     160          24 :     if (poSrcDS->GetGeoTransform(gt) == CE_None)
     161          11 :         ds->SetGeoTransform(gt);
     162          24 :     if (const OGRSpatialReference *poSRS = poSrcDS->GetSpatialRef())
     163             :     {
     164          11 :         ds->SetSpatialRef(poSRS);
     165             :     }
     166             : 
     167          24 :     bool ret = true;
     168          24 :     if (nBand != 0)
     169             :     {
     170           0 :         for (size_t i = 0; i < aKernelDefs.size() && ret; ++i)
     171             :         {
     172             :             ret =
     173           0 :                 CreateDerivedBandXML(ds.get(), poSrcDS->GetRasterBand(nBand),
     174           0 :                                      eType, noData, methods[i], aKernelDefs[i]);
     175             :         }
     176             :     }
     177             :     else
     178             :     {
     179          50 :         for (int iBand = 1; iBand <= poSrcDS->GetRasterCount(); ++iBand)
     180             :         {
     181          54 :             for (size_t i = 0; i < aKernelDefs.size() && ret; ++i)
     182             :             {
     183          28 :                 ret = CreateDerivedBandXML(ds.get(),
     184             :                                            poSrcDS->GetRasterBand(iBand), eType,
     185          28 :                                            noData, methods[i], aKernelDefs[i]);
     186             :             }
     187             :         }
     188             :     }
     189          24 :     if (!ret)
     190           1 :         ds.reset();
     191          48 :     return ds;
     192             : }
     193             : 
     194             : /************************************************************************/
     195             : /*       GDALRasterNeighborsAlgorithm::GDALRasterNeighborsAlgorithm()   */
     196             : /************************************************************************/
     197             : 
     198          76 : GDALRasterNeighborsAlgorithm::GDALRasterNeighborsAlgorithm(
     199          76 :     bool standaloneStep) noexcept
     200             :     : GDALRasterPipelineStepAlgorithm(
     201             :           NAME, DESCRIPTION, HELP_URL,
     202          76 :           ConstructorOptions().SetStandaloneStep(standaloneStep))
     203             : {
     204          76 :     AddBandArg(&m_band);
     205             : 
     206         152 :     AddArg("method", 0, _("Method to combine weighed source pixels"), &m_method)
     207          76 :         .SetChoices("mean", "sum", "min", "max", "stddev", "median", "mode");
     208             : 
     209         152 :     AddArg("size", 0, _("Neighborhood size"), &m_size)
     210          76 :         .SetMinValueIncluded(3)
     211          76 :         .SetMaxValueIncluded(99)
     212             :         .AddValidationAction(
     213          12 :             [this]()
     214             :             {
     215          11 :                 if ((m_size % 2) != 1)
     216             :                 {
     217           1 :                     ReportError(CE_Failure, CPLE_IllegalArg,
     218             :                                 "The value of 'size' must be an odd number.");
     219           1 :                     return false;
     220             :                 }
     221          10 :                 return true;
     222          76 :             });
     223             : 
     224         152 :     AddArg("kernel", 0, _("Convolution kernel(s) to apply"), &m_kernel)
     225          76 :         .SetPackedValuesAllowed(false)
     226          76 :         .SetMinCount(1)
     227          76 :         .SetMinCharCount(1)
     228          76 :         .SetRequired()
     229             :         .SetAutoCompleteFunction(
     230           2 :             [](const std::string &v)
     231             :             {
     232           2 :                 std::vector<std::string> ret;
     233           2 :                 if (v.empty() || v.front() != '[')
     234             :                 {
     235           1 :                     ret.insert(ret.end(), oSetKernelNames.begin(),
     236           2 :                                oSetKernelNames.end());
     237           1 :                     ret.push_back(
     238             :                         "[[val00,val10,...,valN0],...,[val0N,val1N,...valNN]]");
     239             :                 }
     240           2 :                 return ret;
     241         152 :             })
     242             :         .AddValidationAction(
     243          66 :             [this]()
     244             :             {
     245         124 :                 for (const std::string &kernel : m_kernel)
     246             :                 {
     247          66 :                     if (kernel.front() == '[' && kernel.back() == ']')
     248             :                     {
     249             :                         const CPLStringList aosValues(CSLTokenizeString2(
     250             :                             kernel.c_str(), "[] ,",
     251          11 :                             CSLT_STRIPLEADSPACES | CSLT_STRIPENDSPACES));
     252             :                         const double dfSize =
     253          11 :                             static_cast<double>(aosValues.size());
     254          11 :                         const int nSqrt = static_cast<int>(sqrt(dfSize) + 0.5);
     255          20 :                         if (!((aosValues.size() % 2) == 1 &&
     256           9 :                               nSqrt * nSqrt == aosValues.size()))
     257             :                         {
     258           2 :                             ReportError(
     259             :                                 CE_Failure, CPLE_IllegalArg,
     260             :                                 "The number of values in the 'kernel' "
     261             :                                 "argument must be an odd square number.");
     262           2 :                             return false;
     263             :                         }
     264          85 :                         for (int i = 0; i < aosValues.size(); ++i)
     265             :                         {
     266          77 :                             if (CPLGetValueType(aosValues[i]) ==
     267             :                                 CPL_VALUE_STRING)
     268             :                             {
     269           1 :                                 ReportError(CE_Failure, CPLE_IllegalArg,
     270             :                                             "Non-numeric value found in the "
     271             :                                             "'kernel' argument");
     272           1 :                                 return false;
     273             :                             }
     274             :                         }
     275             :                     }
     276          55 :                     else if (!cpl::contains(oSetKernelNames, kernel))
     277             :                     {
     278             :                         std::string osMsg =
     279           1 :                             "Valid values for 'kernel' argument are: ";
     280           1 :                         bool bFirst = true;
     281           9 :                         for (const auto &name : oSetKernelNames)
     282             :                         {
     283           8 :                             if (!bFirst)
     284           7 :                                 osMsg += ", ";
     285           8 :                             bFirst = false;
     286           8 :                             osMsg += '\'';
     287           8 :                             osMsg += name;
     288           8 :                             osMsg += '\'';
     289             :                         }
     290             :                         osMsg += " or "
     291             :                                  "[[val00,val10,...,valN0],...,[val0N,val1N,..."
     292           1 :                                  "valNN]]";
     293           1 :                         ReportError(CE_Failure, CPLE_IllegalArg, "%s",
     294             :                                     osMsg.c_str());
     295           1 :                         return false;
     296             :                     }
     297             :                 }
     298          58 :                 return true;
     299          76 :             });
     300             : 
     301          76 :     AddOutputDataTypeArg(&m_type).SetDefault("Float64");
     302             : 
     303          76 :     AddNodataArg(&m_nodata, true);
     304             : 
     305          76 :     AddValidationAction(
     306          76 :         [this]()
     307             :         {
     308          29 :             if (m_method.size() > 1 && m_method.size() != m_kernel.size())
     309             :             {
     310           1 :                 ReportError(
     311             :                     CE_Failure, CPLE_AppDefined,
     312             :                     "The number of values for the 'method' argument should "
     313             :                     "be one or exactly the number of values of 'kernel'");
     314           1 :                 return false;
     315             :             }
     316             : 
     317          28 :             if (m_size > 0)
     318             :             {
     319           6 :                 for (const std::string &kernel : m_kernel)
     320             :                 {
     321           5 :                     if (kernel == "gaussian")
     322             :                     {
     323           2 :                         if (m_size != 3 && m_size != 5)
     324             :                         {
     325           1 :                             ReportError(CE_Failure, CPLE_AppDefined,
     326             :                                         "Currently only size = 3 or 5 is "
     327             :                                         "supported for kernel '%s'",
     328             :                                         kernel.c_str());
     329           4 :                             return false;
     330             :                         }
     331             :                     }
     332           3 :                     else if (kernel == "unsharp-masking")
     333             :                     {
     334           1 :                         if (m_size != 5)
     335             :                         {
     336           1 :                             ReportError(CE_Failure, CPLE_AppDefined,
     337             :                                         "Currently only size = 5 is supported "
     338             :                                         "for kernel '%s'",
     339             :                                         kernel.c_str());
     340           1 :                             return false;
     341             :                         }
     342             :                     }
     343           2 :                     else if (kernel[0] == '[')
     344             :                     {
     345             :                         const CPLStringList aosValues(CSLTokenizeString2(
     346             :                             kernel.c_str(), "[] ,",
     347           1 :                             CSLT_STRIPLEADSPACES | CSLT_STRIPENDSPACES));
     348             :                         const double dfSize =
     349           1 :                             static_cast<double>(aosValues.size());
     350             :                         const int size =
     351           1 :                             static_cast<int>(std::floor(sqrt(dfSize) + 0.5));
     352           1 :                         if (m_size != size)
     353             :                         {
     354           1 :                             ReportError(CE_Failure, CPLE_AppDefined,
     355             :                                         "Value of 'size' argument (%d) "
     356             :                                         "inconsistent with the one deduced "
     357             :                                         "from the kernel matrix (%d)",
     358             :                                         m_size, size);
     359           1 :                             return false;
     360             :                         }
     361             :                     }
     362           2 :                     else if (m_size != 3 && kernel != "equal" &&
     363           1 :                              kernel[0] != '[')
     364             :                     {
     365           1 :                         ReportError(CE_Failure, CPLE_AppDefined,
     366             :                                     "Currently only size = 3 is supported for "
     367             :                                     "kernel '%s'",
     368             :                                     kernel.c_str());
     369           1 :                         return false;
     370             :                     }
     371             :                 }
     372             :             }
     373             : 
     374          24 :             return true;
     375             :         });
     376          76 : }
     377             : 
     378             : /************************************************************************/
     379             : /*                               GetKernelDef()                         */
     380             : /************************************************************************/
     381             : 
     382          11 : static KernelDef GetKernelDef(const std::string &name, bool normalizeCoefs,
     383             :                               double weightIfNotNormalized)
     384             : {
     385          11 :     auto it = oMapKernelNameToMatrix.find(name);
     386          11 :     CPLAssert(it != oMapKernelNameToMatrix.end());
     387          11 :     KernelDef def;
     388          11 :     def.size = it->second.first;
     389          11 :     int nSum = 0;
     390         142 :     for (const int nVal : it->second.second)
     391         131 :         nSum += nVal;
     392             :     const double dfWeight = normalizeCoefs
     393          14 :                                 ? 1.0 / (static_cast<double>(nSum) +
     394           3 :                                          std::numeric_limits<double>::min())
     395          11 :                                 : weightIfNotNormalized;
     396         142 :     for (const int nVal : it->second.second)
     397             :     {
     398         131 :         def.adfCoefficients.push_back(nVal * dfWeight);
     399             :     }
     400          22 :     return def;
     401             : }
     402             : 
     403             : /************************************************************************/
     404             : /*                GDALRasterNeighborsAlgorithm::RunStep()               */
     405             : /************************************************************************/
     406             : 
     407          24 : bool GDALRasterNeighborsAlgorithm::RunStep(GDALPipelineStepRunContext &)
     408             : {
     409          24 :     auto poSrcDS = m_inputDataset[0].GetDatasetRef();
     410          24 :     CPLAssert(!m_outputDataset.GetDatasetRef());
     411             : 
     412          24 :     CPLAssert(m_band <= poSrcDS->GetRasterCount());
     413             : 
     414          24 :     auto eType = GDALGetDataTypeByName(m_type.c_str());
     415          24 :     if (eType == GDT_Unknown)
     416             :     {
     417           0 :         eType = GDT_Float64;
     418             :     }
     419             : 
     420          24 :     if (m_method.size() <= 1)
     421             :     {
     422          38 :         while (m_method.size() < m_kernel.size())
     423             :         {
     424          14 :             m_method.push_back(
     425          14 :                 m_method.empty()
     426          47 :                     ? ((m_kernel[0] == "u" || m_kernel[0] == "v" ||
     427          17 :                         m_kernel[0] == "edge1" || m_kernel[0] == "edge2")
     428             :                            ? "sum"
     429             :                            : "mean")
     430           2 :                     : m_method.back());
     431             :         }
     432             :     }
     433             : 
     434          24 :     if (m_size == 0 && m_kernel[0][0] != '[')
     435          21 :         m_size = m_kernel[0] == "unsharp-masking" ? 5 : 3;
     436             : 
     437          48 :     std::vector<KernelDef> aKernelDefs;
     438          24 :     size_t i = 0;
     439          50 :     for (std::string &kernel : m_kernel)
     440             :     {
     441          26 :         KernelDef def;
     442          26 :         if (kernel == "edge1" || kernel == "edge2" || kernel == "sharpen")
     443             :         {
     444           4 :             CPLAssert(m_size == 3);
     445           4 :             def = GetKernelDef(kernel, false, 1.0);
     446             :         }
     447          22 :         else if (kernel == "u" || kernel == "v")
     448             :         {
     449           4 :             CPLAssert(m_size == 3);
     450           4 :             def = GetKernelDef(kernel, false, 0.5);
     451             :         }
     452          18 :         else if (kernel == "equal")
     453             :         {
     454          12 :             def.size = m_size;
     455             :             const double dfWeight =
     456          12 :                 m_method[i] == "mean"
     457          13 :                     ? 1.0 / (static_cast<double>(m_size) * m_size +
     458           1 :                              std::numeric_limits<double>::min())
     459          12 :                     : 1.0;
     460          12 :             def.adfCoefficients.resize(static_cast<size_t>(m_size) * m_size,
     461             :                                        dfWeight);
     462             :         }
     463           6 :         else if (kernel == "gaussian")
     464             :         {
     465           2 :             CPLAssert(m_size == 3 || m_size == 5);
     466           4 :             def = GetKernelDef(m_size == 3 ? "gaussian-3x3" : "gaussian-5x5",
     467           2 :                                true, 0.0);
     468             :         }
     469           4 :         else if (kernel == "unsharp-masking")
     470             :         {
     471           1 :             CPLAssert(m_size == 5);
     472           1 :             def = GetKernelDef("unsharp-masking-5x5", true, 0.0);
     473             :         }
     474             :         else
     475             :         {
     476           3 :             CPLAssert(kernel.front() == '[');
     477             :             const CPLStringList aosValues(
     478             :                 CSLTokenizeString2(kernel.c_str(), "[] ,",
     479           6 :                                    CSLT_STRIPLEADSPACES | CSLT_STRIPENDSPACES));
     480           3 :             const double dfSize = static_cast<double>(aosValues.size());
     481           3 :             def.size = static_cast<int>(std::floor(sqrt(dfSize) + 0.5));
     482          30 :             for (const char *pszC : cpl::Iterate(aosValues))
     483             :             {
     484             :                 // Already validated to be numeric by the validation action
     485          27 :                 def.adfCoefficients.push_back(CPLAtof(pszC));
     486             :             }
     487             :         }
     488          26 :         aKernelDefs.push_back(std::move(def));
     489             : 
     490          26 :         ++i;
     491             :     }
     492             : 
     493          24 :     auto vrt = GDALNeighborsCreateVRTDerived(poSrcDS, m_band, eType, m_nodata,
     494          24 :                                              m_method, aKernelDefs);
     495          24 :     const bool ret = vrt != nullptr;
     496          24 :     if (vrt)
     497             :     {
     498          23 :         m_outputDataset.Set(std::move(vrt));
     499             :     }
     500          48 :     return ret;
     501             : }
     502             : 
     503             : GDALRasterNeighborsAlgorithmStandalone::
     504             :     ~GDALRasterNeighborsAlgorithmStandalone() = default;
     505             : 
     506             : //! @endcond

Generated by: LCOV version 1.14