TRIQS/nda 2.0.0
Multi-dimensional array library for C++
Loading...
Searching...
No Matches
lu.hpp
Go to the documentation of this file.
1// Copyright (c) 2019--present, The Simons Foundation
2// This file is part of TRIQS/nda and is licensed under the Apache License, Version 2.0.
3// SPDX-License-Identifier: Apache-2.0
4// See LICENSE in the root of this distribution for details.
5
10
11#pragma once
12
13#include "./utils.hpp"
14#include "../basic_array.hpp"
16#include "../blas/tools.hpp"
17#include "../concepts.hpp"
18#include "../declarations.hpp"
19#include "../exceptions.hpp"
20#include "../lapack/getrf.hpp"
22#include "../layout/range.hpp"
23#include "../macros.hpp"
25#include "../traits.hpp"
26
27#include <algorithm>
28#include <tuple>
29#include <type_traits>
30
31namespace nda::linalg {
32
37
53 template <typename LP = F_layout, MemoryMatrix A>
55 auto get_lu_matrices(A const &a) {
56 using output_t = matrix<get_value_t<A>, LP>;
57
58 // copy the first k columns from A to L and U
59 auto const [m, n] = a.shape();
60 auto const k = std::min(m, n);
61 auto L = output_t::zeros(m, k);
62 auto U = output_t::zeros(k, n);
63 for (int i = 0; i < k; ++i) {
64 L(i, i) = get_value_t<A>{1};
65 L(range(i + 1, m), i) = a(range(i + 1, m), i);
66 U(range(i + 1), i) = a(range(i + 1), i);
67 }
68
69 // in case of n > m, copy the remaining columns to U
70 for (int i = k; i < n; ++i) U(range::all, i) = a(range::all, i);
71
72 return std::make_tuple(L, U);
73 }
74
105 template <typename LP = F_layout, blas_lapack::BlasArray<2> A>
107 auto lu_in_place(A &&a, bool allow_singular = false) { // NOLINT (temporary views are allowed here)
108 // pivot indices vector
109 auto ipiv = vector<int>{};
110
111 // call getrf
112 int info = lapack::getrf(a, ipiv);
113 if (info < 0) {
114 NDA_RUNTIME_ERROR << "Error in nda::linalg::lu_in_place: getrf failed with invalid argument (info = " << info << ")";
115 } else if (info > 0 and not allow_singular) {
116 NDA_RUNTIME_ERROR << "Error in nda::linalg::lu_in_place: Matrix is singular, U(" << info << "," << info << ") is exactly zero";
117 }
118
119 // extract sigma, L, U from the output of getrf
120 auto sigma = get_permutation_vector(ipiv, a.extent(0));
121 auto [L, U] = get_lu_matrices<LP>(a);
122
123 return std::make_tuple(sigma, L, U);
124 }
125
143 template <Matrix A>
145 auto lu(A const &a, bool allow_singular = false) {
146 auto a_copy = matrix<get_value_t<A>, F_layout>(a);
147 if constexpr (blas_lapack::has_F_layout<A>) {
148 return lu_in_place(a_copy, allow_singular);
149 } else {
150 return lu_in_place<C_layout>(a_copy, allow_singular);
151 }
152 }
153
155
156} // namespace nda::linalg
Provides definitions and type traits involving the different memory address spaces supported by nda.
Provides the generic class for arrays.
Provides basic functions to create and manipulate arrays and views.
Provides various traits and utilities for the BLAS interface.
Provides concepts for the nda library.
Provides various convenient aliases and helper functions for nda::basic_array and nda::basic_array_vi...
Provides a custom runtime error class and macros to assert conditions and throw exceptions.
Provides a generic interface to the LAPACK/cuSOLVER getrf routine.
basic_array< ValueType, 1, C_layout, 'V', ContainerPolicy > vector
Alias template of an nda::basic_array with rank 1 and a 'V' algebra.
basic_array< ValueType, 2, Layout, 'M', ContainerPolicy > matrix
Alias template of an nda::basic_array with rank 2 and an 'M' algebra.
std::decay_t< decltype(get_first_element(std::declval< A const >()))> get_value_t
Get the value type of an array/view or a scalar type.
Definition traits.hpp:212
static constexpr bool has_F_layout
Constexpr variable that is true if all given nda::Array types have nda::F_layout.
Definition tools.hpp:79
auto get_lu_matrices(A const &a)
Get the and matrices from the output of nda::lapack::getrf.
Definition lu.hpp:55
auto lu_in_place(A &&a, bool allow_singular=false)
Compute the LU factorization of a matrix in place.
Definition lu.hpp:107
auto lu(A const &a, bool allow_singular=false)
Compute the LU factorization of a matrix.
Definition lu.hpp:145
int getrf(A &&a, IPIV &&ipiv, W &&work=vector_value_t< A >{})
Interface to the LAPACK/cuSOLVER getrf routine.
Definition getrf.hpp:61
auto get_permutation_vector(Vector auto const &ipiv, int m)
Get the permutation vector from the pivot indices returned by nda::lapack::getrf or other LAPACK rou...
Definition utils.hpp:47
static constexpr bool have_host_compatible_addr_space
Constexpr variable that is true if all given types have an address space compatible with Host.
constexpr bool is_blas_lapack_v
Constexpr variable that is true if type T is either of type 'float', double, std::complex<float>' or ...
Definition traits.hpp:95
Provides definitions of various layout policies.
Provides utility functions for the nda::linalg namespace.
Macros used in the nda library.
Includes the itertools header and provides some additional utilities.
Contiguous layout policy with Fortran-order (column-major order).
Definition policies.hpp:52
Provides type traits for the nda library.