TRIQS/nda 1.3.0
Multi-dimensional array library for C++
Loading...
Searching...
No Matches
lu.hpp
Go to the documentation of this file.
1// Copyright (c) 2019-2024 Simons Foundation
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7// http://www.apache.org/licenses/LICENSE-2.0.txt
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14//
15// Authors: Thomas Hahn, Olivier Parcollet, Nils Wentzell
16
21
22#pragma once
23
24#include "./utils.hpp"
25#include "../basic_array.hpp"
27#include "../blas/tools.hpp"
28#include "../concepts.hpp"
29#include "../declarations.hpp"
30#include "../exceptions.hpp"
31#include "../lapack/getrf.hpp"
33#include "../layout/range.hpp"
34#include "../macros.hpp"
36#include "../traits.hpp"
37
38#include <algorithm>
39#include <tuple>
40#include <type_traits>
41
42namespace nda::linalg {
43
48
61 template <typename LP = F_layout, MemoryMatrix A>
63 auto get_lu_matrices(A const &a) {
64 using output_t = matrix<get_value_t<A>, LP>;
65
66 // copy the first k columns from A to L and U
67 auto const [m, n] = a.shape();
68 auto const k = std::min(m, n);
69 auto L = output_t::zeros(m, k);
70 auto U = output_t::zeros(k, n);
71 for (int i = 0; i < k; ++i) {
72 L(i, i) = get_value_t<A>{1};
73 L(range(i + 1, m), i) = a(range(i + 1, m), i);
74 U(range(i + 1), i) = a(range(i + 1), i);
75 }
76
77 // in case of n > m, copy the remaining columns to U
78 for (int i = k; i < n; ++i) U(range::all, i) = a(range::all, i);
79
80 return std::make_tuple(L, U);
81 }
82
110 template <typename LP = F_layout, MemoryMatrix A>
112 auto lu_in_place(A &&a, bool allow_singular = false) { // NOLINT (temporary views are allowed here)
113 // input, output types and static assertions
114
115 // pivot indices vector
116 auto ipiv = vector<int>{};
117
118 // call lapack getrf
119 int info = lapack::getrf(a, ipiv);
120 if (info < 0) {
121 NDA_RUNTIME_ERROR << "Error in nda::lu_in_place: getrf failed with invalid argument (info = " << info << ")";
122 } else if (info > 0 and not allow_singular) {
123 NDA_RUNTIME_ERROR << "Error in nda::lu_in_place: Matrix is singular, U(" << info << "," << info << ") is exactly zero";
124 }
125
126 // extract sigma, L, U from the output of getrf
127 auto sigma = get_permutation_vector(ipiv, a.extent(0));
128 auto [L, U] = get_lu_matrices<LP>(a);
129
130 return std::make_tuple(sigma, L, U);
131 }
132
146 template <Matrix A>
148 auto lu(A const &a, bool allow_singular = false) {
149 auto a_copy = matrix<get_value_t<A>, F_layout>(a);
150 if constexpr (nda::blas::has_F_layout<A>) {
151 return lu_in_place(a_copy, allow_singular);
152 } else {
153 return lu_in_place<C_layout>(a_copy, allow_singular);
154 }
155 }
156
158
159} // namespace nda::linalg
Provides definitions and type traits involving the different memory address spaces supported by nda.
Provides the generic class for arrays.
Provides basic functions to create and manipulate arrays and views.
Provides concepts for the nda library.
Provides various convenient aliases and helper functions for nda::basic_array and nda::basic_array_vi...
Provides a custom runtime error class and macros to assert conditions and throw exceptions.
Provides a generic interface to the LAPACK getrf routine.
basic_array< ValueType, 1, C_layout, 'V', ContainerPolicy > vector
Alias template of an nda::basic_array with rank 1 and a 'V' algebra.
basic_array< ValueType, 2, Layout, 'M', ContainerPolicy > matrix
Alias template of an nda::basic_array with rank 2 and an 'M' algebra.
std::decay_t< decltype(get_first_element(std::declval< A const >()))> get_value_t
Get the value type of an array/view or a scalar type.
Definition traits.hpp:182
static constexpr bool has_F_layout
Constexpr variable that is true if the given nda::Array type has nda::F_layout.
Definition tools.hpp:73
int getrf(A &&a, IPIV &&ipiv)
Interface to the LAPACK getrf routine.
Definition getrf.hpp:58
auto lu_in_place(A &&a, bool allow_singular=false)
Compute the LU factorization of a matrix in place.
Definition lu.hpp:112
auto get_permutation_vector(Vector auto const &ipiv, int m)
Get the permutation vector from the pivot indices returned by nda::lapack::getrf or other LAPACK rou...
Definition utils.hpp:56
auto lu(A const &a, bool allow_singular=false)
Compute the LU factorization of a matrix.
Definition lu.hpp:148
auto get_lu_matrices(A const &a)
Get the and matrices from the output of nda::lapack::getrf.
Definition lu.hpp:63
static constexpr bool have_host_compatible_addr_space
Constexpr variable that is true if all given types have an address space compatible with Host.
constexpr bool is_blas_lapack_v
Alias for nda::is_double_or_complex_v.
Definition traits.hpp:92
Provides definitions of various layout policies.
Provides utility functions for the nda::linalg namespace.
Macros used in the nda library.
Includes the itertools header and provides some additional utilities.
Contiguous layout policy with Fortran-order (column-major order).
Definition policies.hpp:52
Provides various traits and utilities for the BLAS interface.
Provides type traits for the nda library.