LCOV - code coverage report
Current view: top level - apps - gdalalg_raster_calc.cpp (source / functions) Hit Total Coverage
Test: gdal_filtered.info Lines: 402 438 91.8 %
Date: 2025-06-19 12:30:01 Functions: 14 14 100.0 %

          Line data    Source code
       1             : /******************************************************************************
       2             :  *
       3             :  * Project:  GDAL
       4             :  * Purpose:  "gdal raster calc" subcommand
       5             :  * Author:   Daniel Baston
       6             :  *
       7             :  ******************************************************************************
       8             :  * Copyright (c) 2025, ISciences LLC
       9             :  *
      10             :  * SPDX-License-Identifier: MIT
      11             :  ****************************************************************************/
      12             : 
      13             : #include "gdalalg_raster_calc.h"
      14             : 
      15             : #include "../frmts/vrt/gdal_vrt.h"
      16             : #include "../frmts/vrt/vrtdataset.h"
      17             : 
      18             : #include "cpl_float.h"
      19             : #include "cpl_vsi_virtual.h"
      20             : #include "gdal_priv.h"
      21             : #include "gdal_utils.h"
      22             : #include "vrtdataset.h"
      23             : 
      24             : #include <algorithm>
      25             : #include <array>
      26             : #include <optional>
      27             : 
      28             : //! @cond Doxygen_Suppress
      29             : 
      30             : #ifndef _
      31             : #define _(x) (x)
      32             : #endif
      33             : 
      34             : struct GDALCalcOptions
      35             : {
      36             :     GDALDataType dstType{GDT_Unknown};
      37             :     bool checkSRS{true};
      38             :     bool checkExtent{true};
      39             : };
      40             : 
      41         175 : static bool MatchIsCompleteVariableNameWithNoIndex(const std::string &str,
      42             :                                                    size_t from, size_t to)
      43             : {
      44         175 :     if (to < str.size())
      45             :     {
      46             :         // If the character after the end of the match is:
      47             :         // * alphanumeric or _ : we've matched only part of a variable name
      48             :         // * [ : we've matched a variable that already has an index
      49             :         // * ( : we've matched a function name
      50         220 :         if (std::isalnum(str[to]) || str[to] == '_' || str[to] == '[' ||
      51          80 :             str[to] == '(')
      52             :         {
      53          61 :             return false;
      54             :         }
      55             :     }
      56         114 :     if (from > 0)
      57             :     {
      58             :         // If the character before the start of the match is alphanumeric or _,
      59             :         // we've matched only part of a variable name.
      60          69 :         if (std::isalnum(str[from - 1]) || str[from - 1] == '_')
      61             :         {
      62           3 :             return false;
      63             :         }
      64             :     }
      65             : 
      66         111 :     return true;
      67             : }
      68             : 
      69             : /**
      70             :  *  Add a band subscript to all instances of a specified variable that
      71             :  *  do not already have such a subscript. For example, "X" would be
      72             :  *  replaced with "X[3]" but "X[1]" would be left untouched.
      73             :  */
      74         115 : static std::string SetBandIndices(const std::string &origExpression,
      75             :                                   const std::string &variable, int band,
      76             :                                   bool &expressionChanged)
      77             : {
      78         115 :     std::string expression = origExpression;
      79         115 :     expressionChanged = false;
      80             : 
      81         115 :     std::string::size_type seekPos = 0;
      82         115 :     auto pos = expression.find(variable, seekPos);
      83         270 :     while (pos != std::string::npos)
      84             :     {
      85         155 :         auto end = pos + variable.size();
      86             : 
      87         155 :         if (MatchIsCompleteVariableNameWithNoIndex(expression, pos, end))
      88             :         {
      89             :             // No index specified for variable
      90         182 :             expression = expression.substr(0, pos + variable.size()) + '[' +
      91         273 :                          std::to_string(band) + ']' + expression.substr(end);
      92          91 :             expressionChanged = true;
      93             :         }
      94             : 
      95         155 :         seekPos = end;
      96         155 :         pos = expression.find(variable, seekPos);
      97             :     }
      98             : 
      99         115 :     return expression;
     100             : }
     101             : 
     102          40 : static bool PosIsFunctionArgument(const std::string &expression, size_t pos)
     103             : {
     104             :     // If this position is a function argument, we should be able to
     105             :     // scan backwards for a ( and find only variable names, literals or commas.
     106          40 :     while (pos != 0)
     107             :     {
     108          32 :         char c = expression[pos];
     109          32 :         if (c == '(')
     110             :         {
     111           8 :             pos--;
     112           8 :             break;
     113             :         }
     114          24 :         if (!(isspace(c) || isalnum(c) || c == ',' || c == '.' || c == '[' ||
     115             :               c == ']' || c == '_'))
     116             :         {
     117           4 :             return false;
     118             :         }
     119          20 :         pos--;
     120             :     }
     121             : 
     122             :     // Now what we've found the (, the preceding character should be part of a
     123             :     // value function name
     124          16 :     while (pos != 0)
     125             :     {
     126           8 :         char c = expression[pos];
     127           8 :         if (isalnum(c) || c == '_')
     128             :         {
     129           8 :             return true;
     130             :         }
     131           0 :         if (!isspace(c))
     132             :         {
     133           0 :             return false;
     134             :         }
     135           0 :         pos--;
     136             :     }
     137             : 
     138           8 :     return false;
     139             : }
     140             : 
     141             : /**
     142             :  *  Replace X by X[1],X[2],...X[n]
     143             :  */
     144             : static std::string
     145          16 : SetBandIndicesFlattenedExpression(const std::string &origExpression,
     146             :                                   const std::string &variable, int nBands)
     147             : {
     148          16 :     std::string expression = origExpression;
     149             : 
     150          16 :     std::string::size_type seekPos = 0;
     151          16 :     auto pos = expression.find(variable, seekPos);
     152          36 :     while (pos != std::string::npos)
     153             :     {
     154          20 :         auto end = pos + variable.size();
     155             : 
     156          40 :         if (MatchIsCompleteVariableNameWithNoIndex(expression, pos, end) &&
     157          20 :             PosIsFunctionArgument(expression, pos))
     158             :         {
     159           8 :             std::string newExpr = expression.substr(0, pos);
     160          24 :             for (int i = 1; i <= nBands; ++i)
     161             :             {
     162          16 :                 if (i > 1)
     163           8 :                     newExpr += ',';
     164          16 :                 newExpr += variable;
     165          16 :                 newExpr += '[';
     166          16 :                 newExpr += std::to_string(i);
     167          16 :                 newExpr += ']';
     168             :             }
     169           8 :             const size_t oldExprSize = expression.size();
     170           8 :             newExpr += expression.substr(end);
     171           8 :             expression = std::move(newExpr);
     172           8 :             end += expression.size() - oldExprSize;
     173             :         }
     174             : 
     175          20 :         seekPos = end;
     176          20 :         pos = expression.find(variable, seekPos);
     177             :     }
     178             : 
     179          16 :     return expression;
     180             : }
     181             : 
     182             : struct SourceProperties
     183             : {
     184             :     int nBands{0};
     185             :     int nX{0};
     186             :     int nY{0};
     187             :     std::array<double, 6> gt{};
     188             :     std::unique_ptr<OGRSpatialReference, OGRSpatialReferenceReleaser> srs{
     189             :         nullptr};
     190             :     GDALDataType eDT{GDT_Unknown};
     191             : };
     192             : 
     193             : static std::optional<SourceProperties>
     194         117 : UpdateSourceProperties(SourceProperties &out, const std::string &dsn,
     195             :                        const GDALCalcOptions &options)
     196             : {
     197         234 :     SourceProperties source;
     198         117 :     bool srsMismatch = false;
     199         117 :     bool extentMismatch = false;
     200         117 :     bool dimensionMismatch = false;
     201             : 
     202             :     {
     203             :         std::unique_ptr<GDALDataset> ds(
     204         117 :             GDALDataset::Open(dsn.c_str(), GDAL_OF_RASTER));
     205             : 
     206         117 :         if (!ds)
     207             :         {
     208           0 :             CPLError(CE_Failure, CPLE_AppDefined, "Failed to open %s",
     209             :                      dsn.c_str());
     210           0 :             return std::nullopt;
     211             :         }
     212             : 
     213         117 :         source.nBands = ds->GetRasterCount();
     214         117 :         source.nX = ds->GetRasterXSize();
     215         117 :         source.nY = ds->GetRasterYSize();
     216             : 
     217         117 :         if (options.checkExtent)
     218             :         {
     219         113 :             ds->GetGeoTransform(source.gt.data());
     220             :         }
     221             : 
     222         117 :         if (options.checkSRS && out.srs)
     223             :         {
     224          54 :             const OGRSpatialReference *srs = ds->GetSpatialRef();
     225          54 :             srsMismatch = srs && !srs->IsSame(out.srs.get());
     226             :         }
     227             : 
     228             :         // Store the source data type if it is the same for all bands in the source
     229         312 :         for (int i = 0; i < source.nBands; ++i)
     230             :         {
     231         195 :             if (i == 0)
     232             :             {
     233         117 :                 source.eDT = ds->GetRasterBand(1)->GetRasterDataType();
     234             :             }
     235         156 :             else if (source.eDT !=
     236          78 :                      ds->GetRasterBand(i + 1)->GetRasterDataType())
     237             :             {
     238           0 :                 source.eDT = GDT_Unknown;
     239           0 :                 break;
     240             :             }
     241             :         }
     242             :     }
     243             : 
     244         117 :     if (source.nX != out.nX || source.nY != out.nY)
     245             :     {
     246           3 :         dimensionMismatch = true;
     247             :     }
     248             : 
     249         234 :     if (source.gt[0] != out.gt[0] || source.gt[2] != out.gt[2] ||
     250         234 :         source.gt[3] != out.gt[3] || source.gt[4] != out.gt[4])
     251             :     {
     252           5 :         extentMismatch = true;
     253             :     }
     254         117 :     if (source.gt[1] != out.gt[1] || source.gt[5] != out.gt[5])
     255             :     {
     256             :         // Resolutions are different. Are the extents the same?
     257           8 :         double xmaxOut = out.gt[0] + out.nX * out.gt[1] + out.nY * out.gt[2];
     258           8 :         double yminOut = out.gt[3] + out.nX * out.gt[4] + out.nY * out.gt[5];
     259             : 
     260             :         double xmax =
     261           8 :             source.gt[0] + source.nX * source.gt[1] + source.nY * source.gt[2];
     262             :         double ymin =
     263           8 :             source.gt[3] + source.nX * source.gt[4] + source.nY * source.gt[5];
     264             : 
     265             :         // Max allowable extent misalignment, expressed as fraction of a pixel
     266           8 :         constexpr double EXTENT_RTOL = 1e-3;
     267             : 
     268          11 :         if (std::abs(xmax - xmaxOut) > EXTENT_RTOL * std::abs(source.gt[1]) ||
     269           3 :             std::abs(ymin - yminOut) > EXTENT_RTOL * std::abs(source.gt[5]))
     270             :         {
     271           5 :             extentMismatch = true;
     272             :         }
     273             :     }
     274             : 
     275         117 :     if (options.checkExtent && extentMismatch)
     276             :     {
     277           1 :         CPLError(CE_Failure, CPLE_AppDefined,
     278             :                  "Input extents are inconsistent.");
     279           1 :         return std::nullopt;
     280             :     }
     281             : 
     282         116 :     if (!options.checkExtent && dimensionMismatch)
     283             :     {
     284           1 :         CPLError(CE_Failure, CPLE_AppDefined,
     285             :                  "Inputs do not have the same dimensions.");
     286           1 :         return std::nullopt;
     287             :     }
     288             : 
     289             :     // Find a common resolution
     290         115 :     if (source.nX > out.nX)
     291             :     {
     292           1 :         auto dx = CPLGreatestCommonDivisor(out.gt[1], source.gt[1]);
     293           1 :         if (dx == 0)
     294             :         {
     295           0 :             CPLError(CE_Failure, CPLE_AppDefined,
     296             :                      "Failed to find common resolution for inputs.");
     297           0 :             return std::nullopt;
     298             :         }
     299           1 :         out.nX = static_cast<int>(
     300           1 :             std::round(static_cast<double>(out.nX) * out.gt[1] / dx));
     301           1 :         out.gt[1] = dx;
     302             :     }
     303         115 :     if (source.nY > out.nY)
     304             :     {
     305           1 :         auto dy = CPLGreatestCommonDivisor(out.gt[5], source.gt[5]);
     306           1 :         if (dy == 0)
     307             :         {
     308           0 :             CPLError(CE_Failure, CPLE_AppDefined,
     309             :                      "Failed to find common resolution for inputs.");
     310           0 :             return std::nullopt;
     311             :         }
     312           1 :         out.nY = static_cast<int>(
     313           1 :             std::round(static_cast<double>(out.nY) * out.gt[5] / dy));
     314           1 :         out.gt[5] = dy;
     315             :     }
     316             : 
     317         115 :     if (srsMismatch)
     318             :     {
     319           1 :         CPLError(CE_Failure, CPLE_AppDefined,
     320             :                  "Input spatial reference systems are inconsistent.");
     321           1 :         return std::nullopt;
     322             :     }
     323             : 
     324         114 :     return source;
     325             : }
     326             : 
     327             : /** Create XML nodes for one or more derived bands resulting from the evaluation
     328             :  *  of a single expression
     329             :  *
     330             :  * @param root VRTDataset node to which the band nodes should be added
     331             :  * @param bandType the type of the band(s) to create
     332             :  * @param nXOut Number of columns in VRT dataset
     333             :  * @param nYOut Number of rows in VRT dataset
     334             :  * @param expression Expression for which band(s) should be added
     335             :  * @param dialect Expression dialect
     336             :  * @param flatten Generate a single band output raster per expression, even if
     337             :  *                input datasets are multiband.
     338             :  * @param pixelFunctionArguments Pixel function arguments.
     339             :  * @param sources Mapping of source names to DSNs
     340             :  * @param sourceProps Mapping of source names to properties
     341             :  * @param fakeSourceFilename If not empty, used instead of real input filenames.
     342             :  * @return true if the band(s) were added, false otherwise
     343             :  */
     344             : static bool
     345          91 : CreateDerivedBandXML(CPLXMLNode *root, int nXOut, int nYOut,
     346             :                      GDALDataType bandType, const std::string &expression,
     347             :                      const std::string &dialect, bool flatten,
     348             :                      const std::vector<std::string> &pixelFunctionArguments,
     349             :                      const std::map<std::string, std::string> &sources,
     350             :                      const std::map<std::string, SourceProperties> &sourceProps,
     351             :                      const std::string &fakeSourceFilename)
     352             : {
     353          91 :     int nOutBands = 1;  // By default, each expression produces a single output
     354             :                         // band. When processing the expression below, we may
     355             :                         // discover that the expression produces multiple bands,
     356             :                         // in which case this will be updated.
     357             : 
     358         204 :     for (int nOutBand = 1; nOutBand <= nOutBands; nOutBand++)
     359             :     {
     360             :         // Copy the expression for each output band, because we may modify it
     361             :         // when adding band indices (e.g., X -> X[1]) to the variables in the
     362             :         // expression.
     363         116 :         std::string bandExpression = expression;
     364             : 
     365         116 :         CPLXMLNode *band = CPLCreateXMLNode(root, CXT_Element, "VRTRasterBand");
     366         116 :         CPLAddXMLAttributeAndValue(band, "subClass", "VRTDerivedRasterBand");
     367         116 :         const char *pszDataType = nullptr;
     368         116 :         if (bandType != GDT_Unknown)
     369             :         {
     370           7 :             pszDataType = GDALGetDataTypeName(bandType);
     371             :         }
     372         116 :         CPLAddXMLAttributeAndValue(band, "dataType",
     373             :                                    pszDataType ? pszDataType : "Float64");
     374             : 
     375         262 :         for (const auto &[source_name, dsn] : sources)
     376             :         {
     377         149 :             auto it = sourceProps.find(source_name);
     378         149 :             CPLAssert(it != sourceProps.end());
     379         149 :             const auto &props = it->second;
     380             : 
     381         149 :             bool expressionAppliedPerBand = false;
     382         149 :             if (dialect == "builtin")
     383             :             {
     384          34 :                 expressionAppliedPerBand = !flatten;
     385             :             }
     386             :             else
     387             :             {
     388         115 :                 const int nDefaultInBand = std::min(props.nBands, nOutBand);
     389             : 
     390         115 :                 if (flatten)
     391             :                 {
     392          16 :                     bandExpression = SetBandIndicesFlattenedExpression(
     393          16 :                         bandExpression, source_name, props.nBands);
     394             :                 }
     395             : 
     396             :                 bandExpression =
     397         230 :                     SetBandIndices(bandExpression, source_name, nDefaultInBand,
     398         115 :                                    expressionAppliedPerBand);
     399             :             }
     400             : 
     401         149 :             if (expressionAppliedPerBand)
     402             :             {
     403         107 :                 if (nOutBands <= 1)
     404             :                 {
     405          68 :                     nOutBands = props.nBands;
     406             :                 }
     407          39 :                 else if (props.nBands != 1 && props.nBands != nOutBands)
     408             :                 {
     409           3 :                     CPLError(CE_Failure, CPLE_AppDefined,
     410             :                              "Expression cannot operate on all bands of "
     411             :                              "rasters with incompatible numbers of bands "
     412             :                              "(source %s has %d bands but expected to have "
     413             :                              "1 or %d bands).",
     414           3 :                              source_name.c_str(), props.nBands, nOutBands);
     415           3 :                     return false;
     416             :                 }
     417             :             }
     418             : 
     419             :             // Create a <SimpleSource> for each input band that is used in
     420             :             // the expression.
     421         411 :             for (int nInBand = 1; nInBand <= props.nBands; nInBand++)
     422             :             {
     423         265 :                 CPLString inBandVariable;
     424         265 :                 if (dialect == "builtin")
     425             :                 {
     426          64 :                     if (!flatten && props.nBands >= 2 && nInBand != nOutBand)
     427          11 :                         continue;
     428             :                 }
     429             :                 else
     430             :                 {
     431             :                     inBandVariable.Printf("%s[%d]", source_name.c_str(),
     432         201 :                                           nInBand);
     433         201 :                     if (bandExpression.find(inBandVariable) ==
     434             :                         std::string::npos)
     435             :                     {
     436          75 :                         continue;
     437             :                     }
     438             :                 }
     439             : 
     440             :                 CPLXMLNode *source =
     441         179 :                     CPLCreateXMLNode(band, CXT_Element, "SimpleSource");
     442         179 :                 if (!inBandVariable.empty())
     443             :                 {
     444         126 :                     CPLAddXMLAttributeAndValue(source, "name",
     445             :                                                inBandVariable.c_str());
     446             :                 }
     447             : 
     448             :                 CPLXMLNode *sourceFilename =
     449         179 :                     CPLCreateXMLNode(source, CXT_Element, "SourceFilename");
     450         179 :                 if (fakeSourceFilename.empty())
     451             :                 {
     452         132 :                     CPLAddXMLAttributeAndValue(sourceFilename, "relativeToVRT",
     453             :                                                "0");
     454         132 :                     CPLCreateXMLNode(sourceFilename, CXT_Text, dsn.c_str());
     455             :                 }
     456             :                 else
     457             :                 {
     458          47 :                     CPLCreateXMLNode(sourceFilename, CXT_Text,
     459             :                                      fakeSourceFilename.c_str());
     460             :                 }
     461             : 
     462             :                 CPLXMLNode *sourceBand =
     463         179 :                     CPLCreateXMLNode(source, CXT_Element, "SourceBand");
     464         179 :                 CPLCreateXMLNode(sourceBand, CXT_Text,
     465         358 :                                  std::to_string(nInBand).c_str());
     466             : 
     467         179 :                 if (fakeSourceFilename.empty())
     468             :                 {
     469             :                     CPLXMLNode *srcRect =
     470         132 :                         CPLCreateXMLNode(source, CXT_Element, "SrcRect");
     471         132 :                     CPLAddXMLAttributeAndValue(srcRect, "xOff", "0");
     472         132 :                     CPLAddXMLAttributeAndValue(srcRect, "yOff", "0");
     473         132 :                     CPLAddXMLAttributeAndValue(
     474         264 :                         srcRect, "xSize", std::to_string(props.nX).c_str());
     475         132 :                     CPLAddXMLAttributeAndValue(
     476         264 :                         srcRect, "ySize", std::to_string(props.nY).c_str());
     477             : 
     478             :                     CPLXMLNode *dstRect =
     479         132 :                         CPLCreateXMLNode(source, CXT_Element, "DstRect");
     480         132 :                     CPLAddXMLAttributeAndValue(dstRect, "xOff", "0");
     481         132 :                     CPLAddXMLAttributeAndValue(dstRect, "yOff", "0");
     482         132 :                     CPLAddXMLAttributeAndValue(dstRect, "xSize",
     483         264 :                                                std::to_string(nXOut).c_str());
     484         132 :                     CPLAddXMLAttributeAndValue(dstRect, "ySize",
     485         264 :                                                std::to_string(nYOut).c_str());
     486             :                 }
     487             :             }
     488             :         }
     489             : 
     490             :         CPLXMLNode *pixelFunctionType =
     491         113 :             CPLCreateXMLNode(band, CXT_Element, "PixelFunctionType");
     492         113 :         if (dialect == "builtin")
     493             :         {
     494          24 :             CPLCreateXMLNode(pixelFunctionType, CXT_Text, expression.c_str());
     495          24 :             if (!pixelFunctionArguments.empty())
     496             :             {
     497           4 :                 CPLXMLNode *arguments = CPLCreateXMLNode(
     498             :                     band, CXT_Element, "PixelFunctionArguments");
     499           8 :                 const CPLStringList args(pixelFunctionArguments);
     500           8 :                 for (const auto &[key, value] : cpl::IterateNameValue(args))
     501             :                 {
     502           4 :                     CPLAddXMLAttributeAndValue(arguments, key, value);
     503             :                 }
     504             :             }
     505             :         }
     506             :         else
     507             :         {
     508          89 :             CPLCreateXMLNode(pixelFunctionType, CXT_Text, "expression");
     509             :             CPLXMLNode *arguments =
     510          89 :                 CPLCreateXMLNode(band, CXT_Element, "PixelFunctionArguments");
     511             : 
     512             :             // Add the expression as a last step, because we may modify the
     513             :             // expression as we iterate through the bands.
     514          89 :             CPLAddXMLAttributeAndValue(arguments, "expression",
     515             :                                        bandExpression.c_str());
     516          89 :             CPLAddXMLAttributeAndValue(arguments, "dialect", "muparser");
     517             :         }
     518             :     }
     519             : 
     520          88 :     return true;
     521             : }
     522             : 
     523          96 : static bool ParseSourceDescriptors(const std::vector<std::string> &inputs,
     524             :                                    std::map<std::string, std::string> &datasets,
     525             :                                    std::string &firstSourceName,
     526             :                                    bool requireSourceNames)
     527             : {
     528         216 :     for (size_t iInput = 0; iInput < inputs.size(); iInput++)
     529             :     {
     530         125 :         const std::string &input = inputs[iInput];
     531         125 :         std::string name;
     532             : 
     533         125 :         const auto pos = input.find('=');
     534         125 :         if (pos == std::string::npos)
     535             :         {
     536          50 :             if (requireSourceNames && inputs.size() > 1)
     537             :             {
     538           1 :                 CPLError(CE_Failure, CPLE_AppDefined,
     539             :                          "Inputs must be named when more than one input is "
     540             :                          "provided.");
     541           1 :                 return false;
     542             :             }
     543          49 :             name = "X";
     544          49 :             if (iInput > 0)
     545             :             {
     546           2 :                 name += std::to_string(iInput);
     547             :             }
     548             :         }
     549             :         else
     550             :         {
     551          75 :             name = input.substr(0, pos);
     552             :         }
     553             : 
     554             :         // Check input name is legal
     555         269 :         for (size_t i = 0; i < name.size(); ++i)
     556             :         {
     557         148 :             const char c = name[i];
     558         148 :             if ((c >= 'a' && c <= 'z') || (c >= 'A' && c <= 'Z'))
     559             :             {
     560             :                 // ok
     561             :             }
     562          20 :             else if (c == '_' || (c >= '0' && c <= '9'))
     563             :             {
     564          19 :                 if (i == 0)
     565             :                 {
     566             :                     // Reserved constants in MuParser start with an underscore
     567           2 :                     CPLError(
     568             :                         CE_Failure, CPLE_AppDefined,
     569             :                         "Name '%s' is illegal because it starts with a '%c'",
     570             :                         name.c_str(), c);
     571           2 :                     return false;
     572             :                 }
     573             :             }
     574             :             else
     575             :             {
     576           1 :                 CPLError(CE_Failure, CPLE_AppDefined,
     577             :                          "Name '%s' is illegal because character '%c' is not "
     578             :                          "allowed",
     579             :                          name.c_str(), c);
     580           1 :                 return false;
     581             :             }
     582             :         }
     583             : 
     584             :         std::string dsn =
     585         121 :             (pos == std::string::npos) ? input : input.substr(pos + 1);
     586         121 :         if (datasets.find(name) != datasets.end())
     587             :         {
     588           1 :             CPLError(CE_Failure, CPLE_AppDefined,
     589             :                      "An input with name '%s' has already been provided",
     590             :                      name.c_str());
     591           1 :             return false;
     592             :         }
     593         120 :         datasets[name] = std::move(dsn);
     594             : 
     595         120 :         if (iInput == 0)
     596             :         {
     597          92 :             firstSourceName = std::move(name);
     598             :         }
     599             :     }
     600             : 
     601          91 :     return true;
     602             : }
     603             : 
     604          74 : static bool ReadFileLists(const std::vector<GDALArgDatasetValue> &inputDS,
     605             :                           std::vector<std::string> &inputFilenames)
     606             : {
     607         172 :     for (const auto &dsVal : inputDS)
     608             :     {
     609          98 :         const auto &input = dsVal.GetName();
     610          98 :         if (!input.empty() && input[0] == '@')
     611             :         {
     612             :             auto f =
     613           2 :                 VSIVirtualHandleUniquePtr(VSIFOpenL(input.c_str() + 1, "r"));
     614           2 :             if (!f)
     615             :             {
     616           0 :                 CPLError(CE_Failure, CPLE_FileIO, "Cannot open %s",
     617           0 :                          input.c_str() + 1);
     618           0 :                 return false;
     619             :             }
     620           6 :             while (const char *filename = CPLReadLineL(f.get()))
     621             :             {
     622           4 :                 inputFilenames.push_back(filename);
     623           4 :             }
     624             :         }
     625             :         else
     626             :         {
     627          96 :             inputFilenames.push_back(input);
     628             :         }
     629             :     }
     630             : 
     631          74 :     return true;
     632             : }
     633             : 
     634             : /** Creates a VRT datasource with one or more derived raster bands containing
     635             :  *  results of an expression.
     636             :  *
     637             :  * To make this work with muparser (which does not support vector types), we
     638             :  * do a simple parsing of the expression internally, transforming it into
     639             :  * multiple expressions with explicit band indices. For example, for a two-band
     640             :  * raster "X", the expression "X + 3" will be transformed into "X[1] + 3" and
     641             :  * "X[2] + 3". The use of brackets is for readability only; as far as the
     642             :  * expression engine is concerned, the variables "X[1]" and "X[2]" have nothing
     643             :  * to do with each other.
     644             :  *
     645             :  * @param inputs A list of sources, expressed as NAME=DSN
     646             :  * @param expressions A list of expressions to be evaluated
     647             :  * @param dialect Expression dialect
     648             :  * @param flatten Generate a single band output raster per expression, even if
     649             :  *                input datasets are multiband.
     650             :  * @param pixelFunctionArguments Pixel function arguments.
     651             :  * @param options flags controlling which checks should be performed on the inputs
     652             :  * @param[out] maxSourceBands Maximum number of bands in source dataset(s)
     653             :  * @param fakeSourceFilename If not empty, used instead of real input filenames.
     654             :  *
     655             :  * @return a newly created VRTDataset, or nullptr on error
     656             :  */
     657          96 : static std::unique_ptr<GDALDataset> GDALCalcCreateVRTDerived(
     658             :     const std::vector<std::string> &inputs,
     659             :     const std::vector<std::string> &expressions, const std::string &dialect,
     660             :     bool flatten,
     661             :     const std::vector<std::vector<std::string>> &pixelFunctionArguments,
     662             :     const GDALCalcOptions &options, int &maxSourceBands,
     663             :     const std::string &fakeSourceFilename = std::string())
     664             : {
     665          96 :     if (inputs.empty())
     666             :     {
     667           0 :         return nullptr;
     668             :     }
     669             : 
     670         192 :     std::map<std::string, std::string> sources;
     671         192 :     std::string firstSource;
     672          96 :     bool requireSourceNames = dialect != "builtin";
     673          96 :     if (!ParseSourceDescriptors(inputs, sources, firstSource,
     674             :                                 requireSourceNames))
     675             :     {
     676           5 :         return nullptr;
     677             :     }
     678             : 
     679             :     // Use the first source provided to determine properties of the output
     680          91 :     const char *firstDSN = sources[firstSource].c_str();
     681             : 
     682          91 :     maxSourceBands = 0;
     683             : 
     684             :     // Read properties from the first source
     685         182 :     SourceProperties out;
     686             :     {
     687             :         std::unique_ptr<GDALDataset> ds(
     688          91 :             GDALDataset::Open(firstDSN, GDAL_OF_RASTER));
     689             : 
     690          91 :         if (!ds)
     691             :         {
     692           0 :             CPLError(CE_Failure, CPLE_AppDefined, "Failed to open %s",
     693             :                      firstDSN);
     694           0 :             return nullptr;
     695             :         }
     696             : 
     697          91 :         out.nX = ds->GetRasterXSize();
     698          91 :         out.nY = ds->GetRasterYSize();
     699          91 :         out.nBands = 1;
     700          91 :         out.srs.reset(ds->GetSpatialRef() ? ds->GetSpatialRef()->Clone()
     701             :                                           : nullptr);
     702          91 :         ds->GetGeoTransform(out.gt.data());
     703             :     }
     704             : 
     705         182 :     CPLXMLTreeCloser root(CPLCreateXMLNode(nullptr, CXT_Element, "VRTDataset"));
     706             : 
     707          91 :     maxSourceBands = 0;
     708             : 
     709             :     // Collect properties of the different sources, and verity them for
     710             :     // consistency.
     711         182 :     std::map<std::string, SourceProperties> sourceProps;
     712         205 :     for (const auto &[source_name, dsn] : sources)
     713             :     {
     714             :         // TODO avoid opening the first source twice.
     715         117 :         auto props = UpdateSourceProperties(out, dsn, options);
     716         117 :         if (props.has_value())
     717             :         {
     718         114 :             maxSourceBands = std::max(maxSourceBands, props->nBands);
     719         114 :             sourceProps[source_name] = std::move(props.value());
     720             :         }
     721             :         else
     722             :         {
     723           3 :             return nullptr;
     724             :         }
     725             :     }
     726             : 
     727          88 :     size_t iExpr = 0;
     728         176 :     for (const auto &origExpression : expressions)
     729             :     {
     730          91 :         GDALDataType bandType = options.dstType;
     731             : 
     732             :         // If output band type has not been specified, set it equal to the
     733             :         // input band type for certain pixel functions, if the inputs have
     734             :         // a consistent band type.
     735         133 :         if (bandType == GDT_Unknown && dialect == "builtin" &&
     736          60 :             (origExpression == "min" || origExpression == "max" ||
     737          18 :              origExpression == "mode"))
     738             :         {
     739          12 :             for (const auto &[_, props] : sourceProps)
     740             :             {
     741           6 :                 if (bandType == GDT_Unknown)
     742             :                 {
     743           6 :                     bandType = props.eDT;
     744             :                 }
     745           0 :                 else if (props.eDT == GDT_Unknown || props.eDT != bandType)
     746             :                 {
     747           0 :                     bandType = GDT_Unknown;
     748           0 :                     break;
     749             :                 }
     750             :             }
     751             :         }
     752             : 
     753          91 :         if (!CreateDerivedBandXML(root.get(), out.nX, out.nY, bandType,
     754             :                                   origExpression, dialect, flatten,
     755          91 :                                   pixelFunctionArguments[iExpr], sources,
     756             :                                   sourceProps, fakeSourceFilename))
     757             :         {
     758           3 :             return nullptr;
     759             :         }
     760          88 :         ++iExpr;
     761             :     }
     762             : 
     763             :     //CPLDebug("VRT", "%s", CPLSerializeXMLTree(root.get()));
     764             : 
     765          85 :     auto ds = fakeSourceFilename.empty()
     766             :                   ? std::make_unique<VRTDataset>(out.nX, out.nY)
     767         170 :                   : std::make_unique<VRTDataset>(1, 1);
     768          85 :     if (ds->XMLInit(root.get(), "") != CE_None)
     769             :     {
     770           0 :         return nullptr;
     771             :     };
     772          85 :     ds->SetGeoTransform(out.gt.data());
     773          85 :     if (out.srs)
     774             :     {
     775          52 :         ds->SetSpatialRef(out.srs.get());
     776             :     }
     777             : 
     778          85 :     return ds;
     779             : }
     780             : 
     781             : /************************************************************************/
     782             : /*          GDALRasterCalcAlgorithm::GDALRasterCalcAlgorithm()          */
     783             : /************************************************************************/
     784             : 
     785          93 : GDALRasterCalcAlgorithm::GDALRasterCalcAlgorithm(bool standaloneStep) noexcept
     786             :     : GDALRasterPipelineStepAlgorithm(NAME, DESCRIPTION, HELP_URL,
     787         279 :                                       ConstructorOptions()
     788          93 :                                           .SetStandaloneStep(standaloneStep)
     789          93 :                                           .SetAddDefaultArguments(false)
     790          93 :                                           .SetAutoOpenInputDatasets(false)
     791         186 :                                           .SetInputDatasetMetaVar("INPUTS")
     792         279 :                                           .SetInputDatasetMaxCount(INT_MAX))
     793             : {
     794          93 :     AddRasterInputArgs(false, false);
     795          93 :     if (standaloneStep)
     796             :     {
     797          72 :         AddProgressArg();
     798          72 :         AddRasterOutputArgs(false);
     799             :     }
     800             : 
     801          93 :     AddOutputDataTypeArg(&m_type);
     802             : 
     803             :     AddArg("no-check-srs", 0,
     804             :            _("Do not check consistency of input spatial reference systems"),
     805          93 :            &m_NoCheckSRS);
     806             :     AddArg("no-check-extent", 0, _("Do not check consistency of input extents"),
     807          93 :            &m_NoCheckExtent);
     808             : 
     809         186 :     AddArg("calc", 0, _("Expression(s) to evaluate"), &m_expr)
     810          93 :         .SetRequired()
     811          93 :         .SetPackedValuesAllowed(false)
     812          93 :         .SetMinCount(1)
     813             :         .SetAutoCompleteFunction(
     814           4 :             [this](const std::string &currentValue)
     815             :             {
     816           4 :                 std::vector<std::string> ret;
     817           2 :                 if (m_dialect == "builtin")
     818             :                 {
     819           1 :                     if (currentValue.find('(') == std::string::npos)
     820           1 :                         return VRTDerivedRasterBand::GetPixelFunctionNames();
     821             :                 }
     822           1 :                 return ret;
     823          93 :             });
     824             : 
     825         186 :     AddArg("dialect", 0, _("Expression dialect"), &m_dialect)
     826          93 :         .SetDefault(m_dialect)
     827          93 :         .SetChoices("muparser", "builtin");
     828             : 
     829             :     AddArg("flatten", 0,
     830             :            _("Generate a single band output raster per expression, even if "
     831             :              "input datasets are multiband"),
     832          93 :            &m_flatten);
     833             : 
     834             :     // This is a hidden option only used by test_gdalalg_raster_calc_expression_rewriting()
     835             :     // for now
     836             :     AddArg("no-check-expression", 0,
     837             :            _("Whether to skip expression validity checks for virtual format "
     838             :              "output"),
     839         186 :            &m_noCheckExpression)
     840          93 :         .SetHidden();
     841             : 
     842          93 :     AddValidationAction(
     843         146 :         [this]()
     844             :         {
     845          78 :             GDALPipelineStepRunContext ctxt;
     846          78 :             return m_noCheckExpression || !IsGDALGOutput() || RunStep(ctxt);
     847             :         });
     848          93 : }
     849             : 
     850             : /************************************************************************/
     851             : /*                  GDALRasterCalcAlgorithm::RunImpl()                  */
     852             : /************************************************************************/
     853             : 
     854          69 : bool GDALRasterCalcAlgorithm::RunImpl(GDALProgressFunc pfnProgress,
     855             :                                       void *pProgressData)
     856             : {
     857          69 :     GDALPipelineStepRunContext stepCtxt;
     858          69 :     stepCtxt.m_pfnProgress = pfnProgress;
     859          69 :     stepCtxt.m_pProgressData = pProgressData;
     860          69 :     return RunPreStepPipelineValidations() && RunStep(stepCtxt);
     861             : }
     862             : 
     863             : /************************************************************************/
     864             : /*                GDALRasterCalcAlgorithm::RunStep()                    */
     865             : /************************************************************************/
     866             : 
     867          74 : bool GDALRasterCalcAlgorithm::RunStep(GDALPipelineStepRunContext &ctxt)
     868             : {
     869          74 :     CPLAssert(!m_outputDataset.GetDatasetRef());
     870             : 
     871          74 :     GDALCalcOptions options;
     872          74 :     options.checkExtent = !m_NoCheckExtent;
     873          74 :     options.checkSRS = !m_NoCheckSRS;
     874          74 :     if (!m_type.empty())
     875             :     {
     876           1 :         options.dstType = GDALGetDataTypeByName(m_type.c_str());
     877             :     }
     878             : 
     879         148 :     std::vector<std::string> inputFilenames;
     880          74 :     if (!ReadFileLists(m_inputDataset, inputFilenames))
     881             :     {
     882           0 :         return false;
     883             :     }
     884             : 
     885         148 :     std::vector<std::vector<std::string>> pixelFunctionArgs;
     886          74 :     if (m_dialect == "builtin")
     887             :     {
     888          23 :         for (std::string &expr : m_expr)
     889             :         {
     890             :             const CPLStringList aosTokens(
     891             :                 CSLTokenizeString2(expr.c_str(), "()",
     892          12 :                                    CSLT_STRIPLEADSPACES | CSLT_STRIPENDSPACES));
     893          12 :             const char *pszFunction = aosTokens[0];
     894             :             const auto *pair =
     895          12 :                 VRTDerivedRasterBand::GetPixelFunction(pszFunction);
     896          12 :             if (!pair)
     897             :             {
     898           0 :                 ReportError(CE_Failure, CPLE_NotSupported,
     899             :                             "'%s' is a unknown builtin function", pszFunction);
     900           0 :                 return false;
     901             :             }
     902          12 :             if (aosTokens.size() == 2)
     903             :             {
     904           2 :                 std::vector<std::string> validArguments;
     905           2 :                 AddOptionsSuggestions(pair->second.c_str(), 0, std::string(),
     906             :                                       validArguments);
     907           6 :                 for (std::string &s : validArguments)
     908             :                 {
     909           4 :                     if (!s.empty() && s.back() == '=')
     910           4 :                         s.pop_back();
     911             :                 }
     912             : 
     913             :                 const CPLStringList aosTokensArgs(CSLTokenizeString2(
     914             :                     aosTokens[1], ",",
     915           2 :                     CSLT_STRIPLEADSPACES | CSLT_STRIPENDSPACES));
     916           4 :                 for (const auto &[key, value] :
     917           6 :                      cpl::IterateNameValue(aosTokensArgs))
     918             :                 {
     919           2 :                     if (std::find(validArguments.begin(), validArguments.end(),
     920           2 :                                   key) == validArguments.end())
     921             :                     {
     922           0 :                         if (validArguments.empty())
     923             :                         {
     924           0 :                             ReportError(
     925             :                                 CE_Failure, CPLE_IllegalArg,
     926             :                                 "'%s' is a unrecognized argument for builtin "
     927             :                                 "function '%s'. It does not accept any "
     928             :                                 "argument",
     929             :                                 key, pszFunction);
     930             :                         }
     931             :                         else
     932             :                         {
     933           0 :                             std::string validArgumentsStr;
     934           0 :                             for (const std::string &s : validArguments)
     935             :                             {
     936           0 :                                 if (!validArgumentsStr.empty())
     937           0 :                                     validArgumentsStr += ", ";
     938           0 :                                 validArgumentsStr += '\'';
     939           0 :                                 validArgumentsStr += s;
     940           0 :                                 validArgumentsStr += '\'';
     941             :                             }
     942           0 :                             ReportError(
     943             :                                 CE_Failure, CPLE_IllegalArg,
     944             :                                 "'%s' is a unrecognized argument for builtin "
     945             :                                 "function '%s'. Only %s %s supported",
     946             :                                 key, pszFunction,
     947           0 :                                 validArguments.size() == 1 ? "is" : "are",
     948             :                                 validArgumentsStr.c_str());
     949             :                         }
     950           0 :                         return false;
     951             :                     }
     952           2 :                     CPL_IGNORE_RET_VAL(value);
     953             :                 }
     954           2 :                 pixelFunctionArgs.emplace_back(aosTokensArgs);
     955             :             }
     956             :             else
     957             :             {
     958          10 :                 pixelFunctionArgs.push_back(std::vector<std::string>());
     959             :             }
     960          12 :             expr = pszFunction;
     961             :         }
     962             :     }
     963             :     else
     964             :     {
     965          63 :         pixelFunctionArgs.resize(m_expr.size());
     966             :     }
     967             : 
     968          74 :     int maxSourceBands = 0;
     969             :     auto vrt =
     970          74 :         GDALCalcCreateVRTDerived(inputFilenames, m_expr, m_dialect, m_flatten,
     971         148 :                                  pixelFunctionArgs, options, maxSourceBands);
     972          74 :     if (vrt == nullptr)
     973             :     {
     974          11 :         return false;
     975             :     }
     976             : 
     977          63 :     if (!m_noCheckExpression)
     978             :     {
     979             :         const bool bIsVRT =
     980         126 :             m_format == "VRT" ||
     981          49 :             (m_format.empty() &&
     982          54 :              EQUAL(
     983             :                  CPLGetExtensionSafe(m_outputDataset.GetName().c_str()).c_str(),
     984          50 :                  "VRT"));
     985             :         const bool bIsGDALG =
     986         126 :             m_format == "GDALG" ||
     987          49 :             (m_format.empty() &&
     988          27 :              cpl::ends_with(m_outputDataset.GetName(), ".gdalg.json"));
     989          50 :         if (!m_standaloneStep || m_format == "stream" || bIsVRT || bIsGDALG)
     990             :         {
     991             :             // Try reading a single pixel to check formulas are valid.
     992          22 :             std::vector<GByte> dummyData(vrt->GetRasterCount());
     993             : 
     994          22 :             auto poGTIFFDrv = GetGDALDriverManager()->GetDriverByName("GTiff");
     995          22 :             std::string osTmpFilename;
     996          22 :             if (poGTIFFDrv)
     997             :             {
     998             :                 std::string osFilename =
     999          44 :                     VSIMemGenerateHiddenFilename("tmp.tif");
    1000             :                 auto poDS = std::unique_ptr<GDALDataset>(
    1001             :                     poGTIFFDrv->Create(osFilename.c_str(), 1, 1, maxSourceBands,
    1002          44 :                                        GDT_Byte, nullptr));
    1003          22 :                 if (poDS)
    1004          22 :                     osTmpFilename = std::move(osFilename);
    1005             :             }
    1006          22 :             if (!osTmpFilename.empty())
    1007             :             {
    1008             :                 auto fakeVRT = GDALCalcCreateVRTDerived(
    1009          22 :                     inputFilenames, m_expr, m_dialect, m_flatten,
    1010          22 :                     pixelFunctionArgs, options, maxSourceBands, osTmpFilename);
    1011          44 :                 if (fakeVRT &&
    1012          22 :                     fakeVRT->RasterIO(GF_Read, 0, 0, 1, 1, dummyData.data(), 1,
    1013             :                                       1, GDT_Byte, vrt->GetRasterCount(),
    1014          22 :                                       nullptr, 0, 0, 0, nullptr) != CE_None)
    1015             :                 {
    1016           5 :                     return false;
    1017             :                 }
    1018             :             }
    1019          17 :             if (bIsGDALG)
    1020             :             {
    1021           1 :                 return true;
    1022             :             }
    1023             :         }
    1024             :     }
    1025             : 
    1026          57 :     if (m_format == "stream" || !m_standaloneStep)
    1027             :     {
    1028          15 :         m_outputDataset.Set(std::move(vrt));
    1029          15 :         return true;
    1030             :     }
    1031             : 
    1032          84 :     CPLStringList translateArgs;
    1033          42 :     if (!m_format.empty())
    1034             :     {
    1035           7 :         translateArgs.AddString("-of");
    1036           7 :         translateArgs.AddString(m_format.c_str());
    1037             :     }
    1038          43 :     for (const auto &co : m_creationOptions)
    1039             :     {
    1040           1 :         translateArgs.AddString("-co");
    1041           1 :         translateArgs.AddString(co.c_str());
    1042             :     }
    1043             : 
    1044             :     GDALTranslateOptions *translateOptions =
    1045          42 :         GDALTranslateOptionsNew(translateArgs.List(), nullptr);
    1046          42 :     GDALTranslateOptionsSetProgress(translateOptions, ctxt.m_pfnProgress,
    1047             :                                     ctxt.m_pProgressData);
    1048             : 
    1049             :     auto poOutDS =
    1050             :         std::unique_ptr<GDALDataset>(GDALDataset::FromHandle(GDALTranslate(
    1051          42 :             m_outputDataset.GetName().c_str(), GDALDataset::ToHandle(vrt.get()),
    1052          84 :             translateOptions, nullptr)));
    1053          42 :     GDALTranslateOptionsFree(translateOptions);
    1054             : 
    1055          42 :     const bool bOK = poOutDS != nullptr;
    1056          42 :     m_outputDataset.Set(std::move(poOutDS));
    1057             : 
    1058          42 :     return bOK;
    1059             : }
    1060             : 
    1061             : GDALRasterCalcAlgorithmStandalone::~GDALRasterCalcAlgorithmStandalone() =
    1062             :     default;
    1063             : 
    1064             : //! @endcond

Generated by: LCOV version 1.14