27namespace nda::linalg {
37 template <Matrix A, Matrix B, MemoryMatrix C>
39 void gemm_generic(
auto alpha, A
const &a, B
const &b,
auto beta, C &&c) {
41 auto const [m, k] = a.shape();
42 auto const [l, n] = b.shape();
44 EXPECTS(m == c.extent(0));
45 EXPECTS(n == c.extent(1));
48 for (
int i = 0; i < m; ++i) {
49 for (
int j = 0; j < n; ++j) {
50 c(i, j) = beta * c(i, j);
51 for (
int r = 0; r < k; ++r) c(i, j) += alpha * a(i, r) * b(r, j);
58 template <
typename T,
typename LP,
typename CP, MemoryMatrix C, Matrix A>
59 decltype(
auto) get_gemm_matrix(A &&a) {
60 using namespace blas_lapack;
61 if constexpr (
requires {
get_array(a); } and std::is_same_v<get_value_t<A>, T>) {
62 if constexpr (MemoryMatrix<A>
63 or (is_conj_array_expr<A> and ((has_F_layout<C> and has_C_layout<A>) or (has_C_layout<C> and has_F_layout<A>)))) {
64 return std::forward<A>(a);
74 template <Matrix A, Matrix B, MemoryMatrix C>
75 void make_gemm_call(A
const &a, B
const &b, C &c) {
93 using get_layout_policy =
typename std::remove_cvref_t<decltype(make_regular(std::declval<A>()))>::layout_policy_t;
135 template <Matrix A, Matrix B>
139 using value_t =
decltype(a(0, 0) * b(0, 0));
140 using layout_pol = std::conditional_t<get_layout_info<A>.stride_order ==
get_layout_info<B>.stride_order, detail::get_layout_policy<A>,
C_layout>;
145 auto res = return_t(a.shape()[0], b.shape()[1]);
146#if defined(__has_feature)
147#if __has_feature(memory_sanitizer)
155 auto &&a_mat = detail::get_gemm_matrix<value_t, layout_pol, cont_pol, return_t>(a);
156 auto &&b_mat = detail::get_gemm_matrix<value_t, layout_pol, cont_pol, return_t>(b);
159 detail::make_gemm_call(a_mat, b_mat, res);
161 detail::gemm_generic(1, a, b, 0, res);
Provides definitions and type traits involving the different memory address spaces supported by nda.
Provides the generic class for arrays.
Provides basic functions to create and manipulate arrays and views.
Provides concepts for the nda library.
Provides various convenient aliases and helper functions for nda::basic_array and nda::basic_array_vi...
Provides a generic interface to the BLAS/cuBLAS gemm routine.
decltype(auto) make_regular(A &&a)
Make a given object regular.
basic_array< ValueType, 2, Layout, 'M', ContainerPolicy > matrix
Alias template of an nda::basic_array with rank 2 and an 'M' algebra.
constexpr layout_info_t get_layout_info
Constexpr variable that specifies the nda::layout_info_t of type A.
MemoryArray decltype(auto) get_array(A &&a)
Get the underlying array of a conjugate lazy expression or return the array itself in case it is an n...
void gemm(get_value_t< A > alpha, A const &a, B const &b, get_value_t< A > beta, C &&c)
Interface to the BLAS/cuBLAS gemm routine.
auto matmul(A &&a, B &&b)
Compute the matrix-matrix product of two nda::Matrix objects.
static constexpr bool have_host_compatible_addr_space
Constexpr variable that is true if all given types have an address space compatible with Host.
static constexpr bool have_compatible_addr_space
Constexpr variable that is true if all given types have compatible address spaces.
heap_basic< mem::mallocator< AdrSp > > heap
Alias template of the nda::heap_basic policy using an nda::mem::mallocator.
constexpr bool is_blas_lapack_v
Constexpr variable that is true if type T is either of type 'float', double, std::complex<float>' or ...
Provides definitions of various layout policies.
Defines various memory handling policies.
Contiguous layout policy with C-order (row-major order).
Provides type traits for the nda library.