55 template <Matrix A, Matrix B, MemoryMatrix C,
bool conj_A = blas::is_conj_array_expr<A>,
bool conj_B = blas::is_conj_array_expr<B>>
89 EXPECTS_WITH_MESSAGE(a.shape()[1] == b.shape()[0],
"Error in nda::matmul: Dimension mismatch in matrix-matrix product");
92 static constexpr auto L_adr_spc = mem::get_addr_space<A>;
93 static constexpr auto R_adr_spc = mem::get_addr_space<B>;
94 mem::check_adr_sp_valid<L_adr_spc, R_adr_spc>();
97 using value_t =
decltype(get_value_t<A>{} * get_value_t<B>{});
99 std::conditional_t<get_layout_info<A>.stride_order == get_layout_info<B>.stride_order, detail::get_layout_policy<A>,
C_layout>;
100 using matrix_t =
basic_array<value_t, 2, layout_policy,
'M', nda::heap<mem::combine<L_adr_spc, R_adr_spc>>>;
103 auto result = matrix_t(a.shape()[0], b.shape()[1]);
104 if constexpr (is_blas_lapack_v<value_t>) {
107 auto as_container = []<Matrix M>(M &&m) ->
decltype(
auto) {
108 if constexpr (std::is_same_v<get_value_t<M>, value_t>
and (MemoryMatrix<M>
or blas::is_conj_array_expr<M>))
109 return std::forward<M>(m);
111 return matrix_t{std::forward<M>(m)};
116#if defined(__has_feature
)
117#if __has_feature
(memory_sanitizer)
123 if constexpr (detail::is_valid_gemm_triple<
decltype(as_container(a)),
decltype(as_container(b)), matrix_t>) {
124 blas::gemm(1, as_container(a), as_container(b), 0, result);
127 blas::gemm(1, make_regular(as_container(a)), make_regular(as_container(b)), 0, result);
132 blas::gemm_generic(1, a, b, 0, result);
154 EXPECTS_WITH_MESSAGE(a.shape()[1] == x.shape()[0],
"Error in nda::matvecmul: Dimension mismatch in matrix-vector product");
157 static constexpr auto L_adr_spc = mem::get_addr_space<A>;
158 static constexpr auto R_adr_spc = mem::get_addr_space<X>;
159 static_assert(L_adr_spc == R_adr_spc,
"Error in nda::matvecmul: Matrix-vector product requires arguments with same address spaces");
160 static_assert(L_adr_spc != mem::None);
163 using value_t =
decltype(get_value_t<A>{} * get_value_t<X>{});
164 using vector_t = vector<value_t, heap<L_adr_spc>>;
167 auto result = vector_t(a.shape()[0]);
168 if constexpr (is_blas_lapack_v<value_t>) {
171 auto as_container = []<Array B>(B &&b) ->
decltype(
auto) {
172 if constexpr (std::is_same_v<get_value_t<B>, value_t>
and (MemoryMatrix<B>
or (Matrix<B>
and blas::is_conj_array_expr<B>)))
173 return std::forward<B>(b);
175 return basic_array<value_t, get_rank<B>,
C_layout,
'A', heap<L_adr_spc>>{std::forward<B>(b)};
180#if defined(__has_feature
)
181#if __has_feature
(memory_sanitizer)
188 if constexpr (blas::is_conj_array_expr<
decltype(as_container(a))>
and blas::has_F_layout<
decltype(as_container(a))>) {
189 blas::gemv(1, make_regular(as_container(a)), as_container(x), 0, result);
191 blas::gemv(1, as_container(a), as_container(x), 0, result);
195 blas::gemv_generic(1, a, x, 0, result);