LCOV - code coverage report
Current view: top level - apps - gdalalg_vector_filter.cpp (source / functions) Hit Total Coverage
Test: gdal_filtered.info Lines: 205 240 85.4 %
Date: 2026-06-01 10:05:33 Functions: 11 12 91.7 %

          Line data    Source code
       1             : /******************************************************************************
       2             :  *
       3             :  * Project:  GDAL
       4             :  * Purpose:  "filter" step of "vector pipeline"
       5             :  * Author:   Even Rouault <even dot rouault at spatialys.com>
       6             :  *
       7             :  ******************************************************************************
       8             :  * Copyright (c) 2024, Even Rouault <even dot rouault at spatialys.com>
       9             :  *
      10             :  * SPDX-License-Identifier: MIT
      11             :  ****************************************************************************/
      12             : 
      13             : #include "gdalalg_vector_filter.h"
      14             : 
      15             : #include "gdal_priv.h"
      16             : #include "ogrsf_frmts.h"
      17             : #include "ogr_p.h"
      18             : 
      19             : #include <set>
      20             : 
      21             : //! @cond Doxygen_Suppress
      22             : 
      23             : #ifndef _
      24             : #define _(x) (x)
      25             : #endif
      26             : 
      27             : /************************************************************************/
      28             : /*        GDALVectorFilterAlgorithm::GDALVectorFilterAlgorithm()        */
      29             : /************************************************************************/
      30             : 
      31         114 : GDALVectorFilterAlgorithm::GDALVectorFilterAlgorithm(bool standaloneStep)
      32             :     : GDALVectorPipelineStepAlgorithm(NAME, DESCRIPTION, HELP_URL,
      33         114 :                                       standaloneStep)
      34             : {
      35         114 :     auto &layerArg = AddActiveLayerArg(&m_activeLayer);
      36         114 :     AddBBOXArg(&m_bbox);
      37             :     AddArg("where", 0,
      38             :            _("Attribute query in a restricted form of the queries used in the "
      39             :              "SQL WHERE statement"),
      40         228 :            &m_where)
      41         114 :         .SetReadFromFileAtSyntaxAllowed()
      42         228 :         .SetMetaVar("<WHERE>|@<filename>")
      43         114 :         .SetRemoveSQLCommentsEnabled()
      44             :         .SetAutoCompleteFunction(
      45          11 :             [this, &layerArg](const std::string &currentValue)
      46         125 :             { return CompleteWhere(layerArg, currentValue); });
      47             :     AddArg("update-extent", 0,
      48             :            _("Update layer extent to take into account the filter"),
      49         114 :            &m_updateExtent);
      50         114 : }
      51             : 
      52             : /************************************************************************/
      53             : /*              GDALVectorFilterAlgorithmLayerChangeExtent              */
      54             : /************************************************************************/
      55             : 
      56             : constexpr const char *const SQL_OPERATORS[] = {
      57             :     "=", "<>", "<", "<=", ">", ">=", "AND", "OR", "LIKE", "BETWEEN"};
      58             : 
      59           7 : static bool IsSQLOperator(const char *pszStr)
      60             : {
      61           7 :     return std::find_if(std::begin(SQL_OPERATORS), std::end(SQL_OPERATORS),
      62          31 :                         [pszStr](const char *pszStr2)
      63          31 :                         { return EQUAL(pszStr, pszStr2); }) !=
      64           7 :            std::end(SQL_OPERATORS);
      65             : }
      66             : 
      67          24 : static std::string GetSQLIdentifier(const std::string &name)
      68             : {
      69          24 :     if (name.find_first_of("'\" ") != std::string::npos)
      70             :     {
      71           0 :         char *pszEscaped = CPLEscapeString(name.c_str(), -1, CPLES_SQLI);
      72           0 :         std::string ret = std::string("\"").append(pszEscaped).append("\"");
      73           0 :         CPLFree(pszEscaped);
      74           0 :         return ret;
      75             :     }
      76             :     else
      77             :     {
      78          24 :         return name;
      79             :     }
      80             : }
      81             : 
      82          30 : static std::string GetSQLStringLiteral(const char *val)
      83             : {
      84          30 :     char *pszEscaped = CPLEscapeString(val, -1, CPLES_SQL);
      85          60 :     std::string ret = std::string("'").append(pszEscaped).append("'");
      86          30 :     CPLFree(pszEscaped);
      87          30 :     return ret;
      88             : }
      89             : 
      90             : std::vector<std::string>
      91          11 : GDALVectorFilterAlgorithm::CompleteWhere(const GDALAlgorithmArg &layerArg,
      92             :                                          const std::string &currentValue) const
      93             : {
      94          11 :     std::vector<std::string> ret;
      95          22 :     if (currentValue.empty() || currentValue[0] != '"' ||
      96          11 :         m_inputDataset.empty())
      97           0 :         return ret;
      98             : 
      99             :     auto poDS = std::unique_ptr<GDALDataset>(
     100          11 :         GDALDataset::Open(m_inputDataset[0].GetName().c_str(),
     101          22 :                           GDAL_OF_VECTOR | GDAL_OF_READONLY));
     102          11 :     if (!poDS)
     103           0 :         return ret;
     104             : 
     105             :     // Collect field names
     106          22 :     std::string layerName;
     107          11 :     if (layerArg.IsExplicitlySet())
     108           2 :         layerName = layerArg.Get<std::string>();
     109          22 :     std::map<std::string, std::vector<OGRLayer *>> fieldNames;
     110          11 :     if (layerName.empty())
     111             :     {
     112          18 :         for (auto *poLayer : poDS->GetLayers())
     113             :         {
     114          37 :             for (const auto *poFieldDefn : poLayer->GetLayerDefn()->GetFields())
     115             :             {
     116          28 :                 fieldNames[poFieldDefn->GetNameRef()].push_back(poLayer);
     117             :             }
     118             :         }
     119             :     }
     120           2 :     else if (auto *poLayer = poDS->GetLayerByName(layerName.c_str()))
     121             :     {
     122           5 :         for (const auto *poFieldDefn : poLayer->GetLayerDefn()->GetFields())
     123             :         {
     124           4 :             fieldNames[poFieldDefn->GetNameRef()].push_back(poLayer);
     125             :         }
     126             :     }
     127          11 :     if (fieldNames.empty())
     128           1 :         return ret;
     129             : 
     130             :     const CPLStringList aosTokens(CSLTokenizeString2(
     131          10 :         currentValue.c_str() + 1, " ",
     132             :         CSLT_HONOURSTRINGS | CSLT_HONOURSINGLEQUOTES | CSLT_PRESERVEQUOTES |
     133          20 :             CSLT_PRESERVEESCAPES | CSLT_STRIPLEADSPACES | CSLT_STRIPENDSPACES));
     134             : 
     135          20 :     std::string prefix;
     136          10 :     const int nTokens = aosTokens.size();
     137          10 :     if (nTokens > 0 && cpl::contains(fieldNames, aosTokens[nTokens - 1]))
     138             :     {
     139           1 :         prefix = currentValue.substr(1);
     140           1 :         if (!prefix.empty() && prefix.back() != ' ')
     141           0 :             prefix += ' ';
     142          11 :         for (const char *op : SQL_OPERATORS)
     143             :         {
     144          10 :             if (nTokens > 1 && strcmp(op, "AND") == 0)
     145           0 :                 break;
     146          10 :             ret.push_back(prefix + op);
     147             :         }
     148             :     }
     149             :     else
     150             :     {
     151             :         const bool bLastTokenIsSQLOperator =
     152           9 :             nTokens > 0 && IsSQLOperator(aosTokens[nTokens - 1]);
     153           9 :         const int nCompleteTokens =
     154           9 :             nTokens + (bLastTokenIsSQLOperator ? 0 : -1);
     155          23 :         for (int i = 0; i < nCompleteTokens; ++i)
     156             :         {
     157          14 :             if (!prefix.empty())
     158           8 :                 prefix += ' ';
     159          14 :             prefix += aosTokens[i];
     160             :         }
     161           9 :         if (!prefix.empty())
     162           6 :             prefix += ' ';
     163             : 
     164           9 :         const char *pszLastFieldName = nullptr;
     165           9 :         OGRLayer *poLastFieldLayer = nullptr;
     166           9 :         OGRFieldType eLastFieldType = OFTString;
     167           9 :         if (nCompleteTokens >= 2)
     168             :         {
     169           6 :             const char *pszFieldName = aosTokens[nCompleteTokens - 2];
     170           6 :             const auto oIter = fieldNames.find(pszFieldName);
     171           6 :             if (oIter != fieldNames.end() && oIter->second.size() == 1)
     172             :             {
     173             :                 const int nIdx =
     174           5 :                     oIter->second[0]->GetLayerDefn()->GetFieldIndex(
     175           5 :                         pszFieldName);
     176           5 :                 if (nIdx >= 0)
     177             :                 {
     178           5 :                     pszLastFieldName = pszFieldName;
     179           5 :                     poLastFieldLayer = oIter->second[0];
     180           5 :                     eLastFieldType = poLastFieldLayer->GetLayerDefn()
     181           5 :                                          ->GetFieldDefn(nIdx)
     182           5 :                                          ->GetType();
     183             :                 }
     184             :             }
     185             :         }
     186             : 
     187          38 :         for (const auto &[name, layers] : fieldNames)
     188             :         {
     189          29 :             bool canAdd = false;
     190          29 :             if (!pszLastFieldName)
     191             :             {
     192          12 :                 canAdd = true;
     193             :             }
     194          17 :             else if (name != pszLastFieldName)
     195             :             {
     196          12 :                 if (layers.size() > 1)
     197             :                 {
     198           0 :                     canAdd = true;
     199             :                 }
     200          12 :                 else if (layers[0]
     201          12 :                              ->GetLayerDefn()
     202             :                              ->GetFieldDefn(
     203          12 :                                  layers[0]->GetLayerDefn()->GetFieldIndex(
     204          12 :                                      name.c_str()))
     205          12 :                              ->GetType() == eLastFieldType)
     206             :                 {
     207           2 :                     canAdd = true;
     208             :                 }
     209             :             }
     210          29 :             if (canAdd)
     211             :             {
     212          14 :                 ret.push_back(prefix + GetSQLIdentifier(name));
     213             :             }
     214             :         }
     215             : 
     216           9 :         if (pszLastFieldName && poLastFieldLayer)
     217             :         {
     218           5 :             auto poLayer = poLastFieldLayer;
     219           5 :             constexpr int NOT_TOO_LARGE = 1000;
     220           5 :             const auto nFeatureCount = poLayer->GetFeatureCount();
     221           5 :             if (nFeatureCount > 0 && nFeatureCount < NOT_TOO_LARGE)
     222             :             {
     223           5 :                 constexpr int VALUES_COUNT = 10;
     224             :                 const std::string osSQLField =
     225          15 :                     GetSQLIdentifier(pszLastFieldName);
     226             :                 const std::string osSQLLayer =
     227          15 :                     GetSQLIdentifier(poLayer->GetName());
     228           5 :                 if (eLastFieldType == OFTString)
     229             :                 {
     230           6 :                     CPLString osSQL;
     231           3 :                     const char *pszDialect = nullptr;
     232           3 :                     if (GetGDALDriverManager()->GetDriverByName("SQLite"))
     233             :                     {
     234           3 :                         pszDialect = "SQLite";
     235             :                         // Find 10 most frequent strings
     236             :                         osSQL.Printf("SELECT %s, COUNT(%s) cnt FROM %s GROUP "
     237             :                                      "BY %s ORDER BY cnt DESC, %s ASC LIMIT %d",
     238             :                                      osSQLField.c_str(), osSQLField.c_str(),
     239             :                                      osSQLLayer.c_str(), osSQLField.c_str(),
     240           3 :                                      osSQLField.c_str(), VALUES_COUNT + 1);
     241             :                     }
     242             :                     else
     243             :                     {
     244             :                         osSQL.Printf("SELECT DISTINCT %s FROM %s LIMIT %d",
     245             :                                      osSQLField.c_str(), osSQLLayer.c_str(),
     246           0 :                                      VALUES_COUNT + 1);
     247             :                     }
     248             :                     auto poSQLLayer =
     249           3 :                         poDS->ExecuteSQL(osSQL.c_str(), nullptr, pszDialect);
     250           3 :                     if (poSQLLayer)
     251             :                     {
     252           3 :                         int nCount = 0;
     253          34 :                         for (auto &&poFeature : poSQLLayer)
     254             :                         {
     255          31 :                             if (nCount == VALUES_COUNT)
     256             :                             {
     257           1 :                                 ret.push_back(prefix + "'...other values...");
     258           1 :                                 break;
     259             :                             }
     260          30 :                             ret.push_back(prefix +
     261          60 :                                           GetSQLStringLiteral(
     262             :                                               poFeature->GetFieldAsString(0)));
     263          30 :                             nCount++;
     264             :                         }
     265           3 :                         poDS->ReleaseResultSet(poSQLLayer);
     266             :                     }
     267             :                 }
     268           2 :                 else if (eLastFieldType == OFTInteger ||
     269           0 :                          eLastFieldType == OFTInteger64 ||
     270             :                          eLastFieldType == OFTReal)
     271             :                 {
     272           4 :                     CPLString osSQL;
     273           2 :                     const char *pszDialect = nullptr;
     274           3 :                     if (nFeatureCount > VALUES_COUNT + 2 &&
     275           1 :                         GetGDALDriverManager()->GetDriverByName("SQLite"))
     276             :                     {
     277           1 :                         pszDialect = "SQLite";
     278             :                         // Collect lowest and highest values
     279             :                         osSQL.Printf("SELECT DISTINCT %s FROM ("
     280             :                                      "SELECT * FROM (SELECT DISTINCT %s FROM "
     281             :                                      "%s ORDER BY %s ASC LIMIT %d) UNION ALL "
     282             :                                      "SELECT * FROM (SELECT DISTINCT %s FROM "
     283             :                                      "%s ORDER BY %s DESC LIMIT %d)"
     284             :                                      ") x ORDER BY %s",
     285             :                                      osSQLField.c_str(), osSQLField.c_str(),
     286             :                                      osSQLLayer.c_str(), osSQLField.c_str(),
     287             :                                      VALUES_COUNT / 2, osSQLField.c_str(),
     288             :                                      osSQLLayer.c_str(), osSQLField.c_str(),
     289           1 :                                      VALUES_COUNT / 2 + 1, osSQLField.c_str());
     290             :                     }
     291             :                     else
     292             :                     {
     293             :                         osSQL.Printf("SELECT DISTINCT %s FROM %s LIMIT %d",
     294             :                                      osSQLField.c_str(), osSQLLayer.c_str(),
     295           1 :                                      VALUES_COUNT + 1);
     296             :                     }
     297             :                     auto poSQLLayer =
     298           2 :                         poDS->ExecuteSQL(osSQL.c_str(), nullptr, pszDialect);
     299           2 :                     if (poSQLLayer)
     300             :                     {
     301           2 :                         int nCount = 0;
     302          23 :                         for (auto &&poFeature : poSQLLayer)
     303             :                         {
     304          21 :                             if (nCount == VALUES_COUNT)
     305             :                             {
     306           1 :                                 if (pszDialect)
     307             :                                 {
     308           1 :                                     ret.erase(ret.begin() + VALUES_COUNT / 2 +
     309           1 :                                               1);
     310           1 :                                     ret.push_back(
     311           2 :                                         prefix +
     312             :                                         poFeature->GetFieldAsString(0));
     313             :                                 }
     314           1 :                                 ret.push_back(prefix + "...other values...");
     315           1 :                                 break;
     316             :                             }
     317          20 :                             ret.push_back(prefix +
     318             :                                           poFeature->GetFieldAsString(0));
     319          20 :                             nCount++;
     320             :                         }
     321           2 :                         poDS->ReleaseResultSet(poSQLLayer);
     322             :                     }
     323             :                 }
     324             :             }
     325             :         }
     326             :     }
     327          10 :     return ret;
     328             : }
     329             : 
     330             : /************************************************************************/
     331             : /*              GDALVectorFilterAlgorithmLayerChangeExtent              */
     332             : /************************************************************************/
     333             : 
     334             : namespace
     335             : {
     336             : class GDALVectorFilterAlgorithmLayerChangeExtent final
     337             :     : public GDALVectorPipelinePassthroughLayer
     338             : {
     339             :   public:
     340           1 :     GDALVectorFilterAlgorithmLayerChangeExtent(
     341             :         OGRLayer &oSrcLayer, const OGREnvelope3D &sLayerEnvelope)
     342           1 :         : GDALVectorPipelinePassthroughLayer(oSrcLayer),
     343           1 :           m_sLayerEnvelope(sLayerEnvelope)
     344             :     {
     345           1 :     }
     346             : 
     347           1 :     OGRErr IGetExtent(int /*iGeomField*/, OGREnvelope *psExtent,
     348             :                       bool /* bForce */) override
     349             :     {
     350           1 :         if (m_sLayerEnvelope.IsInit())
     351             :         {
     352           1 :             *psExtent = m_sLayerEnvelope;
     353           1 :             return OGRERR_NONE;
     354             :         }
     355             :         else
     356             :         {
     357           0 :             return OGRERR_FAILURE;
     358             :         }
     359             :     }
     360             : 
     361           1 :     OGRErr IGetExtent3D(int /*iGeomField*/, OGREnvelope3D *psExtent,
     362             :                         bool /* bForce */) override
     363             :     {
     364           1 :         if (m_sLayerEnvelope.IsInit())
     365             :         {
     366           1 :             *psExtent = m_sLayerEnvelope;
     367           1 :             return OGRERR_NONE;
     368             :         }
     369             :         else
     370             :         {
     371           0 :             return OGRERR_FAILURE;
     372             :         }
     373             :     }
     374             : 
     375           0 :     int TestCapability(const char *pszCap) const override
     376             :     {
     377           0 :         if (EQUAL(pszCap, OLCFastGetExtent))
     378           0 :             return true;
     379           0 :         return m_srcLayer.TestCapability(pszCap);
     380             :     }
     381             : 
     382             :   private:
     383             :     const OGREnvelope3D m_sLayerEnvelope;
     384             : };
     385             : 
     386             : }  // namespace
     387             : 
     388             : /************************************************************************/
     389             : /*                 GDALVectorFilterAlgorithm::RunStep()                 */
     390             : /************************************************************************/
     391             : 
     392          18 : bool GDALVectorFilterAlgorithm::RunStep(GDALPipelineStepRunContext &ctxt)
     393             : {
     394          18 :     auto poSrcDS = m_inputDataset[0].GetDatasetRef();
     395          18 :     CPLAssert(poSrcDS);
     396             : 
     397          18 :     CPLAssert(m_outputDataset.GetName().empty());
     398          18 :     CPLAssert(!m_outputDataset.GetDatasetRef());
     399             : 
     400          18 :     const int nLayerCount = poSrcDS->GetLayerCount();
     401             : 
     402          18 :     bool ret = true;
     403          18 :     if (m_bbox.size() == 4)
     404             :     {
     405           4 :         const double xmin = m_bbox[0];
     406           4 :         const double ymin = m_bbox[1];
     407           4 :         const double xmax = m_bbox[2];
     408           4 :         const double ymax = m_bbox[3];
     409           8 :         for (int i = 0; i < nLayerCount; ++i)
     410             :         {
     411           4 :             auto poSrcLayer = poSrcDS->GetLayer(i);
     412           4 :             ret = ret && (poSrcLayer != nullptr);
     413           4 :             if (poSrcLayer && (m_activeLayer.empty() ||
     414           0 :                                m_activeLayer == poSrcLayer->GetDescription()))
     415           4 :                 poSrcLayer->SetSpatialFilterRect(xmin, ymin, xmax, ymax);
     416             :         }
     417             :     }
     418             : 
     419          18 :     if (ret && !m_where.empty())
     420             :     {
     421          28 :         for (int i = 0; i < nLayerCount; ++i)
     422             :         {
     423          16 :             auto poSrcLayer = poSrcDS->GetLayer(i);
     424          16 :             ret = ret && (poSrcLayer != nullptr);
     425          22 :             if (ret && (m_activeLayer.empty() ||
     426           6 :                         m_activeLayer == poSrcLayer->GetDescription()))
     427             :             {
     428          13 :                 ret = poSrcLayer->SetAttributeFilter(m_where.c_str()) ==
     429             :                       OGRERR_NONE;
     430             :             }
     431             :         }
     432             :     }
     433             : 
     434          18 :     if (ret)
     435             :     {
     436             :         auto outDS =
     437          17 :             std::make_unique<GDALVectorPipelineOutputDataset>(*poSrcDS);
     438             : 
     439          17 :         int64_t nTotalFeatures = 0;
     440          17 :         if (m_updateExtent && ctxt.m_pfnProgress)
     441             :         {
     442           0 :             for (int i = 0; ret && i < nLayerCount; ++i)
     443             :             {
     444           0 :                 auto poSrcLayer = poSrcDS->GetLayer(i);
     445           0 :                 ret = (poSrcLayer != nullptr);
     446           0 :                 if (ret)
     447             :                 {
     448           0 :                     if (m_activeLayer.empty() ||
     449           0 :                         m_activeLayer == poSrcLayer->GetDescription())
     450             :                     {
     451           0 :                         if (poSrcLayer->TestCapability(OLCFastFeatureCount))
     452             :                         {
     453           0 :                             const auto nFC = poSrcLayer->GetFeatureCount(false);
     454           0 :                             if (nFC < 0)
     455             :                             {
     456           0 :                                 nTotalFeatures = 0;
     457           0 :                                 break;
     458             :                             }
     459           0 :                             nTotalFeatures += nFC;
     460             :                         }
     461             :                     }
     462             :                 }
     463             :             }
     464             :         }
     465             : 
     466          17 :         int64_t nFeatureCounter = 0;
     467          38 :         for (int i = 0; ret && i < nLayerCount; ++i)
     468             :         {
     469          21 :             auto poSrcLayer = poSrcDS->GetLayer(i);
     470          21 :             ret = (poSrcLayer != nullptr);
     471          21 :             if (ret)
     472             :             {
     473          25 :                 if (m_updateExtent &&
     474           4 :                     (m_activeLayer.empty() ||
     475           2 :                      m_activeLayer == poSrcLayer->GetDescription()))
     476             :                 {
     477           1 :                     OGREnvelope3D sLayerEnvelope, sFeatureEnvelope;
     478           2 :                     for (auto &&poFeature : poSrcLayer)
     479             :                     {
     480           1 :                         const auto poGeom = poFeature->GetGeometryRef();
     481           1 :                         if (poGeom && !poGeom->IsEmpty())
     482             :                         {
     483           1 :                             poGeom->getEnvelope(&sFeatureEnvelope);
     484           1 :                             sLayerEnvelope.Merge(sFeatureEnvelope);
     485             :                         }
     486             : 
     487           1 :                         ++nFeatureCounter;
     488           1 :                         if (nTotalFeatures > 0 && ctxt.m_pfnProgress &&
     489           0 :                             !ctxt.m_pfnProgress(
     490           0 :                                 static_cast<double>(nFeatureCounter) /
     491           0 :                                     static_cast<double>(nTotalFeatures),
     492             :                                 "", ctxt.m_pProgressData))
     493             :                         {
     494           0 :                             ReportError(CE_Failure, CPLE_UserInterrupt,
     495             :                                         "Interrupted by user");
     496           0 :                             return false;
     497             :                         }
     498             :                     }
     499           2 :                     outDS->AddLayer(
     500             :                         *poSrcLayer,
     501             :                         std::make_unique<
     502           2 :                             GDALVectorFilterAlgorithmLayerChangeExtent>(
     503             :                             *poSrcLayer, sLayerEnvelope));
     504             :                 }
     505             :                 else
     506             :                 {
     507          40 :                     outDS->AddLayer(
     508             :                         *poSrcLayer,
     509          40 :                         std::make_unique<GDALVectorPipelinePassthroughLayer>(
     510             :                             *poSrcLayer));
     511             :                 }
     512             :             }
     513             :         }
     514             : 
     515          17 :         if (ret)
     516          17 :             m_outputDataset.Set(std::move(outDS));
     517             :     }
     518             : 
     519          18 :     return ret;
     520             : }
     521             : 
     522             : GDALVectorFilterAlgorithmStandalone::~GDALVectorFilterAlgorithmStandalone() =
     523             :     default;
     524             : 
     525             : //! @endcond

Generated by: LCOV version 1.14