20namespace nda::blas::device {
23 inline cublasHandle_t &get_handle() {
24 struct handle_storage_t {
25 handle_storage_t() { cublasCreate(&handle); }
26 ~handle_storage_t() { cublasDestroy(handle); }
27 cublasHandle_t handle = {};
29 static auto sto = handle_storage_t{};
35 constexpr magma_trans_t get_magma_op(
char op) {
37 case 'N':
return MagmaNoTrans;
break;
38 case 'T':
return MagmaTrans;
break;
39 case 'C':
return MagmaConjTrans;
break;
40 default: std::terminate();
return {};
45 auto &get_magma_queue() {
49 magma_getdevice(&device);
50 magma_queue_create(device, &q);
52 ~queue_t() { magma_queue_destroy(q); }
53 operator magma_queue_t() {
return q; }
58 static queue_t q = {};
64 static bool synchronize =
true;
67#define CUBLAS_CHECK(X, ...) \
69 auto err = X(get_handle(), __VA_ARGS__); \
70 if (err != CUBLAS_STATUS_SUCCESS) { \
71 NDA_RUNTIME_ERROR << AS_STRING(X) << " failed \n" \
72 << " cublasGetStatusName: " << cublasGetStatusName(err) << "\n" \
73 << " cublasGetStatusString: " << cublasGetStatusString(err) << "\n"; \
76 auto errsync = cudaDeviceSynchronize(); \
77 if (errsync != cudaSuccess) { \
78 NDA_RUNTIME_ERROR << " cudaDeviceSynchronize failed after call to: " << AS_STRING(X) << "\n" \
79 << " cudaGetErrorName: " << cudaGetErrorName(errsync) << "\n" \
80 << " cudaGetErrorString: " << cudaGetErrorString(errsync) << "\n"; \
85 void gemm(
char op_a,
char op_b,
int M,
int N,
int K,
double alpha,
const double *A,
int LDA,
const double *B,
int LDB,
double beta,
double *C,
87 CUBLAS_CHECK(cublasDgemm, get_cublas_op(op_a), get_cublas_op(op_b), M, N, K, &alpha, A, LDA, B, LDB, &beta, C, LDC);
89 void gemm(
char op_a,
char op_b,
int M,
int N,
int K,
dcomplex alpha,
const dcomplex *A,
int LDA,
const dcomplex *B,
int LDB,
dcomplex beta,
91 auto alpha_cu = cucplx(alpha);
92 auto beta_cu = cucplx(beta);
93 CUBLAS_CHECK(cublasZgemm, get_cublas_op(op_a), get_cublas_op(op_b), M, N, K, &alpha_cu, cucplx(A), LDA, cucplx(B), LDB, &beta_cu, cucplx(C), LDC);
96 void gemm_batch(
char op_a,
char op_b,
int M,
int N,
int K,
double alpha,
const double **A,
int LDA,
const double **B,
int LDB,
double beta,
97 double **C,
int LDC,
int batch_count) {
98 CUBLAS_CHECK(cublasDgemmBatched, get_cublas_op(op_a), get_cublas_op(op_b), M, N, K, &alpha, A, LDA, B, LDB, &beta, C, LDC, batch_count);
100 void gemm_batch(
char op_a,
char op_b,
int M,
int N,
int K,
dcomplex alpha,
const dcomplex **A,
int LDA,
const dcomplex **B,
int LDB,
dcomplex beta,
101 dcomplex **C,
int LDC,
int batch_count) {
102 auto alpha_cu = cucplx(alpha);
103 auto beta_cu = cucplx(beta);
104 CUBLAS_CHECK(cublasZgemmBatched, get_cublas_op(op_a), get_cublas_op(op_b), M, N, K, &alpha_cu, cucplx(A), LDA, cucplx(B), LDB, &beta_cu,
105 cucplx(C), LDC, batch_count);
109 void gemm_vbatch(
char op_a,
char op_b,
int *M,
int *N,
int *K,
double alpha,
const double **A,
int *LDA,
const double **B,
int *LDB,
double beta,
110 double **C,
int *LDC,
int batch_count) {
111 magmablas_dgemm_vbatched(get_magma_op(op_a), get_magma_op(op_b), M, N, K, alpha, A, LDA, B, LDB, beta, C, LDC, batch_count, get_magma_queue());
112 if (synchronize) magma_queue_sync(get_magma_queue());
113 if (synchronize) cudaDeviceSynchronize();
115 void gemm_vbatch(
char op_a,
char op_b,
int *M,
int *N,
int *K,
dcomplex alpha,
const dcomplex **A,
int *LDA,
const dcomplex **B,
int *LDB,
117 auto alpha_cu = cucplx(alpha);
118 auto beta_cu = cucplx(beta);
119 magmablas_zgemm_vbatched(get_magma_op(op_a), get_magma_op(op_b), M, N, K, alpha_cu, cucplx(A), LDA, cucplx(B), LDB, beta_cu, cucplx(C), LDC,
120 batch_count, get_magma_queue());
121 if (synchronize) magma_queue_sync(get_magma_queue());
122 if (synchronize) cudaDeviceSynchronize();
126 void gemm_batch_strided(
char op_a,
char op_b,
int M,
int N,
int K,
double alpha,
const double *A,
int LDA,
int strideA,
const double *B,
int LDB,
127 int strideB,
double beta,
double *C,
int LDC,
int strideC,
int batch_count) {
128 CUBLAS_CHECK(cublasDgemmStridedBatched, get_cublas_op(op_a), get_cublas_op(op_b), M, N, K, &alpha, A, LDA, strideA, B, LDB, strideB, &beta, C,
129 LDC, strideC, batch_count);
131 void gemm_batch_strided(
char op_a,
char op_b,
int M,
int N,
int K,
dcomplex alpha,
const dcomplex *A,
int LDA,
int strideA,
const dcomplex *B,
132 int LDB,
int strideB,
dcomplex beta,
dcomplex *C,
int LDC,
int strideC,
int batch_count) {
133 auto alpha_cu = cucplx(alpha);
134 auto beta_cu = cucplx(beta);
135 CUBLAS_CHECK(cublasZgemmStridedBatched, get_cublas_op(op_a), get_cublas_op(op_b), M, N, K, &alpha_cu, cucplx(A), LDA, strideA, cucplx(B), LDB,
136 strideB, &beta_cu, cucplx(C), LDC, strideC, batch_count);
139 void axpy(
int N,
double alpha,
const double *x,
int incx,
double *Y,
int incy) { cublasDaxpy(get_handle(), N, &alpha, x, incx, Y, incy); }
141 CUBLAS_CHECK(cublasZaxpy, N, cucplx(&alpha), cucplx(x), incx, cucplx(Y), incy);
144 void copy(
int N,
const double *x,
int incx,
double *Y,
int incy) { cublasDcopy(get_handle(), N, x, incx, Y, incy); }
145 void copy(
int N,
const dcomplex *x,
int incx,
dcomplex *Y,
int incy) { CUBLAS_CHECK(cublasZcopy, N, cucplx(x), incx, cucplx(Y), incy); }
147 double dot(
int M,
const double *x,
int incx,
const double *Y,
int incy) {
149 CUBLAS_CHECK(cublasDdot, M, x, incx, Y, incy, &res);
154 CUBLAS_CHECK(cublasZdotu, M, cucplx(x), incx, cucplx(Y), incy, &res);
155 return {res.x, res.y};
159 CUBLAS_CHECK(cublasZdotc, M, cucplx(x), incx, cucplx(Y), incy, &res);
160 return {res.x, res.y};
163 void gemv(
char op,
int M,
int N,
double alpha,
const double *A,
int LDA,
const double *x,
int incx,
double beta,
double *Y,
int incy) {
164 CUBLAS_CHECK(cublasDgemv, get_cublas_op(op), M, N, &alpha, A, LDA, x, incx, &beta, Y, incy);
166 void gemv(
char op,
int M,
int N,
dcomplex alpha,
const dcomplex *A,
int LDA,
const dcomplex *x,
int incx,
dcomplex beta,
dcomplex *Y,
int incy) {
167 CUBLAS_CHECK(cublasZgemv, get_cublas_op(op), M, N, cucplx(&alpha), cucplx(A), LDA, cucplx(x), incx, cucplx(&beta), cucplx(Y), incy);
170 void ger(
int M,
int N,
double alpha,
const double *x,
int incx,
const double *Y,
int incy,
double *A,
int LDA) {
171 CUBLAS_CHECK(cublasDger, M, N, &alpha, x, incx, Y, incy, A, LDA);
174 CUBLAS_CHECK(cublasZgeru, M, N, cucplx(&alpha), cucplx(x), incx, cucplx(Y), incy, cucplx(A), LDA);
177 void scal(
int M,
double alpha,
double *x,
int incx) { CUBLAS_CHECK(cublasDscal, M, &alpha, x, incx); }
178 void scal(
int M,
dcomplex alpha,
dcomplex *x,
int incx) { CUBLAS_CHECK(cublasZscal, M, cucplx(&alpha), cucplx(x), incx); }
180 void swap(
int N,
double *x,
int incx,
double *Y,
int incy) { CUBLAS_CHECK(cublasDswap, N, x, incx, Y, incy); }
182 CUBLAS_CHECK(cublasZswap, N, cucplx(x), incx, cucplx(Y), incy);
Provides a C++ interface for the GPU versions of various BLAS routines.
Provides GPU and non-GPU specific functionality.
Provides a custom runtime error class and macros to assert conditions and throw exceptions.
std::complex< double > dcomplex
Alias for std::complex<double> type.