TRIQS/nda 1.3.0
Multi-dimensional array library for C++
Loading...
Searching...
No Matches
cublas_interface.cpp
Go to the documentation of this file.
1// Copyright (c) 2022-2024 Simons Foundation
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7// http://www.apache.org/licenses/LICENSE-2.0.txt
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14//
15// Authors: Miguel Morales, Nils Wentzell
16
17/**
18 * @file
19 * @brief Implementation details for blas/interface/cublas_interface.hpp.
20 */
21
22#include "./cublas_interface.hpp"
23#include "../tools.hpp"
24#include "../../device.hpp"
25#include "../../exceptions.hpp"
26
27#ifdef NDA_HAVE_MAGMA
28#include "magma_v2.h"
29#endif
30
31namespace nda::blas::device {
32
33 // Local function to get unique CuBlas handle.
34 inline cublasHandle_t &get_handle() {
35 struct handle_storage_t { // RAII for the handle
36 handle_storage_t() { cublasCreate(&handle); }
37 ~handle_storage_t() { cublasDestroy(handle); }
38 cublasHandle_t handle = {};
39 };
40 static auto sto = handle_storage_t{};
41 return sto.handle;
42 }
43
44#ifdef NDA_HAVE_MAGMA
45 // Local function to get Magma op.
46 constexpr magma_trans_t get_magma_op(char op) {
47 switch (op) {
48 case 'N': return MagmaNoTrans; break;
49 case 'T': return MagmaTrans; break;
50 case 'C': return MagmaConjTrans; break;
51 default: std::terminate(); return {};
52 }
53 }
54
55 // Local function to get Magma queue.
56 auto &get_magma_queue() {
57 struct queue_t {
58 queue_t() {
59 int device{};
60 magma_getdevice(&device);
61 magma_queue_create(device, &q);
62 }
63 ~queue_t() { magma_queue_destroy(q); }
64 operator magma_queue_t() { return q; }
65
66 private:
67 magma_queue_t q = {};
68 };
69 static queue_t q = {};
70 return q;
71 }
72#endif
73
74 // Global option to turn on/off the cudaDeviceSynchronize after cublas library calls.
75 static bool synchronize = true; // NOLINT (global option is on purpose)
76
77// Macro to check cublas calls.
78#define CUBLAS_CHECK(X, ...)
79 {
80 auto err = X(get_handle(), __VA_ARGS__);
81 if (err != CUBLAS_STATUS_SUCCESS) {
82 NDA_RUNTIME_ERROR << AS_STRING(X) << " failed \n"
83 << " cublasGetStatusName: " << cublasGetStatusName(err) << "\n"
84 << " cublasGetStatusString: " << cublasGetStatusString(err) << "\n";
85 }
86 if (synchronize) {
87 auto errsync = cudaDeviceSynchronize();
88 if (errsync != cudaSuccess) {
89 NDA_RUNTIME_ERROR << " cudaDeviceSynchronize failed after call to: " << AS_STRING(X) << "\n"
90 << " cudaGetErrorName: " << cudaGetErrorName(errsync) << "\n"
91 << " cudaGetErrorString: " << cudaGetErrorString(errsync) << "\n";
92 }
93 }
94 }
95
96 void gemm(char op_a, char op_b, int M, int N, int K, double alpha, const double *A, int LDA, const double *B, int LDB, double beta, double *C,
97 int LDC) {
98 CUBLAS_CHECK(cublasDgemm, get_cublas_op(op_a), get_cublas_op(op_b), M, N, K, &alpha, A, LDA, B, LDB, &beta, C, LDC);
99 }
100 void gemm(char op_a, char op_b, int M, int N, int K, dcomplex alpha, const dcomplex *A, int LDA, const dcomplex *B, int LDB, dcomplex beta,
101 dcomplex *C, int LDC) {
102 auto alpha_cu = cucplx(alpha);
103 auto beta_cu = cucplx(beta);
104 CUBLAS_CHECK(cublasZgemm, get_cublas_op(op_a), get_cublas_op(op_b), M, N, K, &alpha_cu, cucplx(A), LDA, cucplx(B), LDB, &beta_cu, cucplx(C), LDC);
105 }
106
107 void gemm_batch(char op_a, char op_b, int M, int N, int K, double alpha, const double **A, int LDA, const double **B, int LDB, double beta,
108 double **C, int LDC, int batch_count) {
109 CUBLAS_CHECK(cublasDgemmBatched, get_cublas_op(op_a), get_cublas_op(op_b), M, N, K, &alpha, A, LDA, B, LDB, &beta, C, LDC, batch_count);
110 }
111 void gemm_batch(char op_a, char op_b, int M, int N, int K, dcomplex alpha, const dcomplex **A, int LDA, const dcomplex **B, int LDB, dcomplex beta,
112 dcomplex **C, int LDC, int batch_count) {
113 auto alpha_cu = cucplx(alpha);
114 auto beta_cu = cucplx(beta);
115 CUBLAS_CHECK(cublasZgemmBatched, get_cublas_op(op_a), get_cublas_op(op_b), M, N, K, &alpha_cu, cucplx(A), LDA, cucplx(B), LDB, &beta_cu,
116 cucplx(C), LDC, batch_count);
117 }
118
119#ifdef NDA_HAVE_MAGMA
120 void gemm_vbatch(char op_a, char op_b, int *M, int *N, int *K, double alpha, const double **A, int *LDA, const double **B, int *LDB, double beta,
121 double **C, int *LDC, int batch_count) {
122 magmablas_dgemm_vbatched(get_magma_op(op_a), get_magma_op(op_b), M, N, K, alpha, A, LDA, B, LDB, beta, C, LDC, batch_count, get_magma_queue());
123 if (synchronize) magma_queue_sync(get_magma_queue());
124 if (synchronize) cudaDeviceSynchronize();
125 }
126 void gemm_vbatch(char op_a, char op_b, int *M, int *N, int *K, dcomplex alpha, const dcomplex **A, int *LDA, const dcomplex **B, int *LDB,
127 dcomplex beta, dcomplex **C, int *LDC, int batch_count) {
128 auto alpha_cu = cucplx(alpha);
129 auto beta_cu = cucplx(beta);
130 magmablas_zgemm_vbatched(get_magma_op(op_a), get_magma_op(op_b), M, N, K, alpha_cu, cucplx(A), LDA, cucplx(B), LDB, beta_cu, cucplx(C), LDC,
131 batch_count, get_magma_queue());
132 if (synchronize) magma_queue_sync(get_magma_queue());
133 if (synchronize) cudaDeviceSynchronize();
134 }
135#endif
136
137 void gemm_batch_strided(char op_a, char op_b, int M, int N, int K, double alpha, const double *A, int LDA, int strideA, const double *B, int LDB,
138 int strideB, double beta, double *C, int LDC, int strideC, int batch_count) {
139 CUBLAS_CHECK(cublasDgemmStridedBatched, get_cublas_op(op_a), get_cublas_op(op_b), M, N, K, &alpha, A, LDA, strideA, B, LDB, strideB, &beta, C,
140 LDC, strideC, batch_count);
141 }
142 void gemm_batch_strided(char op_a, char op_b, int M, int N, int K, dcomplex alpha, const dcomplex *A, int LDA, int strideA, const dcomplex *B,
143 int LDB, int strideB, dcomplex beta, dcomplex *C, int LDC, int strideC, int batch_count) {
144 auto alpha_cu = cucplx(alpha);
145 auto beta_cu = cucplx(beta);
146 CUBLAS_CHECK(cublasZgemmStridedBatched, get_cublas_op(op_a), get_cublas_op(op_b), M, N, K, &alpha_cu, cucplx(A), LDA, strideA, cucplx(B), LDB,
147 strideB, &beta_cu, cucplx(C), LDC, strideC, batch_count);
148 }
149
150 void axpy(int N, double alpha, const double *x, int incx, double *Y, int incy) { cublasDaxpy(get_handle(), N, &alpha, x, incx, Y, incy); }
151 void axpy(int N, dcomplex alpha, const dcomplex *x, int incx, dcomplex *Y, int incy) {
152 CUBLAS_CHECK(cublasZaxpy, N, cucplx(&alpha), cucplx(x), incx, cucplx(Y), incy);
153 }
154
155 void copy(int N, const double *x, int incx, double *Y, int incy) { cublasDcopy(get_handle(), N, x, incx, Y, incy); }
156 void copy(int N, const dcomplex *x, int incx, dcomplex *Y, int incy) { CUBLAS_CHECK(cublasZcopy, N, cucplx(x), incx, cucplx(Y), incy); }
157
158 double dot(int M, const double *x, int incx, const double *Y, int incy) {
159 double res{};
160 CUBLAS_CHECK(cublasDdot, M, x, incx, Y, incy, &res);
161 return res;
162 }
163 dcomplex dot(int M, const dcomplex *x, int incx, const dcomplex *Y, int incy) {
164 cuDoubleComplex res;
165 CUBLAS_CHECK(cublasZdotu, M, cucplx(x), incx, cucplx(Y), incy, &res);
166 return {res.x, res.y};
167 }
168 dcomplex dotc(int M, const dcomplex *x, int incx, const dcomplex *Y, int incy) {
169 cuDoubleComplex res;
170 CUBLAS_CHECK(cublasZdotc, M, cucplx(x), incx, cucplx(Y), incy, &res);
171 return {res.x, res.y};
172 }
173
174 void gemv(char op, int M, int N, double alpha, const double *A, int LDA, const double *x, int incx, double beta, double *Y, int incy) {
175 CUBLAS_CHECK(cublasDgemv, get_cublas_op(op), M, N, &alpha, A, LDA, x, incx, &beta, Y, incy);
176 }
177 void gemv(char op, int M, int N, dcomplex alpha, const dcomplex *A, int LDA, const dcomplex *x, int incx, dcomplex beta, dcomplex *Y, int incy) {
178 CUBLAS_CHECK(cublasZgemv, get_cublas_op(op), M, N, cucplx(&alpha), cucplx(A), LDA, cucplx(x), incx, cucplx(&beta), cucplx(Y), incy);
179 }
180
181 void ger(int M, int N, double alpha, const double *x, int incx, const double *Y, int incy, double *A, int LDA) {
182 CUBLAS_CHECK(cublasDger, M, N, &alpha, x, incx, Y, incy, A, LDA);
183 }
184 void ger(int M, int N, dcomplex alpha, const dcomplex *x, int incx, const dcomplex *Y, int incy, dcomplex *A, int LDA) {
185 CUBLAS_CHECK(cublasZgeru, M, N, cucplx(&alpha), cucplx(x), incx, cucplx(Y), incy, cucplx(A), LDA);
186 }
187
188 void scal(int M, double alpha, double *x, int incx) { CUBLAS_CHECK(cublasDscal, M, &alpha, x, incx); }
189 void scal(int M, dcomplex alpha, dcomplex *x, int incx) { CUBLAS_CHECK(cublasZscal, M, cucplx(&alpha), cucplx(x), incx); }
190
191 void swap(int N, double *x, int incx, double *Y, int incy) { CUBLAS_CHECK(cublasDswap, N, x, incx, Y, incy); } // NOLINT (this is a BLAS swap)
192 void swap(int N, dcomplex *x, int incx, dcomplex *Y, int incy) { // NOLINT (this is a BLAS swap)
193 CUBLAS_CHECK(cublasZswap, N, cucplx(x), incx, cucplx(Y), incy);
194 }
195
196} // namespace nda::blas::device
#define CUBLAS_CHECK(X,...)
#define NDA_RUNTIME_ERROR
#define AS_STRING(...)
Definition macros.hpp:31