TRIQS/nda 1.3.0
Multi-dimensional array library for C++
Loading...
Searching...
No Matches
solve.hpp
Go to the documentation of this file.
1// Copyright (c) 2019-2024 Simons Foundation
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7// http://www.apache.org/licenses/LICENSE-2.0.txt
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14//
15// Authors: Thomas Hahn, Olivier Parcollet, Nils Wentzell
16
21
22#pragma once
23
24#include "../basic_array.hpp"
25#include "../blas/tools.hpp"
26#include "../concepts.hpp"
27#include "../declarations.hpp"
28#include "../exceptions.hpp"
29#include "../lapack/getrf.hpp"
30#include "../lapack/getrs.hpp"
32#include "../macros.hpp"
34#include "../mem/policies.hpp"
35#include "../traits.hpp"
36
37#include <type_traits>
38
39namespace nda::linalg {
40
45
73 template <MemoryMatrix A, MemoryArray B>
75 void solve_in_place(A &&a, B &&b) { // NOLINT (temporary views are allowed here)
76 static_assert(get_rank<B> == 1 || get_rank<B> == 2, "Error in nda::linalg::solve_in_place: Right hand side must have rank 1 or 2");
77 static_assert(nda::blas::has_F_layout<B>, "Error in nda::linalg::solve_in_place: Right hand side must have Fortran layout");
78
79 constexpr auto addr_space = nda::mem::common_addr_space<A, B>;
80
81 // check the dimensions of the input/output arrays/views
82 EXPECTS_WITH_MESSAGE(a.shape()[0] == a.shape()[1], "Error in nda::linalg::solve_in_place: Matrix A is not square");
83 EXPECTS_WITH_MESSAGE(a.shape()[0] == b.extent(0), "Error in nda::linalg::solve_in_place: Dimension mismatch between matrix A and B");
84
85 // pivot indices vector
86 auto ipiv = vector<int, heap<addr_space>>(a.extent(0));
87
88 // call lapack getrf to compute LU factorization
89 int info = lapack::getrf(a, ipiv);
90 if (info != 0) NDA_RUNTIME_ERROR << "Error in nda::linalg::solve_in_place: getrf returned a non-zero value: info = " << info;
91
92 // call lapack getrs to solve using the LU factorization
93 info = lapack::getrs(a, b, ipiv);
94 if (info != 0) NDA_RUNTIME_ERROR << "Error in nda::linalg::solve_in_place: getrs returned a non-zero value: info = " << info;
95 }
96
124 template <Matrix A, Array B>
126 auto solve(A const &a, B const &b) { // NOLINT (temporary views are allowed here)
127 // copy A and preserve its layout
128 using a_layout_policy = nda::detail::layout_to_policy<typename std::remove_cvref_t<A>::layout_t>::type;
129 auto a_copy = matrix<get_value_t<A>, a_layout_policy, heap<nda::mem::common_addr_space<A, B>>>(a);
130
131 // copy B and enforce Fortran layout for the matrix case
134 using b_type = std::conditional_t<get_rank<B> == 1, vector_t, matrix_t>;
135 auto b_copy = b_type(b);
136
137 // call solve_in_place with the copies
138 solve_in_place(a_copy, b_copy);
139 return b_copy;
140 }
141
143
144} // namespace nda::linalg
Provides definitions and type traits involving the different memory address spaces supported by nda.
Provides the generic class for arrays.
Provides concepts for the nda library.
Provides various convenient aliases and helper functions for nda::basic_array and nda::basic_array_vi...
Provides a custom runtime error class and macros to assert conditions and throw exceptions.
Provides a generic interface to the LAPACK getrf routine.
Provides a generic interface to the LAPACK getrs routine.
basic_array< ValueType, 1, C_layout, 'V', ContainerPolicy > vector
Alias template of an nda::basic_array with rank 1 and a 'V' algebra.
basic_array< ValueType, 2, Layout, 'M', ContainerPolicy > matrix
Alias template of an nda::basic_array with rank 2 and an 'M' algebra.
constexpr bool have_same_value_type_v
Constexpr variable that is true if all types in As have the same value type as A0.
Definition traits.hpp:186
constexpr int get_rank
Constexpr variable that specifies the rank of an nda::Array or of a contiguous 1-dimensional range.
Definition traits.hpp:126
static constexpr bool has_F_layout
Constexpr variable that is true if the given nda::Array type has nda::F_layout.
Definition tools.hpp:73
int getrf(A &&a, IPIV &&ipiv)
Interface to the LAPACK getrf routine.
Definition getrf.hpp:58
int getrs(A const &a, B &&b, IPIV const &ipiv)
Interface to the LAPACK getrs routine.
Definition getrs.hpp:53
auto solve(A const &a, B const &b)
Solve a system of linear equations.
Definition solve.hpp:126
void solve_in_place(A &&a, B &&b)
Solve a system of linear equations in-place.
Definition solve.hpp:75
static constexpr bool have_compatible_addr_space
Constexpr variable that is true if all given types have compatible address spaces.
constexpr AddressSpace common_addr_space
Get common address space for a number of given nda::Array types.
heap_basic< mem::mallocator< AdrSp > > heap
Alias template of the nda::heap_basic policy using an nda::mem::mallocator.
Definition policies.hpp:52
constexpr bool is_blas_lapack_v
Alias for nda::is_double_or_complex_v.
Definition traits.hpp:92
Provides definitions of various layout policies.
Macros used in the nda library.
Defines various memory handling policies.
Contiguous layout policy with Fortran-order (column-major order).
Definition policies.hpp:52
Provides various traits and utilities for the BLAS interface.
Provides type traits for the nda library.