27namespace nda::linalg {
37 template <Matrix A, Matrix B, MemoryMatrix C>
39 void gemm_generic(
auto alpha, A
const &a, B
const &b,
auto beta, C &&c) {
41 auto const [m, k] = a.shape();
42 auto const [l, n] = b.shape();
44 EXPECTS(m == c.extent(0));
45 EXPECTS(n == c.extent(1));
48 for (
int i = 0; i < m; ++i) {
49 for (
int j = 0; j < n; ++j) {
50 c(i, j) = beta * c(i, j);
51 for (
int r = 0; r < k; ++r) c(i, j) += alpha * a(i, r) * b(r, j);
58 template <
typename T,
typename LP,
typename CP, MemoryMatrix C, Matrix A>
59 decltype(
auto) get_gemm_matrix(A &&a) {
60 if constexpr (
requires {
blas::get_array(a); } and std::is_same_v<get_value_t<A>, T>) {
64 return std::forward<A>(a);
74 template <Matrix A, Matrix B, MemoryMatrix C>
75 void make_gemm_call(A
const &a, B
const &b, C &c) {
93 using get_layout_policy =
typename std::remove_cvref_t<decltype(make_regular(std::declval<A>()))>::layout_policy_t;
127 template <Matrix A, Matrix B>
130 using value_t =
decltype(a(0, 0) * b(0, 0));
131 using layout_pol = std::conditional_t<get_layout_info<A>.stride_order ==
get_layout_info<B>.stride_order, detail::get_layout_policy<A>,
C_layout>;
136 auto res = return_t(a.shape()[0], b.shape()[1]);
137#if defined(__has_feature)
138#if __has_feature(memory_sanitizer)
146 auto &&a_mat = detail::get_gemm_matrix<value_t, layout_pol, cont_pol, return_t>(a);
147 auto &&b_mat = detail::get_gemm_matrix<value_t, layout_pol, cont_pol, return_t>(b);
150 detail::make_gemm_call(a_mat, b_mat, res);
Provides definitions and type traits involving the different memory address spaces supported by nda.
Provides the generic class for arrays.
Provides basic functions to create and manipulate arrays and views.
Check if a given type is a memory matrix, i.e. an nda::MemoryArrayOfRank<2>.
Provides concepts for the nda library.
Provides various convenient aliases and helper functions for nda::basic_array and nda::basic_array_vi...
Provides a generic interface to the BLAS gemm routine.
decltype(auto) make_regular(A &&a)
Make a given object regular.
basic_array< ValueType, 2, Layout, 'M', ContainerPolicy > matrix
Alias template of an nda::basic_array with rank 2 and an 'M' algebra.
constexpr layout_info_t get_layout_info
Constexpr variable that specifies the nda::layout_info_t of type A.
static constexpr bool has_C_layout
Constexpr variable that is true if the given nda::Array type has nda::C_layout.
static constexpr bool is_conj_array_expr
Constexpr variable that is true if the given type is a conjugate lazy expression.
static constexpr bool has_F_layout
Constexpr variable that is true if the given nda::Array type has nda::F_layout.
MemoryArray decltype(auto) get_array(A &&a)
Get the underlying array of a conjugate lazy expression or return the array itself in case it is an n...
void gemm(get_value_t< A > alpha, A const &a, B const &b, get_value_t< A > beta, C &&c)
Interface to the BLAS gemm routine.
static constexpr bool have_host_compatible_addr_space
Constexpr variable that is true if all given types have an address space compatible with Host.
heap_basic< mem::mallocator< AdrSp > > heap
Alias template of the nda::heap_basic policy using an nda::mem::mallocator.
constexpr bool is_blas_lapack_v
Alias for nda::is_double_or_complex_v.
Provides definitions of various layout policies.
void gemm_generic(auto alpha, A const &a, B const &b, auto beta, C &&c)
Generic matrix-matrix multiplication for types not supported by BLAS.
Defines various memory handling policies.
Contiguous layout policy with C-order (row-major order).
Provides type traits for the nda library.