31namespace nda::blas::device {
34 inline cublasHandle_t &get_handle() {
35 struct handle_storage_t {
36 handle_storage_t() { cublasCreate(&handle); }
37 ~handle_storage_t() { cublasDestroy(handle); }
38 cublasHandle_t handle = {};
40 static auto sto = handle_storage_t{};
46 constexpr magma_trans_t get_magma_op(
char op) {
48 case 'N':
return MagmaNoTrans;
break;
49 case 'T':
return MagmaTrans;
break;
50 case 'C':
return MagmaConjTrans;
break;
51 default: std::terminate();
return {};
56 auto &get_magma_queue() {
60 magma_getdevice(&device);
61 magma_queue_create(device, &q);
63 ~queue_t() { magma_queue_destroy(q); }
64 operator magma_queue_t() {
return q; }
69 static queue_t q = {};
75 static bool synchronize =
true;
78#define CUBLAS_CHECK(X, ...) \
80 auto err = X(get_handle(), __VA_ARGS__); \
81 if (err != CUBLAS_STATUS_SUCCESS) { \
82 NDA_RUNTIME_ERROR << AS_STRING(X) << " failed \n" \
83 << " cublasGetStatusName: " << cublasGetStatusName(err) << "\n" \
84 << " cublasGetStatusString: " << cublasGetStatusString(err) << "\n"; \
87 auto errsync = cudaDeviceSynchronize(); \
88 if (errsync != cudaSuccess) { \
89 NDA_RUNTIME_ERROR << " cudaDeviceSynchronize failed after call to: " << AS_STRING(X) << "\n" \
90 << " cudaGetErrorName: " << cudaGetErrorName(errsync) << "\n" \
91 << " cudaGetErrorString: " << cudaGetErrorString(errsync) << "\n"; \
96 void gemm(
char op_a,
char op_b,
int M,
int N,
int K,
double alpha,
const double *A,
int LDA,
const double *B,
int LDB,
double beta,
double *C,
98 CUBLAS_CHECK(cublasDgemm, get_cublas_op(op_a), get_cublas_op(op_b), M, N, K, &alpha, A, LDA, B, LDB, &beta, C, LDC);
100 void gemm(
char op_a,
char op_b,
int M,
int N,
int K,
dcomplex alpha,
const dcomplex *A,
int LDA,
const dcomplex *B,
int LDB,
dcomplex beta,
102 auto alpha_cu = cucplx(alpha);
103 auto beta_cu = cucplx(beta);
104 CUBLAS_CHECK(cublasZgemm, get_cublas_op(op_a), get_cublas_op(op_b), M, N, K, &alpha_cu, cucplx(A), LDA, cucplx(B), LDB, &beta_cu, cucplx(C), LDC);
107 void gemm_batch(
char op_a,
char op_b,
int M,
int N,
int K,
double alpha,
const double **A,
int LDA,
const double **B,
int LDB,
double beta,
108 double **C,
int LDC,
int batch_count) {
109 CUBLAS_CHECK(cublasDgemmBatched, get_cublas_op(op_a), get_cublas_op(op_b), M, N, K, &alpha, A, LDA, B, LDB, &beta, C, LDC, batch_count);
111 void gemm_batch(
char op_a,
char op_b,
int M,
int N,
int K,
dcomplex alpha,
const dcomplex **A,
int LDA,
const dcomplex **B,
int LDB,
dcomplex beta,
112 dcomplex **C,
int LDC,
int batch_count) {
113 auto alpha_cu = cucplx(alpha);
114 auto beta_cu = cucplx(beta);
115 CUBLAS_CHECK(cublasZgemmBatched, get_cublas_op(op_a), get_cublas_op(op_b), M, N, K, &alpha_cu, cucplx(A), LDA, cucplx(B), LDB, &beta_cu,
116 cucplx(C), LDC, batch_count);
120 void gemm_vbatch(
char op_a,
char op_b,
int *M,
int *N,
int *K,
double alpha,
const double **A,
int *LDA,
const double **B,
int *LDB,
double beta,
121 double **C,
int *LDC,
int batch_count) {
122 magmablas_dgemm_vbatched(get_magma_op(op_a), get_magma_op(op_b), M, N, K, alpha, A, LDA, B, LDB, beta, C, LDC, batch_count, get_magma_queue());
123 if (synchronize) magma_queue_sync(get_magma_queue());
124 if (synchronize) cudaDeviceSynchronize();
126 void gemm_vbatch(
char op_a,
char op_b,
int *M,
int *N,
int *K,
dcomplex alpha,
const dcomplex **A,
int *LDA,
const dcomplex **B,
int *LDB,
128 auto alpha_cu = cucplx(alpha);
129 auto beta_cu = cucplx(beta);
130 magmablas_zgemm_vbatched(get_magma_op(op_a), get_magma_op(op_b), M, N, K, alpha_cu, cucplx(A), LDA, cucplx(B), LDB, beta_cu, cucplx(C), LDC,
131 batch_count, get_magma_queue());
132 if (synchronize) magma_queue_sync(get_magma_queue());
133 if (synchronize) cudaDeviceSynchronize();
137 void gemm_batch_strided(
char op_a,
char op_b,
int M,
int N,
int K,
double alpha,
const double *A,
int LDA,
int strideA,
const double *B,
int LDB,
138 int strideB,
double beta,
double *C,
int LDC,
int strideC,
int batch_count) {
139 CUBLAS_CHECK(cublasDgemmStridedBatched, get_cublas_op(op_a), get_cublas_op(op_b), M, N, K, &alpha, A, LDA, strideA, B, LDB, strideB, &beta, C,
140 LDC, strideC, batch_count);
142 void gemm_batch_strided(
char op_a,
char op_b,
int M,
int N,
int K,
dcomplex alpha,
const dcomplex *A,
int LDA,
int strideA,
const dcomplex *B,
143 int LDB,
int strideB,
dcomplex beta,
dcomplex *C,
int LDC,
int strideC,
int batch_count) {
144 auto alpha_cu = cucplx(alpha);
145 auto beta_cu = cucplx(beta);
146 CUBLAS_CHECK(cublasZgemmStridedBatched, get_cublas_op(op_a), get_cublas_op(op_b), M, N, K, &alpha_cu, cucplx(A), LDA, strideA, cucplx(B), LDB,
147 strideB, &beta_cu, cucplx(C), LDC, strideC, batch_count);
150 void axpy(
int N,
double alpha,
const double *x,
int incx,
double *Y,
int incy) { cublasDaxpy(get_handle(), N, &alpha, x, incx, Y, incy); }
152 CUBLAS_CHECK(cublasZaxpy, N, cucplx(&alpha), cucplx(x), incx, cucplx(Y), incy);
155 void copy(
int N,
const double *x,
int incx,
double *Y,
int incy) { cublasDcopy(get_handle(), N, x, incx, Y, incy); }
156 void copy(
int N,
const dcomplex *x,
int incx,
dcomplex *Y,
int incy) { CUBLAS_CHECK(cublasZcopy, N, cucplx(x), incx, cucplx(Y), incy); }
158 double dot(
int M,
const double *x,
int incx,
const double *Y,
int incy) {
160 CUBLAS_CHECK(cublasDdot, M, x, incx, Y, incy, &res);
165 CUBLAS_CHECK(cublasZdotu, M, cucplx(x), incx, cucplx(Y), incy, &res);
166 return {res.x, res.y};
170 CUBLAS_CHECK(cublasZdotc, M, cucplx(x), incx, cucplx(Y), incy, &res);
171 return {res.x, res.y};
174 void gemv(
char op,
int M,
int N,
double alpha,
const double *A,
int LDA,
const double *x,
int incx,
double beta,
double *Y,
int incy) {
175 CUBLAS_CHECK(cublasDgemv, get_cublas_op(op), M, N, &alpha, A, LDA, x, incx, &beta, Y, incy);
177 void gemv(
char op,
int M,
int N,
dcomplex alpha,
const dcomplex *A,
int LDA,
const dcomplex *x,
int incx,
dcomplex beta,
dcomplex *Y,
int incy) {
178 CUBLAS_CHECK(cublasZgemv, get_cublas_op(op), M, N, cucplx(&alpha), cucplx(A), LDA, cucplx(x), incx, cucplx(&beta), cucplx(Y), incy);
181 void ger(
int M,
int N,
double alpha,
const double *x,
int incx,
const double *Y,
int incy,
double *A,
int LDA) {
182 CUBLAS_CHECK(cublasDger, M, N, &alpha, x, incx, Y, incy, A, LDA);
185 CUBLAS_CHECK(cublasZgeru, M, N, cucplx(&alpha), cucplx(x), incx, cucplx(Y), incy, cucplx(A), LDA);
188 void scal(
int M,
double alpha,
double *x,
int incx) { CUBLAS_CHECK(cublasDscal, M, &alpha, x, incx); }
189 void scal(
int M,
dcomplex alpha,
dcomplex *x,
int incx) { CUBLAS_CHECK(cublasZscal, M, cucplx(&alpha), cucplx(x), incx); }
191 void swap(
int N,
double *x,
int incx,
double *Y,
int incy) { CUBLAS_CHECK(cublasDswap, N, x, incx, Y, incy); }
193 CUBLAS_CHECK(cublasZswap, N, cucplx(x), incx, cucplx(Y), incy);
void swap(nda::basic_array_view< V1, R1, LP1, A1, AP1, OP1 > &a, nda::basic_array_view< V2, R2, LP2, A2, AP2, OP2 > &b)=delete
std::swap is deleted for nda::basic_array_view.
Provides a C++ interface for the GPU versions of various BLAS routines.
Provides GPU and non-GPU specific functionality.
Provides a custom runtime error class and macros to assert conditions and throw exceptions.
void gemm_batch_strided(get_value_t< A > alpha, A const &a, B const &b, get_value_t< A > beta, C &&c)
Implements a strided batched version of nda::blas::gemm taking 3-dimensional arrays as arguments.
std::complex< double > dcomplex
Alias for std::complex<double> type.
void gemm_vbatch(get_value_t< A > alpha, std::vector< A > const &va, std::vector< B > const &vb, get_value_t< A > beta, std::vector< C > &vc)
Wrapper of nda::blas::gemm_batch that allows variable sized matrices.
auto dotc(X const &x, Y const &y)
Interface to the BLAS dotc routine.
void gemv(get_value_t< A > alpha, A const &a, X const &x, get_value_t< A > beta, Y &&y)
Interface to the BLAS gemv routine.
auto dot(X const &x, Y const &y)
Interface to the BLAS dot routine.
void gemm(get_value_t< A > alpha, A const &a, B const &b, get_value_t< A > beta, C &&c)
Interface to the BLAS gemm routine.
void ger(get_value_t< X > alpha, X const &x, Y const &y, M &&m)
Interface to the BLAS ger routine.
void gemm_batch(get_value_t< A > alpha, std::vector< A > const &va, std::vector< B > const &vb, get_value_t< A > beta, std::vector< C > &vc)
Implements a batched version of nda::blas::gemm taking vectors of matrices as arguments.