TRIQS/nda 1.3.0
Multi-dimensional array library for C++
Loading...
Searching...
No Matches
matvecmul.hpp
Go to the documentation of this file.
1// Copyright (c) 2019--present, The Simons Foundation
2// This file is part of TRIQS/nda and is licensed under the Apache License, Version 2.0.
3// SPDX-License-Identifier: Apache-2.0
4// See LICENSE in the root of this distribution for details.
5
10
11#pragma once
12
13#include "../basic_array.hpp"
15#include "../blas/gemv.hpp"
16#include "../blas/tools.hpp"
17#include "../concepts.hpp"
18#include "../declarations.hpp"
19#include "../exceptions.hpp"
22#include "../mem/policies.hpp"
23#include "../traits.hpp"
24
25#include <type_traits>
26#include <utility>
27
28namespace nda::linalg {
29
34
35 namespace detail {
36
37 // Generic matrix-vector multiplication for types not supported by BLAS.
38 template <Matrix A, Vector X, MemoryVector Y>
40 void gemv_generic(auto alpha, A const &a, X const &x, auto beta, Y &&y) { // NOLINT (temporary views are allowed here)
41 // check the dimensions of the input/output arrays/views
42 auto const [m, n] = a.shape();
43 EXPECTS(n == x.size());
44 EXPECTS(m == y.size());
45
46 // perform the matrix-vector multiplication
47 for (int i = 0; i < m; ++i) {
48 y(i) = beta * y(i);
49 for (int j = 0; j < n; ++j) y(i) += alpha * a(i, j) * x(j);
50 }
51 }
52
53 // Make compile time checks if blas::gemv can handle the given input vector. If it can, simply forward the vector.
54 // Otherwise, return a copy with the given value type T and container policy CP.
55 template <typename T, typename CP, Vector X>
56 decltype(auto) get_gemv_vector(X &&x) {
57 if constexpr (std::is_same_v<get_value_t<X>, T> and MemoryVector<X>) {
58 return std::forward<X>(x);
59 } else {
60 return vector<T, CP>{x};
61 }
62 }
63
64 // Make compile time checks if blas::gemv can handle the given input matrix. If it can, simply forward the matrix.
65 // Otherwise, return a copy with the given value type T and container policy CP.
66 template <typename T, typename CP, Matrix A>
67 decltype(auto) get_gemv_matrix(A &&a) {
68 if constexpr (requires { blas::get_array(a); } and std::is_same_v<get_value_t<A>, T>) {
69 if constexpr (MemoryMatrix<A> or (blas::is_conj_array_expr<A> and blas::has_C_layout<A>)) {
70 return std::forward<A>(a);
71 } else {
73 }
74 } else {
76 }
77 }
78
79 // Make the call to nda::blas::gemv with a copy of the matrix if it is not contiguous.
80 template <Matrix A, Vector X, MemoryVector Y>
81 void make_gemv_call(A const &a, X const &x, Y &y) {
82 if (blas::get_array(a).is_contiguous()) {
83 blas::gemv(1, a, x, 0, y);
84 } else {
85 blas::gemv(1, nda::make_regular(a), x, 0, y);
86 }
87 }
88
89 } // namespace detail
90
117 template <Matrix A, Vector X>
119 auto matvecmul(A const &a, X const &x) {
120 // get the return type
121 using value_t = decltype(a(0, 0) * x(0));
122 using cont_pol = heap<mem::common_addr_space<A, X>>;
123 using return_t = vector<value_t, cont_pol>;
124
125 // result vector (MSAN complains if it is not initialized)
126 auto res = return_t(a.shape()[0]);
127#if defined(__has_feature)
128#if __has_feature(memory_sanitizer)
129 res = 0;
130#endif
131#endif
132
133 // perform matrix-vector multiplication (if possible we try to call blas::gemv even if this requires making copies)
134 if constexpr (is_blas_lapack_v<value_t>) {
135 auto &&a_mat = detail::get_gemv_matrix<value_t, cont_pol>(a);
136 auto &&x_vec = detail::get_gemv_vector<value_t, cont_pol>(x);
137
138 // check at runtime if the input matrix is contiguous, make copies if not and call blas::gemv
139 detail::make_gemv_call(a_mat, x_vec, res);
140 } else {
141 detail::gemv_generic(1, a, x, 0, res);
142 }
143 return res;
144 }
145
147
148} // namespace nda::linalg
Provides definitions and type traits involving the different memory address spaces supported by nda.
Provides the generic class for arrays.
Provides basic functions to create and manipulate arrays and views.
Provides concepts for the nda library.
Provides various convenient aliases and helper functions for nda::basic_array and nda::basic_array_vi...
Provides a custom runtime error class and macros to assert conditions and throw exceptions.
Provides a generic interface to the BLAS gemv routine.
decltype(auto) make_regular(A &&a)
Make a given object regular.
basic_array< ValueType, 1, C_layout, 'V', ContainerPolicy > vector
Alias template of an nda::basic_array with rank 1 and a 'V' algebra.
basic_array< ValueType, 2, Layout, 'M', ContainerPolicy > matrix
Alias template of an nda::basic_array with rank 2 and an 'M' algebra.
static constexpr bool has_C_layout
Constexpr variable that is true if the given nda::Array type has nda::C_layout.
Definition tools.hpp:83
MemoryArray decltype(auto) get_array(A &&a)
Get the underlying array of a conjugate lazy expression or return the array itself in case it is an n...
Definition tools.hpp:62
void gemv(get_value_t< A > alpha, A const &a, X const &x, get_value_t< A > beta, Y &&y)
Interface to the BLAS gemv routine.
Definition gemv.hpp:57
auto matvecmul(A const &a, X const &x)
Compute the matrix-vector product of an nda::matrix and an nda::vector object.
static constexpr bool have_host_compatible_addr_space
Constexpr variable that is true if all given types have an address space compatible with Host.
static constexpr bool have_compatible_addr_space
Constexpr variable that is true if all given types have compatible address spaces.
heap_basic< mem::mallocator< AdrSp > > heap
Alias template of the nda::heap_basic policy using an nda::mem::mallocator.
Definition policies.hpp:52
constexpr bool is_blas_lapack_v
Alias for nda::is_double_or_complex_v.
Definition traits.hpp:92
Provides definitions of various layout policies.
Defines various memory handling policies.
Provides various traits and utilities for the BLAS interface.
Provides type traits for the nda library.