TRIQS/nda 1.3.0
Multi-dimensional array library for C++
Loading...
Searching...
No Matches
matmul.hpp
Go to the documentation of this file.
1// Copyright (c) 2019-2023 Simons Foundation
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7// http://www.apache.org/licenses/LICENSE-2.0.txt
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14//
15// Authors: Olivier Parcollet, Nils Wentzell
16
17/**
18 * @file
19 * @brief Provides matrix-matrix an matrix-vector multiplication.
20 */
21
22#pragma once
23
24#include "../basic_functions.hpp"
25#include "../blas/gemm.hpp"
26#include "../blas/gemv.hpp"
27#include "../blas/tools.hpp"
28#include "../concepts.hpp"
29#include "../declarations.hpp"
30#include "../layout/policies.hpp"
31#include "../mem/address_space.hpp"
32#include "../mem/policies.hpp"
33#include "../traits.hpp"
34
35#include <type_traits>
36#include <utility>
37
38namespace nda {
39
40 /**
41 * @addtogroup linalg_tools
42 * @{
43 */
44
45 namespace detail {
46
47 // Helper variable template to check if the three matrix types can be passed to gemm.
48 // The following combinations are allowed (gemm can only be called with 'N', 'T' or 'C' op tags):
49 // - C in Fortran layout:
50 // -- A/B is not a conj expression and has Fortran layout
51 // -- A/B is a conj expression and has C layout
52 // - C in C layout:
53 // -- A/B is not a conj expression and has C layout
54 // -- A/B is a conj expression and has Fortran layout
55 template <Matrix A, Matrix B, MemoryMatrix C, bool conj_A = blas::is_conj_array_expr<A>, bool conj_B = blas::is_conj_array_expr<B>>
56 requires((MemoryMatrix<A> or conj_A) and (MemoryMatrix<B> or conj_B))
57 static constexpr bool is_valid_gemm_triple = []() {
58 using blas::has_F_layout;
59 if constexpr (has_F_layout<C>) {
60 return !(conj_A and has_F_layout<A>)and!(conj_B and has_F_layout<B>);
61 } else {
62 return !(conj_B and !has_F_layout<B>)and!(conj_A and !has_F_layout<A>);
63 }
64 }();
65
66 // Get the layout policy for a given array type.
67 template <Array A>
68 using get_layout_policy = typename std::remove_reference_t<decltype(make_regular(std::declval<A>()))>::layout_policy_t;
69
70 } // namespace detail
71
72 /**
73 * @brief Perform a matrix-matrix multiplication.
74 *
75 * @details It is generic in the sense that it allows the input matrices to belong to a different
76 * nda::mem::AddressSpace (as long as they are compatible).
77 *
78 * If possible, it uses nda::blas::gemm, otherwise it calls nda::blas::gemm_generic.
79 *
80 * @tparam A nda::Matrix type of lhs operand.
81 * @tparam B nda::Matrix type of rhs operand.
82 * @param a Left hand side matrix operand.
83 * @param b Right hand side matrix operand.
84 * @return Result of the matrix-matrix multiplication.
85 */
86 template <Matrix A, Matrix B>
87 auto matmul(A &&a, B &&b) { // NOLINT (temporary views are allowed here)
88 // check dimensions
89 EXPECTS_WITH_MESSAGE(a.shape()[1] == b.shape()[0], "Error in nda::matmul: Dimension mismatch in matrix-matrix product");
90
91 // check address space compatibility
92 static constexpr auto L_adr_spc = mem::get_addr_space<A>;
93 static constexpr auto R_adr_spc = mem::get_addr_space<B>;
94 mem::check_adr_sp_valid<L_adr_spc, R_adr_spc>();
95
96 // get resulting value type, layout policy and matrix type
97 using value_t = decltype(get_value_t<A>{} * get_value_t<B>{});
98 using layout_policy =
99 std::conditional_t<get_layout_info<A>.stride_order == get_layout_info<B>.stride_order, detail::get_layout_policy<A>, C_layout>;
100 using matrix_t = basic_array<value_t, 2, layout_policy, 'M', nda::heap<mem::combine<L_adr_spc, R_adr_spc>>>;
101
102 // perform matrix-matrix multiplication
103 auto result = matrix_t(a.shape()[0], b.shape()[1]);
104 if constexpr (is_blas_lapack_v<value_t>) {
105 // for double or complex value types we use blas::gemm
106 // lambda to form a new matrix with the correct value type if necessary
107 auto as_container = []<Matrix M>(M &&m) -> decltype(auto) {
108 if constexpr (std::is_same_v<get_value_t<M>, value_t> and (MemoryMatrix<M> or blas::is_conj_array_expr<M>))
109 return std::forward<M>(m);
110 else
111 return matrix_t{std::forward<M>(m)};
112 };
113
114 // MSAN has no way to know that we are calling with beta = 0, hence this is not necessary.
115 // Of course, in production code, we do NOT waste time to do this.
116#if defined(__has_feature)
117#if __has_feature(memory_sanitizer)
118 result = 0;
119#endif
120#endif
121
122 // check if we can call gemm directly
123 if constexpr (detail::is_valid_gemm_triple<decltype(as_container(a)), decltype(as_container(b)), matrix_t>) {
124 blas::gemm(1, as_container(a), as_container(b), 0, result);
125 } else {
126 // otherwise, turn the lhs and rhs first into regular matrices and then call gemm
127 blas::gemm(1, make_regular(as_container(a)), make_regular(as_container(b)), 0, result);
128 }
129
130 } else {
131 // for other value types we use a generic implementation
132 blas::gemm_generic(1, a, b, 0, result);
133 }
134 return result;
135 }
136
137 /**
138 * @brief Perform a matrix-vector multiplication.
139 *
140 * @details It is generic in the sense that it allows the input matrix and vector to belong to a different
141 * nda::mem::AddressSpace (as long as they are compatible).
142 *
143 * If possible, it uses nda::blas::gemv, otherwise it calls nda::blas::gemv_generic.
144 *
145 * @tparam A nda::Matrix type of lhs operand.
146 * @tparam X nda::Vector type of rhs operand.
147 * @param a Left hand side matrix operand.
148 * @param x Right hand side vector operand.
149 * @return Result of the matrix-vector multiplication.
150 */
151 template <Matrix A, Vector X>
152 auto matvecmul(A &&a, X &&x) { // NOLINT (temporary views are allowed here)
153 // check dimensions
154 EXPECTS_WITH_MESSAGE(a.shape()[1] == x.shape()[0], "Error in nda::matvecmul: Dimension mismatch in matrix-vector product");
155
156 // check address space compatibility
157 static constexpr auto L_adr_spc = mem::get_addr_space<A>;
158 static constexpr auto R_adr_spc = mem::get_addr_space<X>;
159 static_assert(L_adr_spc == R_adr_spc, "Error in nda::matvecmul: Matrix-vector product requires arguments with same address spaces");
160 static_assert(L_adr_spc != mem::None);
161
162 // get resulting value type and vector type
163 using value_t = decltype(get_value_t<A>{} * get_value_t<X>{});
164 using vector_t = vector<value_t, heap<L_adr_spc>>;
165
166 // perform matrix-matrix multiplication
167 auto result = vector_t(a.shape()[0]);
168 if constexpr (is_blas_lapack_v<value_t>) {
169 // for double or complex value types we use blas::gemv
170 // lambda to form a new array with the correct value type if necessary
171 auto as_container = []<Array B>(B &&b) -> decltype(auto) {
172 if constexpr (std::is_same_v<get_value_t<B>, value_t> and (MemoryMatrix<B> or (Matrix<B> and blas::is_conj_array_expr<B>)))
173 return std::forward<B>(b);
174 else
175 return basic_array<value_t, get_rank<B>, C_layout, 'A', heap<L_adr_spc>>{std::forward<B>(b)};
176 };
177
178 // MSAN has no way to know that we are calling with beta = 0, hence this is not necessary.
179 // Of course, in production code, we do NOT waste time to do this.
180#if defined(__has_feature)
181#if __has_feature(memory_sanitizer)
182 result = 0;
183#endif
184#endif
185
186 // for expressions of the kind 'conj(M) * V' with a Matrix in Fortran Layout, we have to explicitly
187 // form the conj operation in memory as gemv only provides op tags 'N', 'T' and 'C' (hermitian conjugate)
188 if constexpr (blas::is_conj_array_expr<decltype(as_container(a))> and blas::has_F_layout<decltype(as_container(a))>) {
189 blas::gemv(1, make_regular(as_container(a)), as_container(x), 0, result);
190 } else {
191 blas::gemv(1, as_container(a), as_container(x), 0, result);
192 }
193 } else {
194 // for other value types we use a generic implementation
195 blas::gemv_generic(1, a, x, 0, result);
196 }
197 return result;
198 }
199
200 /** @} */
201
202} // namespace nda
A generic multi-dimensional array.
auto matvecmul(A &&a, X &&x)
Perform a matrix-vector multiplication.
Definition matmul.hpp:152
auto matmul(A &&a, B &&b)
Perform a matrix-matrix multiplication.
Definition matmul.hpp:87
#define EXPECTS_WITH_MESSAGE(X,...)
Definition macros.hpp:75
Contiguous layout policy with C-order (row-major order).
Definition policies.hpp:47