TRIQS/nda 2.0.0
Multi-dimensional array library for C++
Loading...
Searching...
No Matches
gemm_batch.hpp
Go to the documentation of this file.
1// Copyright (c) 2022--present, The Simons Foundation
2// This file is part of TRIQS/nda and is licensed under the Apache License, Version 2.0.
3// SPDX-License-Identifier: Apache-2.0
4// See LICENSE in the root of this distribution for details.
5
10
11#pragma once
12
14#include "./tools.hpp"
15#include "../concepts.hpp"
16#include "../declarations.hpp"
17#include "../device.hpp"
19#include "../macros.hpp"
21#include "../traits.hpp"
22
23#include <algorithm>
24#include <iterator>
25#include <tuple>
26#include <type_traits>
27#include <utility>
28#include <vector>
29
30namespace nda::blas {
31
36
37 namespace detail {
38
39 // Get a vector of transpose matrices from a given vector of matrices.
40 auto get_transpose_vector(auto &&v) {
41 auto v_t = std::vector<std::decay_t<decltype(transpose(v[0]))>>{};
42 v_t.reserve(v.size());
43 std::transform(v.begin(), v.end(), std::back_inserter(v_t), [](auto &x) { return transpose(x); });
44 return v_t;
45 }
46
47 // Get a vector of pointers to the memory of matrices from a given vector of matrices.
48 template <bool is_vbatch, nda::mem::AddressSpace vec_addr_spc>
49 auto get_ptr_vector(auto &&v) {
50 EXPECTS(std::ranges::all_of(v, [&v](auto &A) { return is_vbatch or A.shape() == v[0].shape(); }));
51 EXPECTS(std::ranges::all_of(v, [](auto &A) { return get_array(A).indexmap().min_stride() == 1; }));
52 using ptr_t = std::remove_reference_t<decltype(get_first_element(v[0]))> *;
53 auto v_ptrs = nda::vector<ptr_t, heap<vec_addr_spc>>(v.size());
54 std::transform(v.begin(), v.end(), v_ptrs.begin(), [](auto &z) { return get_array(z).data(); });
55 return v_ptrs;
56 }
57
58 } // namespace detail
59
100 template <bool is_vbatch = false, BlasArrayOrConj<2> A, BlasArrayOrConjFor<A, 2> B, BlasArrayFor<A, 2> C>
101 void gemm_batch(get_value_t<A> alpha, std::vector<A> const &va, std::vector<B> const &vb, get_value_t<A> beta, std::vector<C> &vc) {
102 auto const n_b = va.size();
103
104 // check sizes of input vectors and return if they are empty
105 EXPECTS(n_b == vb.size() and n_b == vc.size());
106 if (va.empty()) return;
107
108 // if C is in C-layout, compute the transpose of the product in Fortran order
109 if constexpr (has_C_layout<C>) {
110 auto vcT = detail::get_transpose_vector(vc);
111 return gemm_batch<is_vbatch>(alpha, detail::get_transpose_vector(vb), detail::get_transpose_vector(va), beta, vcT);
112 } else {
113 // for operations on the device, use unified memory for vector of ints or ptrs
114 auto constexpr vec_addr_spc = []() { return mem::on_host<C> ? mem::Host : mem::Unified; }();
115
116 // convert the vector of matrices to the corresponding vector of pointers
117 auto a_ptrs = detail::get_ptr_vector<is_vbatch, vec_addr_spc>(va);
118 auto b_ptrs = detail::get_ptr_vector<is_vbatch, vec_addr_spc>(vb);
119 auto c_ptrs = detail::get_ptr_vector<is_vbatch, vec_addr_spc>(vc);
120
121 // either call gemm_vbatch or gemm_batch
122 if constexpr (is_vbatch) {
123 // create vectors to store shapes and leading dimensions of size 'batch_count + 1' as required by Magma
124 nda::vector<int, heap<vec_addr_spc>> vm(n_b + 1), vk(n_b + 1), vn(n_b + 1), vlda(n_b + 1), vldb(n_b + 1), vldc(n_b + 1);
125
126 for (auto i : range(n_b)) {
127 auto &&mat_a = get_array(va[i]);
128 auto &&mat_b = get_array(vb[i]);
129 auto &&mat_c = get_array(vc[i]);
130
131 // check the dimensions of the input/output arrays/views
132 auto const [m, k] = mat_a.shape();
133 auto const [l, n] = mat_b.shape();
134 EXPECTS(k == l);
135 EXPECTS(m == mat_c.extent(0));
136 EXPECTS(n == mat_c.extent(1));
137
138 // store shapes and leading dimensions
139 vm[i] = m;
140 vk[i] = k;
141 vn[i] = n;
142 vlda[i] = get_ld(mat_a);
143 vldb[i] = get_ld(mat_b);
144 vldc[i] = get_ld(mat_c);
145 }
146
147 // perform the actual library call
149 device::gemm_vbatch(get_op<A>, get_op<B>, vm.data(), vn.data(), vk.data(), alpha, a_ptrs.data(), vlda.data(), b_ptrs.data(), vldb.data(),
150 beta, c_ptrs.data(), vldc.data(), n_b);
151 } else {
152 f77::gemm_vbatch(get_op<A>, get_op<B>, vm.data(), vn.data(), vk.data(), alpha, a_ptrs.data(), vlda.data(), b_ptrs.data(), vldb.data(), beta,
153 c_ptrs.data(), vldc.data(), n_b);
154 }
155 } else {
156 auto &&mat_a = get_array(va[0]);
157 auto &&mat_b = get_array(vb[0]);
158 auto &&mat_c = get_array(vc[0]);
159
160 // check the dimensions of the input/output arrays/views
161 auto const [m, k] = mat_a.shape();
162 auto const [l, n] = mat_b.shape();
163 EXPECTS(k == l);
164 EXPECTS(m == mat_c.extent(0));
165 EXPECTS(n == mat_c.extent(1));
166
167 // perform the actual library call
169 device::gemm_batch(get_op<A>, get_op<B>, m, n, k, alpha, a_ptrs.data(), get_ld(mat_a), b_ptrs.data(), get_ld(mat_b), beta, c_ptrs.data(),
170 get_ld(mat_c), n_b);
171 } else {
172 f77::gemm_batch(get_op<A>, get_op<B>, m, n, k, alpha, a_ptrs.data(), get_ld(mat_a), b_ptrs.data(), get_ld(mat_b), beta, c_ptrs.data(),
173 get_ld(mat_c), n_b);
174 }
175 }
176 }
177 }
178
193 template <BlasArrayOrConj<2> A, BlasArrayOrConjFor<A, 2> B, BlasArrayFor<A, 2> C>
194 void gemm_vbatch(get_value_t<A> alpha, std::vector<A> const &va, std::vector<B> const &vb, get_value_t<A> beta, std::vector<C> &vc) {
195 gemm_batch<true>(alpha, va, vb, beta, vc);
196 }
197
232 template <BlasArrayOrConj<3> A, BlasArrayOrConjFor<A, 3> B, BlasArrayFor<A, 3> C>
234 void gemm_batch_strided(get_value_t<A> alpha, A const &a, B const &b, get_value_t<A> beta, C &&c) {
235 // if C is in C-layout, compute the transpose of the product in Fortran order
236 if constexpr (has_C_layout<C>) {
237 gemm_batch_strided(alpha, transpose(b), transpose(a), beta, transpose(std::forward<C>(c)));
238 } else {
239 // get array info: batch count, matrix dims, leading dims, slowest stride
240 auto array_info = [](auto &arr) {
241 if constexpr (has_C_layout<decltype(arr)>) {
242 auto mat = arr(0, nda::ellipsis{});
243 return std::array<long, 5>{arr.extent(0), mat.extent(0), mat.extent(1), get_ld(mat), arr.strides()[0]};
244 } else {
245 auto mat = arr(nda::ellipsis{}, 0);
246 return std::array<long, 5>{arr.extent(2), mat.extent(0), mat.extent(1), get_ld(mat), arr.strides()[2]};
247 }
248 };
249
250 // get underlying array in case it is given as a conjugate expression
251 auto arr_a = get_array(a);
252 auto arr_b = get_array(b);
253
254 // check the dimensions of the input/output arrays/views
255 auto const [nb_a, m_a, k_a, ld_a, s_a] = array_info(arr_a);
256 auto const [nb_b, k_b, n_b, ld_b, s_b] = array_info(arr_b);
257 auto const [nb_c, m_c, n_c, ld_c, s_c] = array_info(c);
258 EXPECTS(k_a == k_b);
259 EXPECTS(m_a == m_c);
260 EXPECTS(n_b == n_c);
261 EXPECTS(nb_a == nb_b and nb_a == nb_c);
262
263 // arrays/views must be BLAS compatible
264 EXPECTS(arr_a.indexmap().min_stride() == 1);
265 EXPECTS(arr_b.indexmap().min_stride() == 1);
266 EXPECTS(c.indexmap().min_stride() == 1);
267
268 // perform the actual library call
270 device::gemm_batch_strided(get_op<A>, get_op<B>, m_c, n_c, k_a, alpha, arr_a.data(), ld_a, s_a, arr_b.data(), ld_b, s_b, beta, c.data(), ld_c,
271 s_c, nb_c);
272 } else {
273 f77::gemm_batch_strided(get_op<A>, get_op<B>, m_c, n_c, k_a, alpha, arr_a.data(), ld_a, s_a, arr_b.data(), ld_b, s_b, beta, c.data(), ld_c,
274 s_c, nb_c);
275 }
276 }
277 }
278
285 template <BlasArrayOrConj<3> A, BlasArrayOrConjFor<A, 3> B, BlasArrayFor<A, 3> C>
287 void gemm(get_value_t<A> alpha, A const &a, B const &b, get_value_t<A> beta, C &&c) {
288 gemm_batch_strided(alpha, a, b, beta, std::forward<C>(c));
289 }
290
292
293} // namespace nda::blas
Provides definitions and type traits involving the different memory address spaces supported by nda.
Provides a C++ interface for various BLAS routines.
Provides various traits and utilities for the BLAS interface.
ValueType const * data() const noexcept
Get a pointer to the actual data (in general this is not the beginning of the memory block for a view...
Provides concepts for the nda library.
Provides various convenient aliases and helper functions for nda::basic_array and nda::basic_array_vi...
Provides GPU and non-GPU specific functionality.
auto transpose(A &&a)
Transpose the memory layout of an nda::MemoryArray or an nda::expr_call.
basic_array< ValueType, 1, C_layout, 'V', ContainerPolicy > vector
Alias template of an nda::basic_array with rank 1 and a 'V' algebra.
decltype(auto) get_first_element(A &&a)
Get the first element of an array/view or simply return the scalar if a scalar is given.
Definition traits.hpp:197
std::decay_t< decltype(get_first_element(std::declval< A const >()))> get_value_t
Get the value type of an array/view or a scalar type.
Definition traits.hpp:212
MemoryArray decltype(auto) get_array(A &&a)
Get the underlying array of a conjugate lazy expression or return the array itself in case it is an n...
Definition tools.hpp:68
static constexpr char get_op
Variable template that determines the BLAS matrix operation tag ('N','T','C') based on the given bool...
Definition tools.hpp:104
int get_ld(A const &a)
Get the leading dimension of an nda::MemoryArray with rank 1 or 2 for BLAS/LAPACK calls.
Definition tools.hpp:128
static constexpr bool has_C_layout
Constexpr variable that is true if all given nda::Array types have nda::C_layout.
Definition tools.hpp:89
static constexpr bool has_F_layout
Constexpr variable that is true if all given nda::Array types have nda::F_layout.
Definition tools.hpp:79
void gemm_vbatch(get_value_t< A > alpha, std::vector< A > const &va, std::vector< B > const &vb, get_value_t< A > beta, std::vector< C > &vc)
Interface to batched versions of the BLAS/cuBLAS gemm routine for variable sized matrices.
void gemm(get_value_t< A > alpha, A const &a, B const &b, get_value_t< A > beta, C &&c)
Interface to the BLAS/cuBLAS gemm routine.
Definition gemm.hpp:57
void gemm_batch(get_value_t< A > alpha, std::vector< A > const &va, std::vector< B > const &vb, get_value_t< A > beta, std::vector< C > &vc)
Interface to batched versions of the BLAS/cuBLAS gemm routine.
void gemm_batch_strided(get_value_t< A > alpha, A const &a, B const &b, get_value_t< A > beta, C &&c)
Interface to batched-strided versions of the BLAS/cuBLAS gemm routine.
static constexpr bool have_device_compatible_addr_space
Constexpr variable that is true if all given types have an address space compatible with Device.
static constexpr bool on_host
Constexpr variable that is true if all given types have a Host address space.
Provides functions to transform the memory layout of an nda::basic_array or nda::basic_array_view.
Macros used in the nda library.
Mimics Python's ... syntax.
Definition range.hpp:36
Provides type traits for the nda library.