Line data Source code
1 : /******************************************************************************
2 : *
3 : * Project: GDAL
4 : * Purpose: "neighbors" step of "raster pipeline"
5 : * Author: Even Rouault <even dot rouault at spatialys.com>
6 : *
7 : ******************************************************************************
8 : * Copyright (c) 2025, Even Rouault <even dot rouault at spatialys.com>
9 : *
10 : * SPDX-License-Identifier: MIT
11 : ****************************************************************************/
12 :
13 : #include "gdalalg_raster_neighbors.h"
14 :
15 : #include "gdal_priv.h"
16 : #include "gdal_priv_templates.hpp"
17 : #include "vrtdataset.h"
18 :
19 : #include <algorithm>
20 : #include <limits>
21 : #include <map>
22 : #include <optional>
23 : #include <set>
24 : #include <utility>
25 : #include <vector>
26 :
27 : //! @cond Doxygen_Suppress
28 :
29 : #ifndef _
30 : #define _(x) (x)
31 : #endif
32 :
33 : static const std::set<std::string> oSetKernelNames = {
34 : "u", "v", "equal", "edge1",
35 : "edge2", "sharpen", "gaussian", "unsharp-masking"};
36 :
37 : namespace
38 : {
39 : struct KernelDef
40 : {
41 : int size = 0;
42 : std::vector<double> adfCoefficients{};
43 : };
44 : } // namespace
45 :
46 : // clang-format off
47 : // Cf https://en.wikipedia.org/wiki/Kernel_(image_processing)
48 : static const std::map<std::string, std::pair<int, std::vector<int>>> oMapKernelNameToMatrix = {
49 : { "u", { 3, { 0, 0, 0,
50 : -1, 0, 1,
51 : 0, 0, 0 } } },
52 : { "v", { 3, { 0, -1, 0,
53 : 0, 0, 0,
54 : 0, 1, 0 } } },
55 : { "edge1", { 3, { 0, -1, 0,
56 : -1, 4, -1,
57 : 0, -1, 0 } } },
58 : { "edge2", { 3, { -1, -1, -1,
59 : -1, 8, -1,
60 : -1, -1, -1 } } },
61 : { "sharpen", { 3, { 0, -1, 0,
62 : -1, 5, -1,
63 : 0, -1, 0 } } },
64 : { "gaussian-3x3", { 3, { 1, 2, 1,
65 : 2, 4, 2,
66 : 1, 2, 1 } } },
67 : { "gaussian-5x5", { 5, { 1, 4, 6, 4, 1,
68 : 4, 16, 24, 16, 4,
69 : 6, 24, 36, 24, 6,
70 : 4, 16, 24, 16, 4,
71 : 1, 4, 6, 4, 1, } } },
72 : { "unsharp-masking-5x5", { 5, { 1, 4, 6, 4, 1,
73 : 4, 16, 24, 16, 4,
74 : 6, 24, -476, 24, 6,
75 : 4, 16, 24, 16, 4,
76 : 1, 4, 6, 4, 1, } } },
77 : };
78 :
79 : // clang-format on
80 :
81 : /************************************************************************/
82 : /* CreateDerivedBandXML() */
83 : /************************************************************************/
84 :
85 25 : static bool CreateDerivedBandXML(VRTDataset *poVRTDS, GDALRasterBand *poSrcBand,
86 : GDALDataType eType, const std::string &noData,
87 : const std::string &method,
88 : const KernelDef &kernelDef)
89 : {
90 25 : poVRTDS->AddBand(eType, nullptr);
91 :
92 25 : std::optional<double> dstNoData;
93 25 : bool autoSelectNoDataValue = false;
94 25 : if (noData.empty())
95 : {
96 22 : autoSelectNoDataValue = true;
97 : }
98 3 : else if (noData != "none")
99 : {
100 : // Already validated to be numeric by the validation action
101 2 : dstNoData = CPLAtof(noData.c_str());
102 : }
103 :
104 25 : auto poVRTBand = cpl::down_cast<VRTSourcedRasterBand *>(
105 : poVRTDS->GetRasterBand(poVRTDS->GetRasterCount()));
106 :
107 50 : auto poSource = std::make_unique<VRTKernelFilteredSource>();
108 25 : poSrcBand->GetDataset()->Reference();
109 25 : poSource->SetSrcBand(poSrcBand);
110 25 : poSource->SetKernel(kernelDef.size, /* separable = */ false,
111 25 : kernelDef.adfCoefficients);
112 25 : poSource->SetNormalized(method != "sum");
113 25 : if (method != "sum" && method != "mean")
114 10 : poSource->SetFunction(method.c_str());
115 :
116 25 : int bSrcHasNoData = false;
117 25 : const double dfNoDataValue = poSrcBand->GetNoDataValue(&bSrcHasNoData);
118 25 : if (bSrcHasNoData)
119 : {
120 4 : poSource->SetNoDataValue(dfNoDataValue);
121 4 : if (autoSelectNoDataValue && !dstNoData.has_value())
122 : {
123 2 : dstNoData = dfNoDataValue;
124 : }
125 : }
126 :
127 25 : if (dstNoData.has_value())
128 : {
129 4 : if (!GDALIsValueExactAs(dstNoData.value(), eType))
130 : {
131 1 : CPLError(CE_Failure, CPLE_AppDefined,
132 : "Band output type %s cannot represent NoData value %g",
133 1 : GDALGetDataTypeName(eType), dstNoData.value());
134 1 : return false;
135 : }
136 :
137 3 : poVRTBand->SetNoDataValue(dstNoData.value());
138 : }
139 :
140 24 : poVRTBand->AddSource(std::move(poSource));
141 :
142 24 : return true;
143 : }
144 :
145 : /************************************************************************/
146 : /* GDALNeighborsCreateVRTDerived() */
147 : /************************************************************************/
148 :
149 : static std::unique_ptr<GDALDataset>
150 23 : GDALNeighborsCreateVRTDerived(GDALDataset *poSrcDS, int nBand,
151 : GDALDataType eType, const std::string &noData,
152 : const std::vector<std::string> &methods,
153 : const std::vector<KernelDef> &aKernelDefs)
154 : {
155 23 : CPLAssert(methods.size() == aKernelDefs.size());
156 :
157 0 : auto ds = std::make_unique<VRTDataset>(poSrcDS->GetRasterXSize(),
158 46 : poSrcDS->GetRasterYSize());
159 23 : GDALGeoTransform gt;
160 23 : if (poSrcDS->GetGeoTransform(gt) == CE_None)
161 10 : ds->SetGeoTransform(gt);
162 23 : if (const OGRSpatialReference *poSRS = poSrcDS->GetSpatialRef())
163 : {
164 10 : ds->SetSpatialRef(poSRS);
165 : }
166 :
167 23 : bool ret = true;
168 48 : for (size_t i = 0; i < aKernelDefs.size() && ret; ++i)
169 : {
170 25 : ret = CreateDerivedBandXML(ds.get(), poSrcDS->GetRasterBand(nBand),
171 25 : eType, noData, methods[i], aKernelDefs[i]);
172 : }
173 23 : if (!ret)
174 1 : ds.reset();
175 46 : return ds;
176 : }
177 :
178 : /************************************************************************/
179 : /* GDALRasterNeighborsAlgorithm::GDALRasterNeighborsAlgorithm() */
180 : /************************************************************************/
181 :
182 76 : GDALRasterNeighborsAlgorithm::GDALRasterNeighborsAlgorithm(
183 76 : bool standaloneStep) noexcept
184 : : GDALRasterPipelineStepAlgorithm(
185 : NAME, DESCRIPTION, HELP_URL,
186 76 : ConstructorOptions().SetStandaloneStep(standaloneStep))
187 : {
188 76 : AddBandArg(&m_band);
189 :
190 152 : AddArg("method", 0, _("Method to combine weighed source pixels"), &m_method)
191 76 : .SetChoices("mean", "sum", "min", "max", "stddev", "median", "mode");
192 :
193 152 : AddArg("size", 0, _("Neighborhood size"), &m_size)
194 76 : .SetMinValueIncluded(3)
195 76 : .SetMaxValueIncluded(99)
196 : .AddValidationAction(
197 12 : [this]()
198 : {
199 11 : if ((m_size % 2) != 1)
200 : {
201 1 : ReportError(CE_Failure, CPLE_IllegalArg,
202 : "The value of 'size' must be an odd number.");
203 1 : return false;
204 : }
205 10 : return true;
206 76 : });
207 :
208 152 : AddArg("kernel", 0, _("Convolution kernel(s) to apply"), &m_kernel)
209 76 : .SetPackedValuesAllowed(false)
210 76 : .SetMinCount(1)
211 76 : .SetMinCharCount(1)
212 76 : .SetRequired()
213 : .SetAutoCompleteFunction(
214 2 : [](const std::string &v)
215 : {
216 2 : std::vector<std::string> ret;
217 2 : if (v.empty() || v.front() != '[')
218 : {
219 1 : ret.insert(ret.end(), oSetKernelNames.begin(),
220 2 : oSetKernelNames.end());
221 1 : ret.push_back(
222 : "[[val00,val10,...,valN0],...,[val0N,val1N,...valNN]]");
223 : }
224 2 : return ret;
225 152 : })
226 : .AddValidationAction(
227 66 : [this]()
228 : {
229 124 : for (const std::string &kernel : m_kernel)
230 : {
231 66 : if (kernel.front() == '[' && kernel.back() == ']')
232 : {
233 : const CPLStringList aosValues(CSLTokenizeString2(
234 : kernel.c_str(), "[] ,",
235 11 : CSLT_STRIPLEADSPACES | CSLT_STRIPENDSPACES));
236 : const double dfSize =
237 11 : static_cast<double>(aosValues.size());
238 11 : const int nSqrt = static_cast<int>(sqrt(dfSize) + 0.5);
239 20 : if (!((aosValues.size() % 2) == 1 &&
240 9 : nSqrt * nSqrt == aosValues.size()))
241 : {
242 2 : ReportError(
243 : CE_Failure, CPLE_IllegalArg,
244 : "The number of values in the 'kernel' "
245 : "argument must be an odd square number.");
246 2 : return false;
247 : }
248 85 : for (int i = 0; i < aosValues.size(); ++i)
249 : {
250 77 : if (CPLGetValueType(aosValues[i]) ==
251 : CPL_VALUE_STRING)
252 : {
253 1 : ReportError(CE_Failure, CPLE_IllegalArg,
254 : "Non-numeric value found in the "
255 : "'kernel' argument");
256 1 : return false;
257 : }
258 : }
259 : }
260 55 : else if (!cpl::contains(oSetKernelNames, kernel))
261 : {
262 : std::string osMsg =
263 1 : "Valid values for 'kernel' argument are: ";
264 1 : bool bFirst = true;
265 9 : for (const auto &name : oSetKernelNames)
266 : {
267 8 : if (!bFirst)
268 7 : osMsg += ", ";
269 8 : bFirst = false;
270 8 : osMsg += '\'';
271 8 : osMsg += name;
272 8 : osMsg += '\'';
273 : }
274 : osMsg += " or "
275 : "[[val00,val10,...,valN0],...,[val0N,val1N,..."
276 1 : "valNN]]";
277 1 : ReportError(CE_Failure, CPLE_IllegalArg, "%s",
278 : osMsg.c_str());
279 1 : return false;
280 : }
281 : }
282 58 : return true;
283 76 : });
284 :
285 76 : AddOutputDataTypeArg(&m_type).SetDefault("Float64");
286 :
287 76 : AddNodataArg(&m_nodata, true);
288 :
289 76 : AddValidationAction(
290 156 : [this]()
291 : {
292 29 : if (m_method.size() > 1 && m_method.size() != m_kernel.size())
293 : {
294 1 : ReportError(
295 : CE_Failure, CPLE_AppDefined,
296 : "The number of values for the 'method' argument should "
297 : "be one or exactly the number of values of 'kernel'");
298 1 : return false;
299 : }
300 :
301 28 : if (m_band == 0 && !m_inputDataset.empty())
302 : {
303 24 : auto poDS = m_inputDataset[0].GetDatasetRef();
304 24 : if (poDS && poDS->GetRasterCount() > 1)
305 : {
306 1 : ReportError(
307 : CE_Failure, CPLE_AppDefined,
308 : "'band' argument should be specified given input "
309 : "dataset has several bands.");
310 1 : return false;
311 : }
312 : }
313 :
314 27 : if (m_size > 0)
315 : {
316 6 : for (const std::string &kernel : m_kernel)
317 : {
318 5 : if (kernel == "gaussian")
319 : {
320 2 : if (m_size != 3 && m_size != 5)
321 : {
322 1 : ReportError(CE_Failure, CPLE_AppDefined,
323 : "Currently only size = 3 or 5 is "
324 : "supported for kernel '%s'",
325 : kernel.c_str());
326 4 : return false;
327 : }
328 : }
329 3 : else if (kernel == "unsharp-masking")
330 : {
331 1 : if (m_size != 5)
332 : {
333 1 : ReportError(CE_Failure, CPLE_AppDefined,
334 : "Currently only size = 5 is supported "
335 : "for kernel '%s'",
336 : kernel.c_str());
337 1 : return false;
338 : }
339 : }
340 2 : else if (kernel[0] == '[')
341 : {
342 : const CPLStringList aosValues(CSLTokenizeString2(
343 : kernel.c_str(), "[] ,",
344 1 : CSLT_STRIPLEADSPACES | CSLT_STRIPENDSPACES));
345 : const double dfSize =
346 1 : static_cast<double>(aosValues.size());
347 : const int size =
348 1 : static_cast<int>(std::floor(sqrt(dfSize) + 0.5));
349 1 : if (m_size != size)
350 : {
351 1 : ReportError(CE_Failure, CPLE_AppDefined,
352 : "Value of 'size' argument (%d) "
353 : "inconsistent with the one deduced "
354 : "from the kernel matrix (%d)",
355 : m_size, size);
356 1 : return false;
357 : }
358 : }
359 2 : else if (m_size != 3 && kernel != "equal" &&
360 1 : kernel[0] != '[')
361 : {
362 1 : ReportError(CE_Failure, CPLE_AppDefined,
363 : "Currently only size = 3 is supported for "
364 : "kernel '%s'",
365 : kernel.c_str());
366 1 : return false;
367 : }
368 : }
369 : }
370 :
371 23 : return true;
372 : });
373 76 : }
374 :
375 : /************************************************************************/
376 : /* GetKernelDef() */
377 : /************************************************************************/
378 :
379 10 : static KernelDef GetKernelDef(const std::string &name, bool normalizeCoefs,
380 : double weightIfNotNormalized)
381 : {
382 10 : auto it = oMapKernelNameToMatrix.find(name);
383 10 : CPLAssert(it != oMapKernelNameToMatrix.end());
384 10 : KernelDef def;
385 10 : def.size = it->second.first;
386 10 : int nSum = 0;
387 132 : for (const int nVal : it->second.second)
388 122 : nSum += nVal;
389 : const double dfWeight = normalizeCoefs
390 13 : ? 1.0 / (static_cast<double>(nSum) +
391 3 : std::numeric_limits<double>::min())
392 10 : : weightIfNotNormalized;
393 132 : for (const int nVal : it->second.second)
394 : {
395 122 : def.adfCoefficients.push_back(nVal * dfWeight);
396 : }
397 20 : return def;
398 : }
399 :
400 : /************************************************************************/
401 : /* GDALRasterNeighborsAlgorithm::RunStep() */
402 : /************************************************************************/
403 :
404 23 : bool GDALRasterNeighborsAlgorithm::RunStep(GDALPipelineStepRunContext &)
405 : {
406 23 : auto poSrcDS = m_inputDataset[0].GetDatasetRef();
407 23 : CPLAssert(!m_outputDataset.GetDatasetRef());
408 :
409 23 : if (m_band == 0)
410 23 : m_band = 1;
411 23 : CPLAssert(m_band <= poSrcDS->GetRasterCount());
412 :
413 23 : auto eType = GDALGetDataTypeByName(m_type.c_str());
414 23 : if (eType == GDT_Unknown)
415 : {
416 0 : eType = GDT_Float64;
417 : }
418 :
419 23 : if (m_method.size() <= 1)
420 : {
421 36 : while (m_method.size() < m_kernel.size())
422 : {
423 13 : m_method.push_back(
424 13 : m_method.empty()
425 43 : ? ((m_kernel[0] == "u" || m_kernel[0] == "v" ||
426 15 : m_kernel[0] == "edge1" || m_kernel[0] == "edge2")
427 : ? "sum"
428 : : "mean")
429 2 : : m_method.back());
430 : }
431 : }
432 :
433 23 : if (m_size == 0 && m_kernel[0][0] != '[')
434 20 : m_size = m_kernel[0] == "unsharp-masking" ? 5 : 3;
435 :
436 46 : std::vector<KernelDef> aKernelDefs;
437 23 : size_t i = 0;
438 48 : for (std::string &kernel : m_kernel)
439 : {
440 25 : KernelDef def;
441 25 : if (kernel == "edge1" || kernel == "edge2" || kernel == "sharpen")
442 : {
443 3 : CPLAssert(m_size == 3);
444 3 : def = GetKernelDef(kernel, false, 1.0);
445 : }
446 22 : else if (kernel == "u" || kernel == "v")
447 : {
448 4 : CPLAssert(m_size == 3);
449 4 : def = GetKernelDef(kernel, false, 0.5);
450 : }
451 18 : else if (kernel == "equal")
452 : {
453 12 : def.size = m_size;
454 : const double dfWeight =
455 12 : m_method[i] == "mean"
456 13 : ? 1.0 / (static_cast<double>(m_size) * m_size +
457 1 : std::numeric_limits<double>::min())
458 12 : : 1.0;
459 12 : def.adfCoefficients.resize(static_cast<size_t>(m_size) * m_size,
460 : dfWeight);
461 : }
462 6 : else if (kernel == "gaussian")
463 : {
464 2 : CPLAssert(m_size == 3 || m_size == 5);
465 4 : def = GetKernelDef(m_size == 3 ? "gaussian-3x3" : "gaussian-5x5",
466 2 : true, 0.0);
467 : }
468 4 : else if (kernel == "unsharp-masking")
469 : {
470 1 : CPLAssert(m_size == 5);
471 1 : def = GetKernelDef("unsharp-masking-5x5", true, 0.0);
472 : }
473 : else
474 : {
475 3 : CPLAssert(kernel.front() == '[');
476 : const CPLStringList aosValues(
477 : CSLTokenizeString2(kernel.c_str(), "[] ,",
478 6 : CSLT_STRIPLEADSPACES | CSLT_STRIPENDSPACES));
479 3 : const double dfSize = static_cast<double>(aosValues.size());
480 3 : def.size = static_cast<int>(std::floor(sqrt(dfSize) + 0.5));
481 30 : for (const char *pszC : cpl::Iterate(aosValues))
482 : {
483 : // Already validated to be numeric by the validation action
484 27 : def.adfCoefficients.push_back(CPLAtof(pszC));
485 : }
486 : }
487 25 : aKernelDefs.push_back(std::move(def));
488 :
489 25 : ++i;
490 : }
491 :
492 23 : auto vrt = GDALNeighborsCreateVRTDerived(poSrcDS, m_band, eType, m_nodata,
493 23 : m_method, aKernelDefs);
494 23 : const bool ret = vrt != nullptr;
495 23 : if (vrt)
496 : {
497 22 : m_outputDataset.Set(std::move(vrt));
498 : }
499 46 : return ret;
500 : }
501 :
502 : GDALRasterNeighborsAlgorithmStandalone::
503 : ~GDALRasterNeighborsAlgorithmStandalone() = default;
504 :
505 : //! @endcond
|