LCOV - code coverage report
Current view: top level - frmts/vrt - vrtexpression_exprtk.cpp (source / functions) Hit Total Coverage
Test: gdal_filtered.info Lines: 127 132 96.2 %
Date: 2025-12-08 00:14:33 Functions: 18 18 100.0 %

          Line data    Source code
       1             : /******************************************************************************
       2             :  *
       3             :  * Project:  Virtual GDAL Datasets
       4             :  * Purpose:  Implementation of GDALExpressionEvaluator.
       5             :  * Author:   Daniel Baston
       6             :  *
       7             :  ******************************************************************************
       8             :  * Copyright (c) 2024, ISciences LLC
       9             :  *
      10             :  * SPDX-License-Identifier: MIT
      11             :  ****************************************************************************/
      12             : 
      13             : #include "cpl_conv.h"
      14             : #include "cpl_error.h"
      15             : #include "cpl_string.h"
      16             : #include "vrtexpression.h"
      17             : 
      18             : #define exprtk_disable_caseinsensitivity
      19             : #define exprtk_disable_rtl_io
      20             : #define exprtk_disable_rtl_io_file
      21             : #define exprtk_disable_rtl_vecops
      22             : #define exprtk_disable_string_capabilities
      23             : 
      24             : #if defined(__GNUC__)
      25             : #pragma GCC diagnostic push
      26             : #pragma GCC diagnostic ignored "-Wnull-dereference"
      27             : #endif
      28             : 
      29             : #include <exprtk.hpp>
      30             : 
      31             : #if defined(__GNUC__)
      32             : #pragma GCC diagnostic pop
      33             : #endif
      34             : 
      35             : #include <chrono>
      36             : #include <cstdint>
      37             : #include <limits>
      38             : #include <sstream>
      39             : #include <thread>
      40             : 
      41             : namespace gdal
      42             : {
      43             : 
      44             : /*! @cond Doxygen_Suppress */
      45             : struct vector_access_check final : public exprtk::vector_access_runtime_check
      46             : {
      47             :     bool handle_runtime_violation(violation_context &context) override;
      48             : };
      49             : 
      50           1 : bool vector_access_check::handle_runtime_violation(violation_context &context)
      51             : {
      52           1 :     auto nElements = (static_cast<std::uint8_t *>(context.end_ptr) -
      53           1 :                       static_cast<std::uint8_t *>(context.base_ptr)) /
      54           1 :                      context.type_size;
      55           1 :     auto nIndexAccessed = (static_cast<std::uint8_t *>(context.access_ptr) -
      56           1 :                            static_cast<std::uint8_t *>(context.base_ptr)) /
      57           1 :                           context.type_size;
      58             : 
      59           2 :     std::ostringstream oss;
      60           1 :     oss << "Attempted to access index " << nIndexAccessed << " in a vector of "
      61           1 :         << nElements << " elements when evaluating VRT expression.";
      62           1 :     throw std::runtime_error(oss.str());
      63             : }
      64             : 
      65             : struct loop_timeout_check final : public exprtk::loop_runtime_check
      66             : {
      67             :     using time_point_t = std::chrono::time_point<std::chrono::steady_clock>;
      68             : 
      69          27 :     loop_timeout_check() : exprtk::loop_runtime_check()
      70             :     {
      71             :         double dfMaxLoopIterationSeconds =
      72          27 :             CPLAtofM(CPLGetConfigOption("GDAL_EXPRTK_TIMEOUT_SECONDS", "1"));
      73          54 :         max_duration = std::chrono::microseconds(
      74          27 :             static_cast<size_t>(dfMaxLoopIterationSeconds * 1e6));
      75          27 :     }
      76             : 
      77          26 :     void start_timer()
      78             :     {
      79          26 :         timeout_t = std::chrono::steady_clock::now() + max_duration;
      80          26 :     }
      81             : 
      82       10114 :     bool check() override
      83             :     {
      84             : 
      85       10114 :         if (++iterations >= max_iters_per_check)
      86             :         {
      87           1 :             if (std::chrono::steady_clock::now() > timeout_t)
      88             :             {
      89           1 :                 timeout = true;
      90           1 :                 return false;
      91             :             }
      92             : 
      93           0 :             iterations = 0;
      94             :         }
      95             : 
      96       10113 :         return true;
      97             :     }
      98             : 
      99             :     void handle_runtime_violation(const violation_context &) override;
     100             : 
     101             :   private:
     102             :     static constexpr size_t max_iters_per_check = 10000;
     103             :     size_t iterations = 0;
     104             :     time_point_t timeout_t{};
     105             :     std::chrono::microseconds max_duration{};
     106             :     bool timeout{false};
     107             : };
     108             : 
     109           1 : void loop_timeout_check::handle_runtime_violation(const violation_context &)
     110             : {
     111           2 :     std::ostringstream oss;
     112             : 
     113             :     // current version of exprtk does not report the correct
     114             :     // violation in case of timeout, so we track the error category
     115             :     // ourselves
     116           1 :     if (timeout)
     117             :     {
     118           1 :         oss << "Expression evaluation time exceeded maximum of "
     119           1 :             << static_cast<double>(max_duration.count() / 1e6)
     120             :             << " seconds. You can increase this threshold by setting the "
     121             :             << "GDAL_EXPRTK_TIMEOUT_SECONDS configuration "
     122           1 :             << "option.";
     123             :     }
     124             :     else
     125             :     {
     126           0 :         oss << "Exceeded maximum of " << max_loop_iterations
     127           0 :             << " loop iterations.";
     128             :     }
     129             : 
     130           1 :     throw std::runtime_error(oss.str());
     131             : }
     132             : 
     133             : class ExprtkExpression::Impl
     134             : {
     135             :   public:
     136             :     exprtk::expression<double> m_oExpression{};
     137             :     exprtk::parser<double> m_oParser{};
     138             :     exprtk::symbol_table<double> m_oSymbolTable{};
     139             :     std::string m_osExpression{};
     140             : 
     141             :     std::vector<std::pair<std::string, double *>> m_aoVariables{};
     142             :     std::vector<std::pair<std::string, std::vector<double> *>> m_aoVectors{};
     143             :     std::vector<double> m_adfResults{};
     144             :     vector_access_check m_oVectorAccessCheck{};
     145             :     loop_timeout_check m_oLoopRuntimeCheck{};
     146             : 
     147             :     bool m_bIsCompiled{false};
     148             : 
     149          27 :     explicit Impl()
     150          27 :     {
     151             :         using settings_t = std::decay_t<decltype(m_oParser.settings())>;
     152             : 
     153          27 :         m_oLoopRuntimeCheck.loop_set = loop_timeout_check::e_all_loops;
     154          27 :         m_oLoopRuntimeCheck.max_loop_iterations = std::numeric_limits<
     155          27 :             decltype(m_oLoopRuntimeCheck.max_loop_iterations)>::max();
     156          27 :         m_oParser.register_vector_access_runtime_check(m_oVectorAccessCheck);
     157          27 :         m_oParser.register_loop_runtime_check(m_oLoopRuntimeCheck);
     158             : 
     159             : #ifndef NDEBUG
     160             :         // Only used for automated testing of GDAL_EXPRTK_TIMEOUT_SECONDS
     161          27 :         m_oSymbolTable.add_function("sleep", sleep);
     162             : #endif
     163             : 
     164          27 :         int nMaxVectorLength = std::atoi(
     165             :             CPLGetConfigOption("GDAL_EXPRTK_MAX_VECTOR_LENGTH", "100000"));
     166             : 
     167          27 :         if (nMaxVectorLength > 0)
     168             :         {
     169          27 :             m_oParser.settings().set_max_local_vector_size(nMaxVectorLength);
     170             :         }
     171             : 
     172             :         bool bEnableLoops =
     173          27 :             CPLTestBool(CPLGetConfigOption("GDAL_EXPRTK_ENABLE_LOOPS", "YES"));
     174          27 :         if (!bEnableLoops)
     175             :         {
     176           1 :             m_oParser.settings().disable_control_structure(
     177           1 :                 settings_t::e_ctrl_for_loop);
     178           1 :             m_oParser.settings().disable_control_structure(
     179           1 :                 settings_t::e_ctrl_while_loop);
     180           1 :             m_oParser.settings().disable_control_structure(
     181           1 :                 settings_t::e_ctrl_repeat_loop);
     182             :         }
     183          27 :     }
     184             : 
     185          27 :     CPLErr compile()
     186             :     {
     187          27 :         int nMaxExpressionLength = std::atoi(
     188             :             CPLGetConfigOption("GDAL_EXPRTK_MAX_EXPRESSION_LENGTH", "100000"));
     189          27 :         if (m_osExpression.size() >
     190          27 :             static_cast<std::size_t>(nMaxExpressionLength))
     191             :         {
     192           1 :             CPLError(CE_Failure, CPLE_AppDefined,
     193             :                      "Expression length of %d exceeds maximum of %d set by "
     194             :                      "GDAL_EXPRTK_MAX_EXPRESSION_LENGTH",
     195           1 :                      static_cast<int>(m_osExpression.size()),
     196             :                      nMaxExpressionLength);
     197           1 :             return CE_Failure;
     198             :         }
     199             : 
     200         138 :         for (const auto &[osVariable, pdfValueLoc] : m_aoVariables)
     201             :         {
     202         112 :             m_oSymbolTable.add_variable(osVariable, *pdfValueLoc);
     203             :         }
     204             : 
     205          32 :         for (const auto &[osVariable, padfVectorLoc] : m_aoVectors)
     206             :         {
     207           6 :             m_oSymbolTable.add_vector(osVariable, *padfVectorLoc);
     208             :         }
     209             : 
     210          26 :         m_oExpression.register_symbol_table(m_oSymbolTable);
     211          26 :         bool bSuccess = m_oParser.compile(m_osExpression, m_oExpression);
     212             : 
     213          26 :         if (!bSuccess)
     214             :         {
     215         208 :             for (size_t i = 0; i < m_oParser.error_count(); i++)
     216             :             {
     217         204 :                 const auto &oError = m_oParser.get_error(i);
     218             : 
     219         408 :                 CPLError(CE_Warning, CPLE_AppDefined,
     220             :                          "Position: %02d "
     221             :                          "Type: [%s] "
     222             :                          "Message: %s\n",
     223         204 :                          static_cast<int>(oError.token.position),
     224         408 :                          exprtk::parser_error::to_str(oError.mode).c_str(),
     225             :                          oError.diagnostic.c_str());
     226             :             }
     227             : 
     228           4 :             CPLError(CE_Failure, CPLE_AppDefined,
     229             :                      "Failed to parse expression.");
     230           4 :             return CE_Failure;
     231             :         }
     232             : 
     233          22 :         m_bIsCompiled = true;
     234             : 
     235          22 :         return CE_None;
     236             :     }
     237             : 
     238          31 :     CPLErr evaluate()
     239             :     {
     240          31 :         if (!m_bIsCompiled)
     241             :         {
     242          15 :             auto eErr = compile();
     243          15 :             if (eErr != CE_None)
     244             :             {
     245           3 :                 return eErr;
     246             :             }
     247             :         }
     248             : 
     249          28 :         m_adfResults.clear();
     250             :         double value;
     251             :         try
     252             :         {
     253          28 :             value = m_oExpression.value();  // force evaluation
     254             :         }
     255           2 :         catch (const std::exception &e)
     256             :         {
     257           2 :             CPLError(CE_Failure, CPLE_AppDefined, "%s", e.what());
     258           2 :             return CE_Failure;
     259             :         }
     260             : 
     261          26 :         m_oLoopRuntimeCheck.start_timer();
     262          26 :         const auto &results = m_oExpression.results();
     263             : 
     264             :         // We follow a different method to get the result depending on
     265             :         // how the expression was formed. If a "return" statement was
     266             :         // used, the result will be accessible via the "result" object.
     267             :         // If no "return" statement was used, the result is accessible
     268             :         // from the "value" variable (and must not be a vector.)
     269          26 :         if (results.count() == 0)
     270             :         {
     271          16 :             m_adfResults.resize(1);
     272          16 :             m_adfResults[0] = value;
     273             :         }
     274          10 :         else if (results.count() == 1)
     275             :         {
     276             : 
     277           6 :             if (results[0].type == exprtk::type_store<double>::e_scalar)
     278             :             {
     279           2 :                 m_adfResults.resize(1);
     280           2 :                 results.get_scalar(0, m_adfResults[0]);
     281             :             }
     282           4 :             else if (results[0].type == exprtk::type_store<double>::e_vector)
     283             :             {
     284           4 :                 results.get_vector(0, m_adfResults);
     285             :             }
     286             :             else
     287             :             {
     288           0 :                 CPLError(CE_Failure, CPLE_AppDefined,
     289             :                          "Expression returned an unexpected type.");
     290           0 :                 return CE_Failure;
     291             :             }
     292             :         }
     293             :         else
     294             :         {
     295           4 :             m_adfResults.resize(results.count());
     296          10 :             for (size_t i = 0; i < results.count(); i++)
     297             :             {
     298           7 :                 if (results[i].type != exprtk::type_store<double>::e_scalar)
     299             :                 {
     300           1 :                     CPLError(CE_Failure, CPLE_AppDefined,
     301             :                              "Expression must return a vector or a list of "
     302             :                              "scalars.");
     303           1 :                     return CE_Failure;
     304             :                 }
     305             :                 else
     306             :                 {
     307           6 :                     results.get_scalar(i, m_adfResults[i]);
     308             :                 }
     309             :             }
     310             :         }
     311             : 
     312          25 :         return CE_None;
     313             :     }
     314             : 
     315             :   private:
     316             : #ifndef NDEBUG
     317             :     struct sleep_fn : public exprtk::ifunction<double>
     318             :     {
     319          27 :         sleep_fn() : exprtk::ifunction<double>(1)
     320             :         {
     321          27 :         }
     322             : 
     323             :         using exprtk::ifunction<double>::operator();
     324             : 
     325        9999 :         double operator()(const double &seconds) override
     326             :         {
     327        9999 :             std::this_thread::sleep_for(
     328        9999 :                 std::chrono::microseconds(static_cast<int>(seconds * 1e6)));
     329        9999 :             return 0;
     330             :         }
     331             :     };
     332             : 
     333             :     sleep_fn sleep{};
     334             : #endif
     335             : };
     336             : 
     337             : /**
     338             :  * Define an expression to be evaluated using the exprtk library.
     339             :  *
     340             :  * @param osExpression the expression to evaluate. Refer to exprtk library documentation
     341             :  *                     for details of the allowable syntax.
     342             :  *
     343             :  * @since 3.11
     344             :  */
     345          27 : ExprtkExpression::ExprtkExpression(std::string_view osExpression)
     346          27 :     : m_pImpl(std::make_unique<Impl>())
     347             : {
     348          27 :     m_pImpl->m_osExpression = osExpression;
     349          27 : }
     350             : 
     351          54 : ExprtkExpression::~ExprtkExpression()
     352             : {
     353          54 : }
     354             : 
     355         113 : void ExprtkExpression::RegisterVariable(std::string_view osVariable,
     356             :                                         double *pdfValue)
     357             : {
     358         113 :     m_pImpl->m_aoVariables.emplace_back(osVariable, pdfValue);
     359         113 : }
     360             : 
     361           6 : void ExprtkExpression::RegisterVector(std::string_view osVariable,
     362             :                                       std::vector<double> *padfValue)
     363             : {
     364           6 :     m_pImpl->m_aoVectors.emplace_back(osVariable, padfValue);
     365           6 : }
     366             : 
     367          12 : CPLErr ExprtkExpression::Compile()
     368             : {
     369          12 :     return m_pImpl->compile();
     370             : }
     371             : 
     372          51 : const std::vector<double> &ExprtkExpression::Results() const
     373             : {
     374          51 :     return m_pImpl->m_adfResults;
     375             : }
     376             : 
     377          31 : CPLErr ExprtkExpression::Evaluate()
     378             : {
     379          31 :     return m_pImpl->evaluate();
     380             : }
     381             : 
     382             : /*! @endcond Doxygen_Suppress */
     383             : 
     384             : }  // namespace gdal

Generated by: LCOV version 1.14