Line data Source code
1 : /******************************************************************************
2 : *
3 : * Project: Virtual GDAL Datasets
4 : * Purpose: Implementation of GDALExpressionEvaluator.
5 : * Author: Daniel Baston
6 : *
7 : ******************************************************************************
8 : * Copyright (c) 2024, ISciences LLC
9 : *
10 : * SPDX-License-Identifier: MIT
11 : ****************************************************************************/
12 :
13 : #include "cpl_conv.h"
14 : #include "cpl_error.h"
15 : #include "cpl_string.h"
16 : #include "vrtexpression.h"
17 :
18 : #define exprtk_disable_caseinsensitivity
19 : #define exprtk_disable_rtl_io
20 : #define exprtk_disable_rtl_io_file
21 : #define exprtk_disable_rtl_vecops
22 : #define exprtk_disable_string_capabilities
23 :
24 : #if defined(__GNUC__)
25 : #pragma GCC diagnostic push
26 : #pragma GCC diagnostic ignored "-Wnull-dereference"
27 : #endif
28 :
29 : #include <exprtk.hpp>
30 :
31 : #if defined(__GNUC__)
32 : #pragma GCC diagnostic pop
33 : #endif
34 :
35 : #include <chrono>
36 : #include <cstdint>
37 : #include <limits>
38 : #include <sstream>
39 : #include <thread>
40 :
41 : namespace gdal
42 : {
43 :
44 : /*! @cond Doxygen_Suppress */
45 : struct vector_access_check final : public exprtk::vector_access_runtime_check
46 : {
47 1 : bool handle_runtime_violation(violation_context &context) override
48 : {
49 1 : auto nElements = (static_cast<std::uint8_t *>(context.end_ptr) -
50 1 : static_cast<std::uint8_t *>(context.base_ptr)) /
51 1 : context.type_size;
52 1 : auto nIndexAccessed = (static_cast<std::uint8_t *>(context.access_ptr) -
53 1 : static_cast<std::uint8_t *>(context.base_ptr)) /
54 1 : context.type_size;
55 :
56 2 : std::ostringstream oss;
57 1 : oss << "Attempted to access index " << nIndexAccessed
58 1 : << " in a vector of " << nElements
59 1 : << " elements when evaluating VRT expression.";
60 1 : throw std::runtime_error(oss.str());
61 : }
62 : };
63 :
64 : struct loop_timeout_check final : public exprtk::loop_runtime_check
65 : {
66 : using time_point_t = std::chrono::time_point<std::chrono::steady_clock>;
67 :
68 27 : loop_timeout_check() : exprtk::loop_runtime_check()
69 : {
70 : double dfMaxLoopIterationSeconds =
71 27 : CPLAtofM(CPLGetConfigOption("GDAL_EXPRTK_TIMEOUT_SECONDS", "1"));
72 54 : max_duration = std::chrono::microseconds(
73 27 : static_cast<size_t>(dfMaxLoopIterationSeconds * 1e6));
74 27 : }
75 :
76 26 : void start_timer()
77 : {
78 26 : timeout_t = std::chrono::steady_clock::now() + max_duration;
79 26 : }
80 :
81 10114 : bool check() override
82 : {
83 :
84 10114 : if (++iterations >= max_iters_per_check)
85 : {
86 1 : if (std::chrono::steady_clock::now() > timeout_t)
87 : {
88 1 : timeout = true;
89 1 : return false;
90 : }
91 :
92 0 : iterations = 0;
93 : }
94 :
95 10113 : return true;
96 : }
97 :
98 1 : void handle_runtime_violation(const violation_context &) override
99 : {
100 2 : std::ostringstream oss;
101 :
102 : // current version of exprtk does not report the correct
103 : // violation in case of timeout, so we track the error category
104 : // ourselves
105 1 : if (timeout)
106 : {
107 1 : oss << "Expression evaluation time exceeded maximum of "
108 1 : << static_cast<double>(max_duration.count() / 1e6)
109 : << " seconds. You can increase this threshold by setting the "
110 : << "GDAL_EXPRTK_TIMEOUT_SECONDS configuration "
111 1 : << "option.";
112 : }
113 : else
114 : {
115 0 : oss << "Exceeded maximum of " << max_loop_iterations
116 0 : << " loop iterations.";
117 : }
118 :
119 1 : throw std::runtime_error(oss.str());
120 : }
121 :
122 : static constexpr size_t max_iters_per_check = 10000;
123 : size_t iterations = 0;
124 : time_point_t timeout_t{};
125 : std::chrono::microseconds max_duration{};
126 : bool timeout{false};
127 : };
128 :
129 : class ExprtkExpression::Impl
130 : {
131 : public:
132 : exprtk::expression<double> m_oExpression{};
133 : exprtk::parser<double> m_oParser{};
134 : exprtk::symbol_table<double> m_oSymbolTable{};
135 : std::string m_osExpression{};
136 :
137 : std::vector<std::pair<std::string, double *>> m_aoVariables{};
138 : std::vector<std::pair<std::string, std::vector<double> *>> m_aoVectors{};
139 : std::vector<double> m_adfResults{};
140 : vector_access_check m_oVectorAccessCheck{};
141 : loop_timeout_check m_oLoopRuntimeCheck{};
142 :
143 : bool m_bIsCompiled{false};
144 :
145 27 : explicit Impl()
146 27 : {
147 : using settings_t = std::decay_t<decltype(m_oParser.settings())>;
148 :
149 27 : m_oLoopRuntimeCheck.loop_set = loop_timeout_check::e_all_loops;
150 27 : m_oLoopRuntimeCheck.max_loop_iterations = std::numeric_limits<
151 27 : decltype(m_oLoopRuntimeCheck.max_loop_iterations)>::max();
152 27 : m_oParser.register_vector_access_runtime_check(m_oVectorAccessCheck);
153 27 : m_oParser.register_loop_runtime_check(m_oLoopRuntimeCheck);
154 :
155 : #ifndef NDEBUG
156 : // Only used for automated testing of GDAL_EXPRTK_TIMEOUT_SECONDS
157 27 : m_oSymbolTable.add_function("sleep", sleep);
158 : #endif
159 :
160 27 : int nMaxVectorLength = std::atoi(
161 : CPLGetConfigOption("GDAL_EXPRTK_MAX_VECTOR_LENGTH", "100000"));
162 :
163 27 : if (nMaxVectorLength > 0)
164 : {
165 27 : m_oParser.settings().set_max_local_vector_size(nMaxVectorLength);
166 : }
167 :
168 : bool bEnableLoops =
169 27 : CPLTestBool(CPLGetConfigOption("GDAL_EXPRTK_ENABLE_LOOPS", "YES"));
170 27 : if (!bEnableLoops)
171 : {
172 1 : m_oParser.settings().disable_control_structure(
173 1 : settings_t::e_ctrl_for_loop);
174 1 : m_oParser.settings().disable_control_structure(
175 1 : settings_t::e_ctrl_while_loop);
176 1 : m_oParser.settings().disable_control_structure(
177 1 : settings_t::e_ctrl_repeat_loop);
178 : }
179 27 : }
180 :
181 27 : CPLErr compile()
182 : {
183 27 : int nMaxExpressionLength = std::atoi(
184 : CPLGetConfigOption("GDAL_EXPRTK_MAX_EXPRESSION_LENGTH", "100000"));
185 27 : if (m_osExpression.size() >
186 27 : static_cast<std::size_t>(nMaxExpressionLength))
187 : {
188 1 : CPLError(CE_Failure, CPLE_AppDefined,
189 : "Expression length of %d exceeds maximum of %d set by "
190 : "GDAL_EXPRTK_MAX_EXPRESSION_LENGTH",
191 1 : static_cast<int>(m_osExpression.size()),
192 : nMaxExpressionLength);
193 1 : return CE_Failure;
194 : }
195 :
196 138 : for (const auto &[osVariable, pdfValueLoc] : m_aoVariables)
197 : {
198 112 : m_oSymbolTable.add_variable(osVariable, *pdfValueLoc);
199 : }
200 :
201 32 : for (const auto &[osVariable, padfVectorLoc] : m_aoVectors)
202 : {
203 6 : m_oSymbolTable.add_vector(osVariable, *padfVectorLoc);
204 : }
205 :
206 26 : m_oExpression.register_symbol_table(m_oSymbolTable);
207 26 : bool bSuccess = m_oParser.compile(m_osExpression, m_oExpression);
208 :
209 26 : if (!bSuccess)
210 : {
211 208 : for (size_t i = 0; i < m_oParser.error_count(); i++)
212 : {
213 204 : const auto &oError = m_oParser.get_error(i);
214 :
215 408 : CPLError(CE_Warning, CPLE_AppDefined,
216 : "Position: %02d "
217 : "Type: [%s] "
218 : "Message: %s\n",
219 204 : static_cast<int>(oError.token.position),
220 408 : exprtk::parser_error::to_str(oError.mode).c_str(),
221 : oError.diagnostic.c_str());
222 : }
223 :
224 4 : CPLError(CE_Failure, CPLE_AppDefined,
225 : "Failed to parse expression.");
226 4 : return CE_Failure;
227 : }
228 :
229 22 : m_bIsCompiled = true;
230 :
231 22 : return CE_None;
232 : }
233 :
234 31 : CPLErr evaluate()
235 : {
236 31 : if (!m_bIsCompiled)
237 : {
238 15 : auto eErr = compile();
239 15 : if (eErr != CE_None)
240 : {
241 3 : return eErr;
242 : }
243 : }
244 :
245 28 : m_adfResults.clear();
246 : double value;
247 : try
248 : {
249 28 : value = m_oExpression.value(); // force evaluation
250 : }
251 2 : catch (const std::exception &e)
252 : {
253 2 : CPLError(CE_Failure, CPLE_AppDefined, "%s", e.what());
254 2 : return CE_Failure;
255 : }
256 :
257 26 : m_oLoopRuntimeCheck.start_timer();
258 26 : const auto &results = m_oExpression.results();
259 :
260 : // We follow a different method to get the result depending on
261 : // how the expression was formed. If a "return" statement was
262 : // used, the result will be accessible via the "result" object.
263 : // If no "return" statement was used, the result is accessible
264 : // from the "value" variable (and must not be a vector.)
265 26 : if (results.count() == 0)
266 : {
267 16 : m_adfResults.resize(1);
268 16 : m_adfResults[0] = value;
269 : }
270 10 : else if (results.count() == 1)
271 : {
272 :
273 6 : if (results[0].type == exprtk::type_store<double>::e_scalar)
274 : {
275 2 : m_adfResults.resize(1);
276 2 : results.get_scalar(0, m_adfResults[0]);
277 : }
278 4 : else if (results[0].type == exprtk::type_store<double>::e_vector)
279 : {
280 4 : results.get_vector(0, m_adfResults);
281 : }
282 : else
283 : {
284 0 : CPLError(CE_Failure, CPLE_AppDefined,
285 : "Expression returned an unexpected type.");
286 0 : return CE_Failure;
287 : }
288 : }
289 : else
290 : {
291 4 : m_adfResults.resize(results.count());
292 10 : for (size_t i = 0; i < results.count(); i++)
293 : {
294 7 : if (results[i].type != exprtk::type_store<double>::e_scalar)
295 : {
296 1 : CPLError(CE_Failure, CPLE_AppDefined,
297 : "Expression must return a vector or a list of "
298 : "scalars.");
299 1 : return CE_Failure;
300 : }
301 : else
302 : {
303 6 : results.get_scalar(i, m_adfResults[i]);
304 : }
305 : }
306 : }
307 :
308 25 : return CE_None;
309 : }
310 :
311 : private:
312 : #ifndef NDEBUG
313 : struct sleep_fn : public exprtk::ifunction<double>
314 : {
315 27 : sleep_fn() : exprtk::ifunction<double>(1)
316 : {
317 27 : }
318 :
319 : using exprtk::ifunction<double>::operator();
320 :
321 9999 : double operator()(const double &seconds) override
322 : {
323 9999 : std::this_thread::sleep_for(
324 9999 : std::chrono::microseconds(static_cast<int>(seconds * 1e6)));
325 9999 : return 0;
326 : }
327 : };
328 :
329 : sleep_fn sleep{};
330 : #endif
331 : };
332 :
333 : /**
334 : * Define an expression to be evaluated using the exprtk library.
335 : *
336 : * @param osExpression the expression to evaluate. Refer to exprtk library documentation
337 : * for details of the allowable syntax.
338 : *
339 : * @since 3.11
340 : */
341 27 : ExprtkExpression::ExprtkExpression(std::string_view osExpression)
342 27 : : m_pImpl(std::make_unique<Impl>())
343 : {
344 27 : m_pImpl->m_osExpression = osExpression;
345 27 : }
346 :
347 54 : ExprtkExpression::~ExprtkExpression()
348 : {
349 54 : }
350 :
351 113 : void ExprtkExpression::RegisterVariable(std::string_view osVariable,
352 : double *pdfValue)
353 : {
354 113 : m_pImpl->m_aoVariables.emplace_back(osVariable, pdfValue);
355 113 : }
356 :
357 6 : void ExprtkExpression::RegisterVector(std::string_view osVariable,
358 : std::vector<double> *padfValue)
359 : {
360 6 : m_pImpl->m_aoVectors.emplace_back(osVariable, padfValue);
361 6 : }
362 :
363 12 : CPLErr ExprtkExpression::Compile()
364 : {
365 12 : return m_pImpl->compile();
366 : }
367 :
368 51 : const std::vector<double> &ExprtkExpression::Results() const
369 : {
370 51 : return m_pImpl->m_adfResults;
371 : }
372 :
373 31 : CPLErr ExprtkExpression::Evaluate()
374 : {
375 31 : return m_pImpl->evaluate();
376 : }
377 :
378 : /*! @endcond Doxygen_Suppress */
379 :
380 : } // namespace gdal
|