18
19
20
22#include "./cublas_interface.hpp"
23#include "../tools.hpp"
24#include "../../device.hpp"
25#include "../../exceptions.hpp"
31namespace nda::blas::device {
34 inline cublasHandle_t &get_handle() {
35 struct handle_storage_t {
36 handle_storage_t() { cublasCreate(&handle); }
37 ~handle_storage_t() { cublasDestroy(handle); }
38 cublasHandle_t handle = {};
40 static auto sto = handle_storage_t{};
46 constexpr magma_trans_t get_magma_op(
char op) {
48 case 'N':
return MagmaNoTrans;
break;
49 case 'T':
return MagmaTrans;
break;
50 case 'C':
return MagmaConjTrans;
break;
51 default: std::terminate();
return {};
56 auto &get_magma_queue() {
60 magma_getdevice(&device);
61 magma_queue_create(device, &q);
63 ~queue_t() { magma_queue_destroy(q); }
64 operator magma_queue_t() {
return q; }
69 static queue_t q = {};
75 static bool synchronize =
true;
78#define CUBLAS_CHECK(X, ...)
80 auto err = X(get_handle(), __VA_ARGS__);
81 if (err != CUBLAS_STATUS_SUCCESS) {
83 << " cublasGetStatusName: " << cublasGetStatusName(err) << "\n"
84 << " cublasGetStatusString: " << cublasGetStatusString(err) << "\n";
87 auto errsync = cudaDeviceSynchronize();
88 if (errsync != cudaSuccess) {
90 << " cudaGetErrorName: " << cudaGetErrorName(errsync) << "\n"
91 << " cudaGetErrorString: " << cudaGetErrorString(errsync) << "\n";
96 void gemm(
char op_a,
char op_b,
int M,
int N,
int K,
double alpha,
const double *A,
int LDA,
const double *B,
int LDB,
double beta,
double *C,
98 CUBLAS_CHECK(cublasDgemm, get_cublas_op(op_a), get_cublas_op(op_b), M, N, K, &alpha, A, LDA, B, LDB, &beta, C, LDC);
100 void gemm(
char op_a,
char op_b,
int M,
int N,
int K, dcomplex alpha,
const dcomplex *A,
int LDA,
const dcomplex *B,
int LDB, dcomplex beta,
101 dcomplex *C,
int LDC) {
102 auto alpha_cu = cucplx(alpha);
103 auto beta_cu = cucplx(beta);
104 CUBLAS_CHECK(cublasZgemm, get_cublas_op(op_a), get_cublas_op(op_b), M, N, K, &alpha_cu, cucplx(A), LDA, cucplx(B), LDB, &beta_cu, cucplx(C), LDC);
107 void gemm_batch(
char op_a,
char op_b,
int M,
int N,
int K,
double alpha,
const double **A,
int LDA,
const double **B,
int LDB,
double beta,
108 double **C,
int LDC,
int batch_count) {
109 CUBLAS_CHECK(cublasDgemmBatched, get_cublas_op(op_a), get_cublas_op(op_b), M, N, K, &alpha, A, LDA, B, LDB, &beta, C, LDC, batch_count);
111 void gemm_batch(
char op_a,
char op_b,
int M,
int N,
int K, dcomplex alpha,
const dcomplex **A,
int LDA,
const dcomplex **B,
int LDB, dcomplex beta,
112 dcomplex **C,
int LDC,
int batch_count) {
113 auto alpha_cu = cucplx(alpha);
114 auto beta_cu = cucplx(beta);
115 CUBLAS_CHECK(cublasZgemmBatched, get_cublas_op(op_a), get_cublas_op(op_b), M, N, K, &alpha_cu, cucplx(A), LDA, cucplx(B), LDB, &beta_cu,
116 cucplx(C), LDC, batch_count);
120 void gemm_vbatch(
char op_a,
char op_b,
int *M,
int *N,
int *K,
double alpha,
const double **A,
int *LDA,
const double **B,
int *LDB,
double beta,
121 double **C,
int *LDC,
int batch_count) {
122 magmablas_dgemm_vbatched(get_magma_op(op_a), get_magma_op(op_b), M, N, K, alpha, A, LDA, B, LDB, beta, C, LDC, batch_count, get_magma_queue());
123 if (synchronize) magma_queue_sync(get_magma_queue());
124 if (synchronize) cudaDeviceSynchronize();
126 void gemm_vbatch(
char op_a,
char op_b,
int *M,
int *N,
int *K, dcomplex alpha,
const dcomplex **A,
int *LDA,
const dcomplex **B,
int *LDB,
127 dcomplex beta, dcomplex **C,
int *LDC,
int batch_count) {
128 auto alpha_cu = cucplx(alpha);
129 auto beta_cu = cucplx(beta);
130 magmablas_zgemm_vbatched(get_magma_op(op_a), get_magma_op(op_b), M, N, K, alpha_cu, cucplx(A), LDA, cucplx(B), LDB, beta_cu, cucplx(C), LDC,
131 batch_count, get_magma_queue());
132 if (synchronize) magma_queue_sync(get_magma_queue());
133 if (synchronize) cudaDeviceSynchronize();
137 void gemm_batch_strided(
char op_a,
char op_b,
int M,
int N,
int K,
double alpha,
const double *A,
int LDA,
int strideA,
const double *B,
int LDB,
138 int strideB,
double beta,
double *C,
int LDC,
int strideC,
int batch_count) {
139 CUBLAS_CHECK(cublasDgemmStridedBatched, get_cublas_op(op_a), get_cublas_op(op_b), M, N, K, &alpha, A, LDA, strideA, B, LDB, strideB, &beta, C,
140 LDC, strideC, batch_count);
142 void gemm_batch_strided(
char op_a,
char op_b,
int M,
int N,
int K, dcomplex alpha,
const dcomplex *A,
int LDA,
int strideA,
const dcomplex *B,
143 int LDB,
int strideB, dcomplex beta, dcomplex *C,
int LDC,
int strideC,
int batch_count) {
144 auto alpha_cu = cucplx(alpha);
145 auto beta_cu = cucplx(beta);
146 CUBLAS_CHECK(cublasZgemmStridedBatched, get_cublas_op(op_a), get_cublas_op(op_b), M, N, K, &alpha_cu, cucplx(A), LDA, strideA, cucplx(B), LDB,
147 strideB, &beta_cu, cucplx(C), LDC, strideC, batch_count);
150 void axpy(
int N,
double alpha,
const double *x,
int incx,
double *Y,
int incy) { cublasDaxpy(get_handle(), N, &alpha, x, incx, Y, incy); }
151 void axpy(
int N, dcomplex alpha,
const dcomplex *x,
int incx, dcomplex *Y,
int incy) {
152 CUBLAS_CHECK(cublasZaxpy, N, cucplx(&alpha), cucplx(x), incx, cucplx(Y), incy);
155 void copy(
int N,
const double *x,
int incx,
double *Y,
int incy) { cublasDcopy(get_handle(), N, x, incx, Y, incy); }
156 void copy(
int N,
const dcomplex *x,
int incx, dcomplex *Y,
int incy) {
CUBLAS_CHECK(cublasZcopy, N, cucplx(x), incx, cucplx(Y), incy); }
158 double dot(
int M,
const double *x,
int incx,
const double *Y,
int incy) {
163 dcomplex dot(
int M,
const dcomplex *x,
int incx,
const dcomplex *Y,
int incy) {
165 CUBLAS_CHECK(cublasZdotu, M, cucplx(x), incx, cucplx(Y), incy, &res);
166 return {res.x, res.y};
168 dcomplex dotc(
int M,
const dcomplex *x,
int incx,
const dcomplex *Y,
int incy) {
170 CUBLAS_CHECK(cublasZdotc, M, cucplx(x), incx, cucplx(Y), incy, &res);
171 return {res.x, res.y};
174 void gemv(
char op,
int M,
int N,
double alpha,
const double *A,
int LDA,
const double *x,
int incx,
double beta,
double *Y,
int incy) {
175 CUBLAS_CHECK(cublasDgemv, get_cublas_op(op), M, N, &alpha, A, LDA, x, incx, &beta, Y, incy);
177 void gemv(
char op,
int M,
int N, dcomplex alpha,
const dcomplex *A,
int LDA,
const dcomplex *x,
int incx, dcomplex beta, dcomplex *Y,
int incy) {
178 CUBLAS_CHECK(cublasZgemv, get_cublas_op(op), M, N, cucplx(&alpha), cucplx(A), LDA, cucplx(x), incx, cucplx(&beta), cucplx(Y), incy);
181 void ger(
int M,
int N,
double alpha,
const double *x,
int incx,
const double *Y,
int incy,
double *A,
int LDA) {
182 CUBLAS_CHECK(cublasDger, M, N, &alpha, x, incx, Y, incy, A, LDA);
184 void ger(
int M,
int N, dcomplex alpha,
const dcomplex *x,
int incx,
const dcomplex *Y,
int incy, dcomplex *A,
int LDA) {
185 CUBLAS_CHECK(cublasZgeru, M, N, cucplx(&alpha), cucplx(x), incx, cucplx(Y), incy, cucplx(A), LDA);
188 void scal(
int M,
double alpha,
double *x,
int incx) {
CUBLAS_CHECK(cublasDscal, M, &alpha, x, incx); }
189 void scal(
int M, dcomplex alpha, dcomplex *x,
int incx) {
CUBLAS_CHECK(cublasZscal, M, cucplx(&alpha), cucplx(x), incx); }
191 void swap(
int N,
double *x,
int incx,
double *Y,
int incy) {
CUBLAS_CHECK(cublasDswap, N, x, incx, Y, incy); }
192 void swap(
int N, dcomplex *x,
int incx, dcomplex *Y,
int incy) {
193 CUBLAS_CHECK(cublasZswap, N, cucplx(x), incx, cucplx(Y), incy);
#define CUBLAS_CHECK(X,...)
#define NDA_RUNTIME_ERROR