TRIQS/nda 1.3.0
Multi-dimensional array library for C++
Loading...
Searching...
No Matches
matmul.hpp
Go to the documentation of this file.
1// Copyright (c) 2019--present, The Simons Foundation
2// This file is part of TRIQS/nda and is licensed under the Apache License, Version 2.0.
3// SPDX-License-Identifier: Apache-2.0
4// See LICENSE in the root of this distribution for details.
5
10
11#pragma once
12
14#include "../blas/gemm.hpp"
15#include "../blas/gemv.hpp"
16#include "../blas/tools.hpp"
17#include "../concepts.hpp"
18#include "../declarations.hpp"
21#include "../mem/policies.hpp"
22#include "../traits.hpp"
23
24#include <type_traits>
25#include <utility>
26
27namespace nda {
28
33
34 namespace detail {
35
36 // Helper variable template to check if the three matrix types can be passed to gemm.
37 // The following combinations are allowed (gemm can only be called with 'N', 'T' or 'C' op tags):
38 // - C in Fortran layout:
39 // -- A/B is not a conj expression and has Fortran layout
40 // -- A/B is a conj expression and has C layout
41 // - C in C layout:
42 // -- A/B is not a conj expression and has C layout
43 // -- A/B is a conj expression and has Fortran layout
44 template <Matrix A, Matrix B, MemoryMatrix C, bool conj_A = blas::is_conj_array_expr<A>, bool conj_B = blas::is_conj_array_expr<B>>
45 requires((MemoryMatrix<A> or conj_A) and (MemoryMatrix<B> or conj_B))
46 static constexpr bool is_valid_gemm_triple = []() {
48 if constexpr (has_F_layout<C>) {
49 return !(conj_A and has_F_layout<A>)and!(conj_B and has_F_layout<B>);
50 } else {
51 return !(conj_B and !has_F_layout<B>)and!(conj_A and !has_F_layout<A>);
52 }
53 }();
54
55 // Get the layout policy for a given array type.
56 template <Array A>
57 using get_layout_policy = typename std::remove_reference_t<decltype(make_regular(std::declval<A>()))>::layout_policy_t;
58
59 } // namespace detail
60
75 template <Matrix A, Matrix B>
76 auto matmul(A &&a, B &&b) { // NOLINT (temporary views are allowed here)
77 // check dimensions
78 EXPECTS_WITH_MESSAGE(a.shape()[1] == b.shape()[0], "Error in nda::matmul: Dimension mismatch in matrix-matrix product");
79
80 // check address space compatibility
81 static constexpr auto L_adr_spc = mem::get_addr_space<A>;
82 static constexpr auto R_adr_spc = mem::get_addr_space<B>;
84
85 // get resulting value type, layout policy and matrix type
86 using value_t = decltype(get_value_t<A>{} * get_value_t<B>{});
87 using layout_policy =
88 std::conditional_t<get_layout_info<A>.stride_order == get_layout_info<B>.stride_order, detail::get_layout_policy<A>, C_layout>;
89 using matrix_t = basic_array<value_t, 2, layout_policy, 'M', nda::heap<mem::combine<L_adr_spc, R_adr_spc>>>;
90
91 // perform matrix-matrix multiplication
92 auto result = matrix_t(a.shape()[0], b.shape()[1]);
93 if constexpr (is_blas_lapack_v<value_t>) {
94 // for double or complex value types we use blas::gemm
95 // lambda to form a new matrix with the correct value type if necessary
96 auto as_container = []<Matrix M>(M &&m) -> decltype(auto) {
97 if constexpr (std::is_same_v<get_value_t<M>, value_t> and (MemoryMatrix<M> or blas::is_conj_array_expr<M>))
98 return std::forward<M>(m);
99 else
100 return matrix_t{std::forward<M>(m)};
101 };
102
103 // MSAN has no way to know that we are calling with beta = 0, hence this is not necessary.
104 // Of course, in production code, we do NOT waste time to do this.
105#if defined(__has_feature)
106#if __has_feature(memory_sanitizer)
107 result = 0;
108#endif
109#endif
110
111 // check if we can call gemm directly
112 if constexpr (detail::is_valid_gemm_triple<decltype(as_container(a)), decltype(as_container(b)), matrix_t>) {
113 blas::gemm(1, as_container(a), as_container(b), 0, result);
114 } else {
115 // otherwise, turn the lhs and rhs first into regular matrices and then call gemm
116 blas::gemm(1, make_regular(as_container(a)), make_regular(as_container(b)), 0, result);
117 }
118
119 } else {
120 // for other value types we use a generic implementation
121 blas::gemm_generic(1, a, b, 0, result);
122 }
123 return result;
124 }
125
140 template <Matrix A, Vector X>
141 auto matvecmul(A &&a, X &&x) { // NOLINT (temporary views are allowed here)
142 // check dimensions
143 EXPECTS_WITH_MESSAGE(a.shape()[1] == x.shape()[0], "Error in nda::matvecmul: Dimension mismatch in matrix-vector product");
144
145 // check address space compatibility
146 static constexpr auto L_adr_spc = mem::get_addr_space<A>;
147 static constexpr auto R_adr_spc = mem::get_addr_space<X>;
148 static_assert(L_adr_spc == R_adr_spc, "Error in nda::matvecmul: Matrix-vector product requires arguments with same address spaces");
149 static_assert(L_adr_spc != mem::None);
150
151 // get resulting value type and vector type
152 using value_t = decltype(get_value_t<A>{} * get_value_t<X>{});
153 using vector_t = vector<value_t, heap<L_adr_spc>>;
154
155 // perform matrix-matrix multiplication
156 auto result = vector_t(a.shape()[0]);
157 if constexpr (is_blas_lapack_v<value_t>) {
158 // for double or complex value types we use blas::gemv
159 // lambda to form a new array with the correct value type if necessary
160 auto as_container = []<Array B>(B &&b) -> decltype(auto) {
161 if constexpr (std::is_same_v<get_value_t<B>, value_t> and (MemoryMatrix<B> or (Matrix<B> and blas::is_conj_array_expr<B>)))
162 return std::forward<B>(b);
163 else
164 return basic_array<value_t, get_rank<B>, C_layout, 'A', heap<L_adr_spc>>{std::forward<B>(b)};
165 };
166
167 // MSAN has no way to know that we are calling with beta = 0, hence this is not necessary.
168 // Of course, in production code, we do NOT waste time to do this.
169#if defined(__has_feature)
170#if __has_feature(memory_sanitizer)
171 result = 0;
172#endif
173#endif
174
175 // for expressions of the kind 'conj(M) * V' with a Matrix in Fortran Layout, we have to explicitly
176 // form the conj operation in memory as gemv only provides op tags 'N', 'T' and 'C' (hermitian conjugate)
177 if constexpr (blas::is_conj_array_expr<decltype(as_container(a))> and blas::has_F_layout<decltype(as_container(a))>) {
178 blas::gemv(1, make_regular(as_container(a)), as_container(x), 0, result);
179 } else {
180 blas::gemv(1, as_container(a), as_container(x), 0, result);
181 }
182 } else {
183 // for other value types we use a generic implementation
184 blas::gemv_generic(1, a, x, 0, result);
185 }
186 return result;
187 }
188
190
191} // namespace nda
Provides definitions and type traits involving the different memory address spaces supported by nda.
Provides basic functions to create and manipulate arrays and views.
A generic multi-dimensional array.
Check if a given type satisfies the array concept.
Definition concepts.hpp:230
Check if a given type is a matrix, i.e. an nda::ArrayOfRank<2>.
Definition concepts.hpp:290
Check if a given type is a memory matrix, i.e. an nda::MemoryArrayOfRank<2>.
Definition concepts.hpp:306
Provides concepts for the nda library.
Provides various convenient aliases and helper functions for nda::basic_array and nda::basic_array_vi...
Provides a generic interface to the BLAS gemm routine.
Provides a generic interface to the BLAS gemv routine.
decltype(auto) make_regular(A &&a)
Make a given object regular.
basic_array< ValueType, 1, C_layout, 'V', ContainerPolicy > vector
Alias template of an nda::basic_array with rank 1 and a 'V' algebra.
std::decay_t< decltype(get_first_element(std::declval< A const >()))> get_value_t
Get the value type of an array/view or a scalar type.
Definition traits.hpp:181
constexpr layout_info_t get_layout_info
Constexpr variable that specifies the nda::layout_info_t of type A.
Definition traits.hpp:310
void gemm_generic(typename A::value_type alpha, A const &a, B const &b, typename A::value_type beta, C &&c)
Generic nda::blas::gemm implementation for types not supported by BLAS/LAPACK.
Definition gemm.hpp:48
void gemv_generic(get_value_t< A > alpha, A const &a, X const &x, get_value_t< A > beta, Y &&y)
Generic nda::blas::gemv implementation for types not supported by BLAS/LAPACK.
Definition gemv.hpp:46
static constexpr bool is_conj_array_expr
Constexpr variable that is true if the given type is a conjugate lazy expression.
Definition tools.hpp:41
void gemv(get_value_t< A > alpha, A const &a, X const &x, get_value_t< A > beta, Y &&y)
Interface to the BLAS gemv routine.
Definition gemv.hpp:78
static constexpr bool has_F_layout
Constexpr variable that is true if the given nda::Array type has a Fortran memory layout.
Definition tools.hpp:55
void gemm(get_value_t< A > alpha, A const &a, B const &b, get_value_t< A > beta, C &&c)
Interface to the BLAS gemm routine.
Definition gemm.hpp:90
auto matvecmul(A &&a, X &&x)
Perform a matrix-vector multiplication.
Definition matmul.hpp:141
auto matmul(A &&a, B &&b)
Perform a matrix-matrix multiplication.
Definition matmul.hpp:76
static constexpr AddressSpace get_addr_space
Variable template providing the address space for different types.
static const auto check_adr_sp_valid
Check validity of a set of nda::mem::AddressSpace values.
heap_basic< mem::mallocator< AdrSp > > heap
Alias template of the nda::heap_basic policy using an nda::mem::mallocator.
Definition policies.hpp:52
constexpr bool is_blas_lapack_v
Alias for nda::is_double_or_complex_v.
Definition traits.hpp:91
Provides definitions of various layout policies.
Defines various memory handling policies.
Contiguous layout policy with C-order (row-major order).
Definition policies.hpp:36
Provides various traits and utilities for the BLAS interface.
Provides type traits for the nda library.