TRIQS/nda 1.3.0
Multi-dimensional array library for C++
Loading...
Searching...
No Matches
matmul.hpp
Go to the documentation of this file.
1// Copyright (c) 2019--present, The Simons Foundation
2// This file is part of TRIQS/nda and is licensed under the Apache License, Version 2.0.
3// SPDX-License-Identifier: Apache-2.0
4// See LICENSE in the root of this distribution for details.
5
10
11#pragma once
12
13#include "../basic_array.hpp"
15#include "../blas/gemm.hpp"
16#include "../blas/tools.hpp"
17#include "../concepts.hpp"
18#include "../declarations.hpp"
21#include "../mem/policies.hpp"
22#include "../traits.hpp"
23
24#include <type_traits>
25#include <utility>
26
27namespace nda::linalg {
28
33
34 namespace detail {
35
37 template <Matrix A, Matrix B, MemoryMatrix C>
39 void gemm_generic(auto alpha, A const &a, B const &b, auto beta, C &&c) { // NOLINT (temporary views are allowed here)
40 // check the dimensions of the input/output arrays/views
41 auto const [m, k] = a.shape();
42 auto const [l, n] = b.shape();
43 EXPECTS(k == l);
44 EXPECTS(m == c.extent(0));
45 EXPECTS(n == c.extent(1));
46
47 // perform the matrix-matrix multiplication
48 for (int i = 0; i < m; ++i) {
49 for (int j = 0; j < n; ++j) {
50 c(i, j) = beta * c(i, j);
51 for (int r = 0; r < k; ++r) c(i, j) += alpha * a(i, r) * b(r, j);
52 }
53 }
54 }
55
56 // Make compile time checks if blas::gemm can handle the given input matrix. If it can, simply forward the matrix.
57 // Otherwise, return a copy with the given value type T, layout policy LP and container policy CP.
58 template <typename T, typename LP, typename CP, MemoryMatrix C, Matrix A>
59 decltype(auto) get_gemm_matrix(A &&a) {
60 if constexpr (requires { blas::get_array(a); } and std::is_same_v<get_value_t<A>, T>) {
61 if constexpr (MemoryMatrix<A>
64 return std::forward<A>(a);
65 } else {
66 return matrix<T, LP, CP>{a};
67 }
68 } else {
69 return matrix<T, LP, CP>{a};
70 }
71 }
72
73 // Make the call to nda::blas::gemm (with copies of the matrices if they are not contiguous).
74 template <Matrix A, Matrix B, MemoryMatrix C>
75 void make_gemm_call(A const &a, B const &b, C &c) {
76 if (blas::get_array(a).is_contiguous()) {
77 if (blas::get_array(b).is_contiguous()) {
78 blas::gemm(1, a, b, 0, c);
79 } else {
80 blas::gemm(1, a, nda::make_regular(b), 0, c);
81 }
82 } else {
83 if (blas::get_array(b).is_contiguous()) {
84 blas::gemm(1, nda::make_regular(a), b, 0, c);
85 } else {
87 }
88 }
89 }
90
91 // Get the layout policy for a given array type.
92 template <Array A>
93 using get_layout_policy = typename std::remove_cvref_t<decltype(make_regular(std::declval<A>()))>::layout_policy_t;
94
95 } // namespace detail
96
127 template <Matrix A, Matrix B>
128 auto matmul(A &&a, B &&b) { // NOLINT (temporary views are allowed here)
129 // get the return type
130 using value_t = decltype(a(0, 0) * b(0, 0));
131 using layout_pol = std::conditional_t<get_layout_info<A>.stride_order == get_layout_info<B>.stride_order, detail::get_layout_policy<A>, C_layout>;
132 using cont_pol = heap<mem::common_addr_space<A, B>>;
134
135 // result matrix (MSAN complains if it is not initialized)
136 auto res = return_t(a.shape()[0], b.shape()[1]);
137#if defined(__has_feature)
138#if __has_feature(memory_sanitizer)
139 res = 0;
140#endif
141#endif
142
143 // perform matrix-matrix multiplication (if possible we try to call blas::gemm even if this requires making copies)
144 if constexpr (is_blas_lapack_v<value_t>) {
145 // check at compile time if we need to make a copy of the input matrices
146 auto &&a_mat = detail::get_gemm_matrix<value_t, layout_pol, cont_pol, return_t>(a);
147 auto &&b_mat = detail::get_gemm_matrix<value_t, layout_pol, cont_pol, return_t>(b);
148
149 // check at runtime if the input matrices are contiguous, make copies if not and call blas::gemm
150 detail::make_gemm_call(a_mat, b_mat, res);
151 } else {
152 detail::gemm_generic(1, a, b, 0, res);
153 }
154 return res;
155 }
156
158
159} // namespace nda::linalg
Provides definitions and type traits involving the different memory address spaces supported by nda.
Provides the generic class for arrays.
Provides basic functions to create and manipulate arrays and views.
Check if a given type is a memory matrix, i.e. an nda::MemoryArrayOfRank<2>.
Definition concepts.hpp:281
Provides concepts for the nda library.
Provides various convenient aliases and helper functions for nda::basic_array and nda::basic_array_vi...
Provides a generic interface to the BLAS gemm routine.
decltype(auto) make_regular(A &&a)
Make a given object regular.
basic_array< ValueType, 2, Layout, 'M', ContainerPolicy > matrix
Alias template of an nda::basic_array with rank 2 and an 'M' algebra.
constexpr layout_info_t get_layout_info
Constexpr variable that specifies the nda::layout_info_t of type A.
Definition traits.hpp:311
static constexpr bool has_C_layout
Constexpr variable that is true if the given nda::Array type has nda::C_layout.
Definition tools.hpp:83
static constexpr bool is_conj_array_expr
Constexpr variable that is true if the given type is a conjugate lazy expression.
Definition tools.hpp:41
static constexpr bool has_F_layout
Constexpr variable that is true if the given nda::Array type has nda::F_layout.
Definition tools.hpp:73
MemoryArray decltype(auto) get_array(A &&a)
Get the underlying array of a conjugate lazy expression or return the array itself in case it is an n...
Definition tools.hpp:62
void gemm(get_value_t< A > alpha, A const &a, B const &b, get_value_t< A > beta, C &&c)
Interface to the BLAS gemm routine.
Definition gemm.hpp:63
auto matmul(A &&a, B &&b)
Compute the matrix-matrix product of two nda::matrix objects.
Definition matmul.hpp:128
static constexpr bool have_host_compatible_addr_space
Constexpr variable that is true if all given types have an address space compatible with Host.
heap_basic< mem::mallocator< AdrSp > > heap
Alias template of the nda::heap_basic policy using an nda::mem::mallocator.
Definition policies.hpp:52
constexpr bool is_blas_lapack_v
Alias for nda::is_double_or_complex_v.
Definition traits.hpp:92
Provides definitions of various layout policies.
void gemm_generic(auto alpha, A const &a, B const &b, auto beta, C &&c)
Generic matrix-matrix multiplication for types not supported by BLAS.
Definition matmul.hpp:39
Defines various memory handling policies.
Contiguous layout policy with C-order (row-major order).
Definition policies.hpp:36
Provides various traits and utilities for the BLAS interface.
Provides type traits for the nda library.