TRIQS/nda 2.0.0
Multi-dimensional array library for C++
Loading...
Searching...
No Matches
inv.hpp
Go to the documentation of this file.
1// Copyright (c) 2019--present, The Simons Foundation
2// This file is part of TRIQS/nda and is licensed under the Apache License, Version 2.0.
3// SPDX-License-Identifier: Apache-2.0
4// See LICENSE in the root of this distribution for details.
5
10
11#pragma once
12
13#include "../basic_array.hpp"
15#include "../blas/tools.hpp"
16#include "../concepts.hpp"
17#include "../exceptions.hpp"
18#include "../lapack/getrf.hpp"
19#include "../lapack/getri.hpp"
20#include "../lapack/getrs.hpp"
22#include "../macros.hpp"
25#include "../mem/policies.hpp"
26#include "../traits.hpp"
27
28#include <utility>
29
30namespace nda::linalg {
31
36
37 namespace detail {
38
39 // Compute the inverse of a 2x2 matrix in place.
40 void inv_in_place_2d(MemoryMatrix auto &&m) { // NOLINT (temporary views are allowed here)
41 using value_t = get_value_t<decltype(m)>;
42
43 // calculate the determinant of the matrix
44 auto const det = (m(0, 0) * m(1, 1) - m(0, 1) * m(1, 0));
45 if (det == value_t{0.0}) NDA_RUNTIME_ERROR << "Error in nda::linalg::inv_in_place: Matrix is not invertible";
46 auto const detinv = value_t{1.0} / det;
47
48 // multiply the adjoint by the inverse determinant
49 std::swap(m(0, 0), m(1, 1));
50 m(0, 0) *= +detinv;
51 m(1, 1) *= +detinv;
52 m(1, 0) *= -detinv;
53 m(0, 1) *= -detinv;
54 }
55
56 // Compute the inverse of a 3x3 matrix in place.
57 void inv_in_place_3d(MemoryMatrix auto &&m) { // NOLINT (temporary views are allowed here)
58 using value_t = get_value_t<decltype(m)>;
59 EXPECTS(is_matrix_square(m) and m.extent(0) == 3);
60
61 // calculate the adjoint of the matrix
62 auto adj = stack_array<get_value_t<decltype(m)>, 3, 3>();
63 adj(0, 0) = +m(1, 1) * m(2, 2) - m(1, 2) * m(2, 1);
64 adj(1, 0) = -m(1, 0) * m(2, 2) + m(1, 2) * m(2, 0);
65 adj(2, 0) = +m(1, 0) * m(2, 1) - m(1, 1) * m(2, 0);
66 adj(0, 1) = -m(0, 1) * m(2, 2) + m(0, 2) * m(2, 1);
67 adj(1, 1) = +m(0, 0) * m(2, 2) - m(0, 2) * m(2, 0);
68 adj(2, 1) = -m(0, 0) * m(2, 1) + m(0, 1) * m(2, 0);
69 adj(0, 2) = +m(0, 1) * m(1, 2) - m(0, 2) * m(1, 1);
70 adj(1, 2) = -m(0, 0) * m(1, 2) + m(0, 2) * m(1, 0);
71 adj(2, 2) = +m(0, 0) * m(1, 1) - m(0, 1) * m(1, 0);
72
73 // calculate the determinant of the matrix
74 auto const det = m(0, 0) * adj(0, 0) + m(0, 1) * adj(1, 0) + m(0, 2) * adj(2, 0);
75 if (det == value_t{0.0}) NDA_RUNTIME_ERROR << "Error in nda::linalg::inv_in_place: Matrix is not invertible";
76 auto const detinv = value_t{1.0} / det;
77
78 // multiply the adjoint by the inverse determinant
79 m = detinv * adj;
80 }
81
82 } // namespace detail
83
99 template <blas_lapack::BlasArray<2> M>
101 void inv_in_place(M &&m) { // NOLINT (temporary views are allowed here)
102 using value_t = get_value_t<M>;
103 EXPECTS(is_matrix_square(m));
104
105 // use optimized routines for small matrices, otherwise use LAPACK routines
106 auto const dim = m.shape()[0];
107 if (dim == 1) {
108 if (m(0, 0) == value_t{0.0}) NDA_RUNTIME_ERROR << "Error in nda::linalg::inv_in_place: Matrix is not invertible";
109 m(0, 0) = value_t{1.0} / m(0, 0);
110 } else if (dim == 2) {
111 detail::inv_in_place_2d(m);
112 } else if (dim == 3) {
113 detail::inv_in_place_3d(m);
114 } else if (dim > 3) {
115 // LU factorization with getrf
117 int info = lapack::getrf(m, ipiv);
118 if (info != 0) NDA_RUNTIME_ERROR << "Error in nda::linalg::inv_in_place: getrf routine failed: info = " << info;
119
120 // calculate the inverse with getri
121 info = lapack::getri(m, ipiv);
122 if (info != 0) NDA_RUNTIME_ERROR << "Error in nda::linalg::inv_in_place: getri routine failed: info = " << info;
123 }
124 }
125
148 template <Matrix M>
150 auto inv(M const &m) {
151 EXPECTS(is_matrix_square(m));
152
153 // make a copy of the input matrix/view
154 auto m_copy = make_regular(m);
155
156 // for device compatible address spaces, we use getrf and getrs, otherwise we call inv_in_place
158 // LU factorization with getrf
159 auto ipiv = vector<int, heap<mem::get_addr_space<M>>>(m_copy.extent(0));
160 int info = lapack::getrf(m_copy, ipiv);
161 if (info != 0) NDA_RUNTIME_ERROR << "Error in nda::linalg::inv: getrf routine failed: info = " << info;
162
163 // calculate the inverse with getrs and the identity matrix
165 info = lapack::getrs(m_copy, B, ipiv);
166 if (info != 0) NDA_RUNTIME_ERROR << "Error in nda::linalg::inv: getrs routine failed: info = " << info;
167 return B;
168 } else {
169 inv_in_place(m_copy);
170 return m_copy;
171 }
172 }
173
179
180
181
182} // namespace nda::linalg
Provides definitions and type traits involving the different memory address spaces supported by nda.
Provides the generic class for arrays.
void swap(nda::basic_array_view< V1, R1, LP1, A1, AP1, OP1 > &a, nda::basic_array_view< V2, R2, LP2, A2, AP2, OP2 > &b)=delete
std::swap is deleted for nda::basic_array_view.
Provides basic functions to create and manipulate arrays and views.
Provides various traits and utilities for the BLAS interface.
Provides concepts for the nda library.
Provides a custom runtime error class and macros to assert conditions and throw exceptions.
Provides a generic interface to the LAPACK/cuSOLVER getrf routine.
Provides a generic interface to the LAPACK getri routine.
Provides a generic interface to the LAPACK/cuSOLVER getrs routine.
auto eye(Int dim)
Create an identity nda::matrix with ones on the diagonal.
decltype(auto) make_regular(A &&a)
Make a given object regular.
auto transpose(A &&a)
Transpose the memory layout of an nda::MemoryArray or an nda::expr_call.
bool is_matrix_square(A const &a, bool print_error=false)
Check if a given matrix is square, i.e. if the first dimension has the same extent as the second dime...
basic_array< ValueType, 1, C_layout, 'V', ContainerPolicy > vector
Alias template of an nda::basic_array with rank 1 and a 'V' algebra.
nda::basic_array< ValueType, 1+sizeof...(Ns), nda::basic_layout< nda::static_extents(N0, Ns...), nda::C_stride_order< 1+sizeof...(Ns)>, nda::layout_prop_e::contiguous >, 'A', nda::stack< N0 *(Ns *... *1)> > stack_array
Alias template of an nda::basic_array with static extents, contiguous C layout, 'A' algebra and nda::...
basic_array< ValueType, 2, Layout, 'M', ContainerPolicy > matrix
Alias template of an nda::basic_array with rank 2 and an 'M' algebra.
constexpr char get_algebra
Constexpr variable that specifies the algebra of a type.
Definition traits.hpp:137
std::decay_t< decltype(get_first_element(std::declval< A const >()))> get_value_t
Get the value type of an array/view or a scalar type.
Definition traits.hpp:212
#define CLEF_MAKE_FNT_LAZY(name)
Macro to make any function lazy, i.e. accept lazy arguments and return a function call expression nod...
Definition make_lazy.hpp:89
int getri(A &&a, IPIV const &ipiv, W &&work=vector_value_t< A >{})
Interface to the LAPACK getri routine.
Definition getri.hpp:53
int getrf(A &&a, IPIV &&ipiv, W &&work=vector_value_t< A >{})
Interface to the LAPACK/cuSOLVER getrf routine.
Definition getrf.hpp:61
int getrs(A const &a, B &&b, IPIV const &ipiv)
Interface to the LAPACK/cuSOLVER getrs routine.
Definition getrs.hpp:56
auto det(M const &m)
Compute the determinant of an matrix .
Definition det.hpp:99
void inv_in_place(M &&m)
Compute the inverse of an matrix in place.
Definition inv.hpp:101
auto inv(M const &m)
Compute the inverse of an matrix .
Definition inv.hpp:150
static constexpr bool have_host_compatible_addr_space
Constexpr variable that is true if all given types have an address space compatible with Host.
static constexpr bool have_device_compatible_addr_space
Constexpr variable that is true if all given types have an address space compatible with Device.
heap_basic< mem::mallocator< AdrSp > > heap
Alias template of the nda::heap_basic policy using an nda::mem::mallocator.
Definition policies.hpp:52
constexpr bool is_blas_lapack_v
Constexpr variable that is true if type T is either of type 'float', double, std::complex<float>' or ...
Definition traits.hpp:95
Provides definitions of various layout policies.
Macros used in the nda library.
Provides functions to create and manipulate matrices, i.e. arrays/view with 'M' algebra.
Defines various memory handling policies.
Contiguous layout policy with Fortran-order (column-major order).
Definition policies.hpp:52
Provides type traits for the nda library.