Line data Source code
1 : /******************************************************************************
2 : *
3 : * Project: Virtual GDAL Datasets
4 : * Purpose: Implementation of GDALExpressionEvaluator.
5 : * Author: Daniel Baston
6 : *
7 : ******************************************************************************
8 : * Copyright (c) 2024, ISciences LLC
9 : *
10 : * SPDX-License-Identifier: MIT
11 : ****************************************************************************/
12 :
13 : #include "cpl_conv.h"
14 : #include "cpl_error.h"
15 : #include "cpl_string.h"
16 : #include "vrtexpression.h"
17 :
18 : #define exprtk_disable_caseinsensitivity
19 : #define exprtk_disable_rtl_io
20 : #define exprtk_disable_rtl_io_file
21 : #define exprtk_disable_rtl_vecops
22 : #define exprtk_disable_string_capabilities
23 :
24 : #if defined(__GNUC__)
25 : #pragma GCC diagnostic push
26 : #pragma GCC diagnostic ignored "-Wnull-dereference"
27 : #endif
28 :
29 : #include <exprtk.hpp>
30 :
31 : #if defined(__GNUC__)
32 : #pragma GCC diagnostic pop
33 : #endif
34 :
35 : #include <chrono>
36 : #include <cstdint>
37 : #include <limits>
38 : #include <sstream>
39 : #include <thread>
40 :
41 : namespace gdal
42 : {
43 :
44 : /*! @cond Doxygen_Suppress */
45 : struct vector_access_check final : public exprtk::vector_access_runtime_check
46 : {
47 : bool handle_runtime_violation(violation_context &context) override;
48 : };
49 :
50 1 : bool vector_access_check::handle_runtime_violation(violation_context &context)
51 : {
52 1 : auto nElements = (static_cast<std::uint8_t *>(context.end_ptr) -
53 1 : static_cast<std::uint8_t *>(context.base_ptr)) /
54 1 : context.type_size;
55 1 : auto nIndexAccessed = (static_cast<std::uint8_t *>(context.access_ptr) -
56 1 : static_cast<std::uint8_t *>(context.base_ptr)) /
57 1 : context.type_size;
58 :
59 2 : std::ostringstream oss;
60 1 : oss << "Attempted to access index " << nIndexAccessed << " in a vector of "
61 1 : << nElements << " elements when evaluating VRT expression.";
62 1 : throw std::runtime_error(oss.str());
63 : }
64 :
65 : struct loop_timeout_check final : public exprtk::loop_runtime_check
66 : {
67 : using time_point_t = std::chrono::time_point<std::chrono::steady_clock>;
68 :
69 27 : loop_timeout_check() : exprtk::loop_runtime_check()
70 : {
71 : double dfMaxLoopIterationSeconds =
72 27 : CPLAtofM(CPLGetConfigOption("GDAL_EXPRTK_TIMEOUT_SECONDS", "1"));
73 54 : max_duration = std::chrono::microseconds(
74 27 : static_cast<size_t>(dfMaxLoopIterationSeconds * 1e6));
75 27 : }
76 :
77 26 : void start_timer()
78 : {
79 26 : timeout_t = std::chrono::steady_clock::now() + max_duration;
80 26 : }
81 :
82 10114 : bool check() override
83 : {
84 :
85 10114 : if (++iterations >= max_iters_per_check)
86 : {
87 1 : if (std::chrono::steady_clock::now() > timeout_t)
88 : {
89 1 : timeout = true;
90 1 : return false;
91 : }
92 :
93 0 : iterations = 0;
94 : }
95 :
96 10113 : return true;
97 : }
98 :
99 : void handle_runtime_violation(const violation_context &) override;
100 :
101 : private:
102 : static constexpr size_t max_iters_per_check = 10000;
103 : size_t iterations = 0;
104 : time_point_t timeout_t{};
105 : std::chrono::microseconds max_duration{};
106 : bool timeout{false};
107 : };
108 :
109 1 : void loop_timeout_check::handle_runtime_violation(const violation_context &)
110 : {
111 2 : std::ostringstream oss;
112 :
113 : // current version of exprtk does not report the correct
114 : // violation in case of timeout, so we track the error category
115 : // ourselves
116 1 : if (timeout)
117 : {
118 1 : oss << "Expression evaluation time exceeded maximum of "
119 1 : << static_cast<double>(max_duration.count() / 1e6)
120 : << " seconds. You can increase this threshold by setting the "
121 : << "GDAL_EXPRTK_TIMEOUT_SECONDS configuration "
122 1 : << "option.";
123 : }
124 : else
125 : {
126 0 : oss << "Exceeded maximum of " << max_loop_iterations
127 0 : << " loop iterations.";
128 : }
129 :
130 1 : throw std::runtime_error(oss.str());
131 : }
132 :
133 : class ExprtkExpression::Impl
134 : {
135 : public:
136 : exprtk::expression<double> m_oExpression{};
137 : exprtk::parser<double> m_oParser{};
138 : exprtk::symbol_table<double> m_oSymbolTable{};
139 : std::string m_osExpression{};
140 :
141 : std::vector<std::pair<std::string, double *>> m_aoVariables{};
142 : std::vector<std::pair<std::string, std::vector<double> *>> m_aoVectors{};
143 : std::vector<double> m_adfResults{};
144 : vector_access_check m_oVectorAccessCheck{};
145 : loop_timeout_check m_oLoopRuntimeCheck{};
146 :
147 : bool m_bIsCompiled{false};
148 :
149 27 : explicit Impl()
150 27 : {
151 : using settings_t = std::decay_t<decltype(m_oParser.settings())>;
152 :
153 27 : m_oLoopRuntimeCheck.loop_set = loop_timeout_check::e_all_loops;
154 27 : m_oLoopRuntimeCheck.max_loop_iterations = std::numeric_limits<
155 27 : decltype(m_oLoopRuntimeCheck.max_loop_iterations)>::max();
156 27 : m_oParser.register_vector_access_runtime_check(m_oVectorAccessCheck);
157 27 : m_oParser.register_loop_runtime_check(m_oLoopRuntimeCheck);
158 :
159 : #ifndef NDEBUG
160 : // Only used for automated testing of GDAL_EXPRTK_TIMEOUT_SECONDS
161 27 : m_oSymbolTable.add_function("sleep", sleep);
162 : #endif
163 :
164 27 : int nMaxVectorLength = std::atoi(
165 : CPLGetConfigOption("GDAL_EXPRTK_MAX_VECTOR_LENGTH", "100000"));
166 :
167 27 : if (nMaxVectorLength > 0)
168 : {
169 27 : m_oParser.settings().set_max_local_vector_size(nMaxVectorLength);
170 : }
171 :
172 : bool bEnableLoops =
173 27 : CPLTestBool(CPLGetConfigOption("GDAL_EXPRTK_ENABLE_LOOPS", "YES"));
174 27 : if (!bEnableLoops)
175 : {
176 1 : m_oParser.settings().disable_control_structure(
177 1 : settings_t::e_ctrl_for_loop);
178 1 : m_oParser.settings().disable_control_structure(
179 1 : settings_t::e_ctrl_while_loop);
180 1 : m_oParser.settings().disable_control_structure(
181 1 : settings_t::e_ctrl_repeat_loop);
182 : }
183 27 : }
184 :
185 27 : CPLErr compile()
186 : {
187 27 : int nMaxExpressionLength = std::atoi(
188 : CPLGetConfigOption("GDAL_EXPRTK_MAX_EXPRESSION_LENGTH", "100000"));
189 27 : if (m_osExpression.size() >
190 27 : static_cast<std::size_t>(nMaxExpressionLength))
191 : {
192 1 : CPLError(CE_Failure, CPLE_AppDefined,
193 : "Expression length of %d exceeds maximum of %d set by "
194 : "GDAL_EXPRTK_MAX_EXPRESSION_LENGTH",
195 1 : static_cast<int>(m_osExpression.size()),
196 : nMaxExpressionLength);
197 1 : return CE_Failure;
198 : }
199 :
200 138 : for (const auto &[osVariable, pdfValueLoc] : m_aoVariables)
201 : {
202 112 : m_oSymbolTable.add_variable(osVariable, *pdfValueLoc);
203 : }
204 :
205 32 : for (const auto &[osVariable, padfVectorLoc] : m_aoVectors)
206 : {
207 6 : m_oSymbolTable.add_vector(osVariable, *padfVectorLoc);
208 : }
209 :
210 26 : m_oExpression.register_symbol_table(m_oSymbolTable);
211 26 : bool bSuccess = m_oParser.compile(m_osExpression, m_oExpression);
212 :
213 26 : if (!bSuccess)
214 : {
215 208 : for (size_t i = 0; i < m_oParser.error_count(); i++)
216 : {
217 204 : const auto &oError = m_oParser.get_error(i);
218 :
219 408 : CPLError(CE_Warning, CPLE_AppDefined,
220 : "Position: %02d "
221 : "Type: [%s] "
222 : "Message: %s\n",
223 204 : static_cast<int>(oError.token.position),
224 408 : exprtk::parser_error::to_str(oError.mode).c_str(),
225 : oError.diagnostic.c_str());
226 : }
227 :
228 4 : CPLError(CE_Failure, CPLE_AppDefined,
229 : "Failed to parse expression.");
230 4 : return CE_Failure;
231 : }
232 :
233 22 : m_bIsCompiled = true;
234 :
235 22 : return CE_None;
236 : }
237 :
238 31 : CPLErr evaluate()
239 : {
240 31 : if (!m_bIsCompiled)
241 : {
242 15 : auto eErr = compile();
243 15 : if (eErr != CE_None)
244 : {
245 3 : return eErr;
246 : }
247 : }
248 :
249 28 : m_adfResults.clear();
250 : double value;
251 : try
252 : {
253 28 : value = m_oExpression.value(); // force evaluation
254 : }
255 2 : catch (const std::exception &e)
256 : {
257 2 : CPLError(CE_Failure, CPLE_AppDefined, "%s", e.what());
258 2 : return CE_Failure;
259 : }
260 :
261 26 : m_oLoopRuntimeCheck.start_timer();
262 26 : const auto &results = m_oExpression.results();
263 :
264 : // We follow a different method to get the result depending on
265 : // how the expression was formed. If a "return" statement was
266 : // used, the result will be accessible via the "result" object.
267 : // If no "return" statement was used, the result is accessible
268 : // from the "value" variable (and must not be a vector.)
269 26 : if (results.count() == 0)
270 : {
271 16 : m_adfResults.resize(1);
272 16 : m_adfResults[0] = value;
273 : }
274 10 : else if (results.count() == 1)
275 : {
276 :
277 6 : if (results[0].type == exprtk::type_store<double>::e_scalar)
278 : {
279 2 : m_adfResults.resize(1);
280 2 : results.get_scalar(0, m_adfResults[0]);
281 : }
282 4 : else if (results[0].type == exprtk::type_store<double>::e_vector)
283 : {
284 4 : results.get_vector(0, m_adfResults);
285 : }
286 : else
287 : {
288 0 : CPLError(CE_Failure, CPLE_AppDefined,
289 : "Expression returned an unexpected type.");
290 0 : return CE_Failure;
291 : }
292 : }
293 : else
294 : {
295 4 : m_adfResults.resize(results.count());
296 10 : for (size_t i = 0; i < results.count(); i++)
297 : {
298 7 : if (results[i].type != exprtk::type_store<double>::e_scalar)
299 : {
300 1 : CPLError(CE_Failure, CPLE_AppDefined,
301 : "Expression must return a vector or a list of "
302 : "scalars.");
303 1 : return CE_Failure;
304 : }
305 : else
306 : {
307 6 : results.get_scalar(i, m_adfResults[i]);
308 : }
309 : }
310 : }
311 :
312 25 : return CE_None;
313 : }
314 :
315 : private:
316 : #ifndef NDEBUG
317 : struct sleep_fn : public exprtk::ifunction<double>
318 : {
319 27 : sleep_fn() : exprtk::ifunction<double>(1)
320 : {
321 27 : }
322 :
323 : using exprtk::ifunction<double>::operator();
324 :
325 9999 : double operator()(const double &seconds) override
326 : {
327 9999 : std::this_thread::sleep_for(
328 9999 : std::chrono::microseconds(static_cast<int>(seconds * 1e6)));
329 9999 : return 0;
330 : }
331 : };
332 :
333 : sleep_fn sleep{};
334 : #endif
335 : };
336 :
337 : /**
338 : * Define an expression to be evaluated using the exprtk library.
339 : *
340 : * @param osExpression the expression to evaluate. Refer to exprtk library documentation
341 : * for details of the allowable syntax.
342 : *
343 : * @since 3.11
344 : */
345 27 : ExprtkExpression::ExprtkExpression(std::string_view osExpression)
346 27 : : m_pImpl(std::make_unique<Impl>())
347 : {
348 27 : m_pImpl->m_osExpression = osExpression;
349 27 : }
350 :
351 54 : ExprtkExpression::~ExprtkExpression()
352 : {
353 54 : }
354 :
355 113 : void ExprtkExpression::RegisterVariable(std::string_view osVariable,
356 : double *pdfValue)
357 : {
358 113 : m_pImpl->m_aoVariables.emplace_back(osVariable, pdfValue);
359 113 : }
360 :
361 6 : void ExprtkExpression::RegisterVector(std::string_view osVariable,
362 : std::vector<double> *padfValue)
363 : {
364 6 : m_pImpl->m_aoVectors.emplace_back(osVariable, padfValue);
365 6 : }
366 :
367 12 : CPLErr ExprtkExpression::Compile()
368 : {
369 12 : return m_pImpl->compile();
370 : }
371 :
372 51 : const std::vector<double> &ExprtkExpression::Results() const
373 : {
374 51 : return m_pImpl->m_adfResults;
375 : }
376 :
377 31 : CPLErr ExprtkExpression::Evaluate()
378 : {
379 31 : return m_pImpl->evaluate();
380 : }
381 :
382 : /*! @endcond Doxygen_Suppress */
383 :
384 : } // namespace gdal
|