TRIQS/nda 1.3.0
Multi-dimensional array library for C++
Loading...
Searching...
No Matches
matmul.hpp
Go to the documentation of this file.
1// Copyright (c) 2019-2024 Simons Foundation
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7// http://www.apache.org/licenses/LICENSE-2.0.txt
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14//
15// Authors: Thomas Hahn, Olivier Parcollet, Nils Wentzell
16
22#pragma once
23
25#include "../blas/gemm.hpp"
26#include "../blas/gemv.hpp"
27#include "../blas/tools.hpp"
28#include "../concepts.hpp"
29#include "../declarations.hpp"
32#include "../mem/policies.hpp"
33#include "../traits.hpp"
34
35#include <type_traits>
36#include <utility>
37
38namespace nda {
39
45 namespace detail {
46
47 // Helper variable template to check if the three matrix types can be passed to gemm.
48 // The following combinations are allowed (gemm can only be called with 'N', 'T' or 'C' op tags):
49 // - C in Fortran layout:
50 // -- A/B is not a conj expression and has Fortran layout
51 // -- A/B is a conj expression and has C layout
52 // - C in C layout:
53 // -- A/B is not a conj expression and has C layout
54 // -- A/B is a conj expression and has Fortran layout
55 template <Matrix A, Matrix B, MemoryMatrix C, bool conj_A = blas::is_conj_array_expr<A>, bool conj_B = blas::is_conj_array_expr<B>>
56 requires((MemoryMatrix<A> or conj_A) and (MemoryMatrix<B> or conj_B))
57 static constexpr bool is_valid_gemm_triple = []() {
59 if constexpr (has_F_layout<C>) {
60 return !(conj_A and has_F_layout<A>)and!(conj_B and has_F_layout<B>);
61 } else {
62 return !(conj_B and !has_F_layout<B>)and!(conj_A and !has_F_layout<A>);
63 }
64 }();
65
66 // Get the layout policy for a given array type.
67 template <Array A>
68 using get_layout_policy = typename std::remove_reference_t<decltype(make_regular(std::declval<A>()))>::layout_policy_t;
69
70 } // namespace detail
71
86 template <Matrix A, Matrix B>
87 auto matmul(A &&a, B &&b) { // NOLINT (temporary views are allowed here)
88 // check dimensions
89 EXPECTS_WITH_MESSAGE(a.shape()[1] == b.shape()[0], "Error in nda::matmul: Dimension mismatch in matrix-matrix product");
90
91 // check address space compatibility
92 static constexpr auto L_adr_spc = mem::get_addr_space<A>;
93 static constexpr auto R_adr_spc = mem::get_addr_space<B>;
95
96 // get resulting value type, layout policy and matrix type
97 using value_t = decltype(get_value_t<A>{} * get_value_t<B>{});
98 using layout_policy =
99 std::conditional_t<get_layout_info<A>.stride_order == get_layout_info<B>.stride_order, detail::get_layout_policy<A>, C_layout>;
100 using matrix_t = basic_array<value_t, 2, layout_policy, 'M', nda::heap<mem::combine<L_adr_spc, R_adr_spc>>>;
101
102 // perform matrix-matrix multiplication
103 auto result = matrix_t(a.shape()[0], b.shape()[1]);
104 if constexpr (is_blas_lapack_v<value_t>) {
105 // for double or complex value types we use blas::gemm
106 // lambda to form a new matrix with the correct value type if necessary
107 auto as_container = []<Matrix M>(M &&m) -> decltype(auto) {
108 if constexpr (std::is_same_v<get_value_t<M>, value_t> and (MemoryMatrix<M> or blas::is_conj_array_expr<M>))
109 return std::forward<M>(m);
110 else
111 return matrix_t{std::forward<M>(m)};
112 };
113
114 // MSAN has no way to know that we are calling with beta = 0, hence this is not necessary.
115 // Of course, in production code, we do NOT waste time to do this.
116#if defined(__has_feature)
117#if __has_feature(memory_sanitizer)
118 result = 0;
119#endif
120#endif
121
122 // check if we can call gemm directly
123 if constexpr (detail::is_valid_gemm_triple<decltype(as_container(a)), decltype(as_container(b)), matrix_t>) {
124 blas::gemm(1, as_container(a), as_container(b), 0, result);
125 } else {
126 // otherwise, turn the lhs and rhs first into regular matrices and then call gemm
127 blas::gemm(1, make_regular(as_container(a)), make_regular(as_container(b)), 0, result);
128 }
129
130 } else {
131 // for other value types we use a generic implementation
132 blas::gemm_generic(1, a, b, 0, result);
133 }
134 return result;
135 }
136
151 template <Matrix A, Vector X>
152 auto matvecmul(A &&a, X &&x) { // NOLINT (temporary views are allowed here)
153 // check dimensions
154 EXPECTS_WITH_MESSAGE(a.shape()[1] == x.shape()[0], "Error in nda::matvecmul: Dimension mismatch in matrix-vector product");
155
156 // check address space compatibility
157 static constexpr auto L_adr_spc = mem::get_addr_space<A>;
158 static constexpr auto R_adr_spc = mem::get_addr_space<X>;
159 static_assert(L_adr_spc == R_adr_spc, "Error in nda::matvecmul: Matrix-vector product requires arguments with same address spaces");
160 static_assert(L_adr_spc != mem::None);
161
162 // get resulting value type and vector type
163 using value_t = decltype(get_value_t<A>{} * get_value_t<X>{});
164 using vector_t = vector<value_t, heap<L_adr_spc>>;
165
166 // perform matrix-matrix multiplication
167 auto result = vector_t(a.shape()[0]);
168 if constexpr (is_blas_lapack_v<value_t>) {
169 // for double or complex value types we use blas::gemv
170 // lambda to form a new array with the correct value type if necessary
171 auto as_container = []<Array B>(B &&b) -> decltype(auto) {
172 if constexpr (std::is_same_v<get_value_t<B>, value_t> and (MemoryMatrix<B> or (Matrix<B> and blas::is_conj_array_expr<B>)))
173 return std::forward<B>(b);
174 else
175 return basic_array<value_t, get_rank<B>, C_layout, 'A', heap<L_adr_spc>>{std::forward<B>(b)};
176 };
177
178 // MSAN has no way to know that we are calling with beta = 0, hence this is not necessary.
179 // Of course, in production code, we do NOT waste time to do this.
180#if defined(__has_feature)
181#if __has_feature(memory_sanitizer)
182 result = 0;
183#endif
184#endif
185
186 // for expressions of the kind 'conj(M) * V' with a Matrix in Fortran Layout, we have to explicitly
187 // form the conj operation in memory as gemv only provides op tags 'N', 'T' and 'C' (hermitian conjugate)
188 if constexpr (blas::is_conj_array_expr<decltype(as_container(a))> and blas::has_F_layout<decltype(as_container(a))>) {
189 blas::gemv(1, make_regular(as_container(a)), as_container(x), 0, result);
190 } else {
191 blas::gemv(1, as_container(a), as_container(x), 0, result);
192 }
193 } else {
194 // for other value types we use a generic implementation
195 blas::gemv_generic(1, a, x, 0, result);
196 }
197 return result;
198 }
199
202} // namespace nda
Provides definitions and type traits involving the different memory address spaces supported by nda.
Provides basic functions to create and manipulate arrays and views.
A generic multi-dimensional array.
Check if a given type satisfies the array concept.
Definition concepts.hpp:230
Check if a given type is a matrix, i.e. an nda::ArrayOfRank<2>.
Definition concepts.hpp:290
Check if a given type is a memory matrix, i.e. an nda::MemoryArrayOfRank<2>.
Definition concepts.hpp:306
Provides concepts for the nda library.
Provides various convenient aliases and helper functions for nda::basic_array and nda::basic_array_vi...
Provides a generic interface to the BLAS gemm routine.
Provides a generic interface to the BLAS gemv routine.
decltype(auto) make_regular(A &&a)
Make a given object regular.
std::decay_t< decltype(get_first_element(std::declval< A const >()))> get_value_t
Get the value type of an array/view or a scalar type.
Definition traits.hpp:192
constexpr layout_info_t get_layout_info
Constexpr variable that specifies the nda::layout_info_t of type A.
Definition traits.hpp:321
void gemm_generic(typename A::value_type alpha, A const &a, B const &b, typename A::value_type beta, C &&c)
Generic nda::blas::gemm implementation for types not supported by BLAS/LAPACK.
Definition gemm.hpp:59
void gemv_generic(get_value_t< A > alpha, A const &a, X const &x, get_value_t< A > beta, Y &&y)
Generic nda::blas::gemv implementation for types not supported by BLAS/LAPACK.
Definition gemv.hpp:57
static constexpr bool is_conj_array_expr
Constexpr variable that is true if the given type is a conjugate lazy expression.
Definition tools.hpp:52
void gemv(get_value_t< A > alpha, A const &a, X const &x, get_value_t< A > beta, Y &&y)
Interface to the BLAS gemv routine.
Definition gemv.hpp:89
static constexpr bool has_F_layout
Constexpr variable that is true if the given nda::Array type has a Fortran memory layout.
Definition tools.hpp:66
void gemm(get_value_t< A > alpha, A const &a, B const &b, get_value_t< A > beta, C &&c)
Interface to the BLAS gemm routine.
Definition gemm.hpp:101
auto matvecmul(A &&a, X &&x)
Perform a matrix-vector multiplication.
Definition matmul.hpp:152
auto matmul(A &&a, B &&b)
Perform a matrix-matrix multiplication.
Definition matmul.hpp:87
static constexpr AddressSpace get_addr_space
Variable template providing the address space for different types.
static constexpr AddressSpace get_addr_space< A >
Specialization of nda::mem::get_addr_space for nda::Memory Array types.
static const auto check_adr_sp_valid
Check validity of a set of nda::mem::AddressSpace values.
constexpr bool is_blas_lapack_v
Alias for nda::is_double_or_complex_v.
Definition traits.hpp:102
Provides definitions of various layout policies.
Defines various memory handling policies.
Contiguous layout policy with C-order (row-major order).
Definition policies.hpp:47
Memory policy using an nda::mem::handle_heap.
Definition policies.hpp:44
uint64_t stride_order
Stride order of the array/view.
Definition traits.hpp:297
Provides various traits and utilities for the BLAS interface.
Provides type traits for the nda library.