Line data Source code
1 : /******************************************************************************
2 : *
3 : * Project: GDAL
4 : * Purpose: "neighbors" step of "raster pipeline"
5 : * Author: Even Rouault <even dot rouault at spatialys.com>
6 : *
7 : ******************************************************************************
8 : * Copyright (c) 2025, Even Rouault <even dot rouault at spatialys.com>
9 : *
10 : * SPDX-License-Identifier: MIT
11 : ****************************************************************************/
12 :
13 : #include "gdalalg_raster_neighbors.h"
14 :
15 : #include "gdal_priv.h"
16 : #include "gdal_priv_templates.hpp"
17 : #include "vrtdataset.h"
18 :
19 : #include <algorithm>
20 : #include <limits>
21 : #include <map>
22 : #include <optional>
23 : #include <set>
24 : #include <utility>
25 : #include <vector>
26 :
27 : //! @cond Doxygen_Suppress
28 :
29 : #ifndef _
30 : #define _(x) (x)
31 : #endif
32 :
33 : static const std::set<std::string> oSetKernelNames = {
34 : "u", "v", "equal", "edge1",
35 : "edge2", "sharpen", "gaussian", "unsharp-masking"};
36 :
37 : namespace
38 : {
39 : struct KernelDef
40 : {
41 : int size = 0;
42 : std::vector<double> adfCoefficients{};
43 : };
44 : } // namespace
45 :
46 : // clang-format off
47 : // Cf https://en.wikipedia.org/wiki/Kernel_(image_processing)
48 : static const std::map<std::string, std::pair<int, std::vector<int>>> oMapKernelNameToMatrix = {
49 : { "u", { 3, { 0, 0, 0,
50 : -1, 0, 1,
51 : 0, 0, 0 } } },
52 : { "v", { 3, { 0, -1, 0,
53 : 0, 0, 0,
54 : 0, 1, 0 } } },
55 : { "edge1", { 3, { 0, -1, 0,
56 : -1, 4, -1,
57 : 0, -1, 0 } } },
58 : { "edge2", { 3, { -1, -1, -1,
59 : -1, 8, -1,
60 : -1, -1, -1 } } },
61 : { "sharpen", { 3, { 0, -1, 0,
62 : -1, 5, -1,
63 : 0, -1, 0 } } },
64 : { "gaussian-3x3", { 3, { 1, 2, 1,
65 : 2, 4, 2,
66 : 1, 2, 1 } } },
67 : { "gaussian-5x5", { 5, { 1, 4, 6, 4, 1,
68 : 4, 16, 24, 16, 4,
69 : 6, 24, 36, 24, 6,
70 : 4, 16, 24, 16, 4,
71 : 1, 4, 6, 4, 1, } } },
72 : { "unsharp-masking-5x5", { 5, { 1, 4, 6, 4, 1,
73 : 4, 16, 24, 16, 4,
74 : 6, 24, -476, 24, 6,
75 : 4, 16, 24, 16, 4,
76 : 1, 4, 6, 4, 1, } } },
77 : };
78 :
79 : // clang-format on
80 :
81 : /************************************************************************/
82 : /* CreateDerivedBandXML() */
83 : /************************************************************************/
84 :
85 28 : static bool CreateDerivedBandXML(VRTDataset *poVRTDS, GDALRasterBand *poSrcBand,
86 : GDALDataType eType, const std::string &noData,
87 : const std::string &method,
88 : const KernelDef &kernelDef)
89 : {
90 28 : poVRTDS->AddBand(eType, nullptr);
91 :
92 28 : std::optional<double> dstNoData;
93 28 : bool autoSelectNoDataValue = false;
94 28 : if (noData.empty())
95 : {
96 25 : autoSelectNoDataValue = true;
97 : }
98 3 : else if (noData != "none")
99 : {
100 : // Already validated to be numeric by the validation action
101 2 : dstNoData = CPLAtof(noData.c_str());
102 : }
103 :
104 28 : auto poVRTBand = cpl::down_cast<VRTSourcedRasterBand *>(
105 : poVRTDS->GetRasterBand(poVRTDS->GetRasterCount()));
106 :
107 56 : auto poSource = std::make_unique<VRTKernelFilteredSource>();
108 28 : poSrcBand->GetDataset()->Reference();
109 28 : poSource->SetSrcBand(poSrcBand);
110 28 : poSource->SetKernel(kernelDef.size, /* separable = */ false,
111 28 : kernelDef.adfCoefficients);
112 28 : poSource->SetNormalized(method != "sum");
113 28 : if (method != "sum" && method != "mean")
114 10 : poSource->SetFunction(method.c_str());
115 :
116 28 : int bSrcHasNoData = false;
117 28 : const double dfNoDataValue = poSrcBand->GetNoDataValue(&bSrcHasNoData);
118 28 : if (bSrcHasNoData)
119 : {
120 4 : poSource->SetNoDataValue(dfNoDataValue);
121 4 : if (autoSelectNoDataValue && !dstNoData.has_value())
122 : {
123 2 : dstNoData = dfNoDataValue;
124 : }
125 : }
126 :
127 28 : if (dstNoData.has_value())
128 : {
129 4 : if (!GDALIsValueExactAs(dstNoData.value(), eType))
130 : {
131 1 : CPLError(CE_Failure, CPLE_AppDefined,
132 : "Band output type %s cannot represent NoData value %g",
133 1 : GDALGetDataTypeName(eType), dstNoData.value());
134 1 : return false;
135 : }
136 :
137 3 : poVRTBand->SetNoDataValue(dstNoData.value());
138 : }
139 :
140 27 : poVRTBand->AddSource(std::move(poSource));
141 :
142 27 : return true;
143 : }
144 :
145 : /************************************************************************/
146 : /* GDALNeighborsCreateVRTDerived() */
147 : /************************************************************************/
148 :
149 : static std::unique_ptr<GDALDataset>
150 24 : GDALNeighborsCreateVRTDerived(GDALDataset *poSrcDS, int nBand,
151 : GDALDataType eType, const std::string &noData,
152 : const std::vector<std::string> &methods,
153 : const std::vector<KernelDef> &aKernelDefs)
154 : {
155 24 : CPLAssert(methods.size() == aKernelDefs.size());
156 :
157 0 : auto ds = std::make_unique<VRTDataset>(poSrcDS->GetRasterXSize(),
158 48 : poSrcDS->GetRasterYSize());
159 24 : GDALGeoTransform gt;
160 24 : if (poSrcDS->GetGeoTransform(gt) == CE_None)
161 11 : ds->SetGeoTransform(gt);
162 24 : if (const OGRSpatialReference *poSRS = poSrcDS->GetSpatialRef())
163 : {
164 11 : ds->SetSpatialRef(poSRS);
165 : }
166 :
167 24 : bool ret = true;
168 24 : if (nBand != 0)
169 : {
170 0 : for (size_t i = 0; i < aKernelDefs.size() && ret; ++i)
171 : {
172 : ret =
173 0 : CreateDerivedBandXML(ds.get(), poSrcDS->GetRasterBand(nBand),
174 0 : eType, noData, methods[i], aKernelDefs[i]);
175 : }
176 : }
177 : else
178 : {
179 50 : for (int iBand = 1; iBand <= poSrcDS->GetRasterCount(); ++iBand)
180 : {
181 54 : for (size_t i = 0; i < aKernelDefs.size() && ret; ++i)
182 : {
183 28 : ret = CreateDerivedBandXML(ds.get(),
184 : poSrcDS->GetRasterBand(iBand), eType,
185 28 : noData, methods[i], aKernelDefs[i]);
186 : }
187 : }
188 : }
189 24 : if (!ret)
190 1 : ds.reset();
191 48 : return ds;
192 : }
193 :
194 : /************************************************************************/
195 : /* GDALRasterNeighborsAlgorithm::GDALRasterNeighborsAlgorithm() */
196 : /************************************************************************/
197 :
198 76 : GDALRasterNeighborsAlgorithm::GDALRasterNeighborsAlgorithm(
199 76 : bool standaloneStep) noexcept
200 : : GDALRasterPipelineStepAlgorithm(
201 : NAME, DESCRIPTION, HELP_URL,
202 76 : ConstructorOptions().SetStandaloneStep(standaloneStep))
203 : {
204 76 : AddBandArg(&m_band);
205 :
206 152 : AddArg("method", 0, _("Method to combine weighed source pixels"), &m_method)
207 76 : .SetChoices("mean", "sum", "min", "max", "stddev", "median", "mode");
208 :
209 152 : AddArg("size", 0, _("Neighborhood size"), &m_size)
210 76 : .SetMinValueIncluded(3)
211 76 : .SetMaxValueIncluded(99)
212 : .AddValidationAction(
213 12 : [this]()
214 : {
215 11 : if ((m_size % 2) != 1)
216 : {
217 1 : ReportError(CE_Failure, CPLE_IllegalArg,
218 : "The value of 'size' must be an odd number.");
219 1 : return false;
220 : }
221 10 : return true;
222 76 : });
223 :
224 152 : AddArg("kernel", 0, _("Convolution kernel(s) to apply"), &m_kernel)
225 76 : .SetPackedValuesAllowed(false)
226 76 : .SetMinCount(1)
227 76 : .SetMinCharCount(1)
228 76 : .SetRequired()
229 : .SetAutoCompleteFunction(
230 2 : [](const std::string &v)
231 : {
232 2 : std::vector<std::string> ret;
233 2 : if (v.empty() || v.front() != '[')
234 : {
235 1 : ret.insert(ret.end(), oSetKernelNames.begin(),
236 2 : oSetKernelNames.end());
237 1 : ret.push_back(
238 : "[[val00,val10,...,valN0],...,[val0N,val1N,...valNN]]");
239 : }
240 2 : return ret;
241 152 : })
242 : .AddValidationAction(
243 66 : [this]()
244 : {
245 124 : for (const std::string &kernel : m_kernel)
246 : {
247 66 : if (kernel.front() == '[' && kernel.back() == ']')
248 : {
249 : const CPLStringList aosValues(CSLTokenizeString2(
250 : kernel.c_str(), "[] ,",
251 11 : CSLT_STRIPLEADSPACES | CSLT_STRIPENDSPACES));
252 : const double dfSize =
253 11 : static_cast<double>(aosValues.size());
254 11 : const int nSqrt = static_cast<int>(sqrt(dfSize) + 0.5);
255 20 : if (!((aosValues.size() % 2) == 1 &&
256 9 : nSqrt * nSqrt == aosValues.size()))
257 : {
258 2 : ReportError(
259 : CE_Failure, CPLE_IllegalArg,
260 : "The number of values in the 'kernel' "
261 : "argument must be an odd square number.");
262 2 : return false;
263 : }
264 85 : for (int i = 0; i < aosValues.size(); ++i)
265 : {
266 77 : if (CPLGetValueType(aosValues[i]) ==
267 : CPL_VALUE_STRING)
268 : {
269 1 : ReportError(CE_Failure, CPLE_IllegalArg,
270 : "Non-numeric value found in the "
271 : "'kernel' argument");
272 1 : return false;
273 : }
274 : }
275 : }
276 55 : else if (!cpl::contains(oSetKernelNames, kernel))
277 : {
278 : std::string osMsg =
279 1 : "Valid values for 'kernel' argument are: ";
280 1 : bool bFirst = true;
281 9 : for (const auto &name : oSetKernelNames)
282 : {
283 8 : if (!bFirst)
284 7 : osMsg += ", ";
285 8 : bFirst = false;
286 8 : osMsg += '\'';
287 8 : osMsg += name;
288 8 : osMsg += '\'';
289 : }
290 : osMsg += " or "
291 : "[[val00,val10,...,valN0],...,[val0N,val1N,..."
292 1 : "valNN]]";
293 1 : ReportError(CE_Failure, CPLE_IllegalArg, "%s",
294 : osMsg.c_str());
295 1 : return false;
296 : }
297 : }
298 58 : return true;
299 76 : });
300 :
301 76 : AddOutputDataTypeArg(&m_type).SetDefault("Float64");
302 :
303 76 : AddNodataArg(&m_nodata, true);
304 :
305 76 : AddValidationAction(
306 76 : [this]()
307 : {
308 29 : if (m_method.size() > 1 && m_method.size() != m_kernel.size())
309 : {
310 1 : ReportError(
311 : CE_Failure, CPLE_AppDefined,
312 : "The number of values for the 'method' argument should "
313 : "be one or exactly the number of values of 'kernel'");
314 1 : return false;
315 : }
316 :
317 28 : if (m_size > 0)
318 : {
319 6 : for (const std::string &kernel : m_kernel)
320 : {
321 5 : if (kernel == "gaussian")
322 : {
323 2 : if (m_size != 3 && m_size != 5)
324 : {
325 1 : ReportError(CE_Failure, CPLE_AppDefined,
326 : "Currently only size = 3 or 5 is "
327 : "supported for kernel '%s'",
328 : kernel.c_str());
329 4 : return false;
330 : }
331 : }
332 3 : else if (kernel == "unsharp-masking")
333 : {
334 1 : if (m_size != 5)
335 : {
336 1 : ReportError(CE_Failure, CPLE_AppDefined,
337 : "Currently only size = 5 is supported "
338 : "for kernel '%s'",
339 : kernel.c_str());
340 1 : return false;
341 : }
342 : }
343 2 : else if (kernel[0] == '[')
344 : {
345 : const CPLStringList aosValues(CSLTokenizeString2(
346 : kernel.c_str(), "[] ,",
347 1 : CSLT_STRIPLEADSPACES | CSLT_STRIPENDSPACES));
348 : const double dfSize =
349 1 : static_cast<double>(aosValues.size());
350 : const int size =
351 1 : static_cast<int>(std::floor(sqrt(dfSize) + 0.5));
352 1 : if (m_size != size)
353 : {
354 1 : ReportError(CE_Failure, CPLE_AppDefined,
355 : "Value of 'size' argument (%d) "
356 : "inconsistent with the one deduced "
357 : "from the kernel matrix (%d)",
358 : m_size, size);
359 1 : return false;
360 : }
361 : }
362 2 : else if (m_size != 3 && kernel != "equal" &&
363 1 : kernel[0] != '[')
364 : {
365 1 : ReportError(CE_Failure, CPLE_AppDefined,
366 : "Currently only size = 3 is supported for "
367 : "kernel '%s'",
368 : kernel.c_str());
369 1 : return false;
370 : }
371 : }
372 : }
373 :
374 24 : return true;
375 : });
376 76 : }
377 :
378 : /************************************************************************/
379 : /* GetKernelDef() */
380 : /************************************************************************/
381 :
382 11 : static KernelDef GetKernelDef(const std::string &name, bool normalizeCoefs,
383 : double weightIfNotNormalized)
384 : {
385 11 : auto it = oMapKernelNameToMatrix.find(name);
386 11 : CPLAssert(it != oMapKernelNameToMatrix.end());
387 11 : KernelDef def;
388 11 : def.size = it->second.first;
389 11 : int nSum = 0;
390 142 : for (const int nVal : it->second.second)
391 131 : nSum += nVal;
392 : const double dfWeight = normalizeCoefs
393 14 : ? 1.0 / (static_cast<double>(nSum) +
394 3 : std::numeric_limits<double>::min())
395 11 : : weightIfNotNormalized;
396 142 : for (const int nVal : it->second.second)
397 : {
398 131 : def.adfCoefficients.push_back(nVal * dfWeight);
399 : }
400 22 : return def;
401 : }
402 :
403 : /************************************************************************/
404 : /* GDALRasterNeighborsAlgorithm::RunStep() */
405 : /************************************************************************/
406 :
407 24 : bool GDALRasterNeighborsAlgorithm::RunStep(GDALPipelineStepRunContext &)
408 : {
409 24 : auto poSrcDS = m_inputDataset[0].GetDatasetRef();
410 24 : CPLAssert(!m_outputDataset.GetDatasetRef());
411 :
412 24 : CPLAssert(m_band <= poSrcDS->GetRasterCount());
413 :
414 24 : auto eType = GDALGetDataTypeByName(m_type.c_str());
415 24 : if (eType == GDT_Unknown)
416 : {
417 0 : eType = GDT_Float64;
418 : }
419 :
420 24 : if (m_method.size() <= 1)
421 : {
422 38 : while (m_method.size() < m_kernel.size())
423 : {
424 14 : m_method.push_back(
425 14 : m_method.empty()
426 47 : ? ((m_kernel[0] == "u" || m_kernel[0] == "v" ||
427 17 : m_kernel[0] == "edge1" || m_kernel[0] == "edge2")
428 : ? "sum"
429 : : "mean")
430 2 : : m_method.back());
431 : }
432 : }
433 :
434 24 : if (m_size == 0 && m_kernel[0][0] != '[')
435 21 : m_size = m_kernel[0] == "unsharp-masking" ? 5 : 3;
436 :
437 48 : std::vector<KernelDef> aKernelDefs;
438 24 : size_t i = 0;
439 50 : for (std::string &kernel : m_kernel)
440 : {
441 26 : KernelDef def;
442 26 : if (kernel == "edge1" || kernel == "edge2" || kernel == "sharpen")
443 : {
444 4 : CPLAssert(m_size == 3);
445 4 : def = GetKernelDef(kernel, false, 1.0);
446 : }
447 22 : else if (kernel == "u" || kernel == "v")
448 : {
449 4 : CPLAssert(m_size == 3);
450 4 : def = GetKernelDef(kernel, false, 0.5);
451 : }
452 18 : else if (kernel == "equal")
453 : {
454 12 : def.size = m_size;
455 : const double dfWeight =
456 12 : m_method[i] == "mean"
457 13 : ? 1.0 / (static_cast<double>(m_size) * m_size +
458 1 : std::numeric_limits<double>::min())
459 12 : : 1.0;
460 12 : def.adfCoefficients.resize(static_cast<size_t>(m_size) * m_size,
461 : dfWeight);
462 : }
463 6 : else if (kernel == "gaussian")
464 : {
465 2 : CPLAssert(m_size == 3 || m_size == 5);
466 4 : def = GetKernelDef(m_size == 3 ? "gaussian-3x3" : "gaussian-5x5",
467 2 : true, 0.0);
468 : }
469 4 : else if (kernel == "unsharp-masking")
470 : {
471 1 : CPLAssert(m_size == 5);
472 1 : def = GetKernelDef("unsharp-masking-5x5", true, 0.0);
473 : }
474 : else
475 : {
476 3 : CPLAssert(kernel.front() == '[');
477 : const CPLStringList aosValues(
478 : CSLTokenizeString2(kernel.c_str(), "[] ,",
479 6 : CSLT_STRIPLEADSPACES | CSLT_STRIPENDSPACES));
480 3 : const double dfSize = static_cast<double>(aosValues.size());
481 3 : def.size = static_cast<int>(std::floor(sqrt(dfSize) + 0.5));
482 30 : for (const char *pszC : cpl::Iterate(aosValues))
483 : {
484 : // Already validated to be numeric by the validation action
485 27 : def.adfCoefficients.push_back(CPLAtof(pszC));
486 : }
487 : }
488 26 : aKernelDefs.push_back(std::move(def));
489 :
490 26 : ++i;
491 : }
492 :
493 24 : auto vrt = GDALNeighborsCreateVRTDerived(poSrcDS, m_band, eType, m_nodata,
494 24 : m_method, aKernelDefs);
495 24 : const bool ret = vrt != nullptr;
496 24 : if (vrt)
497 : {
498 23 : m_outputDataset.Set(std::move(vrt));
499 : }
500 48 : return ret;
501 : }
502 :
503 : GDALRasterNeighborsAlgorithmStandalone::
504 : ~GDALRasterNeighborsAlgorithmStandalone() = default;
505 :
506 : //! @endcond
|