TRIQS/nda 2.0.0
Multi-dimensional array library for C++
Loading...
Searching...
No Matches
matmul.hpp
Go to the documentation of this file.
1// Copyright (c) 2019--present, The Simons Foundation
2// This file is part of TRIQS/nda and is licensed under the Apache License, Version 2.0.
3// SPDX-License-Identifier: Apache-2.0
4// See LICENSE in the root of this distribution for details.
5
10
11#pragma once
12
13#include "../basic_array.hpp"
15#include "../blas/gemm.hpp"
16#include "../blas/tools.hpp"
17#include "../concepts.hpp"
18#include "../declarations.hpp"
21#include "../mem/policies.hpp"
22#include "../traits.hpp"
23
24#include <type_traits>
25#include <utility>
26
27namespace nda::linalg {
28
33
34 namespace detail {
35
36 // Generic matrix-matrix multiplication for types not supported by BLAS.
37 template <Matrix A, Matrix B, MemoryMatrix C>
39 void gemm_generic(auto alpha, A const &a, B const &b, auto beta, C &&c) { // NOLINT (temporary views are allowed here)
40 // check the dimensions of the input/output arrays/views
41 auto const [m, k] = a.shape();
42 auto const [l, n] = b.shape();
43 EXPECTS(k == l);
44 EXPECTS(m == c.extent(0));
45 EXPECTS(n == c.extent(1));
46
47 // perform the matrix-matrix multiplication
48 for (int i = 0; i < m; ++i) {
49 for (int j = 0; j < n; ++j) {
50 c(i, j) = beta * c(i, j);
51 for (int r = 0; r < k; ++r) c(i, j) += alpha * a(i, r) * b(r, j);
52 }
53 }
54 }
55
56 // Make compile time checks if blas::gemm can handle the given input matrix. If it can, simply forward the matrix.
57 // Otherwise, return a copy with the given value type T, layout policy LP and container policy CP.
58 template <typename T, typename LP, typename CP, MemoryMatrix C, Matrix A>
59 decltype(auto) get_gemm_matrix(A &&a) {
60 using namespace blas_lapack;
61 if constexpr (requires { get_array(a); } and std::is_same_v<get_value_t<A>, T>) {
62 if constexpr (MemoryMatrix<A>
63 or (is_conj_array_expr<A> and ((has_F_layout<C> and has_C_layout<A>) or (has_C_layout<C> and has_F_layout<A>)))) {
64 return std::forward<A>(a);
65 } else {
66 return matrix<T, LP, CP>{a};
67 }
68 } else {
69 return matrix<T, LP, CP>{a};
70 }
71 }
72
73 // Make the call to nda::blas::gemm (with copies of the matrices if they are not contiguous).
74 template <Matrix A, Matrix B, MemoryMatrix C>
75 void make_gemm_call(A const &a, B const &b, C &c) {
76 if (blas_lapack::get_array(a).is_contiguous()) {
77 if (blas_lapack::get_array(b).is_contiguous()) {
78 blas::gemm(1, a, b, 0, c);
79 } else {
80 blas::gemm(1, a, make_regular(b), 0, c);
81 }
82 } else {
83 if (blas_lapack::get_array(b).is_contiguous()) {
84 blas::gemm(1, make_regular(a), b, 0, c);
85 } else {
86 blas::gemm(1, make_regular(a), make_regular(b), 0, c);
87 }
88 }
89 }
90
91 // Get the layout policy for a given array type.
92 template <Array A>
93 using get_layout_policy = typename std::remove_cvref_t<decltype(make_regular(std::declval<A>()))>::layout_policy_t;
94
95 } // namespace detail
96
135 template <Matrix A, Matrix B>
137 auto matmul(A &&a, B &&b) { // NOLINT (temporary views are allowed here)
138 // get the return type
139 using value_t = decltype(a(0, 0) * b(0, 0));
140 using layout_pol = std::conditional_t<get_layout_info<A>.stride_order == get_layout_info<B>.stride_order, detail::get_layout_policy<A>, C_layout>;
141 using cont_pol = heap<mem::common_addr_space<A, B>>;
143
144 // result matrix (MSAN complains if it is not initialized)
145 auto res = return_t(a.shape()[0], b.shape()[1]);
146#if defined(__has_feature)
147#if __has_feature(memory_sanitizer)
148 res = 0;
149#endif
150#endif
151
152 // perform matrix-matrix multiplication (if possible we try to call blas::gemm even if this requires making copies)
153 if constexpr (is_blas_lapack_v<value_t>) {
154 // check at compile time if we need to make a copy of the input matrices
155 auto &&a_mat = detail::get_gemm_matrix<value_t, layout_pol, cont_pol, return_t>(a);
156 auto &&b_mat = detail::get_gemm_matrix<value_t, layout_pol, cont_pol, return_t>(b);
157
158 // check at runtime if the input matrices are contiguous, make copies if not and call blas::gemm
159 detail::make_gemm_call(a_mat, b_mat, res);
160 } else {
161 detail::gemm_generic(1, a, b, 0, res);
162 }
163 return res;
164 }
165
167
168} // namespace nda::linalg
Provides definitions and type traits involving the different memory address spaces supported by nda.
Provides the generic class for arrays.
Provides basic functions to create and manipulate arrays and views.
Provides various traits and utilities for the BLAS interface.
Provides concepts for the nda library.
Provides various convenient aliases and helper functions for nda::basic_array and nda::basic_array_vi...
Provides a generic interface to the BLAS/cuBLAS gemm routine.
decltype(auto) make_regular(A &&a)
Make a given object regular.
basic_array< ValueType, 2, Layout, 'M', ContainerPolicy > matrix
Alias template of an nda::basic_array with rank 2 and an 'M' algebra.
constexpr layout_info_t get_layout_info
Constexpr variable that specifies the nda::layout_info_t of type A.
Definition traits.hpp:350
MemoryArray decltype(auto) get_array(A &&a)
Get the underlying array of a conjugate lazy expression or return the array itself in case it is an n...
Definition tools.hpp:68
void gemm(get_value_t< A > alpha, A const &a, B const &b, get_value_t< A > beta, C &&c)
Interface to the BLAS/cuBLAS gemm routine.
Definition gemm.hpp:57
auto matmul(A &&a, B &&b)
Compute the matrix-matrix product of two nda::Matrix objects.
Definition matmul.hpp:137
static constexpr bool have_host_compatible_addr_space
Constexpr variable that is true if all given types have an address space compatible with Host.
static constexpr bool have_compatible_addr_space
Constexpr variable that is true if all given types have compatible address spaces.
heap_basic< mem::mallocator< AdrSp > > heap
Alias template of the nda::heap_basic policy using an nda::mem::mallocator.
Definition policies.hpp:52
constexpr bool is_blas_lapack_v
Constexpr variable that is true if type T is either of type 'float', double, std::complex<float>' or ...
Definition traits.hpp:95
Provides definitions of various layout policies.
Defines various memory handling policies.
Contiguous layout policy with C-order (row-major order).
Definition policies.hpp:36
Provides type traits for the nda library.