TRIQS/nda 2.0.0
Multi-dimensional array library for C++
Loading...
Searching...
No Matches
cublas_interface.cpp
Go to the documentation of this file.
1// Copyright (c) 2022--present, The Simons Foundation
2// This file is part of TRIQS/nda and is licensed under the Apache License, Version 2.0.
3// SPDX-License-Identifier: Apache-2.0
4// See LICENSE in the root of this distribution for details.
5
10
12#include "../tools.hpp"
13#include "../../device.hpp"
14#include "../../exceptions.hpp"
15#include "../../traits.hpp"
16
17#ifdef NDA_HAVE_MAGMA
18#include "magma_v2.h"
19
20#include <exception>
21#endif
22
23#include <vector>
24#include <type_traits>
25
26namespace nda::blas::device {
27
28 // Local function to get unique CuBlas handle.
29 inline cublasHandle_t &get_handle() {
30 struct handle_storage_t { // RAII for the handle
31 handle_storage_t() { cublasCreate(&handle); }
32 ~handle_storage_t() { cublasDestroy(handle); }
33 cublasHandle_t handle = {};
34 };
35 static auto sto = handle_storage_t{};
36 return sto.handle;
37 }
38
39#ifdef NDA_HAVE_MAGMA
40 // Local function to get Magma op.
41 constexpr magma_trans_t get_magma_op(char op) {
42 switch (op) {
43 case 'N': return MagmaNoTrans; break;
44 case 'T': return MagmaTrans; break;
45 case 'C': return MagmaConjTrans; break;
46 default: std::terminate(); return {};
47 }
48 }
49
50 // Local function to get Magma queue.
51 auto &get_magma_queue() {
52 struct queue_t {
53 queue_t() {
54 int device{};
55 magma_getdevice(&device);
56 magma_queue_create(device, &q);
57 }
58 ~queue_t() { magma_queue_destroy(q); }
59 operator magma_queue_t() { return q; }
60
61 private:
62 magma_queue_t q = {};
63 };
64 static queue_t q = {};
65 return q;
66 }
67#endif
68
69 // Per-thread option to turn on/off the cudaDeviceSynchronize after cublas library calls.
70 thread_local bool synchronize = true; // NOLINT (per-thread option is on purpose)
71 void set_synchronization(bool do_sync) noexcept { synchronize = do_sync; }
72 bool get_synchronization() noexcept { return synchronize; }
73
74// Macro to check cublas calls.
75#define CUBLAS_CHECK(X, ...) \
76 { \
77 auto err = X(get_handle(), __VA_ARGS__); \
78 if (err != CUBLAS_STATUS_SUCCESS) { \
79 NDA_RUNTIME_ERROR << AS_STRING(X) << " failed \n" \
80 << " cublasGetStatusName: " << cublasGetStatusName(err) << "\n" \
81 << " cublasGetStatusString: " << cublasGetStatusString(err) << "\n"; \
82 } \
83 cuda_device_sync(synchronize, AS_STRING(X)); \
84 }
85
86 // Anonymous namespace for some file local helper functions.
87 namespace {
88
89 // Cuda data type conversion.
90 template <typename T>
91 constexpr auto cuda_data_type() {
92 if constexpr (std::is_same_v<T, float>) {
93 return CUDA_R_32F;
94 } else if constexpr (std::is_same_v<T, double>) {
95 return CUDA_R_64F;
96 } else if constexpr (std::is_same_v<T, std::complex<float>>) {
97 return CUDA_C_32F;
98 } else if constexpr (std::is_same_v<T, std::complex<double>>) {
99 return CUDA_C_64F;
100 }
101 }
102
103 // Cuda compute type conversion.
104 template <typename T>
105 constexpr auto cuda_compute_type() {
106 if constexpr (std::is_same_v<T, float> or std::is_same_v<T, std::complex<float>>) {
107 return CUBLAS_COMPUTE_32F;
108 } else if constexpr (std::is_same_v<T, double> or std::is_same_v<T, std::complex<double>>) {
109 return CUBLAS_COMPUTE_64F;
110 }
111 }
112
113 // Helper function to call CUDA's cublasGemmGroupedBatchedEx routine.
114 template <typename T>
115 void cuda_gemm_vbatch(char op_a, char op_b, int *m, int *n, int *k, T alpha, const T **a, int *lda, const T **b, int *ldb, T beta, T **c,
116 int *ldc, int batch_count) {
117 auto data_t = cuda_data_type<T>();
118 auto compute_t = cuda_compute_type<T>();
119 auto vec_op_a = std::vector<cublasOperation_t>(batch_count, get_cublas_op(op_a));
120 auto vec_op_b = std::vector<cublasOperation_t>(batch_count, get_cublas_op(op_b));
121 auto vec_alpha = std::vector<T>(batch_count, alpha);
122 auto vec_beta = std::vector<T>(batch_count, beta);
123 auto vec_sizes = std::vector<int>(batch_count, 1);
124 CUBLAS_CHECK(cublasGemmGroupedBatchedEx, vec_op_a.data(), vec_op_b.data(), m, n, k, vec_alpha.data(), (const void **)a, data_t, lda,
125 (const void **)b, data_t, ldb, vec_beta.data(), (void **)c, data_t, ldc, batch_count, vec_sizes.data(), compute_t);
126 }
127
128 // Helper function to call Magma's magma_gemm_vbatched routine.
129#ifdef NDA_HAVE_MAGMA
130 template <typename T>
131 void magma_gemm_vbatch(char op_a, char op_b, int *m, int *n, int *k, T alpha, const T **a, int *lda, const T **b, int *ldb, T beta, T **c,
132 int *ldc, int batch_count) {
133 if constexpr (std::is_same_v<T, std::complex<float>>) {
134 magmablas_cgemm_vbatched(get_magma_op(op_a), get_magma_op(op_b), m, n, k, cucplx(alpha), cucplx(a), lda, cucplx(b), ldb, cucplx(beta),
135 cucplx(c), ldc, batch_count, get_magma_queue());
136 } else {
137 magmablas_zgemm_vbatched(get_magma_op(op_a), get_magma_op(op_b), m, n, k, cucplx(alpha), cucplx(a), lda, cucplx(b), ldb, cucplx(beta),
138 cucplx(c), ldc, batch_count, get_magma_queue());
139 }
140 if (synchronize) magma_queue_sync(get_magma_queue());
141 cuda_device_sync(synchronize, "magma_gemm_vbatch");
142 }
143#else
144 template <typename T>
145 void magma_gemm_vbatch(char, char, int *, int *, int *, T, const T **, int *, const T **, int *, T, T **, int *, int) {
146 NDA_RUNTIME_ERROR << "nda::blas::device::gemmv_batch with complex types requires Magma. Configure nda with -DMagmaSupport=ON";
147 }
148#endif
149
150 } // namespace
151
152 // axpy
153 void axpy(int n, float alpha, const float *x, int incx, float *y, int incy) { cublasSaxpy(get_handle(), n, &alpha, x, incx, y, incy); }
154 void axpy(int n, std::complex<float> alpha, const std::complex<float> *x, int incx, std::complex<float> *y, int incy) {
155 CUBLAS_CHECK(cublasCaxpy, n, cucplx(&alpha), cucplx(x), incx, cucplx(y), incy);
156 }
157 void axpy(int n, double alpha, const double *x, int incx, double *y, int incy) { cublasDaxpy(get_handle(), n, &alpha, x, incx, y, incy); }
158 void axpy(int n, std::complex<double> alpha, const std::complex<double> *x, int incx, std::complex<double> *y, int incy) {
159 CUBLAS_CHECK(cublasZaxpy, n, cucplx(&alpha), cucplx(x), incx, cucplx(y), incy);
160 }
161
162 // copy
163 void copy(int n, const float *x, int incx, float *y, int incy) { cublasScopy(get_handle(), n, x, incx, y, incy); }
164 void copy(int n, const std::complex<float> *x, int incx, std::complex<float> *y, int incy) {
165 CUBLAS_CHECK(cublasCcopy, n, cucplx(x), incx, cucplx(y), incy);
166 }
167 void copy(int n, const double *x, int incx, double *y, int incy) { cublasDcopy(get_handle(), n, x, incx, y, incy); }
168 void copy(int n, const std::complex<double> *x, int incx, std::complex<double> *y, int incy) {
169 CUBLAS_CHECK(cublasZcopy, n, cucplx(x), incx, cucplx(y), incy);
170 }
171
172 // dot and dotc
173 float dot(int m, const float *x, int incx, const float *y, int incy) {
174 float res{};
175 CUBLAS_CHECK(cublasSdot, m, x, incx, y, incy, &res);
176 return res;
177 }
178 std::complex<float> dot(int m, const std::complex<float> *x, int incx, const std::complex<float> *y, int incy) {
179 cuComplex res;
180 CUBLAS_CHECK(cublasCdotu, m, cucplx(x), incx, cucplx(y), incy, &res);
181 return {res.x, res.y};
182 }
183 std::complex<float> dotc(int m, const std::complex<float> *x, int incx, const std::complex<float> *y, int incy) {
184 cuComplex res;
185 CUBLAS_CHECK(cublasCdotc, m, cucplx(x), incx, cucplx(y), incy, &res);
186 return {res.x, res.y};
187 }
188 double dot(int m, const double *x, int incx, const double *y, int incy) {
189 double res{};
190 CUBLAS_CHECK(cublasDdot, m, x, incx, y, incy, &res);
191 return res;
192 }
193 std::complex<double> dot(int m, const std::complex<double> *x, int incx, const std::complex<double> *y, int incy) {
194 cuDoubleComplex res;
195 CUBLAS_CHECK(cublasZdotu, m, cucplx(x), incx, cucplx(y), incy, &res);
196 return {res.x, res.y};
197 }
198 std::complex<double> dotc(int m, const std::complex<double> *x, int incx, const std::complex<double> *y, int incy) {
199 cuDoubleComplex res;
200 CUBLAS_CHECK(cublasZdotc, m, cucplx(x), incx, cucplx(y), incy, &res);
201 return {res.x, res.y};
202 }
203
204 // gemm
205 void gemm(char op_a, char op_b, int m, int n, int k, float alpha, const float *a, int lda, const float *b, int ldb, float beta, float *c, int ldc) {
206 CUBLAS_CHECK(cublasSgemm, get_cublas_op(op_a), get_cublas_op(op_b), m, n, k, &alpha, a, lda, b, ldb, &beta, c, ldc);
207 }
208 void gemm(char op_a, char op_b, int m, int n, int k, std::complex<float> alpha, const std::complex<float> *a, int lda, const std::complex<float> *b,
209 int ldb, std::complex<float> beta, std::complex<float> *c, int ldc) {
210 CUBLAS_CHECK(cublasCgemm, get_cublas_op(op_a), get_cublas_op(op_b), m, n, k, cucplx(&alpha), cucplx(a), lda, cucplx(b), ldb, cucplx(&beta),
211 cucplx(c), ldc);
212 }
213 void gemm(char op_a, char op_b, int m, int n, int k, double alpha, const double *a, int lda, const double *b, int ldb, double beta, double *c,
214 int ldc) {
215 CUBLAS_CHECK(cublasDgemm, get_cublas_op(op_a), get_cublas_op(op_b), m, n, k, &alpha, a, lda, b, ldb, &beta, c, ldc);
216 }
217 void gemm(char op_a, char op_b, int m, int n, int k, std::complex<double> alpha, const std::complex<double> *a, int lda,
218 const std::complex<double> *b, int ldb, std::complex<double> beta, std::complex<double> *c, int ldc) {
219 CUBLAS_CHECK(cublasZgemm, get_cublas_op(op_a), get_cublas_op(op_b), m, n, k, cucplx(&alpha), cucplx(a), lda, cucplx(b), ldb, cucplx(&beta),
220 cucplx(c), ldc);
221 }
222
223 // gemm_batch
224 void gemm_batch(char op_a, char op_b, int m, int n, int k, float alpha, const float **a, int lda, const float **b, int ldb, float beta, float **c,
225 int ldc, int batch_count) {
226 CUBLAS_CHECK(cublasSgemmBatched, get_cublas_op(op_a), get_cublas_op(op_b), m, n, k, &alpha, a, lda, b, ldb, &beta, c, ldc, batch_count);
227 }
228 void gemm_batch(char op_a, char op_b, int m, int n, int k, std::complex<float> alpha, const std::complex<float> **a, int lda,
229 const std::complex<float> **b, int ldb, std::complex<float> beta, std::complex<float> **c, int ldc, int batch_count) {
230 CUBLAS_CHECK(cublasCgemmBatched, get_cublas_op(op_a), get_cublas_op(op_b), m, n, k, cucplx(&alpha), cucplx(a), lda, cucplx(b), ldb, cucplx(&beta),
231 cucplx(c), ldc, batch_count);
232 }
233 void gemm_batch(char op_a, char op_b, int m, int n, int k, double alpha, const double **a, int lda, const double **b, int ldb, double beta,
234 double **c, int ldc, int batch_count) {
235 CUBLAS_CHECK(cublasDgemmBatched, get_cublas_op(op_a), get_cublas_op(op_b), m, n, k, &alpha, a, lda, b, ldb, &beta, c, ldc, batch_count);
236 }
237 void gemm_batch(char op_a, char op_b, int m, int n, int k, std::complex<double> alpha, const std::complex<double> **a, int lda,
238 const std::complex<double> **b, int ldb, std::complex<double> beta, std::complex<double> **c, int ldc, int batch_count) {
239 CUBLAS_CHECK(cublasZgemmBatched, get_cublas_op(op_a), get_cublas_op(op_b), m, n, k, cucplx(&alpha), cucplx(a), lda, cucplx(b), ldb, cucplx(&beta),
240 cucplx(c), ldc, batch_count);
241 }
242
243 // gemm_vbatch
244 void gemm_vbatch(char op_a, char op_b, int *m, int *n, int *k, float alpha, const float **a, int *lda, const float **b, int *ldb, float beta,
245 float **c, int *ldc, int batch_count) {
246 cuda_gemm_vbatch(op_a, op_b, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, batch_count);
247 }
248 void gemm_vbatch(char op_a, char op_b, int *m, int *n, int *k, std::complex<float> alpha, const std::complex<float> **a, int *lda,
249 const std::complex<float> **b, int *ldb, std::complex<float> beta, std::complex<float> **c, int *ldc, int batch_count) {
250 magma_gemm_vbatch(op_a, op_b, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, batch_count);
251 }
252 void gemm_vbatch(char op_a, char op_b, int *m, int *n, int *k, double alpha, const double **a, int *lda, const double **b, int *ldb, double beta,
253 double **c, int *ldc, int batch_count) {
254 cuda_gemm_vbatch(op_a, op_b, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, batch_count);
255 }
256 void gemm_vbatch(char op_a, char op_b, int *m, int *n, int *k, std::complex<double> alpha, const std::complex<double> **a, int *lda,
257 const std::complex<double> **b, int *ldb, std::complex<double> beta, std::complex<double> **c, int *ldc, int batch_count) {
258 magma_gemm_vbatch(op_a, op_b, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, batch_count);
259 }
260
261 // gemm_batch_strided
262 void gemm_batch_strided(char op_a, char op_b, int m, int n, int k, float alpha, const float *a, int lda, int stride_a, const float *b, int ldb,
263 int stride_b, float beta, float *c, int ldc, int stride_c, int batch_count) {
264 CUBLAS_CHECK(cublasSgemmStridedBatched, get_cublas_op(op_a), get_cublas_op(op_b), m, n, k, &alpha, a, lda, stride_a, b, ldb, stride_b, &beta, c,
265 ldc, stride_c, batch_count);
266 }
267 void gemm_batch_strided(char op_a, char op_b, int m, int n, int k, std::complex<float> alpha, const std::complex<float> *a, int lda, int stride_a,
268 const std::complex<float> *b, int ldb, int stride_b, std::complex<float> beta, std::complex<float> *c, int ldc,
269 int stride_c, int batch_count) {
270 CUBLAS_CHECK(cublasCgemmStridedBatched, get_cublas_op(op_a), get_cublas_op(op_b), m, n, k, cucplx(&alpha), cucplx(a), lda, stride_a, cucplx(b),
271 ldb, stride_b, cucplx(&beta), cucplx(c), ldc, stride_c, batch_count);
272 }
273 void gemm_batch_strided(char op_a, char op_b, int m, int n, int k, double alpha, const double *a, int lda, int stride_a, const double *b, int ldb,
274 int stride_b, double beta, double *c, int ldc, int stride_c, int batch_count) {
275 CUBLAS_CHECK(cublasDgemmStridedBatched, get_cublas_op(op_a), get_cublas_op(op_b), m, n, k, &alpha, a, lda, stride_a, b, ldb, stride_b, &beta, c,
276 ldc, stride_c, batch_count);
277 }
278 void gemm_batch_strided(char op_a, char op_b, int m, int n, int k, std::complex<double> alpha, const std::complex<double> *a, int lda, int stride_a,
279 const std::complex<double> *b, int ldb, int stride_b, std::complex<double> beta, std::complex<double> *c, int ldc,
280 int stride_c, int batch_count) {
281 CUBLAS_CHECK(cublasZgemmStridedBatched, get_cublas_op(op_a), get_cublas_op(op_b), m, n, k, cucplx(&alpha), cucplx(a), lda, stride_a, cucplx(b),
282 ldb, stride_b, cucplx(&beta), cucplx(c), ldc, stride_c, batch_count);
283 }
284
285 // gemv
286 void gemv(char op, int m, int n, float alpha, const float *a, int lda, const float *x, int incx, float beta, float *y, int incy) {
287 CUBLAS_CHECK(cublasSgemv, get_cublas_op(op), m, n, &alpha, a, lda, x, incx, &beta, y, incy);
288 }
289 void gemv(char op, int m, int n, std::complex<float> alpha, const std::complex<float> *a, int lda, const std::complex<float> *x, int incx,
290 std::complex<float> beta, std::complex<float> *y, int incy) {
291 CUBLAS_CHECK(cublasCgemv, get_cublas_op(op), m, n, cucplx(&alpha), cucplx(a), lda, cucplx(x), incx, cucplx(&beta), cucplx(y), incy);
292 }
293 void gemv(char op, int m, int n, double alpha, const double *a, int lda, const double *x, int incx, double beta, double *y, int incy) {
294 CUBLAS_CHECK(cublasDgemv, get_cublas_op(op), m, n, &alpha, a, lda, x, incx, &beta, y, incy);
295 }
296 void gemv(char op, int m, int n, std::complex<double> alpha, const std::complex<double> *a, int lda, const std::complex<double> *x, int incx,
297 std::complex<double> beta, std::complex<double> *y, int incy) {
298 CUBLAS_CHECK(cublasZgemv, get_cublas_op(op), m, n, cucplx(&alpha), cucplx(a), lda, cucplx(x), incx, cucplx(&beta), cucplx(y), incy);
299 }
300
301 // ger and gerc
302 void ger(int m, int n, float alpha, const float *x, int incx, const float *y, int incy, float *a, int lda) {
303 CUBLAS_CHECK(cublasSger, m, n, &alpha, x, incx, y, incy, a, lda);
304 }
305 void ger(int m, int n, std::complex<float> alpha, const std::complex<float> *x, int incx, const std::complex<float> *y, int incy,
306 std::complex<float> *a, int lda) {
307 CUBLAS_CHECK(cublasCgeru, m, n, cucplx(&alpha), cucplx(x), incx, cucplx(y), incy, cucplx(a), lda);
308 }
309 void gerc(int m, int n, std::complex<float> alpha, const std::complex<float> *x, int incx, const std::complex<float> *y, int incy,
310 std::complex<float> *a, int lda) {
311 CUBLAS_CHECK(cublasCgerc, m, n, cucplx(&alpha), cucplx(x), incx, cucplx(y), incy, cucplx(a), lda);
312 }
313 void ger(int m, int n, double alpha, const double *x, int incx, const double *y, int incy, double *a, int lda) {
314 CUBLAS_CHECK(cublasDger, m, n, &alpha, x, incx, y, incy, a, lda);
315 }
316 void ger(int m, int n, std::complex<double> alpha, const std::complex<double> *x, int incx, const std::complex<double> *y, int incy,
317 std::complex<double> *a, int lda) {
318 CUBLAS_CHECK(cublasZgeru, m, n, cucplx(&alpha), cucplx(x), incx, cucplx(y), incy, cucplx(a), lda);
319 }
320 void gerc(int m, int n, std::complex<double> alpha, const std::complex<double> *x, int incx, const std::complex<double> *y, int incy,
321 std::complex<double> *a, int lda) {
322 CUBLAS_CHECK(cublasZgerc, m, n, cucplx(&alpha), cucplx(x), incx, cucplx(y), incy, cucplx(a), lda);
323 }
324
325 // scal
326 void scal(int m, float alpha, float *x, int incx) { CUBLAS_CHECK(cublasSscal, m, &alpha, x, incx); }
327 void scal(int m, std::complex<float> alpha, std::complex<float> *x, int incx) { CUBLAS_CHECK(cublasCscal, m, cucplx(&alpha), cucplx(x), incx); }
328 void scal(int m, double alpha, double *x, int incx) { CUBLAS_CHECK(cublasDscal, m, &alpha, x, incx); }
329 void scal(int m, std::complex<double> alpha, std::complex<double> *x, int incx) { CUBLAS_CHECK(cublasZscal, m, cucplx(&alpha), cucplx(x), incx); }
330
331 // swap
332 void swap(int n, float *x, int incx, float *y, int incy) { CUBLAS_CHECK(cublasSswap, n, x, incx, y, incy); } // NOLINT (this is a BLAS swap)
333 void swap(int n, std::complex<float> *x, int incx, std::complex<float> *y, int incy) { // NOLINT (this is a BLAS swap)
334 CUBLAS_CHECK(cublasCswap, n, cucplx(x), incx, cucplx(y), incy);
335 }
336 void swap(int n, double *x, int incx, double *y, int incy) { CUBLAS_CHECK(cublasDswap, n, x, incx, y, incy); } // NOLINT (this is a BLAS swap)
337 void swap(int n, std::complex<double> *x, int incx, std::complex<double> *y, int incy) { // NOLINT (this is a BLAS swap)
338 CUBLAS_CHECK(cublasZswap, n, cucplx(x), incx, cucplx(y), incy);
339 }
340
341 // getrf_batch
342 void getrf_batch(int n, float **a_array, int lda, int *ipiv_array, int *info_array, int batch_size) {
343 CUBLAS_CHECK(cublasSgetrfBatched, n, a_array, lda, ipiv_array, info_array, batch_size);
344 }
345 void getrf_batch(int n, std::complex<float> **a_array, int lda, int *ipiv_array, int *info_array, int batch_size) {
346 CUBLAS_CHECK(cublasCgetrfBatched, n, cucplx(a_array), lda, ipiv_array, info_array, batch_size);
347 }
348 void getrf_batch(int n, double **a_array, int lda, int *ipiv_array, int *info_array, int batch_size) {
349 CUBLAS_CHECK(cublasDgetrfBatched, n, a_array, lda, ipiv_array, info_array, batch_size);
350 }
351 void getrf_batch(int n, std::complex<double> **a_array, int lda, int *ipiv_array, int *info_array, int batch_size) {
352 CUBLAS_CHECK(cublasZgetrfBatched, n, cucplx(a_array), lda, ipiv_array, info_array, batch_size);
353 }
354
355 // getri_batch
356 void getri_batch(int n, float **a_array, int lda, int const *ipiv_array, float **c_array, int ldc, int *info_array, int batch_size) {
357 CUBLAS_CHECK(cublasSgetriBatched, n, a_array, lda, ipiv_array, c_array, ldc, info_array, batch_size);
358 }
359 void getri_batch(int n, std::complex<float> **a_array, int lda, int const *ipiv_array, std::complex<float> **c_array, int ldc, int *info_array,
360 int batch_size) {
361 CUBLAS_CHECK(cublasCgetriBatched, n, cucplx(a_array), lda, ipiv_array, cucplx(c_array), ldc, info_array, batch_size);
362 }
363 void getri_batch(int n, double **a_array, int lda, int const *ipiv_array, double **c_array, int ldc, int *info_array, int batch_size) {
364 CUBLAS_CHECK(cublasDgetriBatched, n, a_array, lda, ipiv_array, c_array, ldc, info_array, batch_size);
365 }
366 void getri_batch(int n, std::complex<double> **a_array, int lda, int const *ipiv_array, std::complex<double> **c_array, int ldc, int *info_array,
367 int batch_size) {
368 CUBLAS_CHECK(cublasZgetriBatched, n, cucplx(a_array), lda, ipiv_array, cucplx(c_array), ldc, info_array, batch_size);
369 }
370
371 // getrs_batch
372 void getrs_batch(char op, int n, int nrhs, const float **a_array, int lda, int const *ipiv_array, float **b_array, int ldb, int &info,
373 int batch_size) {
374 CUBLAS_CHECK(cublasSgetrsBatched, get_cublas_op(op), n, nrhs, a_array, lda, ipiv_array, b_array, ldb, &info, batch_size);
375 }
376 void getrs_batch(char op, int n, int nrhs, const std::complex<float> **a_array, int lda, int const *ipiv_array, std::complex<float> **b_array,
377 int ldb, int &info, int batch_size) {
378 CUBLAS_CHECK(cublasCgetrsBatched, get_cublas_op(op), n, nrhs, cucplx(a_array), lda, ipiv_array, cucplx(b_array), ldb, &info, batch_size);
379 }
380 void getrs_batch(char op, int n, int nrhs, const double **a_array, int lda, int const *ipiv_array, double **b_array, int ldb, int &info,
381 int batch_size) {
382 CUBLAS_CHECK(cublasDgetrsBatched, get_cublas_op(op), n, nrhs, a_array, lda, ipiv_array, b_array, ldb, &info, batch_size);
383 }
384 void getrs_batch(char op, int n, int nrhs, const std::complex<double> **a_array, int lda, int const *ipiv_array, std::complex<double> **b_array,
385 int ldb, int &info, int batch_size) {
386 CUBLAS_CHECK(cublasZgetrsBatched, get_cublas_op(op), n, nrhs, cucplx(a_array), lda, ipiv_array, cucplx(b_array), ldb, &info, batch_size);
387 }
388
389 // geqrf_batch
390 void geqrf_batch(int n, int m, float **a_array, int lda, float **tau_array, int &info, int batch_size) {
391 CUBLAS_CHECK(cublasSgeqrfBatched, n, m, a_array, lda, tau_array, &info, batch_size);
392 }
393 void geqrf_batch(int n, int m, std::complex<float> **a_array, int lda, std::complex<float> **tau_array, int &info, int batch_size) {
394 CUBLAS_CHECK(cublasCgeqrfBatched, n, m, cucplx(a_array), lda, cucplx(tau_array), &info, batch_size);
395 }
396 void geqrf_batch(int n, int m, double **a_array, int lda, double **tau_array, int &info, int batch_size) {
397 CUBLAS_CHECK(cublasDgeqrfBatched, n, m, a_array, lda, tau_array, &info, batch_size);
398 }
399 void geqrf_batch(int n, int m, std::complex<double> **a_array, int lda, std::complex<double> **tau_array, int &info, int batch_size) {
400 CUBLAS_CHECK(cublasZgeqrfBatched, n, m, cucplx(a_array), lda, cucplx(tau_array), &info, batch_size);
401 }
402
403} // namespace nda::blas::device
Provides various traits and utilities for the BLAS interface.
Provides a C++ interface for the GPU versions of various BLAS routines.
Provides GPU and non-GPU specific functionality.
Provides a custom runtime error class and macros to assert conditions and throw exceptions.
void cuda_device_sync(bool do_sync=true, std::string_view func="")
Empty function if CudaSupport is not enabled.
Definition device.hpp:205
Provides type traits for the nda library.