40 auto get_transpose_vector(
auto &&v) {
41 auto v_t = std::vector<std::decay_t<
decltype(
transpose(v[0]))>>{};
42 v_t.reserve(v.size());
43 std::transform(v.begin(), v.end(), std::back_inserter(v_t), [](
auto &x) { return transpose(x); });
48 template <
bool is_vbatch, nda::mem::AddressSpace vec_addr_spc>
49 auto get_ptr_vector(
auto &&v) {
50 EXPECTS(std::ranges::all_of(v, [&v](
auto &A) {
return is_vbatch or A.shape() == v[0].shape(); }));
51 EXPECTS(std::ranges::all_of(v, [](
auto &A) {
return get_array(A).indexmap().min_stride() == 1; }));
54 std::transform(v.begin(), v.end(), v_ptrs.begin(), [](
auto &z) { return get_array(z).data(); });
100 template <
bool is_vbatch = false, BlasArrayOrConj<2> A, BlasArrayOrConjFor<A, 2> B, BlasArrayFor<A, 2> C>
102 auto const n_b = va.size();
105 EXPECTS(n_b == vb.size() and n_b == vc.size());
106 if (va.empty())
return;
110 auto vcT = detail::get_transpose_vector(vc);
111 return gemm_batch<is_vbatch>(alpha, detail::get_transpose_vector(vb), detail::get_transpose_vector(va), beta, vcT);
114 auto constexpr vec_addr_spc = []() {
return mem::on_host<C> ? mem::Host : mem::Unified; }();
117 auto a_ptrs = detail::get_ptr_vector<is_vbatch, vec_addr_spc>(va);
118 auto b_ptrs = detail::get_ptr_vector<is_vbatch, vec_addr_spc>(vb);
119 auto c_ptrs = detail::get_ptr_vector<is_vbatch, vec_addr_spc>(vc);
122 if constexpr (is_vbatch) {
126 for (
auto i : range(n_b)) {
132 auto const [m, k] = mat_a.shape();
133 auto const [l, n] = mat_b.shape();
135 EXPECTS(m == mat_c.extent(0));
136 EXPECTS(n == mat_c.extent(1));
149 device::gemm_vbatch(
get_op<A>,
get_op<B>, vm.data(), vn.data(), vk.data(), alpha, a_ptrs.data(), vlda.data(), b_ptrs.data(), vldb.data(),
150 beta, c_ptrs.data(), vldc.
data(), n_b);
152 f77::gemm_vbatch(
get_op<A>,
get_op<B>, vm.data(), vn.data(), vk.data(), alpha, a_ptrs.data(), vlda.data(), b_ptrs.data(), vldb.data(), beta,
153 c_ptrs.data(), vldc.
data(), n_b);
161 auto const [m, k] = mat_a.shape();
162 auto const [l, n] = mat_b.shape();
164 EXPECTS(m == mat_c.extent(0));
165 EXPECTS(n == mat_c.extent(1));
169 device::gemm_batch(
get_op<A>,
get_op<B>, m, n, k, alpha, a_ptrs.data(),
get_ld(mat_a), b_ptrs.data(),
get_ld(mat_b), beta, c_ptrs.data(),
172 f77::gemm_batch(
get_op<A>,
get_op<B>, m, n, k, alpha, a_ptrs.data(),
get_ld(mat_a), b_ptrs.data(),
get_ld(mat_b), beta, c_ptrs.data(),
193 template <BlasArrayOrConj<2> A, BlasArrayOrConjFor<A, 2> B, BlasArrayFor<A, 2> C>
232 template <BlasArrayOrConj<3> A, BlasArrayOrConjFor<A, 3> B, BlasArrayFor<A, 3> C>
240 auto array_info = [](
auto &arr) {
243 return std::array<long, 5>{arr.extent(0), mat.extent(0), mat.extent(1),
get_ld(mat), arr.strides()[0]};
246 return std::array<long, 5>{arr.extent(2), mat.extent(0), mat.extent(1),
get_ld(mat), arr.strides()[2]};
255 auto const [nb_a, m_a, k_a, ld_a, s_a] = array_info(arr_a);
256 auto const [nb_b, k_b, n_b, ld_b, s_b] = array_info(arr_b);
257 auto const [nb_c, m_c, n_c, ld_c, s_c] = array_info(c);
261 EXPECTS(nb_a == nb_b and nb_a == nb_c);
264 EXPECTS(arr_a.indexmap().min_stride() == 1);
265 EXPECTS(arr_b.indexmap().min_stride() == 1);
266 EXPECTS(c.indexmap().min_stride() == 1);
270 device::gemm_batch_strided(
get_op<A>,
get_op<B>, m_c, n_c, k_a, alpha, arr_a.data(), ld_a, s_a, arr_b.data(), ld_b, s_b, beta, c.data(), ld_c,
273 f77::gemm_batch_strided(
get_op<A>,
get_op<B>, m_c, n_c, k_a, alpha, arr_a.data(), ld_a, s_a, arr_b.data(), ld_b, s_b, beta, c.data(), ld_c,
285 template <BlasArrayOrConj<3> A, BlasArrayOrConjFor<A, 3> B, BlasArrayFor<A, 3> C>
Provides definitions and type traits involving the different memory address spaces supported by nda.
Provides a C++ interface for various BLAS routines.
ValueType const * data() const noexcept
Get a pointer to the actual data (in general this is not the beginning of the memory block for a view...
Provides concepts for the nda library.
Provides various convenient aliases and helper functions for nda::basic_array and nda::basic_array_vi...
Provides GPU and non-GPU specific functionality.
auto transpose(A &&a)
Transpose the memory layout of an nda::MemoryArray or an nda::expr_call.
basic_array< ValueType, 1, C_layout, 'V', ContainerPolicy > vector
Alias template of an nda::basic_array with rank 1 and a 'V' algebra.
decltype(auto) get_first_element(A &&a)
Get the first element of an array/view or simply return the scalar if a scalar is given.
std::decay_t< decltype(get_first_element(std::declval< A const >()))> get_value_t
Get the value type of an array/view or a scalar type.
MemoryArray decltype(auto) get_array(A &&a)
Get the underlying array of a conjugate lazy expression or return the array itself in case it is an n...
static constexpr char get_op
Variable template that determines the BLAS matrix operation tag ('N','T','C') based on the given bool...
int get_ld(A const &a)
Get the leading dimension of an nda::MemoryArray with rank 1 or 2 for BLAS/LAPACK calls.
static constexpr bool has_C_layout
Constexpr variable that is true if all given nda::Array types have nda::C_layout.
static constexpr bool has_F_layout
Constexpr variable that is true if all given nda::Array types have nda::F_layout.
void gemm_vbatch(get_value_t< A > alpha, std::vector< A > const &va, std::vector< B > const &vb, get_value_t< A > beta, std::vector< C > &vc)
Interface to batched versions of the BLAS/cuBLAS gemm routine for variable sized matrices.
void gemm(get_value_t< A > alpha, A const &a, B const &b, get_value_t< A > beta, C &&c)
Interface to the BLAS/cuBLAS gemm routine.
void gemm_batch(get_value_t< A > alpha, std::vector< A > const &va, std::vector< B > const &vb, get_value_t< A > beta, std::vector< C > &vc)
Interface to batched versions of the BLAS/cuBLAS gemm routine.
void gemm_batch_strided(get_value_t< A > alpha, A const &a, B const &b, get_value_t< A > beta, C &&c)
Interface to batched-strided versions of the BLAS/cuBLAS gemm routine.
static constexpr bool have_device_compatible_addr_space
Constexpr variable that is true if all given types have an address space compatible with Device.
static constexpr bool on_host
Constexpr variable that is true if all given types have a Host address space.
Macros used in the nda library.
Mimics Python's ... syntax.
Provides type traits for the nda library.