TRIQS/nda 2.0.0
Multi-dimensional array library for C++
Loading...
Searching...
No Matches
solve.hpp
Go to the documentation of this file.
1// Copyright (c) 2019--present, The Simons Foundation
2// This file is part of TRIQS/nda and is licensed under the Apache License, Version 2.0.
3// SPDX-License-Identifier: Apache-2.0
4// See LICENSE in the root of this distribution for details.
5
10
11#pragma once
12
13#include "../basic_array.hpp"
14#include "../blas/tools.hpp"
15#include "../concepts.hpp"
16#include "../declarations.hpp"
17#include "../exceptions.hpp"
18#include "../lapack/getrf.hpp"
19#include "../lapack/getrs.hpp"
21#include "../macros.hpp"
23#include "../mem/policies.hpp"
24#include "../traits.hpp"
25
26#include <type_traits>
27
28namespace nda::linalg {
29
34
60 template <blas_lapack::BlasArray<2> A, blas_lapack::BlasArrayFor<A> B>
61 requires((get_rank<B> == 1 || get_rank<B> == 2) and blas_lapack::has_F_layout<B>)
62 void solve_in_place(A &&a, B &&b) { // NOLINT (temporary views are allowed here)
63 // check the dimensions of the input/output arrays/views
64 EXPECTS_WITH_MESSAGE(a.extent(0) == a.extent(1), "Error in nda::linalg::solve_in_place: Matrix A is not square");
65 EXPECTS_WITH_MESSAGE(a.extent(0) == b.extent(0), "Error in nda::linalg::solve_in_place: Dimension mismatch between matrix A and B");
66
67 // pivot indices vector
68 auto ipiv = vector<int, heap<mem::common_addr_space<A, B>>>(a.extent(0));
69
70 // call getrf to compute LU factorization
71 int info = lapack::getrf(a, ipiv);
72 if (info != 0) NDA_RUNTIME_ERROR << "Error in nda::linalg::solve_in_place: getrf returned a non-zero value: info = " << info;
73
74 // call getrs to solve AX=B using the LU factorization
75 info = lapack::getrs(a, b, ipiv);
76 if (info != 0) NDA_RUNTIME_ERROR << "Error in nda::linalg::solve_in_place: getrs returned a non-zero value: info = " << info;
77 }
78
99 template <Matrix A, Array B>
101 auto solve(A const &a, B const &b) { // NOLINT (temporary views are allowed here)
102 // copy A and preserve its layout
103 using a_layout_policy = nda::detail::layout_to_policy<typename std::remove_cvref_t<A>::layout_t>::type;
104 auto a_copy = matrix<get_value_t<A>, a_layout_policy, heap<mem::common_addr_space<A, B>>>(a);
105
106 // copy B and enforce Fortran layout for the matrix case
109 using b_type = std::conditional_t<get_rank<B> == 1, vector_t, matrix_t>;
110 auto b_copy = b_type(b);
111
112 // call solve_in_place with the copies
113 solve_in_place(a_copy, b_copy);
114 return b_copy;
115 }
116
118
119} // namespace nda::linalg
Provides definitions and type traits involving the different memory address spaces supported by nda.
Provides the generic class for arrays.
Provides various traits and utilities for the BLAS interface.
Provides concepts for the nda library.
Provides various convenient aliases and helper functions for nda::basic_array and nda::basic_array_vi...
Provides a custom runtime error class and macros to assert conditions and throw exceptions.
Provides a generic interface to the LAPACK/cuSOLVER getrf routine.
Provides a generic interface to the LAPACK/cuSOLVER getrs routine.
basic_array< ValueType, 1, C_layout, 'V', ContainerPolicy > vector
Alias template of an nda::basic_array with rank 1 and a 'V' algebra.
basic_array< ValueType, 2, Layout, 'M', ContainerPolicy > matrix
Alias template of an nda::basic_array with rank 2 and an 'M' algebra.
constexpr bool have_same_value_type_v
Constexpr variable that is true if all types in As have the same value type as A0.
Definition traits.hpp:225
constexpr int get_rank
Constexpr variable that specifies the rank of an nda::Array or of a contiguous 1-dimensional range.
Definition traits.hpp:147
static constexpr bool has_F_layout
Constexpr variable that is true if all given nda::Array types have nda::F_layout.
Definition tools.hpp:79
int getrf(A &&a, IPIV &&ipiv, W &&work=vector_value_t< A >{})
Interface to the LAPACK/cuSOLVER getrf routine.
Definition getrf.hpp:61
int getrs(A const &a, B &&b, IPIV const &ipiv)
Interface to the LAPACK/cuSOLVER getrs routine.
Definition getrs.hpp:56
void solve_in_place(A &&a, B &&b)
Solve a system of linear equations in place.
Definition solve.hpp:62
auto solve(A const &a, B const &b)
Solve a system of linear equations.
Definition solve.hpp:101
static constexpr bool have_compatible_addr_space
Constexpr variable that is true if all given types have compatible address spaces.
heap_basic< mem::mallocator< AdrSp > > heap
Alias template of the nda::heap_basic policy using an nda::mem::mallocator.
Definition policies.hpp:52
constexpr bool is_blas_lapack_v
Constexpr variable that is true if type T is either of type 'float', double, std::complex<float>' or ...
Definition traits.hpp:95
Provides definitions of various layout policies.
Macros used in the nda library.
Defines various memory handling policies.
Contiguous layout policy with Fortran-order (column-major order).
Definition policies.hpp:52
Provides type traits for the nda library.