TRIQS/nda 1.3.0
Multi-dimensional array library for C++
Loading...
Searching...
No Matches
cublas_interface.cpp
Go to the documentation of this file.
1// Copyright (c) 2022--present, The Simons Foundation
2// This file is part of TRIQS/nda and is licensed under the Apache License, Version 2.0.
3// SPDX-License-Identifier: Apache-2.0
4// See LICENSE in the root of this distribution for details.
5
10
12#include "../tools.hpp"
13#include "../../device.hpp"
14#include "../../exceptions.hpp"
15
16#ifdef NDA_HAVE_MAGMA
17#include "magma_v2.h"
18#endif
19
20namespace nda::blas::device {
21
22 // Local function to get unique CuBlas handle.
23 inline cublasHandle_t &get_handle() {
24 struct handle_storage_t { // RAII for the handle
25 handle_storage_t() { cublasCreate(&handle); }
26 ~handle_storage_t() { cublasDestroy(handle); }
27 cublasHandle_t handle = {};
28 };
29 static auto sto = handle_storage_t{};
30 return sto.handle;
31 }
32
33#ifdef NDA_HAVE_MAGMA
34 // Local function to get Magma op.
35 constexpr magma_trans_t get_magma_op(char op) {
36 switch (op) {
37 case 'N': return MagmaNoTrans; break;
38 case 'T': return MagmaTrans; break;
39 case 'C': return MagmaConjTrans; break;
40 default: std::terminate(); return {};
41 }
42 }
43
44 // Local function to get Magma queue.
45 auto &get_magma_queue() {
46 struct queue_t {
47 queue_t() {
48 int device{};
49 magma_getdevice(&device);
50 magma_queue_create(device, &q);
51 }
52 ~queue_t() { magma_queue_destroy(q); }
53 operator magma_queue_t() { return q; }
54
55 private:
56 magma_queue_t q = {};
57 };
58 static queue_t q = {};
59 return q;
60 }
61#endif
62
63 // Global option to turn on/off the cudaDeviceSynchronize after cublas library calls.
64 static bool synchronize = true; // NOLINT (global option is on purpose)
65
66// Macro to check cublas calls.
67#define CUBLAS_CHECK(X, ...) \
68 { \
69 auto err = X(get_handle(), __VA_ARGS__); \
70 if (err != CUBLAS_STATUS_SUCCESS) { \
71 NDA_RUNTIME_ERROR << AS_STRING(X) << " failed \n" \
72 << " cublasGetStatusName: " << cublasGetStatusName(err) << "\n" \
73 << " cublasGetStatusString: " << cublasGetStatusString(err) << "\n"; \
74 } \
75 if (synchronize) { \
76 auto errsync = cudaDeviceSynchronize(); \
77 if (errsync != cudaSuccess) { \
78 NDA_RUNTIME_ERROR << " cudaDeviceSynchronize failed after call to: " << AS_STRING(X) << "\n" \
79 << " cudaGetErrorName: " << cudaGetErrorName(errsync) << "\n" \
80 << " cudaGetErrorString: " << cudaGetErrorString(errsync) << "\n"; \
81 } \
82 } \
83 }
84
85 void gemm(char op_a, char op_b, int M, int N, int K, double alpha, const double *A, int LDA, const double *B, int LDB, double beta, double *C,
86 int LDC) {
87 CUBLAS_CHECK(cublasDgemm, get_cublas_op(op_a), get_cublas_op(op_b), M, N, K, &alpha, A, LDA, B, LDB, &beta, C, LDC);
88 }
89 void gemm(char op_a, char op_b, int M, int N, int K, dcomplex alpha, const dcomplex *A, int LDA, const dcomplex *B, int LDB, dcomplex beta,
90 dcomplex *C, int LDC) {
91 auto alpha_cu = cucplx(alpha);
92 auto beta_cu = cucplx(beta);
93 CUBLAS_CHECK(cublasZgemm, get_cublas_op(op_a), get_cublas_op(op_b), M, N, K, &alpha_cu, cucplx(A), LDA, cucplx(B), LDB, &beta_cu, cucplx(C), LDC);
94 }
95
96 void gemm_batch(char op_a, char op_b, int M, int N, int K, double alpha, const double **A, int LDA, const double **B, int LDB, double beta,
97 double **C, int LDC, int batch_count) {
98 CUBLAS_CHECK(cublasDgemmBatched, get_cublas_op(op_a), get_cublas_op(op_b), M, N, K, &alpha, A, LDA, B, LDB, &beta, C, LDC, batch_count);
99 }
100 void gemm_batch(char op_a, char op_b, int M, int N, int K, dcomplex alpha, const dcomplex **A, int LDA, const dcomplex **B, int LDB, dcomplex beta,
101 dcomplex **C, int LDC, int batch_count) {
102 auto alpha_cu = cucplx(alpha);
103 auto beta_cu = cucplx(beta);
104 CUBLAS_CHECK(cublasZgemmBatched, get_cublas_op(op_a), get_cublas_op(op_b), M, N, K, &alpha_cu, cucplx(A), LDA, cucplx(B), LDB, &beta_cu,
105 cucplx(C), LDC, batch_count);
106 }
107
108#ifdef NDA_HAVE_MAGMA
109 void gemm_vbatch(char op_a, char op_b, int *M, int *N, int *K, double alpha, const double **A, int *LDA, const double **B, int *LDB, double beta,
110 double **C, int *LDC, int batch_count) {
111 magmablas_dgemm_vbatched(get_magma_op(op_a), get_magma_op(op_b), M, N, K, alpha, A, LDA, B, LDB, beta, C, LDC, batch_count, get_magma_queue());
112 if (synchronize) magma_queue_sync(get_magma_queue());
113 if (synchronize) cudaDeviceSynchronize();
114 }
115 void gemm_vbatch(char op_a, char op_b, int *M, int *N, int *K, dcomplex alpha, const dcomplex **A, int *LDA, const dcomplex **B, int *LDB,
116 dcomplex beta, dcomplex **C, int *LDC, int batch_count) {
117 auto alpha_cu = cucplx(alpha);
118 auto beta_cu = cucplx(beta);
119 magmablas_zgemm_vbatched(get_magma_op(op_a), get_magma_op(op_b), M, N, K, alpha_cu, cucplx(A), LDA, cucplx(B), LDB, beta_cu, cucplx(C), LDC,
120 batch_count, get_magma_queue());
121 if (synchronize) magma_queue_sync(get_magma_queue());
122 if (synchronize) cudaDeviceSynchronize();
123 }
124#endif
125
126 void gemm_batch_strided(char op_a, char op_b, int M, int N, int K, double alpha, const double *A, int LDA, int strideA, const double *B, int LDB,
127 int strideB, double beta, double *C, int LDC, int strideC, int batch_count) {
128 CUBLAS_CHECK(cublasDgemmStridedBatched, get_cublas_op(op_a), get_cublas_op(op_b), M, N, K, &alpha, A, LDA, strideA, B, LDB, strideB, &beta, C,
129 LDC, strideC, batch_count);
130 }
131 void gemm_batch_strided(char op_a, char op_b, int M, int N, int K, dcomplex alpha, const dcomplex *A, int LDA, int strideA, const dcomplex *B,
132 int LDB, int strideB, dcomplex beta, dcomplex *C, int LDC, int strideC, int batch_count) {
133 auto alpha_cu = cucplx(alpha);
134 auto beta_cu = cucplx(beta);
135 CUBLAS_CHECK(cublasZgemmStridedBatched, get_cublas_op(op_a), get_cublas_op(op_b), M, N, K, &alpha_cu, cucplx(A), LDA, strideA, cucplx(B), LDB,
136 strideB, &beta_cu, cucplx(C), LDC, strideC, batch_count);
137 }
138
139 void axpy(int N, double alpha, const double *x, int incx, double *Y, int incy) { cublasDaxpy(get_handle(), N, &alpha, x, incx, Y, incy); }
140 void axpy(int N, dcomplex alpha, const dcomplex *x, int incx, dcomplex *Y, int incy) {
141 CUBLAS_CHECK(cublasZaxpy, N, cucplx(&alpha), cucplx(x), incx, cucplx(Y), incy);
142 }
143
144 void copy(int N, const double *x, int incx, double *Y, int incy) { cublasDcopy(get_handle(), N, x, incx, Y, incy); }
145 void copy(int N, const dcomplex *x, int incx, dcomplex *Y, int incy) { CUBLAS_CHECK(cublasZcopy, N, cucplx(x), incx, cucplx(Y), incy); }
146
147 double dot(int M, const double *x, int incx, const double *Y, int incy) {
148 double res{};
149 CUBLAS_CHECK(cublasDdot, M, x, incx, Y, incy, &res);
150 return res;
151 }
152 dcomplex dot(int M, const dcomplex *x, int incx, const dcomplex *Y, int incy) {
153 cuDoubleComplex res;
154 CUBLAS_CHECK(cublasZdotu, M, cucplx(x), incx, cucplx(Y), incy, &res);
155 return {res.x, res.y};
156 }
157 dcomplex dotc(int M, const dcomplex *x, int incx, const dcomplex *Y, int incy) {
158 cuDoubleComplex res;
159 CUBLAS_CHECK(cublasZdotc, M, cucplx(x), incx, cucplx(Y), incy, &res);
160 return {res.x, res.y};
161 }
162
163 void gemv(char op, int M, int N, double alpha, const double *A, int LDA, const double *x, int incx, double beta, double *Y, int incy) {
164 CUBLAS_CHECK(cublasDgemv, get_cublas_op(op), M, N, &alpha, A, LDA, x, incx, &beta, Y, incy);
165 }
166 void gemv(char op, int M, int N, dcomplex alpha, const dcomplex *A, int LDA, const dcomplex *x, int incx, dcomplex beta, dcomplex *Y, int incy) {
167 CUBLAS_CHECK(cublasZgemv, get_cublas_op(op), M, N, cucplx(&alpha), cucplx(A), LDA, cucplx(x), incx, cucplx(&beta), cucplx(Y), incy);
168 }
169
170 void ger(int M, int N, double alpha, const double *x, int incx, const double *Y, int incy, double *A, int LDA) {
171 CUBLAS_CHECK(cublasDger, M, N, &alpha, x, incx, Y, incy, A, LDA);
172 }
173 void ger(int M, int N, dcomplex alpha, const dcomplex *x, int incx, const dcomplex *Y, int incy, dcomplex *A, int LDA) {
174 CUBLAS_CHECK(cublasZgeru, M, N, cucplx(&alpha), cucplx(x), incx, cucplx(Y), incy, cucplx(A), LDA);
175 }
176
177 void scal(int M, double alpha, double *x, int incx) { CUBLAS_CHECK(cublasDscal, M, &alpha, x, incx); }
178 void scal(int M, dcomplex alpha, dcomplex *x, int incx) { CUBLAS_CHECK(cublasZscal, M, cucplx(&alpha), cucplx(x), incx); }
179
180 void swap(int N, double *x, int incx, double *Y, int incy) { CUBLAS_CHECK(cublasDswap, N, x, incx, Y, incy); } // NOLINT (this is a BLAS swap)
181 void swap(int N, dcomplex *x, int incx, dcomplex *Y, int incy) { // NOLINT (this is a BLAS swap)
182 CUBLAS_CHECK(cublasZswap, N, cucplx(x), incx, cucplx(Y), incy);
183 }
184
185} // namespace nda::blas::device
Provides a C++ interface for the GPU versions of various BLAS routines.
Provides GPU and non-GPU specific functionality.
Provides a custom runtime error class and macros to assert conditions and throw exceptions.
std::complex< double > dcomplex
Alias for std::complex<double> type.
Definition tools.hpp:28
Provides various traits and utilities for the BLAS interface.