TRIQS/nda 2.0.0
Multi-dimensional array library for C++
Loading...
Searching...
No Matches
matvecmul.hpp
Go to the documentation of this file.
1// Copyright (c) 2019--present, The Simons Foundation
2// This file is part of TRIQS/nda and is licensed under the Apache License, Version 2.0.
3// SPDX-License-Identifier: Apache-2.0
4// See LICENSE in the root of this distribution for details.
5
10
11#pragma once
12
13#include "../basic_array.hpp"
15#include "../blas/gemv.hpp"
16#include "../blas/tools.hpp"
17#include "../concepts.hpp"
18#include "../declarations.hpp"
19#include "../exceptions.hpp"
22#include "../mem/policies.hpp"
23#include "../traits.hpp"
24
25#include <type_traits>
26#include <utility>
27
28namespace nda::linalg {
29
34
35 namespace detail {
36
37 // Generic matrix-vector multiplication for types not supported by BLAS.
38 template <Matrix A, Vector X, MemoryVector Y>
40 void gemv_generic(auto alpha, A const &a, X const &x, auto beta, Y &&y) { // NOLINT (temporary views are allowed here)
41 // check the dimensions of the input/output arrays/views
42 auto const [m, n] = a.shape();
43 EXPECTS(n == x.size());
44 EXPECTS(m == y.size());
45
46 // perform the matrix-vector multiplication
47 for (int i = 0; i < m; ++i) {
48 y(i) = beta * y(i);
49 for (int j = 0; j < n; ++j) y(i) += alpha * a(i, j) * x(j);
50 }
51 }
52
53 // Make compile time checks if blas::gemv can handle the given input vector. If it can, simply forward the vector.
54 // Otherwise, return a copy with the given value type T and container policy CP.
55 template <typename T, typename CP, Vector X>
56 decltype(auto) get_gemv_vector(X &&x) {
57 if constexpr (std::is_same_v<get_value_t<X>, T> and MemoryVector<X>) {
58 return std::forward<X>(x);
59 } else {
60 return vector<T, CP>{x};
61 }
62 }
63
64 // Make compile time checks if blas::gemv can handle the given input matrix. If it can, simply forward the matrix.
65 // Otherwise, return a copy with the given value type T and container policy CP.
66 template <typename T, typename CP, Matrix A>
67 decltype(auto) get_gemv_matrix(A &&a) {
68 using namespace blas_lapack;
69 if constexpr (requires { get_array(a); } and std::is_same_v<get_value_t<A>, T>) {
70 if constexpr (MemoryMatrix<A> or (is_conj_array_expr<A> and has_C_layout<A>)) {
71 return std::forward<A>(a);
72 } else {
74 }
75 } else {
77 }
78 }
79
80 // Make the call to nda::blas::gemv with a copy of the matrix if it is not contiguous.
81 template <Matrix A, Vector X, MemoryVector Y>
82 void make_gemv_call(A const &a, X const &x, Y &y) {
83 if (blas_lapack::get_array(a).is_contiguous()) {
84 blas::gemv(1, a, x, 0, y);
85 } else {
86 blas::gemv(1, make_regular(a), x, 0, y);
87 }
88 }
89
90 } // namespace detail
91
128 template <Matrix A, Vector X>
130 auto matvecmul(A const &a, X const &x) {
131 // get the return type
132 using value_t = decltype(a(0, 0) * x(0));
133 using cont_pol = heap<mem::common_addr_space<A, X>>;
134 using return_t = vector<value_t, cont_pol>;
135
136 // result vector (MSAN complains if it is not initialized)
137 auto res = return_t(a.shape()[0]);
138#if defined(__has_feature)
139#if __has_feature(memory_sanitizer)
140 res = 0;
141#endif
142#endif
143
144 // perform matrix-vector multiplication (if possible we try to call blas::gemv even if this requires making copies)
145 if constexpr (is_blas_lapack_v<value_t>) {
146 auto &&a_mat = detail::get_gemv_matrix<value_t, cont_pol>(a);
147 auto &&x_vec = detail::get_gemv_vector<value_t, cont_pol>(x);
148
149 // check at runtime if the input matrix is contiguous, make copies if not and call blas::gemv
150 detail::make_gemv_call(a_mat, x_vec, res);
151 } else {
152 detail::gemv_generic(1, a, x, 0, res);
153 }
154 return res;
155 }
156
158
159} // namespace nda::linalg
Provides definitions and type traits involving the different memory address spaces supported by nda.
Provides the generic class for arrays.
Provides basic functions to create and manipulate arrays and views.
Provides various traits and utilities for the BLAS interface.
Provides concepts for the nda library.
Provides various convenient aliases and helper functions for nda::basic_array and nda::basic_array_vi...
Provides a custom runtime error class and macros to assert conditions and throw exceptions.
Provides a generic interface to the BLAS/cuBLAS gemv routine.
decltype(auto) make_regular(A &&a)
Make a given object regular.
basic_array< ValueType, 1, C_layout, 'V', ContainerPolicy > vector
Alias template of an nda::basic_array with rank 1 and a 'V' algebra.
basic_array< ValueType, 2, Layout, 'M', ContainerPolicy > matrix
Alias template of an nda::basic_array with rank 2 and an 'M' algebra.
MemoryArray decltype(auto) get_array(A &&a)
Get the underlying array of a conjugate lazy expression or return the array itself in case it is an n...
Definition tools.hpp:68
void gemv(get_value_t< A > alpha, A const &a, X const &x, get_value_t< A > beta, Y &&y)
Interface to the BLAS/cuBLAS gemv routine.
Definition gemv.hpp:55
auto matvecmul(A const &a, X const &x)
Compute the matrix-vector product of an nda::Matrix and an nda::Vector object.
static constexpr bool have_host_compatible_addr_space
Constexpr variable that is true if all given types have an address space compatible with Host.
static constexpr bool have_compatible_addr_space
Constexpr variable that is true if all given types have compatible address spaces.
heap_basic< mem::mallocator< AdrSp > > heap
Alias template of the nda::heap_basic policy using an nda::mem::mallocator.
Definition policies.hpp:52
constexpr bool is_blas_lapack_v
Constexpr variable that is true if type T is either of type 'float', double, std::complex<float>' or ...
Definition traits.hpp:95
Provides definitions of various layout policies.
Defines various memory handling policies.
Provides type traits for the nda library.