TRIQS/nda 1.3.0
Multi-dimensional array library for C++
Loading...
Searching...
No Matches
det_and_inverse.hpp
Go to the documentation of this file.
1// Copyright (c) 2019-2023 Simons Foundation
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7// http://www.apache.org/licenses/LICENSE-2.0.txt
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14//
15// Authors: Harrison LaBollita, Olivier Parcollet, Nils Wentzell
16
17/**
18 * @file
19 * @brief Provides functions to compute the determinant and inverse of a matrix.
20 */
21
22#pragma once
23
24#include "../basic_array.hpp"
25#include "../basic_functions.hpp"
26#include "../clef/make_lazy.hpp"
27#include "../concepts.hpp"
28#include "../exceptions.hpp"
29#include "../lapack/getrf.hpp"
30#include "../lapack/getri.hpp"
31#include "../layout/policies.hpp"
32#include "../matrix_functions.hpp"
33#include "../mem/address_space.hpp"
34#include "../mem/policies.hpp"
35#include "../print.hpp"
36#include "../traits.hpp"
37
38#include <iostream>
39#include <type_traits>
40#include <utility>
41
42namespace nda {
43
44 /**
45 * @addtogroup linalg_tools
46 * @{
47 */
48
49 /**
50 * @brief Check if a given array/view is square, i.e. if the first dimension has the same extent as the second
51 * dimension.
52 *
53 * @note It does not check if the array/view has rank 2.
54 *
55 * @tparam A Array/View type.
56 * @param a Array/View object.
57 * @param print_error If true, print an error message if the matrix is not square.
58 * @return True if the array/view is square, false otherwise.
59 */
60 template <typename A>
61 bool is_matrix_square(A const &a, bool print_error = false) {
62 bool r = (a.shape()[0] == a.shape()[1]);
63 if (not r and print_error)
64 std::cerr << "Error in nda::detail::is_matrix_square: Dimensions are: (" << a.shape()[0] << "," << a.shape()[1] << ")\n" << std::endl;
65 return r;
66 }
67
68 /**
69 * @brief Check if a given array/view is diagonal, i.e. if it is square (see nda::is_matrix_square) and all the the
70 * off-diagonal elements are zero.
71 *
72 * @note It does not check if the array/view has rank 2.
73 *
74 * @tparam A Array/View type.
75 * @param a Array/View object.
76 * @param print_error If true, print an error message if the matrix is not diagonal.
77 * @return True if the array/view is diagonal, false otherwise.
78 */
79 template <typename A>
80 bool is_matrix_diagonal(A const &a, bool print_error = false) {
81 bool r = is_matrix_square(a) and a == diag(diagonal(a));
82 if (not r and print_error) std::cerr << "Error in nda::detail::is_matrix_diagonal: Non-diagonal matrix: " << a << std::endl;
83 return r;
84 }
85
86 /**
87 * @brief Compute the determinant of a square matrix/view.
88 *
89 * @details It uses nda::lapack::getrf to compute the LU decomposition of the matrix and then calculates the
90 * determinant by multiplying the diagonal elements of the \f$ \mathbf{U} \f$ matrix and taking into account that
91 * `getrf` may change the ordering of the rows/columns of the matrix.
92 *
93 * The given matrix/view is modified in place.
94 *
95 * @tparam M Type of the matrix/view.
96 * @param m Matrix/view object.
97 * @return Determinant of the matrix/view.
98 */
99 template <typename M>
101 requires(is_matrix_or_view_v<M>)
102 {
103 using value_t = get_value_t<M>;
104 static_assert(std::is_convertible_v<value_t, double> or std::is_convertible_v<value_t, std::complex<double>>,
105 "Error in nda::determinant_in_place: Value type needs to be convertible to double or std::complex<double>");
106 static_assert(not std::is_const_v<M>, "Error in nda::determinant_in_place: Value type cannot be const");
107
108 // special case for an empty matrix
109 if (m.empty()) return value_t{1};
110
111 // check if the matrix is square
112 if (m.extent(0) != m.extent(1)) NDA_RUNTIME_ERROR << "Error in nda::determinant_in_place: Matrix is not square: " << m.shape();
113
114 // calculate the LU decomposition using lapack getrf
115 const int dim = m.extent(0);
116 basic_array<int, 1, C_layout, 'A', sso<100>> ipiv(dim);
117 int info = lapack::getrf(m, ipiv); // it is ok to be in C order
118 if (info < 0) NDA_RUNTIME_ERROR << "Error in nda::determinant_in_place: info = " << info;
119
120 // calculate the determinant from the LU decomposition
121 auto det = value_t{1};
122 int n_flips = 0;
123 for (int i = 0; i < dim; i++) {
124 det *= m(i, i);
125 // count the number of column interchanges performed by getrf
126 if (ipiv(i) != i + 1) ++n_flips;
127 }
128
129 return ((n_flips % 2 == 1) ? -det : det);
130 }
131
132 /**
133 * @brief Compute the determinant of a square matrix/view.
134 *
135 * @details The given matrix/view is not modified. It first makes a copy of the given matrix/view and then calls
136 * nda::determinant_in_place with the copy.
137 *
138 * @tparam M Type of the matrix/view.
139 * @param m Matrix/view object.
140 * @return Determinant of the matrix/view.
141 */
142 template <typename M>
143 auto determinant(M const &m) {
144 auto m_copy = make_regular(m);
145 return determinant_in_place(m_copy);
146 }
147
148 // For small matrices (2x2 and 3x3), we directly
149 // compute the matrix inversion rather than calling the
150 // LaPack routine
151 // ---------- Small Inverse Benchmarks ---------
152 // Run on (16 X 2400 MHz CPUs) (see benchmarks/small_inv.cpp)
153 // ---------------------------------------------
154 // Matrix Size Time (old) Time (new)
155 // 1 502 ns 59.0 ns
156 // 2 595 ns 61.7 ns
157 // 3 701 ns 67.5 ns
158
159 /**
160 * @brief Compute the inverse of a 1-by-1 matrix.
161 *
162 * @details The inversion is performed in place.
163 *
164 * @tparam M nda::MemoryMatrix type.
165 * @param m nda::MemoryMatrix object to be inverted.
166 */
167 template <MemoryMatrix M>
168 requires(get_algebra<M> == 'M' and mem::on_host<M>)
169 void inverse1_in_place(M &&m) { // NOLINT (temporary views are allowed here)
170 if (m(0, 0) == 0.0) NDA_RUNTIME_ERROR << "Error in nda::inverse1_in_place: Matrix is not invertible";
171 m(0, 0) = 1.0 / m(0, 0);
172 }
173
174 /**
175 * @brief Compute the inverse of a 2-by-2 matrix.
176 *
177 * @details The inversion is performed in place.
178 *
179 * @tparam M nda::MemoryMatrix type.
180 * @param m nda::MemoryMatrix object to be inverted.
181 */
182 template <MemoryMatrix M>
183 requires(get_algebra<M> == 'M' and mem::on_host<M>)
184 void inverse2_in_place(M &&m) { // NOLINT (temporary views are allowed here)
185 // calculate the adjoint of the matrix
186 std::swap(m(0, 0), m(1, 1));
187
188 // calculate the inverse determinant of the matrix
189 auto det = (m(0, 0) * m(1, 1) - m(0, 1) * m(1, 0));
190 if (det == 0.0) NDA_RUNTIME_ERROR << "Error in nda::inverse2_in_place: Matrix is not invertible";
191 auto detinv = 1.0 / det;
192
193 // multiply the adjoint by the inverse determinant
194 m(0, 0) *= +detinv;
195 m(1, 1) *= +detinv;
196 m(1, 0) *= -detinv;
197 m(0, 1) *= -detinv;
198 }
199
200 /**
201 * @brief Compute the inverse of a 3-by-3 matrix.
202 *
203 * @details The inversion is performed in place.
204 *
205 * @tparam M nda::MemoryMatrix type.
206 * @param m nda::MemoryMatrix object to be inverted.
207 */
208 template <MemoryMatrix M>
209 requires(get_algebra<M> == 'M' and mem::on_host<M>)
210 void inverse3_in_place(M &&m) { // NOLINT (temporary views are allowed here)
211 // calculate the cofactors of the matrix
212 auto b00 = +m(1, 1) * m(2, 2) - m(1, 2) * m(2, 1);
213 auto b10 = -m(1, 0) * m(2, 2) + m(1, 2) * m(2, 0);
214 auto b20 = +m(1, 0) * m(2, 1) - m(1, 1) * m(2, 0);
215 auto b01 = -m(0, 1) * m(2, 2) + m(0, 2) * m(2, 1);
216 auto b11 = +m(0, 0) * m(2, 2) - m(0, 2) * m(2, 0);
217 auto b21 = -m(0, 0) * m(2, 1) + m(0, 1) * m(2, 0);
218 auto b02 = +m(0, 1) * m(1, 2) - m(0, 2) * m(1, 1);
219 auto b12 = -m(0, 0) * m(1, 2) + m(0, 2) * m(1, 0);
220 auto b22 = +m(0, 0) * m(1, 1) - m(0, 1) * m(1, 0);
221
222 // calculate the inverse determinant of the matrix
223 auto det = m(0, 0) * b00 + m(0, 1) * b10 + m(0, 2) * b20;
224 if (det == 0.0) NDA_RUNTIME_ERROR << "Error in nda::inverse3_in_place: Matrix is not invertible";
225 auto detinv = 1.0 / det;
226
227 // fill the matrix by multiplying the cofactors by the inverse determinant
228 m(0, 0) = detinv * b00;
229 m(0, 1) = detinv * b01;
230 m(0, 2) = detinv * b02;
231 m(1, 0) = detinv * b10;
232 m(1, 1) = detinv * b11;
233 m(1, 2) = detinv * b12;
234 m(2, 0) = detinv * b20;
235 m(2, 1) = detinv * b21;
236 m(2, 2) = detinv * b22;
237 }
238
239 /**
240 * @brief Compute the inverse of an n-by-n matrix.
241 *
242 * @details The inversion is performed in place.
243 *
244 * For small matrices (1-by-1, 2-by-2, 3-by-3), we directly compute the matrix inversion using the optimized routines:
245 * nda::inverse1_in_place, nda::inverse2_in_place, nda::inverse3_in_place.
246 *
247 * For larger matrices, it uses nda::lapack::getrf and nda::lapack::getri.
248 *
249 * @tparam M nda::MemoryMatrix type.
250 * @param m nda::MemoryMatrix object to be inverted.
251 */
252 template <MemoryMatrix M>
253 requires(get_algebra<M> == 'M')
254 void inverse_in_place(M &&m) { // NOLINT (temporary views are allowed here)
255 EXPECTS(is_matrix_square(m, true));
256
257 // nothing to do if the matrix/view is empty
258 if (m.empty()) return;
259
260 // use optimized routines for small matrices
261 if constexpr (mem::on_host<M>) {
262 if (m.shape()[0] == 1) {
263 inverse1_in_place(m);
264 return;
265 }
266
267 if (m.shape()[0] == 2) {
268 inverse2_in_place(m);
269 return;
270 }
271
272 if (m.shape()[0] == 3) {
273 inverse3_in_place(m);
274 return;
275 }
276 }
277
278 // use getrf and getri from lapack for larger matrices
279 array<int, 1> ipiv(m.extent(0));
280 int info = lapack::getrf(m, ipiv); // it is ok to be in C order
281 if (info != 0) NDA_RUNTIME_ERROR << "Error in nda::inverse_in_place: Matrix is not invertible: info = " << info;
282 info = lapack::getri(m, ipiv);
283 if (info != 0) NDA_RUNTIME_ERROR << "Error in nda::inverse_in_place: Matrix is not invertible: info = " << info;
284 }
285
286 /**
287 * @brief Compute the inverse of an n-by-n matrix.
288 *
289 * @details The given matrix/view is not modified. It first makes copy of the given matrix/view and then calls
290 * nda::inverse_in_place with the copy.
291 *
292 * @tparam M nda::MemoryMatrix type.
293 * @param m nda::MemoryMatrix object to be inverted.
294 * @return Inverse of the matrix.
295 */
296 template <Matrix M>
297 auto inverse(M const &m)
298 requires(get_algebra<M> == 'M')
299 {
300 EXPECTS(is_matrix_square(m, true));
301 auto r = make_regular(m);
302 inverse_in_place(r);
303 return r;
304 }
305
306 /** @} */
307
308} // namespace nda
309
310namespace nda::clef {
311
312 /**
313 * @ingroup linalg_tools
314 * @brief Lazy version of nda::determinant.
315 */
317
318} // namespace nda::clef
A generic multi-dimensional array.
#define NDA_RUNTIME_ERROR
#define CLEF_MAKE_FNT_LAZY(name)
Macro to make any function lazy, i.e. accept lazy arguments and return a function call expression nod...
auto determinant_in_place(M &m)
Compute the determinant of a square matrix/view.
void inverse3_in_place(M &&m)
Compute the inverse of a 3-by-3 matrix.
void inverse1_in_place(M &&m)
Compute the inverse of a 1-by-1 matrix.
void inverse2_in_place(M &&m)
Compute the inverse of a 2-by-2 matrix.
auto determinant(M const &m)
Compute the determinant of a square matrix/view.
auto inverse(M const &m)
Compute the inverse of an n-by-n matrix.
void inverse_in_place(M &&m)
Compute the inverse of an n-by-n matrix.
bool is_matrix_diagonal(A const &a, bool print_error=false)
Check if a given array/view is diagonal, i.e. if it is square (see nda::is_matrix_square) and all the...
bool is_matrix_square(A const &a, bool print_error=false)
Check if a given array/view is square, i.e. if the first dimension has the same extent as the second ...
#define EXPECTS(X)
Definition macros.hpp:59
Contiguous layout policy with C-order (row-major order).
Definition policies.hpp:47
Memory policy using an nda::mem::handle_sso.
Definition policies.hpp:70