LCOV - code coverage report
Current view: top level - frmts/vrt - vrtexpression_exprtk.cpp (source / functions) Hit Total Coverage
Test: gdal_filtered.info Lines: 128 133 96.2 %
Date: 2025-05-31 00:00:17 Functions: 18 18 100.0 %

          Line data    Source code
       1             : /******************************************************************************
       2             :  *
       3             :  * Project:  Virtual GDAL Datasets
       4             :  * Purpose:  Implementation of GDALExpressionEvaluator.
       5             :  * Author:   Daniel Baston
       6             :  *
       7             :  ******************************************************************************
       8             :  * Copyright (c) 2024, ISciences LLC
       9             :  *
      10             :  * SPDX-License-Identifier: MIT
      11             :  ****************************************************************************/
      12             : 
      13             : #include "cpl_conv.h"
      14             : #include "cpl_error.h"
      15             : #include "cpl_string.h"
      16             : #include "vrtexpression.h"
      17             : 
      18             : #define exprtk_disable_caseinsensitivity
      19             : #define exprtk_disable_rtl_io
      20             : #define exprtk_disable_rtl_io_file
      21             : #define exprtk_disable_rtl_vecops
      22             : #define exprtk_disable_string_capabilities
      23             : 
      24             : #if defined(__GNUC__)
      25             : #pragma GCC diagnostic push
      26             : #pragma GCC diagnostic ignored "-Wnull-dereference"
      27             : #endif
      28             : 
      29             : #include <exprtk.hpp>
      30             : 
      31             : #if defined(__GNUC__)
      32             : #pragma GCC diagnostic pop
      33             : #endif
      34             : 
      35             : #include <chrono>
      36             : #include <cstdint>
      37             : #include <limits>
      38             : #include <sstream>
      39             : #include <thread>
      40             : 
      41             : namespace gdal
      42             : {
      43             : 
      44             : /*! @cond Doxygen_Suppress */
      45             : struct vector_access_check final : public exprtk::vector_access_runtime_check
      46             : {
      47           1 :     bool handle_runtime_violation(violation_context &context) override
      48             :     {
      49           1 :         auto nElements = (static_cast<std::uint8_t *>(context.end_ptr) -
      50           1 :                           static_cast<std::uint8_t *>(context.base_ptr)) /
      51           1 :                          context.type_size;
      52           1 :         auto nIndexAccessed = (static_cast<std::uint8_t *>(context.access_ptr) -
      53           1 :                                static_cast<std::uint8_t *>(context.base_ptr)) /
      54           1 :                               context.type_size;
      55             : 
      56           2 :         std::ostringstream oss;
      57           1 :         oss << "Attempted to access index " << nIndexAccessed
      58           1 :             << " in a vector of " << nElements
      59           1 :             << " elements when evaluating VRT expression.";
      60           1 :         throw std::runtime_error(oss.str());
      61             :     }
      62             : };
      63             : 
      64             : struct loop_timeout_check final : public exprtk::loop_runtime_check
      65             : {
      66             :     using time_point_t = std::chrono::time_point<std::chrono::steady_clock>;
      67             : 
      68          27 :     loop_timeout_check() : exprtk::loop_runtime_check()
      69             :     {
      70             :         double dfMaxLoopIterationSeconds =
      71          27 :             CPLAtofM(CPLGetConfigOption("GDAL_EXPRTK_TIMEOUT_SECONDS", "1"));
      72          54 :         max_duration = std::chrono::microseconds(
      73          27 :             static_cast<size_t>(dfMaxLoopIterationSeconds * 1e6));
      74          27 :     }
      75             : 
      76          26 :     void start_timer()
      77             :     {
      78          26 :         timeout_t = std::chrono::steady_clock::now() + max_duration;
      79          26 :     }
      80             : 
      81       10114 :     bool check() override
      82             :     {
      83             : 
      84       10114 :         if (++iterations >= max_iters_per_check)
      85             :         {
      86           1 :             if (std::chrono::steady_clock::now() > timeout_t)
      87             :             {
      88           1 :                 timeout = true;
      89           1 :                 return false;
      90             :             }
      91             : 
      92           0 :             iterations = 0;
      93             :         }
      94             : 
      95       10113 :         return true;
      96             :     }
      97             : 
      98           1 :     void handle_runtime_violation(const violation_context &) override
      99             :     {
     100           2 :         std::ostringstream oss;
     101             : 
     102             :         // current version of exprtk does not report the correct
     103             :         // violation in case of timeout, so we track the error category
     104             :         // ourselves
     105           1 :         if (timeout)
     106             :         {
     107           1 :             oss << "Expression evaluation time exceeded maximum of "
     108           1 :                 << static_cast<double>(max_duration.count() / 1e6)
     109             :                 << " seconds. You can increase this threshold by setting the "
     110             :                 << "GDAL_EXPRTK_TIMEOUT_SECONDS configuration "
     111           1 :                 << "option.";
     112             :         }
     113             :         else
     114             :         {
     115           0 :             oss << "Exceeded maximum of " << max_loop_iterations
     116           0 :                 << " loop iterations.";
     117             :         }
     118             : 
     119           1 :         throw std::runtime_error(oss.str());
     120             :     }
     121             : 
     122             :     static constexpr size_t max_iters_per_check = 10000;
     123             :     size_t iterations = 0;
     124             :     time_point_t timeout_t{};
     125             :     std::chrono::microseconds max_duration{};
     126             :     bool timeout{false};
     127             : };
     128             : 
     129             : class ExprtkExpression::Impl
     130             : {
     131             :   public:
     132             :     exprtk::expression<double> m_oExpression{};
     133             :     exprtk::parser<double> m_oParser{};
     134             :     exprtk::symbol_table<double> m_oSymbolTable{};
     135             :     std::string m_osExpression{};
     136             : 
     137             :     std::vector<std::pair<std::string, double *>> m_aoVariables{};
     138             :     std::vector<std::pair<std::string, std::vector<double> *>> m_aoVectors{};
     139             :     std::vector<double> m_adfResults{};
     140             :     vector_access_check m_oVectorAccessCheck{};
     141             :     loop_timeout_check m_oLoopRuntimeCheck{};
     142             : 
     143             :     bool m_bIsCompiled{false};
     144             : 
     145          27 :     explicit Impl()
     146          27 :     {
     147             :         using settings_t = std::decay_t<decltype(m_oParser.settings())>;
     148             : 
     149          27 :         m_oLoopRuntimeCheck.loop_set = loop_timeout_check::e_all_loops;
     150          27 :         m_oLoopRuntimeCheck.max_loop_iterations = std::numeric_limits<
     151          27 :             decltype(m_oLoopRuntimeCheck.max_loop_iterations)>::max();
     152          27 :         m_oParser.register_vector_access_runtime_check(m_oVectorAccessCheck);
     153          27 :         m_oParser.register_loop_runtime_check(m_oLoopRuntimeCheck);
     154             : 
     155             : #ifndef NDEBUG
     156             :         // Only used for automated testing of GDAL_EXPRTK_TIMEOUT_SECONDS
     157          27 :         m_oSymbolTable.add_function("sleep", sleep);
     158             : #endif
     159             : 
     160          27 :         int nMaxVectorLength = std::atoi(
     161             :             CPLGetConfigOption("GDAL_EXPRTK_MAX_VECTOR_LENGTH", "100000"));
     162             : 
     163          27 :         if (nMaxVectorLength > 0)
     164             :         {
     165          27 :             m_oParser.settings().set_max_local_vector_size(nMaxVectorLength);
     166             :         }
     167             : 
     168             :         bool bEnableLoops =
     169          27 :             CPLTestBool(CPLGetConfigOption("GDAL_EXPRTK_ENABLE_LOOPS", "YES"));
     170          27 :         if (!bEnableLoops)
     171             :         {
     172           1 :             m_oParser.settings().disable_control_structure(
     173           1 :                 settings_t::e_ctrl_for_loop);
     174           1 :             m_oParser.settings().disable_control_structure(
     175           1 :                 settings_t::e_ctrl_while_loop);
     176           1 :             m_oParser.settings().disable_control_structure(
     177           1 :                 settings_t::e_ctrl_repeat_loop);
     178             :         }
     179          27 :     }
     180             : 
     181          27 :     CPLErr compile()
     182             :     {
     183          27 :         int nMaxExpressionLength = std::atoi(
     184             :             CPLGetConfigOption("GDAL_EXPRTK_MAX_EXPRESSION_LENGTH", "100000"));
     185          27 :         if (m_osExpression.size() >
     186          27 :             static_cast<std::size_t>(nMaxExpressionLength))
     187             :         {
     188           1 :             CPLError(CE_Failure, CPLE_AppDefined,
     189             :                      "Expression length of %d exceeds maximum of %d set by "
     190             :                      "GDAL_EXPRTK_MAX_EXPRESSION_LENGTH",
     191           1 :                      static_cast<int>(m_osExpression.size()),
     192             :                      nMaxExpressionLength);
     193           1 :             return CE_Failure;
     194             :         }
     195             : 
     196         138 :         for (const auto &[osVariable, pdfValueLoc] : m_aoVariables)
     197             :         {
     198         112 :             m_oSymbolTable.add_variable(osVariable, *pdfValueLoc);
     199             :         }
     200             : 
     201          32 :         for (const auto &[osVariable, padfVectorLoc] : m_aoVectors)
     202             :         {
     203           6 :             m_oSymbolTable.add_vector(osVariable, *padfVectorLoc);
     204             :         }
     205             : 
     206          26 :         m_oExpression.register_symbol_table(m_oSymbolTable);
     207          26 :         bool bSuccess = m_oParser.compile(m_osExpression, m_oExpression);
     208             : 
     209          26 :         if (!bSuccess)
     210             :         {
     211         208 :             for (size_t i = 0; i < m_oParser.error_count(); i++)
     212             :             {
     213         204 :                 const auto &oError = m_oParser.get_error(i);
     214             : 
     215         408 :                 CPLError(CE_Warning, CPLE_AppDefined,
     216             :                          "Position: %02d "
     217             :                          "Type: [%s] "
     218             :                          "Message: %s\n",
     219         204 :                          static_cast<int>(oError.token.position),
     220         408 :                          exprtk::parser_error::to_str(oError.mode).c_str(),
     221             :                          oError.diagnostic.c_str());
     222             :             }
     223             : 
     224           4 :             CPLError(CE_Failure, CPLE_AppDefined,
     225             :                      "Failed to parse expression.");
     226           4 :             return CE_Failure;
     227             :         }
     228             : 
     229          22 :         m_bIsCompiled = true;
     230             : 
     231          22 :         return CE_None;
     232             :     }
     233             : 
     234          31 :     CPLErr evaluate()
     235             :     {
     236          31 :         if (!m_bIsCompiled)
     237             :         {
     238          15 :             auto eErr = compile();
     239          15 :             if (eErr != CE_None)
     240             :             {
     241           3 :                 return eErr;
     242             :             }
     243             :         }
     244             : 
     245          28 :         m_adfResults.clear();
     246             :         double value;
     247             :         try
     248             :         {
     249          28 :             value = m_oExpression.value();  // force evaluation
     250             :         }
     251           2 :         catch (const std::exception &e)
     252             :         {
     253           2 :             CPLError(CE_Failure, CPLE_AppDefined, "%s", e.what());
     254           2 :             return CE_Failure;
     255             :         }
     256             : 
     257          26 :         m_oLoopRuntimeCheck.start_timer();
     258          26 :         const auto &results = m_oExpression.results();
     259             : 
     260             :         // We follow a different method to get the result depending on
     261             :         // how the expression was formed. If a "return" statement was
     262             :         // used, the result will be accessible via the "result" object.
     263             :         // If no "return" statement was used, the result is accessible
     264             :         // from the "value" variable (and must not be a vector.)
     265          26 :         if (results.count() == 0)
     266             :         {
     267          16 :             m_adfResults.resize(1);
     268          16 :             m_adfResults[0] = value;
     269             :         }
     270          10 :         else if (results.count() == 1)
     271             :         {
     272             : 
     273           6 :             if (results[0].type == exprtk::type_store<double>::e_scalar)
     274             :             {
     275           2 :                 m_adfResults.resize(1);
     276           2 :                 results.get_scalar(0, m_adfResults[0]);
     277             :             }
     278           4 :             else if (results[0].type == exprtk::type_store<double>::e_vector)
     279             :             {
     280           4 :                 results.get_vector(0, m_adfResults);
     281             :             }
     282             :             else
     283             :             {
     284           0 :                 CPLError(CE_Failure, CPLE_AppDefined,
     285             :                          "Expression returned an unexpected type.");
     286           0 :                 return CE_Failure;
     287             :             }
     288             :         }
     289             :         else
     290             :         {
     291           4 :             m_adfResults.resize(results.count());
     292          10 :             for (size_t i = 0; i < results.count(); i++)
     293             :             {
     294           7 :                 if (results[i].type != exprtk::type_store<double>::e_scalar)
     295             :                 {
     296           1 :                     CPLError(CE_Failure, CPLE_AppDefined,
     297             :                              "Expression must return a vector or a list of "
     298             :                              "scalars.");
     299           1 :                     return CE_Failure;
     300             :                 }
     301             :                 else
     302             :                 {
     303           6 :                     results.get_scalar(i, m_adfResults[i]);
     304             :                 }
     305             :             }
     306             :         }
     307             : 
     308          25 :         return CE_None;
     309             :     }
     310             : 
     311             :   private:
     312             : #ifndef NDEBUG
     313             :     struct sleep_fn : public exprtk::ifunction<double>
     314             :     {
     315          27 :         sleep_fn() : exprtk::ifunction<double>(1)
     316             :         {
     317          27 :         }
     318             : 
     319             :         using exprtk::ifunction<double>::operator();
     320             : 
     321        9999 :         double operator()(const double &seconds) override
     322             :         {
     323        9999 :             std::this_thread::sleep_for(
     324        9999 :                 std::chrono::microseconds(static_cast<int>(seconds * 1e6)));
     325        9999 :             return 0;
     326             :         }
     327             :     };
     328             : 
     329             :     sleep_fn sleep{};
     330             : #endif
     331             : };
     332             : 
     333             : /**
     334             :  * Define an expression to be evaluated using the exprtk library.
     335             :  *
     336             :  * @param osExpression the expression to evaluate. Refer to exprtk library documentation
     337             :  *                     for details of the allowable syntax.
     338             :  *
     339             :  * @since 3.11
     340             :  */
     341          27 : ExprtkExpression::ExprtkExpression(std::string_view osExpression)
     342          27 :     : m_pImpl(std::make_unique<Impl>())
     343             : {
     344          27 :     m_pImpl->m_osExpression = osExpression;
     345          27 : }
     346             : 
     347          54 : ExprtkExpression::~ExprtkExpression()
     348             : {
     349          54 : }
     350             : 
     351         113 : void ExprtkExpression::RegisterVariable(std::string_view osVariable,
     352             :                                         double *pdfValue)
     353             : {
     354         113 :     m_pImpl->m_aoVariables.emplace_back(osVariable, pdfValue);
     355         113 : }
     356             : 
     357           6 : void ExprtkExpression::RegisterVector(std::string_view osVariable,
     358             :                                       std::vector<double> *padfValue)
     359             : {
     360           6 :     m_pImpl->m_aoVectors.emplace_back(osVariable, padfValue);
     361           6 : }
     362             : 
     363          12 : CPLErr ExprtkExpression::Compile()
     364             : {
     365          12 :     return m_pImpl->compile();
     366             : }
     367             : 
     368          51 : const std::vector<double> &ExprtkExpression::Results() const
     369             : {
     370          51 :     return m_pImpl->m_adfResults;
     371             : }
     372             : 
     373          31 : CPLErr ExprtkExpression::Evaluate()
     374             : {
     375          31 :     return m_pImpl->evaluate();
     376             : }
     377             : 
     378             : /*! @endcond Doxygen_Suppress */
     379             : 
     380             : }  // namespace gdal

Generated by: LCOV version 1.14