20namespace nda::blas::device {
23 inline cublasHandle_t &get_handle() {
24 struct handle_storage_t {
25 handle_storage_t() { cublasCreate(&handle); }
26 ~handle_storage_t() { cublasDestroy(handle); }
27 cublasHandle_t handle = {};
29 static auto sto = handle_storage_t{};
35 constexpr magma_trans_t get_magma_op(
char op) {
37 case 'N':
return MagmaNoTrans;
break;
38 case 'T':
return MagmaTrans;
break;
39 case 'C':
return MagmaConjTrans;
break;
40 default: std::terminate();
return {};
45 auto &get_magma_queue() {
49 magma_getdevice(&device);
50 magma_queue_create(device, &q);
52 ~queue_t() { magma_queue_destroy(q); }
53 operator magma_queue_t() {
return q; }
58 static queue_t q = {};
64 static bool synchronize =
true;
67#define CUBLAS_CHECK(X, ...) \
69 auto err = X(get_handle(), __VA_ARGS__); \
70 if (err != CUBLAS_STATUS_SUCCESS) { \
71 NDA_RUNTIME_ERROR << AS_STRING(X) << " failed \n" \
72 << " cublasGetStatusName: " << cublasGetStatusName(err) << "\n" \
73 << " cublasGetStatusString: " << cublasGetStatusString(err) << "\n"; \
76 auto errsync = cudaDeviceSynchronize(); \
77 if (errsync != cudaSuccess) { \
78 NDA_RUNTIME_ERROR << " cudaDeviceSynchronize failed after call to: " << AS_STRING(X) << "\n" \
79 << " cudaGetErrorName: " << cudaGetErrorName(errsync) << "\n" \
80 << " cudaGetErrorString: " << cudaGetErrorString(errsync) << "\n"; \
85 void gemm(
char op_a,
char op_b,
int m,
int n,
int k,
double alpha,
const double *a,
int lda,
const double *b,
int ldb,
double beta,
double *c,
87 CUBLAS_CHECK(cublasDgemm, get_cublas_op(op_a), get_cublas_op(op_b), m, n, k, &alpha, a, lda, b, ldb, &beta, c, ldc);
89 void gemm(
char op_a,
char op_b,
int m,
int n,
int k,
dcomplex alpha,
const dcomplex *a,
int lda,
const dcomplex *b,
int ldb,
dcomplex beta,
91 auto alpha_cu = cucplx(alpha);
92 auto beta_cu = cucplx(beta);
93 CUBLAS_CHECK(cublasZgemm, get_cublas_op(op_a), get_cublas_op(op_b), m, n, k, &alpha_cu, cucplx(a), lda, cucplx(b), ldb, &beta_cu, cucplx(c), ldc);
96 void gemm_batch(
char op_a,
char op_b,
int m,
int n,
int k,
double alpha,
const double **a,
int lda,
const double **b,
int ldb,
double beta,
97 double **c,
int ldc,
int batch_count) {
98 CUBLAS_CHECK(cublasDgemmBatched, get_cublas_op(op_a), get_cublas_op(op_b), m, n, k, &alpha, a, lda, b, ldb, &beta, c, ldc, batch_count);
100 void gemm_batch(
char op_a,
char op_b,
int m,
int n,
int k,
dcomplex alpha,
const dcomplex **a,
int lda,
const dcomplex **b,
int ldb,
dcomplex beta,
101 dcomplex **c,
int ldc,
int batch_count) {
102 auto alpha_cu = cucplx(alpha);
103 auto beta_cu = cucplx(beta);
104 CUBLAS_CHECK(cublasZgemmBatched, get_cublas_op(op_a), get_cublas_op(op_b), m, n, k, &alpha_cu, cucplx(a), lda, cucplx(b), ldb, &beta_cu,
105 cucplx(c), ldc, batch_count);
109 void gemm_vbatch(
char op_a,
char op_b,
int *m,
int *n,
int *k,
double alpha,
const double **a,
int *lda,
const double **b,
int *ldb,
double beta,
110 double **c,
int *ldc,
int batch_count) {
111 magmablas_dgemm_vbatched(get_magma_op(op_a), get_magma_op(op_b), m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, batch_count, get_magma_queue());
112 if (synchronize) magma_queue_sync(get_magma_queue());
113 if (synchronize) cudaDeviceSynchronize();
115 void gemm_vbatch(
char op_a,
char op_b,
int *m,
int *n,
int *k,
dcomplex alpha,
const dcomplex **a,
int *lda,
const dcomplex **b,
int *ldb,
117 auto alpha_cu = cucplx(alpha);
118 auto beta_cu = cucplx(beta);
119 magmablas_zgemm_vbatched(get_magma_op(op_a), get_magma_op(op_b), m, n, k, alpha_cu, cucplx(a), lda, cucplx(b), ldb, beta_cu, cucplx(c), ldc,
120 batch_count, get_magma_queue());
121 if (synchronize) magma_queue_sync(get_magma_queue());
122 if (synchronize) cudaDeviceSynchronize();
126 void gemm_batch_strided(
char op_a,
char op_b,
int m,
int n,
int k,
double alpha,
const double *a,
int lda,
int stride_a,
const double *b,
int ldb,
127 int stride_b,
double beta,
double *c,
int ldc,
int stride_c,
int batch_count) {
128 CUBLAS_CHECK(cublasDgemmStridedBatched, get_cublas_op(op_a), get_cublas_op(op_b), m, n, k, &alpha, a, lda, stride_a, b, ldb, stride_b, &beta, c,
129 ldc, stride_c, batch_count);
131 void gemm_batch_strided(
char op_a,
char op_b,
int m,
int n,
int k,
dcomplex alpha,
const dcomplex *a,
int lda,
int stride_a,
const dcomplex *b,
132 int ldb,
int stride_b,
dcomplex beta,
dcomplex *c,
int ldc,
int stride_c,
int batch_count) {
133 auto alpha_cu = cucplx(alpha);
134 auto beta_cu = cucplx(beta);
135 CUBLAS_CHECK(cublasZgemmStridedBatched, get_cublas_op(op_a), get_cublas_op(op_b), m, n, k, &alpha_cu, cucplx(a), lda, stride_a, cucplx(b), ldb,
136 stride_b, &beta_cu, cucplx(c), ldc, stride_c, batch_count);
139 void axpy(
int n,
double alpha,
const double *x,
int incx,
double *y,
int incy) { cublasDaxpy(get_handle(), n, &alpha, x, incx, y, incy); }
141 auto alpha_cu = cucplx(alpha);
142 CUBLAS_CHECK(cublasZaxpy, n, &alpha_cu, cucplx(x), incx, cucplx(y), incy);
145 void copy(
int n,
const double *x,
int incx,
double *y,
int incy) { cublasDcopy(get_handle(), n, x, incx, y, incy); }
146 void copy(
int n,
const dcomplex *x,
int incx,
dcomplex *y,
int incy) { CUBLAS_CHECK(cublasZcopy, n, cucplx(x), incx, cucplx(y), incy); }
148 double dot(
int m,
const double *x,
int incx,
const double *y,
int incy) {
150 CUBLAS_CHECK(cublasDdot, m, x, incx, y, incy, &res);
155 CUBLAS_CHECK(cublasZdotu, m, cucplx(x), incx, cucplx(y), incy, &res);
156 return {res.x, res.y};
160 CUBLAS_CHECK(cublasZdotc, m, cucplx(x), incx, cucplx(y), incy, &res);
161 return {res.x, res.y};
164 void gemv(
char op,
int m,
int n,
double alpha,
const double *a,
int lda,
const double *x,
int incx,
double beta,
double *y,
int incy) {
165 CUBLAS_CHECK(cublasDgemv, get_cublas_op(op), m, n, &alpha, a, lda, x, incx, &beta, y, incy);
167 void gemv(
char op,
int m,
int n,
dcomplex alpha,
const dcomplex *a,
int lda,
const dcomplex *x,
int incx,
dcomplex beta,
dcomplex *y,
int incy) {
168 auto alpha_cu = cucplx(alpha);
169 auto beta_cu = cucplx(beta);
170 CUBLAS_CHECK(cublasZgemv, get_cublas_op(op), m, n, &alpha_cu, cucplx(a), lda, cucplx(x), incx, &beta_cu, cucplx(y), incy);
173 void ger(
int m,
int n,
double alpha,
const double *x,
int incx,
const double *y,
int incy,
double *a,
int lda) {
174 CUBLAS_CHECK(cublasDger, m, n, &alpha, x, incx, y, incy, a, lda);
177 auto alpha_cu = cucplx(alpha);
178 CUBLAS_CHECK(cublasZgeru, m, n, &alpha_cu, cucplx(x), incx, cucplx(y), incy, cucplx(a), lda);
181 auto alpha_cu = cucplx(alpha);
182 CUBLAS_CHECK(cublasZgerc, m, n, &alpha_cu, cucplx(x), incx, cucplx(y), incy, cucplx(a), lda);
185 void scal(
int m,
double alpha,
double *x,
int incx) { CUBLAS_CHECK(cublasDscal, m, &alpha, x, incx); }
187 auto alpha_cu = cucplx(alpha);
188 CUBLAS_CHECK(cublasZscal, m, &alpha_cu, cucplx(x), incx);
191 void swap(
int n,
double *x,
int incx,
double *y,
int incy) { CUBLAS_CHECK(cublasDswap, n, x, incx, y, incy); }
193 CUBLAS_CHECK(cublasZswap, n, cucplx(x), incx, cucplx(y), incy);
Provides a C++ interface for the GPU versions of various BLAS routines.
Provides GPU and non-GPU specific functionality.
Provides a custom runtime error class and macros to assert conditions and throw exceptions.
std::complex< double > dcomplex
Alias for std::complex<double> type.