26namespace nda::blas::device {
29 inline cublasHandle_t &get_handle() {
30 struct handle_storage_t {
31 handle_storage_t() { cublasCreate(&handle); }
32 ~handle_storage_t() { cublasDestroy(handle); }
33 cublasHandle_t handle = {};
35 static auto sto = handle_storage_t{};
41 constexpr magma_trans_t get_magma_op(
char op) {
43 case 'N':
return MagmaNoTrans;
break;
44 case 'T':
return MagmaTrans;
break;
45 case 'C':
return MagmaConjTrans;
break;
46 default: std::terminate();
return {};
51 auto &get_magma_queue() {
55 magma_getdevice(&device);
56 magma_queue_create(device, &q);
58 ~queue_t() { magma_queue_destroy(q); }
59 operator magma_queue_t() {
return q; }
64 static queue_t q = {};
70 thread_local bool synchronize =
true;
71 void set_synchronization(
bool do_sync)
noexcept { synchronize = do_sync; }
72 bool get_synchronization() noexcept {
return synchronize; }
75#define CUBLAS_CHECK(X, ...) \
77 auto err = X(get_handle(), __VA_ARGS__); \
78 if (err != CUBLAS_STATUS_SUCCESS) { \
79 NDA_RUNTIME_ERROR << AS_STRING(X) << " failed \n" \
80 << " cublasGetStatusName: " << cublasGetStatusName(err) << "\n" \
81 << " cublasGetStatusString: " << cublasGetStatusString(err) << "\n"; \
83 cuda_device_sync(synchronize, AS_STRING(X)); \
91 constexpr auto cuda_data_type() {
92 if constexpr (std::is_same_v<T, float>) {
94 }
else if constexpr (std::is_same_v<T, double>) {
96 }
else if constexpr (std::is_same_v<T, std::complex<float>>) {
98 }
else if constexpr (std::is_same_v<T, std::complex<double>>) {
104 template <
typename T>
105 constexpr auto cuda_compute_type() {
106 if constexpr (std::is_same_v<T, float> or std::is_same_v<T, std::complex<float>>) {
107 return CUBLAS_COMPUTE_32F;
108 }
else if constexpr (std::is_same_v<T, double> or std::is_same_v<T, std::complex<double>>) {
109 return CUBLAS_COMPUTE_64F;
114 template <
typename T>
115 void cuda_gemm_vbatch(
char op_a,
char op_b,
int *m,
int *n,
int *k, T alpha,
const T **a,
int *lda,
const T **b,
int *ldb, T beta, T **c,
116 int *ldc,
int batch_count) {
117 auto data_t = cuda_data_type<T>();
118 auto compute_t = cuda_compute_type<T>();
119 auto vec_op_a = std::vector<cublasOperation_t>(batch_count, get_cublas_op(op_a));
120 auto vec_op_b = std::vector<cublasOperation_t>(batch_count, get_cublas_op(op_b));
121 auto vec_alpha = std::vector<T>(batch_count, alpha);
122 auto vec_beta = std::vector<T>(batch_count, beta);
123 auto vec_sizes = std::vector<int>(batch_count, 1);
124 CUBLAS_CHECK(cublasGemmGroupedBatchedEx, vec_op_a.data(), vec_op_b.data(), m, n, k, vec_alpha.data(), (
const void **)a, data_t, lda,
125 (
const void **)b, data_t, ldb, vec_beta.data(), (
void **)c, data_t, ldc, batch_count, vec_sizes.data(), compute_t);
130 template <
typename T>
131 void magma_gemm_vbatch(
char op_a,
char op_b,
int *m,
int *n,
int *k, T alpha,
const T **a,
int *lda,
const T **b,
int *ldb, T beta, T **c,
132 int *ldc,
int batch_count) {
133 if constexpr (std::is_same_v<T, std::complex<float>>) {
134 magmablas_cgemm_vbatched(get_magma_op(op_a), get_magma_op(op_b), m, n, k, cucplx(alpha), cucplx(a), lda, cucplx(b), ldb, cucplx(beta),
135 cucplx(c), ldc, batch_count, get_magma_queue());
137 magmablas_zgemm_vbatched(get_magma_op(op_a), get_magma_op(op_b), m, n, k, cucplx(alpha), cucplx(a), lda, cucplx(b), ldb, cucplx(beta),
138 cucplx(c), ldc, batch_count, get_magma_queue());
140 if (synchronize) magma_queue_sync(get_magma_queue());
144 template <
typename T>
145 void magma_gemm_vbatch(
char,
char,
int *,
int *,
int *, T,
const T **,
int *,
const T **,
int *, T, T **,
int *,
int) {
146 NDA_RUNTIME_ERROR <<
"nda::blas::device::gemmv_batch with complex types requires Magma. Configure nda with -DMagmaSupport=ON";
153 void axpy(
int n,
float alpha,
const float *x,
int incx,
float *y,
int incy) { cublasSaxpy(get_handle(), n, &alpha, x, incx, y, incy); }
154 void axpy(
int n, std::complex<float> alpha,
const std::complex<float> *x,
int incx, std::complex<float> *y,
int incy) {
155 CUBLAS_CHECK(cublasCaxpy, n, cucplx(&alpha), cucplx(x), incx, cucplx(y), incy);
157 void axpy(
int n,
double alpha,
const double *x,
int incx,
double *y,
int incy) { cublasDaxpy(get_handle(), n, &alpha, x, incx, y, incy); }
158 void axpy(
int n, std::complex<double> alpha,
const std::complex<double> *x,
int incx, std::complex<double> *y,
int incy) {
159 CUBLAS_CHECK(cublasZaxpy, n, cucplx(&alpha), cucplx(x), incx, cucplx(y), incy);
163 void copy(
int n,
const float *x,
int incx,
float *y,
int incy) { cublasScopy(get_handle(), n, x, incx, y, incy); }
164 void copy(
int n,
const std::complex<float> *x,
int incx, std::complex<float> *y,
int incy) {
165 CUBLAS_CHECK(cublasCcopy, n, cucplx(x), incx, cucplx(y), incy);
167 void copy(
int n,
const double *x,
int incx,
double *y,
int incy) { cublasDcopy(get_handle(), n, x, incx, y, incy); }
168 void copy(
int n,
const std::complex<double> *x,
int incx, std::complex<double> *y,
int incy) {
169 CUBLAS_CHECK(cublasZcopy, n, cucplx(x), incx, cucplx(y), incy);
173 float dot(
int m,
const float *x,
int incx,
const float *y,
int incy) {
175 CUBLAS_CHECK(cublasSdot, m, x, incx, y, incy, &res);
178 std::complex<float> dot(
int m,
const std::complex<float> *x,
int incx,
const std::complex<float> *y,
int incy) {
180 CUBLAS_CHECK(cublasCdotu, m, cucplx(x), incx, cucplx(y), incy, &res);
181 return {res.x, res.y};
183 std::complex<float> dotc(
int m,
const std::complex<float> *x,
int incx,
const std::complex<float> *y,
int incy) {
185 CUBLAS_CHECK(cublasCdotc, m, cucplx(x), incx, cucplx(y), incy, &res);
186 return {res.x, res.y};
188 double dot(
int m,
const double *x,
int incx,
const double *y,
int incy) {
190 CUBLAS_CHECK(cublasDdot, m, x, incx, y, incy, &res);
193 std::complex<double> dot(
int m,
const std::complex<double> *x,
int incx,
const std::complex<double> *y,
int incy) {
195 CUBLAS_CHECK(cublasZdotu, m, cucplx(x), incx, cucplx(y), incy, &res);
196 return {res.x, res.y};
198 std::complex<double> dotc(
int m,
const std::complex<double> *x,
int incx,
const std::complex<double> *y,
int incy) {
200 CUBLAS_CHECK(cublasZdotc, m, cucplx(x), incx, cucplx(y), incy, &res);
201 return {res.x, res.y};
205 void gemm(
char op_a,
char op_b,
int m,
int n,
int k,
float alpha,
const float *a,
int lda,
const float *b,
int ldb,
float beta,
float *c,
int ldc) {
206 CUBLAS_CHECK(cublasSgemm, get_cublas_op(op_a), get_cublas_op(op_b), m, n, k, &alpha, a, lda, b, ldb, &beta, c, ldc);
208 void gemm(
char op_a,
char op_b,
int m,
int n,
int k, std::complex<float> alpha,
const std::complex<float> *a,
int lda,
const std::complex<float> *b,
209 int ldb, std::complex<float> beta, std::complex<float> *c,
int ldc) {
210 CUBLAS_CHECK(cublasCgemm, get_cublas_op(op_a), get_cublas_op(op_b), m, n, k, cucplx(&alpha), cucplx(a), lda, cucplx(b), ldb, cucplx(&beta),
213 void gemm(
char op_a,
char op_b,
int m,
int n,
int k,
double alpha,
const double *a,
int lda,
const double *b,
int ldb,
double beta,
double *c,
215 CUBLAS_CHECK(cublasDgemm, get_cublas_op(op_a), get_cublas_op(op_b), m, n, k, &alpha, a, lda, b, ldb, &beta, c, ldc);
217 void gemm(
char op_a,
char op_b,
int m,
int n,
int k, std::complex<double> alpha,
const std::complex<double> *a,
int lda,
218 const std::complex<double> *b,
int ldb, std::complex<double> beta, std::complex<double> *c,
int ldc) {
219 CUBLAS_CHECK(cublasZgemm, get_cublas_op(op_a), get_cublas_op(op_b), m, n, k, cucplx(&alpha), cucplx(a), lda, cucplx(b), ldb, cucplx(&beta),
224 void gemm_batch(
char op_a,
char op_b,
int m,
int n,
int k,
float alpha,
const float **a,
int lda,
const float **b,
int ldb,
float beta,
float **c,
225 int ldc,
int batch_count) {
226 CUBLAS_CHECK(cublasSgemmBatched, get_cublas_op(op_a), get_cublas_op(op_b), m, n, k, &alpha, a, lda, b, ldb, &beta, c, ldc, batch_count);
228 void gemm_batch(
char op_a,
char op_b,
int m,
int n,
int k, std::complex<float> alpha,
const std::complex<float> **a,
int lda,
229 const std::complex<float> **b,
int ldb, std::complex<float> beta, std::complex<float> **c,
int ldc,
int batch_count) {
230 CUBLAS_CHECK(cublasCgemmBatched, get_cublas_op(op_a), get_cublas_op(op_b), m, n, k, cucplx(&alpha), cucplx(a), lda, cucplx(b), ldb, cucplx(&beta),
231 cucplx(c), ldc, batch_count);
233 void gemm_batch(
char op_a,
char op_b,
int m,
int n,
int k,
double alpha,
const double **a,
int lda,
const double **b,
int ldb,
double beta,
234 double **c,
int ldc,
int batch_count) {
235 CUBLAS_CHECK(cublasDgemmBatched, get_cublas_op(op_a), get_cublas_op(op_b), m, n, k, &alpha, a, lda, b, ldb, &beta, c, ldc, batch_count);
237 void gemm_batch(
char op_a,
char op_b,
int m,
int n,
int k, std::complex<double> alpha,
const std::complex<double> **a,
int lda,
238 const std::complex<double> **b,
int ldb, std::complex<double> beta, std::complex<double> **c,
int ldc,
int batch_count) {
239 CUBLAS_CHECK(cublasZgemmBatched, get_cublas_op(op_a), get_cublas_op(op_b), m, n, k, cucplx(&alpha), cucplx(a), lda, cucplx(b), ldb, cucplx(&beta),
240 cucplx(c), ldc, batch_count);
244 void gemm_vbatch(
char op_a,
char op_b,
int *m,
int *n,
int *k,
float alpha,
const float **a,
int *lda,
const float **b,
int *ldb,
float beta,
245 float **c,
int *ldc,
int batch_count) {
246 cuda_gemm_vbatch(op_a, op_b, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, batch_count);
248 void gemm_vbatch(
char op_a,
char op_b,
int *m,
int *n,
int *k, std::complex<float> alpha,
const std::complex<float> **a,
int *lda,
249 const std::complex<float> **b,
int *ldb, std::complex<float> beta, std::complex<float> **c,
int *ldc,
int batch_count) {
250 magma_gemm_vbatch(op_a, op_b, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, batch_count);
252 void gemm_vbatch(
char op_a,
char op_b,
int *m,
int *n,
int *k,
double alpha,
const double **a,
int *lda,
const double **b,
int *ldb,
double beta,
253 double **c,
int *ldc,
int batch_count) {
254 cuda_gemm_vbatch(op_a, op_b, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, batch_count);
256 void gemm_vbatch(
char op_a,
char op_b,
int *m,
int *n,
int *k, std::complex<double> alpha,
const std::complex<double> **a,
int *lda,
257 const std::complex<double> **b,
int *ldb, std::complex<double> beta, std::complex<double> **c,
int *ldc,
int batch_count) {
258 magma_gemm_vbatch(op_a, op_b, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, batch_count);
262 void gemm_batch_strided(
char op_a,
char op_b,
int m,
int n,
int k,
float alpha,
const float *a,
int lda,
int stride_a,
const float *b,
int ldb,
263 int stride_b,
float beta,
float *c,
int ldc,
int stride_c,
int batch_count) {
264 CUBLAS_CHECK(cublasSgemmStridedBatched, get_cublas_op(op_a), get_cublas_op(op_b), m, n, k, &alpha, a, lda, stride_a, b, ldb, stride_b, &beta, c,
265 ldc, stride_c, batch_count);
267 void gemm_batch_strided(
char op_a,
char op_b,
int m,
int n,
int k, std::complex<float> alpha,
const std::complex<float> *a,
int lda,
int stride_a,
268 const std::complex<float> *b,
int ldb,
int stride_b, std::complex<float> beta, std::complex<float> *c,
int ldc,
269 int stride_c,
int batch_count) {
270 CUBLAS_CHECK(cublasCgemmStridedBatched, get_cublas_op(op_a), get_cublas_op(op_b), m, n, k, cucplx(&alpha), cucplx(a), lda, stride_a, cucplx(b),
271 ldb, stride_b, cucplx(&beta), cucplx(c), ldc, stride_c, batch_count);
273 void gemm_batch_strided(
char op_a,
char op_b,
int m,
int n,
int k,
double alpha,
const double *a,
int lda,
int stride_a,
const double *b,
int ldb,
274 int stride_b,
double beta,
double *c,
int ldc,
int stride_c,
int batch_count) {
275 CUBLAS_CHECK(cublasDgemmStridedBatched, get_cublas_op(op_a), get_cublas_op(op_b), m, n, k, &alpha, a, lda, stride_a, b, ldb, stride_b, &beta, c,
276 ldc, stride_c, batch_count);
278 void gemm_batch_strided(
char op_a,
char op_b,
int m,
int n,
int k, std::complex<double> alpha,
const std::complex<double> *a,
int lda,
int stride_a,
279 const std::complex<double> *b,
int ldb,
int stride_b, std::complex<double> beta, std::complex<double> *c,
int ldc,
280 int stride_c,
int batch_count) {
281 CUBLAS_CHECK(cublasZgemmStridedBatched, get_cublas_op(op_a), get_cublas_op(op_b), m, n, k, cucplx(&alpha), cucplx(a), lda, stride_a, cucplx(b),
282 ldb, stride_b, cucplx(&beta), cucplx(c), ldc, stride_c, batch_count);
286 void gemv(
char op,
int m,
int n,
float alpha,
const float *a,
int lda,
const float *x,
int incx,
float beta,
float *y,
int incy) {
287 CUBLAS_CHECK(cublasSgemv, get_cublas_op(op), m, n, &alpha, a, lda, x, incx, &beta, y, incy);
289 void gemv(
char op,
int m,
int n, std::complex<float> alpha,
const std::complex<float> *a,
int lda,
const std::complex<float> *x,
int incx,
290 std::complex<float> beta, std::complex<float> *y,
int incy) {
291 CUBLAS_CHECK(cublasCgemv, get_cublas_op(op), m, n, cucplx(&alpha), cucplx(a), lda, cucplx(x), incx, cucplx(&beta), cucplx(y), incy);
293 void gemv(
char op,
int m,
int n,
double alpha,
const double *a,
int lda,
const double *x,
int incx,
double beta,
double *y,
int incy) {
294 CUBLAS_CHECK(cublasDgemv, get_cublas_op(op), m, n, &alpha, a, lda, x, incx, &beta, y, incy);
296 void gemv(
char op,
int m,
int n, std::complex<double> alpha,
const std::complex<double> *a,
int lda,
const std::complex<double> *x,
int incx,
297 std::complex<double> beta, std::complex<double> *y,
int incy) {
298 CUBLAS_CHECK(cublasZgemv, get_cublas_op(op), m, n, cucplx(&alpha), cucplx(a), lda, cucplx(x), incx, cucplx(&beta), cucplx(y), incy);
302 void ger(
int m,
int n,
float alpha,
const float *x,
int incx,
const float *y,
int incy,
float *a,
int lda) {
303 CUBLAS_CHECK(cublasSger, m, n, &alpha, x, incx, y, incy, a, lda);
305 void ger(
int m,
int n, std::complex<float> alpha,
const std::complex<float> *x,
int incx,
const std::complex<float> *y,
int incy,
306 std::complex<float> *a,
int lda) {
307 CUBLAS_CHECK(cublasCgeru, m, n, cucplx(&alpha), cucplx(x), incx, cucplx(y), incy, cucplx(a), lda);
309 void gerc(
int m,
int n, std::complex<float> alpha,
const std::complex<float> *x,
int incx,
const std::complex<float> *y,
int incy,
310 std::complex<float> *a,
int lda) {
311 CUBLAS_CHECK(cublasCgerc, m, n, cucplx(&alpha), cucplx(x), incx, cucplx(y), incy, cucplx(a), lda);
313 void ger(
int m,
int n,
double alpha,
const double *x,
int incx,
const double *y,
int incy,
double *a,
int lda) {
314 CUBLAS_CHECK(cublasDger, m, n, &alpha, x, incx, y, incy, a, lda);
316 void ger(
int m,
int n, std::complex<double> alpha,
const std::complex<double> *x,
int incx,
const std::complex<double> *y,
int incy,
317 std::complex<double> *a,
int lda) {
318 CUBLAS_CHECK(cublasZgeru, m, n, cucplx(&alpha), cucplx(x), incx, cucplx(y), incy, cucplx(a), lda);
320 void gerc(
int m,
int n, std::complex<double> alpha,
const std::complex<double> *x,
int incx,
const std::complex<double> *y,
int incy,
321 std::complex<double> *a,
int lda) {
322 CUBLAS_CHECK(cublasZgerc, m, n, cucplx(&alpha), cucplx(x), incx, cucplx(y), incy, cucplx(a), lda);
326 void scal(
int m,
float alpha,
float *x,
int incx) { CUBLAS_CHECK(cublasSscal, m, &alpha, x, incx); }
327 void scal(
int m, std::complex<float> alpha, std::complex<float> *x,
int incx) { CUBLAS_CHECK(cublasCscal, m, cucplx(&alpha), cucplx(x), incx); }
328 void scal(
int m,
double alpha,
double *x,
int incx) { CUBLAS_CHECK(cublasDscal, m, &alpha, x, incx); }
329 void scal(
int m, std::complex<double> alpha, std::complex<double> *x,
int incx) { CUBLAS_CHECK(cublasZscal, m, cucplx(&alpha), cucplx(x), incx); }
332 void swap(
int n,
float *x,
int incx,
float *y,
int incy) { CUBLAS_CHECK(cublasSswap, n, x, incx, y, incy); }
333 void swap(
int n, std::complex<float> *x,
int incx, std::complex<float> *y,
int incy) {
334 CUBLAS_CHECK(cublasCswap, n, cucplx(x), incx, cucplx(y), incy);
336 void swap(
int n,
double *x,
int incx,
double *y,
int incy) { CUBLAS_CHECK(cublasDswap, n, x, incx, y, incy); }
337 void swap(
int n, std::complex<double> *x,
int incx, std::complex<double> *y,
int incy) {
338 CUBLAS_CHECK(cublasZswap, n, cucplx(x), incx, cucplx(y), incy);
342 void getrf_batch(
int n,
float **a_array,
int lda,
int *ipiv_array,
int *info_array,
int batch_size) {
343 CUBLAS_CHECK(cublasSgetrfBatched, n, a_array, lda, ipiv_array, info_array, batch_size);
345 void getrf_batch(
int n, std::complex<float> **a_array,
int lda,
int *ipiv_array,
int *info_array,
int batch_size) {
346 CUBLAS_CHECK(cublasCgetrfBatched, n, cucplx(a_array), lda, ipiv_array, info_array, batch_size);
348 void getrf_batch(
int n,
double **a_array,
int lda,
int *ipiv_array,
int *info_array,
int batch_size) {
349 CUBLAS_CHECK(cublasDgetrfBatched, n, a_array, lda, ipiv_array, info_array, batch_size);
351 void getrf_batch(
int n, std::complex<double> **a_array,
int lda,
int *ipiv_array,
int *info_array,
int batch_size) {
352 CUBLAS_CHECK(cublasZgetrfBatched, n, cucplx(a_array), lda, ipiv_array, info_array, batch_size);
356 void getri_batch(
int n,
float **a_array,
int lda,
int const *ipiv_array,
float **c_array,
int ldc,
int *info_array,
int batch_size) {
357 CUBLAS_CHECK(cublasSgetriBatched, n, a_array, lda, ipiv_array, c_array, ldc, info_array, batch_size);
359 void getri_batch(
int n, std::complex<float> **a_array,
int lda,
int const *ipiv_array, std::complex<float> **c_array,
int ldc,
int *info_array,
361 CUBLAS_CHECK(cublasCgetriBatched, n, cucplx(a_array), lda, ipiv_array, cucplx(c_array), ldc, info_array, batch_size);
363 void getri_batch(
int n,
double **a_array,
int lda,
int const *ipiv_array,
double **c_array,
int ldc,
int *info_array,
int batch_size) {
364 CUBLAS_CHECK(cublasDgetriBatched, n, a_array, lda, ipiv_array, c_array, ldc, info_array, batch_size);
366 void getri_batch(
int n, std::complex<double> **a_array,
int lda,
int const *ipiv_array, std::complex<double> **c_array,
int ldc,
int *info_array,
368 CUBLAS_CHECK(cublasZgetriBatched, n, cucplx(a_array), lda, ipiv_array, cucplx(c_array), ldc, info_array, batch_size);
372 void getrs_batch(
char op,
int n,
int nrhs,
const float **a_array,
int lda,
int const *ipiv_array,
float **b_array,
int ldb,
int &info,
374 CUBLAS_CHECK(cublasSgetrsBatched, get_cublas_op(op), n, nrhs, a_array, lda, ipiv_array, b_array, ldb, &info, batch_size);
376 void getrs_batch(
char op,
int n,
int nrhs,
const std::complex<float> **a_array,
int lda,
int const *ipiv_array, std::complex<float> **b_array,
377 int ldb,
int &info,
int batch_size) {
378 CUBLAS_CHECK(cublasCgetrsBatched, get_cublas_op(op), n, nrhs, cucplx(a_array), lda, ipiv_array, cucplx(b_array), ldb, &info, batch_size);
380 void getrs_batch(
char op,
int n,
int nrhs,
const double **a_array,
int lda,
int const *ipiv_array,
double **b_array,
int ldb,
int &info,
382 CUBLAS_CHECK(cublasDgetrsBatched, get_cublas_op(op), n, nrhs, a_array, lda, ipiv_array, b_array, ldb, &info, batch_size);
384 void getrs_batch(
char op,
int n,
int nrhs,
const std::complex<double> **a_array,
int lda,
int const *ipiv_array, std::complex<double> **b_array,
385 int ldb,
int &info,
int batch_size) {
386 CUBLAS_CHECK(cublasZgetrsBatched, get_cublas_op(op), n, nrhs, cucplx(a_array), lda, ipiv_array, cucplx(b_array), ldb, &info, batch_size);
390 void geqrf_batch(
int n,
int m,
float **a_array,
int lda,
float **tau_array,
int &info,
int batch_size) {
391 CUBLAS_CHECK(cublasSgeqrfBatched, n, m, a_array, lda, tau_array, &info, batch_size);
393 void geqrf_batch(
int n,
int m, std::complex<float> **a_array,
int lda, std::complex<float> **tau_array,
int &info,
int batch_size) {
394 CUBLAS_CHECK(cublasCgeqrfBatched, n, m, cucplx(a_array), lda, cucplx(tau_array), &info, batch_size);
396 void geqrf_batch(
int n,
int m,
double **a_array,
int lda,
double **tau_array,
int &info,
int batch_size) {
397 CUBLAS_CHECK(cublasDgeqrfBatched, n, m, a_array, lda, tau_array, &info, batch_size);
399 void geqrf_batch(
int n,
int m, std::complex<double> **a_array,
int lda, std::complex<double> **tau_array,
int &info,
int batch_size) {
400 CUBLAS_CHECK(cublasZgeqrfBatched, n, m, cucplx(a_array), lda, cucplx(tau_array), &info, batch_size);
Provides a C++ interface for the GPU versions of various BLAS routines.
Provides GPU and non-GPU specific functionality.
Provides a custom runtime error class and macros to assert conditions and throw exceptions.
void cuda_device_sync(bool do_sync=true, std::string_view func="")
Empty function if CudaSupport is not enabled.
Provides type traits for the nda library.