Line data Source code
1 : /******************************************************************************
2 : *
3 : * Project: GDAL
4 : * Purpose: Linear system solver
5 : * Author: VIZRT Development Team.
6 : *
7 : * This code was provided by Gilad Ronnen (gro at visrt dot com) with
8 : * permission to reuse under the following license.
9 : *
10 : ******************************************************************************
11 : * Copyright (c) 2004, VIZRT Inc.
12 : * Copyright (c) 2008-2014, Even Rouault <even dot rouault at spatialys.com>
13 : * Copyright (c) 2019, Martin Franzke <martin dot franzke at telekom dot de>
14 : *
15 : * SPDX-License-Identifier: MIT
16 : ****************************************************************************/
17 :
18 : /*! @cond Doxygen_Suppress */
19 :
20 : #include "cpl_port.h"
21 : #include "cpl_conv.h"
22 : #include "gdallinearsystem.h"
23 :
24 : #ifdef HAVE_ARMADILLO
25 : #include "armadillo_headers.h"
26 : #endif
27 :
28 : #include <cstdio>
29 : #include <algorithm>
30 : #include <cassert>
31 : #include <cmath>
32 :
33 : namespace
34 : {
35 : // LU decomposition of the quadratic matrix A
36 : // see https://en.wikipedia.org/wiki/LU_decomposition#C_code_examples
37 0 : bool solve(GDALMatrix &A, GDALMatrix &RHS, GDALMatrix &X, double eps)
38 : {
39 0 : assert(A.getNumRows() == A.getNumCols());
40 0 : if (eps < 0)
41 0 : return false;
42 0 : int const m = A.getNumRows();
43 0 : int const n = RHS.getNumCols();
44 : // row permutations
45 0 : std::vector<int> perm(m);
46 0 : for (int iRow = 0; iRow < m; ++iRow)
47 0 : perm[iRow] = iRow;
48 :
49 : // Arbitrary threshold to trigger progress in debug mode
50 0 : const bool bDebug = (m > 10000);
51 0 : int nLastPct = -1;
52 :
53 0 : for (int step = 0; step < m - 1; ++step)
54 : {
55 0 : if (bDebug)
56 : {
57 0 : const int nPct = (step * 100 * 10 / m) / 2;
58 0 : if (nPct != nLastPct)
59 : {
60 0 : CPLDebug("GDAL", "solve(): %d.%d %%", nPct / 10, nPct % 10);
61 0 : nLastPct = nPct;
62 : }
63 : }
64 :
65 : // determine pivot element
66 0 : int iMax = step;
67 0 : double dMax = std::abs(A(step, step));
68 0 : for (int i = step + 1; i < m; ++i)
69 : {
70 0 : if (std::abs(A(i, step)) > dMax)
71 : {
72 0 : iMax = i;
73 0 : dMax = std::abs(A(i, step));
74 : }
75 : }
76 0 : if (dMax <= eps)
77 : {
78 0 : CPLError(CE_Failure, CPLE_AppDefined,
79 : "GDALLinearSystemSolve: matrix not invertible");
80 0 : return false;
81 : }
82 : // swap rows
83 0 : if (iMax != step)
84 : {
85 0 : std::swap(perm[iMax], perm[step]);
86 0 : for (int iCol = 0; iCol < m; ++iCol)
87 : {
88 0 : std::swap(A(iMax, iCol), A(step, iCol));
89 : }
90 : }
91 0 : for (int iRow = step + 1; iRow < m; ++iRow)
92 : {
93 0 : A(iRow, step) /= A(step, step);
94 : }
95 0 : for (int iCol = step + 1; iCol < m; ++iCol)
96 : {
97 0 : for (int iRow = step + 1; iRow < m; ++iRow)
98 : {
99 0 : A(iRow, iCol) -= A(iRow, step) * A(step, iCol);
100 : }
101 : }
102 : }
103 :
104 : // LUP solve;
105 0 : for (int iCol = 0; iCol < n; ++iCol)
106 : {
107 0 : if (bDebug)
108 : {
109 0 : const int nPct = 500 + (iCol * 100 * 10 / n) / 2;
110 0 : if (nPct != nLastPct)
111 : {
112 0 : CPLDebug("GDAL", "solve(): %d.%d %%", nPct / 10, nPct % 10);
113 0 : nLastPct = nPct;
114 : }
115 : }
116 :
117 0 : for (int iRow = 0; iRow < m; ++iRow)
118 : {
119 0 : X(iRow, iCol) = RHS(perm[iRow], iCol);
120 0 : for (int k = 0; k < iRow; ++k)
121 : {
122 0 : X(iRow, iCol) -= A(iRow, k) * X(k, iCol);
123 : }
124 : }
125 0 : for (int iRow = m - 1; iRow >= 0; --iRow)
126 : {
127 0 : for (int k = iRow + 1; k < m; ++k)
128 : {
129 0 : X(iRow, iCol) -= A(iRow, k) * X(k, iCol);
130 : }
131 0 : X(iRow, iCol) /= A(iRow, iRow);
132 : }
133 : }
134 :
135 0 : if (bDebug)
136 : {
137 0 : CPLDebug("GDAL", "solve(): 100.0 %%");
138 : }
139 :
140 0 : return true;
141 : }
142 : } // namespace
143 :
144 : /************************************************************************/
145 : /* GDALLinearSystemSolve() */
146 : /* */
147 : /* Solves the linear system A*X_i = RHS_i for each column i */
148 : /* where A is a square matrix. */
149 : /************************************************************************/
150 69 : bool GDALLinearSystemSolve(GDALMatrix &A, GDALMatrix &RHS, GDALMatrix &X,
151 : [[maybe_unused]] bool bForceBuiltinMethod)
152 : {
153 69 : assert(A.getNumRows() == RHS.getNumRows());
154 69 : assert(A.getNumCols() == X.getNumRows());
155 69 : assert(RHS.getNumCols() == X.getNumCols());
156 :
157 : #ifdef HAVE_ARMADILLO
158 69 : if (!bForceBuiltinMethod)
159 : {
160 : try
161 : {
162 69 : arma::mat matA(A.data(), A.getNumRows(), A.getNumCols(), false,
163 138 : true);
164 69 : arma::mat matRHS(RHS.data(), RHS.getNumRows(), RHS.getNumCols(),
165 138 : false, true);
166 69 : arma::mat matOut(X.data(), X.getNumRows(), X.getNumCols(), false,
167 69 : true);
168 : #if ARMA_VERSION_MAJOR > 6 || \
169 : (ARMA_VERSION_MAJOR == 6 && ARMA_VERSION_MINOR >= 500)
170 : // Perhaps available in earlier versions, but didn't check
171 69 : return arma::solve(matOut, matA, matRHS,
172 69 : arma::solve_opts::equilibrate +
173 138 : arma::solve_opts::no_approx);
174 : #else
175 : return arma::solve(matOut, matA, matRHS);
176 : #endif
177 : }
178 0 : catch (std::exception const &e)
179 : {
180 0 : CPLError(CE_Failure, CPLE_AppDefined, "GDALLinearSystemSolve: %s",
181 0 : e.what());
182 0 : return false;
183 : }
184 : }
185 : #endif // HAVE_ARMADILLO
186 :
187 0 : return solve(A, RHS, X, 0);
188 : }
189 :
190 : /*! @endcond */
|