LCOV - code coverage report
Current view: top level - apps - gdalalg_raster_neighbors.cpp (source / functions) Hit Total Coverage
Test: gdal_filtered.info Lines: 207 209 99.0 %
Date: 2025-10-01 17:07:58 Functions: 9 9 100.0 %

          Line data    Source code
       1             : /******************************************************************************
       2             :  *
       3             :  * Project:  GDAL
       4             :  * Purpose:  "neighbors" step of "raster pipeline"
       5             :  * Author:   Even Rouault <even dot rouault at spatialys.com>
       6             :  *
       7             :  ******************************************************************************
       8             :  * Copyright (c) 2025, Even Rouault <even dot rouault at spatialys.com>
       9             :  *
      10             :  * SPDX-License-Identifier: MIT
      11             :  ****************************************************************************/
      12             : 
      13             : #include "gdalalg_raster_neighbors.h"
      14             : 
      15             : #include "gdal_priv.h"
      16             : #include "gdal_priv_templates.hpp"
      17             : #include "vrtdataset.h"
      18             : 
      19             : #include <algorithm>
      20             : #include <limits>
      21             : #include <map>
      22             : #include <optional>
      23             : #include <set>
      24             : #include <utility>
      25             : #include <vector>
      26             : 
      27             : //! @cond Doxygen_Suppress
      28             : 
      29             : #ifndef _
      30             : #define _(x) (x)
      31             : #endif
      32             : 
      33             : static const std::set<std::string> oSetKernelNames = {
      34             :     "u",     "v",       "equal",    "edge1",
      35             :     "edge2", "sharpen", "gaussian", "unsharp-masking"};
      36             : 
      37             : namespace
      38             : {
      39             : struct KernelDef
      40             : {
      41             :     int size = 0;
      42             :     std::vector<double> adfCoefficients{};
      43             : };
      44             : }  // namespace
      45             : 
      46             : // clang-format off
      47             : // Cf https://en.wikipedia.org/wiki/Kernel_(image_processing)
      48             : static const std::map<std::string, std::pair<int, std::vector<int>>> oMapKernelNameToMatrix = {
      49             :     { "u", { 3, {  0,  0,  0,
      50             :                   -1,  0,  1,
      51             :                    0,  0,  0 } } },
      52             :     { "v", { 3, {  0, -1,  0,
      53             :                    0,  0,  0,
      54             :                    0,  1,  0 } } },
      55             :     { "edge1", { 3, {  0, -1,  0,
      56             :                       -1,  4, -1,
      57             :                        0, -1,  0 } } },
      58             :     { "edge2", { 3, { -1, -1, -1,
      59             :                       -1,  8, -1,
      60             :                       -1, -1, -1 } } },
      61             :     { "sharpen",  { 3, { 0, -1,  0,
      62             :                         -1,  5, -1,
      63             :                          0, -1,  0 } } },
      64             :     { "gaussian-3x3", { 3, { 1, 2, 1,
      65             :                                   2, 4, 2,
      66             :                                   1, 2, 1 } } },
      67             :     { "gaussian-5x5", { 5, { 1, 4,   6,  4, 1,
      68             :                                   4, 16, 24, 16, 4,
      69             :                                   6, 24, 36, 24, 6,
      70             :                                   4, 16, 24, 16, 4,
      71             :                                   1, 4,   6,  4, 1, } } },
      72             :     { "unsharp-masking-5x5", { 5, { 1, 4,     6,  4, 1,
      73             :                                     4, 16,   24, 16, 4,
      74             :                                     6, 24, -476, 24, 6,
      75             :                                     4, 16,   24, 16, 4,
      76             :                                     1, 4,     6,  4, 1, } } },
      77             : };
      78             : 
      79             : // clang-format on
      80             : 
      81             : /************************************************************************/
      82             : /*                        CreateDerivedBandXML()                        */
      83             : /************************************************************************/
      84             : 
      85          25 : static bool CreateDerivedBandXML(VRTDataset *poVRTDS, GDALRasterBand *poSrcBand,
      86             :                                  GDALDataType eType, const std::string &noData,
      87             :                                  const std::string &method,
      88             :                                  const KernelDef &kernelDef)
      89             : {
      90          25 :     poVRTDS->AddBand(eType, nullptr);
      91             : 
      92          25 :     std::optional<double> dstNoData;
      93          25 :     bool autoSelectNoDataValue = false;
      94          25 :     if (noData.empty())
      95             :     {
      96          22 :         autoSelectNoDataValue = true;
      97             :     }
      98           3 :     else if (noData != "none")
      99             :     {
     100             :         // Already validated to be numeric by the validation action
     101           2 :         dstNoData = CPLAtof(noData.c_str());
     102             :     }
     103             : 
     104          25 :     auto poVRTBand = cpl::down_cast<VRTSourcedRasterBand *>(
     105             :         poVRTDS->GetRasterBand(poVRTDS->GetRasterCount()));
     106             : 
     107          50 :     auto poSource = std::make_unique<VRTKernelFilteredSource>();
     108          25 :     poSrcBand->GetDataset()->Reference();
     109          25 :     poSource->SetSrcBand(poSrcBand);
     110          25 :     poSource->SetKernel(kernelDef.size, /* separable = */ false,
     111          25 :                         kernelDef.adfCoefficients);
     112          25 :     poSource->SetNormalized(method != "sum");
     113          25 :     if (method != "sum" && method != "mean")
     114          10 :         poSource->SetFunction(method.c_str());
     115             : 
     116          25 :     int bSrcHasNoData = false;
     117          25 :     const double dfNoDataValue = poSrcBand->GetNoDataValue(&bSrcHasNoData);
     118          25 :     if (bSrcHasNoData)
     119             :     {
     120           4 :         poSource->SetNoDataValue(dfNoDataValue);
     121           4 :         if (autoSelectNoDataValue && !dstNoData.has_value())
     122             :         {
     123           2 :             dstNoData = dfNoDataValue;
     124             :         }
     125             :     }
     126             : 
     127          25 :     if (dstNoData.has_value())
     128             :     {
     129           4 :         if (!GDALIsValueExactAs(dstNoData.value(), eType))
     130             :         {
     131           1 :             CPLError(CE_Failure, CPLE_AppDefined,
     132             :                      "Band output type %s cannot represent NoData value %g",
     133           1 :                      GDALGetDataTypeName(eType), dstNoData.value());
     134           1 :             return false;
     135             :         }
     136             : 
     137           3 :         poVRTBand->SetNoDataValue(dstNoData.value());
     138             :     }
     139             : 
     140          24 :     poVRTBand->AddSource(std::move(poSource));
     141             : 
     142          24 :     return true;
     143             : }
     144             : 
     145             : /************************************************************************/
     146             : /*                      GDALNeighborsCreateVRTDerived()                 */
     147             : /************************************************************************/
     148             : 
     149             : static std::unique_ptr<GDALDataset>
     150          23 : GDALNeighborsCreateVRTDerived(GDALDataset *poSrcDS, int nBand,
     151             :                               GDALDataType eType, const std::string &noData,
     152             :                               const std::vector<std::string> &methods,
     153             :                               const std::vector<KernelDef> &aKernelDefs)
     154             : {
     155          23 :     CPLAssert(methods.size() == aKernelDefs.size());
     156             : 
     157           0 :     auto ds = std::make_unique<VRTDataset>(poSrcDS->GetRasterXSize(),
     158          46 :                                            poSrcDS->GetRasterYSize());
     159          23 :     GDALGeoTransform gt;
     160          23 :     if (poSrcDS->GetGeoTransform(gt) == CE_None)
     161          10 :         ds->SetGeoTransform(gt);
     162          23 :     if (const OGRSpatialReference *poSRS = poSrcDS->GetSpatialRef())
     163             :     {
     164          10 :         ds->SetSpatialRef(poSRS);
     165             :     }
     166             : 
     167          23 :     bool ret = true;
     168          48 :     for (size_t i = 0; i < aKernelDefs.size() && ret; ++i)
     169             :     {
     170          25 :         ret = CreateDerivedBandXML(ds.get(), poSrcDS->GetRasterBand(nBand),
     171          25 :                                    eType, noData, methods[i], aKernelDefs[i]);
     172             :     }
     173          23 :     if (!ret)
     174           1 :         ds.reset();
     175          46 :     return ds;
     176             : }
     177             : 
     178             : /************************************************************************/
     179             : /*       GDALRasterNeighborsAlgorithm::GDALRasterNeighborsAlgorithm()   */
     180             : /************************************************************************/
     181             : 
     182          76 : GDALRasterNeighborsAlgorithm::GDALRasterNeighborsAlgorithm(
     183          76 :     bool standaloneStep) noexcept
     184             :     : GDALRasterPipelineStepAlgorithm(
     185             :           NAME, DESCRIPTION, HELP_URL,
     186          76 :           ConstructorOptions().SetStandaloneStep(standaloneStep))
     187             : {
     188          76 :     AddBandArg(&m_band);
     189             : 
     190         152 :     AddArg("method", 0, _("Method to combine weighed source pixels"), &m_method)
     191          76 :         .SetChoices("mean", "sum", "min", "max", "stddev", "median", "mode");
     192             : 
     193         152 :     AddArg("size", 0, _("Neighborhood size"), &m_size)
     194          76 :         .SetMinValueIncluded(3)
     195          76 :         .SetMaxValueIncluded(99)
     196             :         .AddValidationAction(
     197          12 :             [this]()
     198             :             {
     199          11 :                 if ((m_size % 2) != 1)
     200             :                 {
     201           1 :                     ReportError(CE_Failure, CPLE_IllegalArg,
     202             :                                 "The value of 'size' must be an odd number.");
     203           1 :                     return false;
     204             :                 }
     205          10 :                 return true;
     206          76 :             });
     207             : 
     208         152 :     AddArg("kernel", 0, _("Convolution kernel(s) to apply"), &m_kernel)
     209          76 :         .SetPackedValuesAllowed(false)
     210          76 :         .SetMinCount(1)
     211          76 :         .SetMinCharCount(1)
     212          76 :         .SetRequired()
     213             :         .SetAutoCompleteFunction(
     214           2 :             [](const std::string &v)
     215             :             {
     216           2 :                 std::vector<std::string> ret;
     217           2 :                 if (v.empty() || v.front() != '[')
     218             :                 {
     219           1 :                     ret.insert(ret.end(), oSetKernelNames.begin(),
     220           2 :                                oSetKernelNames.end());
     221           1 :                     ret.push_back(
     222             :                         "[[val00,val10,...,valN0],...,[val0N,val1N,...valNN]]");
     223             :                 }
     224           2 :                 return ret;
     225         152 :             })
     226             :         .AddValidationAction(
     227          66 :             [this]()
     228             :             {
     229         124 :                 for (const std::string &kernel : m_kernel)
     230             :                 {
     231          66 :                     if (kernel.front() == '[' && kernel.back() == ']')
     232             :                     {
     233             :                         const CPLStringList aosValues(CSLTokenizeString2(
     234             :                             kernel.c_str(), "[] ,",
     235          11 :                             CSLT_STRIPLEADSPACES | CSLT_STRIPENDSPACES));
     236             :                         const double dfSize =
     237          11 :                             static_cast<double>(aosValues.size());
     238          11 :                         const int nSqrt = static_cast<int>(sqrt(dfSize) + 0.5);
     239          20 :                         if (!((aosValues.size() % 2) == 1 &&
     240           9 :                               nSqrt * nSqrt == aosValues.size()))
     241             :                         {
     242           2 :                             ReportError(
     243             :                                 CE_Failure, CPLE_IllegalArg,
     244             :                                 "The number of values in the 'kernel' "
     245             :                                 "argument must be an odd square number.");
     246           2 :                             return false;
     247             :                         }
     248          85 :                         for (int i = 0; i < aosValues.size(); ++i)
     249             :                         {
     250          77 :                             if (CPLGetValueType(aosValues[i]) ==
     251             :                                 CPL_VALUE_STRING)
     252             :                             {
     253           1 :                                 ReportError(CE_Failure, CPLE_IllegalArg,
     254             :                                             "Non-numeric value found in the "
     255             :                                             "'kernel' argument");
     256           1 :                                 return false;
     257             :                             }
     258             :                         }
     259             :                     }
     260          55 :                     else if (!cpl::contains(oSetKernelNames, kernel))
     261             :                     {
     262             :                         std::string osMsg =
     263           1 :                             "Valid values for 'kernel' argument are: ";
     264           1 :                         bool bFirst = true;
     265           9 :                         for (const auto &name : oSetKernelNames)
     266             :                         {
     267           8 :                             if (!bFirst)
     268           7 :                                 osMsg += ", ";
     269           8 :                             bFirst = false;
     270           8 :                             osMsg += '\'';
     271           8 :                             osMsg += name;
     272           8 :                             osMsg += '\'';
     273             :                         }
     274             :                         osMsg += " or "
     275             :                                  "[[val00,val10,...,valN0],...,[val0N,val1N,..."
     276           1 :                                  "valNN]]";
     277           1 :                         ReportError(CE_Failure, CPLE_IllegalArg, "%s",
     278             :                                     osMsg.c_str());
     279           1 :                         return false;
     280             :                     }
     281             :                 }
     282          58 :                 return true;
     283          76 :             });
     284             : 
     285          76 :     AddOutputDataTypeArg(&m_type).SetDefault("Float64");
     286             : 
     287          76 :     AddNodataArg(&m_nodata, true);
     288             : 
     289          76 :     AddValidationAction(
     290         156 :         [this]()
     291             :         {
     292          29 :             if (m_method.size() > 1 && m_method.size() != m_kernel.size())
     293             :             {
     294           1 :                 ReportError(
     295             :                     CE_Failure, CPLE_AppDefined,
     296             :                     "The number of values for the 'method' argument should "
     297             :                     "be one or exactly the number of values of 'kernel'");
     298           1 :                 return false;
     299             :             }
     300             : 
     301          28 :             if (m_band == 0 && !m_inputDataset.empty())
     302             :             {
     303          24 :                 auto poDS = m_inputDataset[0].GetDatasetRef();
     304          24 :                 if (poDS && poDS->GetRasterCount() > 1)
     305             :                 {
     306           1 :                     ReportError(
     307             :                         CE_Failure, CPLE_AppDefined,
     308             :                         "'band' argument should be specified given input "
     309             :                         "dataset has several bands.");
     310           1 :                     return false;
     311             :                 }
     312             :             }
     313             : 
     314          27 :             if (m_size > 0)
     315             :             {
     316           6 :                 for (const std::string &kernel : m_kernel)
     317             :                 {
     318           5 :                     if (kernel == "gaussian")
     319             :                     {
     320           2 :                         if (m_size != 3 && m_size != 5)
     321             :                         {
     322           1 :                             ReportError(CE_Failure, CPLE_AppDefined,
     323             :                                         "Currently only size = 3 or 5 is "
     324             :                                         "supported for kernel '%s'",
     325             :                                         kernel.c_str());
     326           4 :                             return false;
     327             :                         }
     328             :                     }
     329           3 :                     else if (kernel == "unsharp-masking")
     330             :                     {
     331           1 :                         if (m_size != 5)
     332             :                         {
     333           1 :                             ReportError(CE_Failure, CPLE_AppDefined,
     334             :                                         "Currently only size = 5 is supported "
     335             :                                         "for kernel '%s'",
     336             :                                         kernel.c_str());
     337           1 :                             return false;
     338             :                         }
     339             :                     }
     340           2 :                     else if (kernel[0] == '[')
     341             :                     {
     342             :                         const CPLStringList aosValues(CSLTokenizeString2(
     343             :                             kernel.c_str(), "[] ,",
     344           1 :                             CSLT_STRIPLEADSPACES | CSLT_STRIPENDSPACES));
     345             :                         const double dfSize =
     346           1 :                             static_cast<double>(aosValues.size());
     347             :                         const int size =
     348           1 :                             static_cast<int>(std::floor(sqrt(dfSize) + 0.5));
     349           1 :                         if (m_size != size)
     350             :                         {
     351           1 :                             ReportError(CE_Failure, CPLE_AppDefined,
     352             :                                         "Value of 'size' argument (%d) "
     353             :                                         "inconsistent with the one deduced "
     354             :                                         "from the kernel matrix (%d)",
     355             :                                         m_size, size);
     356           1 :                             return false;
     357             :                         }
     358             :                     }
     359           2 :                     else if (m_size != 3 && kernel != "equal" &&
     360           1 :                              kernel[0] != '[')
     361             :                     {
     362           1 :                         ReportError(CE_Failure, CPLE_AppDefined,
     363             :                                     "Currently only size = 3 is supported for "
     364             :                                     "kernel '%s'",
     365             :                                     kernel.c_str());
     366           1 :                         return false;
     367             :                     }
     368             :                 }
     369             :             }
     370             : 
     371          23 :             return true;
     372             :         });
     373          76 : }
     374             : 
     375             : /************************************************************************/
     376             : /*                               GetKernelDef()                         */
     377             : /************************************************************************/
     378             : 
     379          10 : static KernelDef GetKernelDef(const std::string &name, bool normalizeCoefs,
     380             :                               double weightIfNotNormalized)
     381             : {
     382          10 :     auto it = oMapKernelNameToMatrix.find(name);
     383          10 :     CPLAssert(it != oMapKernelNameToMatrix.end());
     384          10 :     KernelDef def;
     385          10 :     def.size = it->second.first;
     386          10 :     int nSum = 0;
     387         132 :     for (const int nVal : it->second.second)
     388         122 :         nSum += nVal;
     389             :     const double dfWeight = normalizeCoefs
     390          13 :                                 ? 1.0 / (static_cast<double>(nSum) +
     391           3 :                                          std::numeric_limits<double>::min())
     392          10 :                                 : weightIfNotNormalized;
     393         132 :     for (const int nVal : it->second.second)
     394             :     {
     395         122 :         def.adfCoefficients.push_back(nVal * dfWeight);
     396             :     }
     397          20 :     return def;
     398             : }
     399             : 
     400             : /************************************************************************/
     401             : /*                GDALRasterNeighborsAlgorithm::RunStep()               */
     402             : /************************************************************************/
     403             : 
     404          23 : bool GDALRasterNeighborsAlgorithm::RunStep(GDALPipelineStepRunContext &)
     405             : {
     406          23 :     auto poSrcDS = m_inputDataset[0].GetDatasetRef();
     407          23 :     CPLAssert(!m_outputDataset.GetDatasetRef());
     408             : 
     409          23 :     if (m_band == 0)
     410          23 :         m_band = 1;
     411          23 :     CPLAssert(m_band <= poSrcDS->GetRasterCount());
     412             : 
     413          23 :     auto eType = GDALGetDataTypeByName(m_type.c_str());
     414          23 :     if (eType == GDT_Unknown)
     415             :     {
     416           0 :         eType = GDT_Float64;
     417             :     }
     418             : 
     419          23 :     if (m_method.size() <= 1)
     420             :     {
     421          36 :         while (m_method.size() < m_kernel.size())
     422             :         {
     423          13 :             m_method.push_back(
     424          13 :                 m_method.empty()
     425          43 :                     ? ((m_kernel[0] == "u" || m_kernel[0] == "v" ||
     426          15 :                         m_kernel[0] == "edge1" || m_kernel[0] == "edge2")
     427             :                            ? "sum"
     428             :                            : "mean")
     429           2 :                     : m_method.back());
     430             :         }
     431             :     }
     432             : 
     433          23 :     if (m_size == 0 && m_kernel[0][0] != '[')
     434          20 :         m_size = m_kernel[0] == "unsharp-masking" ? 5 : 3;
     435             : 
     436          46 :     std::vector<KernelDef> aKernelDefs;
     437          23 :     size_t i = 0;
     438          48 :     for (std::string &kernel : m_kernel)
     439             :     {
     440          25 :         KernelDef def;
     441          25 :         if (kernel == "edge1" || kernel == "edge2" || kernel == "sharpen")
     442             :         {
     443           3 :             CPLAssert(m_size == 3);
     444           3 :             def = GetKernelDef(kernel, false, 1.0);
     445             :         }
     446          22 :         else if (kernel == "u" || kernel == "v")
     447             :         {
     448           4 :             CPLAssert(m_size == 3);
     449           4 :             def = GetKernelDef(kernel, false, 0.5);
     450             :         }
     451          18 :         else if (kernel == "equal")
     452             :         {
     453          12 :             def.size = m_size;
     454             :             const double dfWeight =
     455          12 :                 m_method[i] == "mean"
     456          13 :                     ? 1.0 / (static_cast<double>(m_size) * m_size +
     457           1 :                              std::numeric_limits<double>::min())
     458          12 :                     : 1.0;
     459          12 :             def.adfCoefficients.resize(static_cast<size_t>(m_size) * m_size,
     460             :                                        dfWeight);
     461             :         }
     462           6 :         else if (kernel == "gaussian")
     463             :         {
     464           2 :             CPLAssert(m_size == 3 || m_size == 5);
     465           4 :             def = GetKernelDef(m_size == 3 ? "gaussian-3x3" : "gaussian-5x5",
     466           2 :                                true, 0.0);
     467             :         }
     468           4 :         else if (kernel == "unsharp-masking")
     469             :         {
     470           1 :             CPLAssert(m_size == 5);
     471           1 :             def = GetKernelDef("unsharp-masking-5x5", true, 0.0);
     472             :         }
     473             :         else
     474             :         {
     475           3 :             CPLAssert(kernel.front() == '[');
     476             :             const CPLStringList aosValues(
     477             :                 CSLTokenizeString2(kernel.c_str(), "[] ,",
     478           6 :                                    CSLT_STRIPLEADSPACES | CSLT_STRIPENDSPACES));
     479           3 :             const double dfSize = static_cast<double>(aosValues.size());
     480           3 :             def.size = static_cast<int>(std::floor(sqrt(dfSize) + 0.5));
     481          30 :             for (const char *pszC : cpl::Iterate(aosValues))
     482             :             {
     483             :                 // Already validated to be numeric by the validation action
     484          27 :                 def.adfCoefficients.push_back(CPLAtof(pszC));
     485             :             }
     486             :         }
     487          25 :         aKernelDefs.push_back(std::move(def));
     488             : 
     489          25 :         ++i;
     490             :     }
     491             : 
     492          23 :     auto vrt = GDALNeighborsCreateVRTDerived(poSrcDS, m_band, eType, m_nodata,
     493          23 :                                              m_method, aKernelDefs);
     494          23 :     const bool ret = vrt != nullptr;
     495          23 :     if (vrt)
     496             :     {
     497          22 :         m_outputDataset.Set(std::move(vrt));
     498             :     }
     499          46 :     return ret;
     500             : }
     501             : 
     502             : GDALRasterNeighborsAlgorithmStandalone::
     503             :     ~GDALRasterNeighborsAlgorithmStandalone() = default;
     504             : 
     505             : //! @endcond

Generated by: LCOV version 1.14