TRIQS/nda 1.3.0
Multi-dimensional array library for C++
Loading...
Searching...
No Matches
cublas_interface.cpp
Go to the documentation of this file.
1// Copyright (c) 2022--present, The Simons Foundation
2// This file is part of TRIQS/nda and is licensed under the Apache License, Version 2.0.
3// SPDX-License-Identifier: Apache-2.0
4// See LICENSE in the root of this distribution for details.
5
10
12#include "../tools.hpp"
13#include "../../device.hpp"
14#include "../../exceptions.hpp"
15
16#ifdef NDA_HAVE_MAGMA
17#include "magma_v2.h"
18#endif
19
20namespace nda::blas::device {
21
22 // Local function to get unique CuBlas handle.
23 inline cublasHandle_t &get_handle() {
24 struct handle_storage_t { // RAII for the handle
25 handle_storage_t() { cublasCreate(&handle); }
26 ~handle_storage_t() { cublasDestroy(handle); }
27 cublasHandle_t handle = {};
28 };
29 static auto sto = handle_storage_t{};
30 return sto.handle;
31 }
32
33#ifdef NDA_HAVE_MAGMA
34 // Local function to get Magma op.
35 constexpr magma_trans_t get_magma_op(char op) {
36 switch (op) {
37 case 'N': return MagmaNoTrans; break;
38 case 'T': return MagmaTrans; break;
39 case 'C': return MagmaConjTrans; break;
40 default: std::terminate(); return {};
41 }
42 }
43
44 // Local function to get Magma queue.
45 auto &get_magma_queue() {
46 struct queue_t {
47 queue_t() {
48 int device{};
49 magma_getdevice(&device);
50 magma_queue_create(device, &q);
51 }
52 ~queue_t() { magma_queue_destroy(q); }
53 operator magma_queue_t() { return q; }
54
55 private:
56 magma_queue_t q = {};
57 };
58 static queue_t q = {};
59 return q;
60 }
61#endif
62
63 // Global option to turn on/off the cudaDeviceSynchronize after cublas library calls.
64 static bool synchronize = true; // NOLINT (global option is on purpose)
65
66// Macro to check cublas calls.
67#define CUBLAS_CHECK(X, ...) \
68 { \
69 auto err = X(get_handle(), __VA_ARGS__); \
70 if (err != CUBLAS_STATUS_SUCCESS) { \
71 NDA_RUNTIME_ERROR << AS_STRING(X) << " failed \n" \
72 << " cublasGetStatusName: " << cublasGetStatusName(err) << "\n" \
73 << " cublasGetStatusString: " << cublasGetStatusString(err) << "\n"; \
74 } \
75 if (synchronize) { \
76 auto errsync = cudaDeviceSynchronize(); \
77 if (errsync != cudaSuccess) { \
78 NDA_RUNTIME_ERROR << " cudaDeviceSynchronize failed after call to: " << AS_STRING(X) << "\n" \
79 << " cudaGetErrorName: " << cudaGetErrorName(errsync) << "\n" \
80 << " cudaGetErrorString: " << cudaGetErrorString(errsync) << "\n"; \
81 } \
82 } \
83 }
84
85 void gemm(char op_a, char op_b, int m, int n, int k, double alpha, const double *a, int lda, const double *b, int ldb, double beta, double *c,
86 int ldc) {
87 CUBLAS_CHECK(cublasDgemm, get_cublas_op(op_a), get_cublas_op(op_b), m, n, k, &alpha, a, lda, b, ldb, &beta, c, ldc);
88 }
89 void gemm(char op_a, char op_b, int m, int n, int k, dcomplex alpha, const dcomplex *a, int lda, const dcomplex *b, int ldb, dcomplex beta,
90 dcomplex *c, int ldc) {
91 auto alpha_cu = cucplx(alpha);
92 auto beta_cu = cucplx(beta);
93 CUBLAS_CHECK(cublasZgemm, get_cublas_op(op_a), get_cublas_op(op_b), m, n, k, &alpha_cu, cucplx(a), lda, cucplx(b), ldb, &beta_cu, cucplx(c), ldc);
94 }
95
96 void gemm_batch(char op_a, char op_b, int m, int n, int k, double alpha, const double **a, int lda, const double **b, int ldb, double beta,
97 double **c, int ldc, int batch_count) {
98 CUBLAS_CHECK(cublasDgemmBatched, get_cublas_op(op_a), get_cublas_op(op_b), m, n, k, &alpha, a, lda, b, ldb, &beta, c, ldc, batch_count);
99 }
100 void gemm_batch(char op_a, char op_b, int m, int n, int k, dcomplex alpha, const dcomplex **a, int lda, const dcomplex **b, int ldb, dcomplex beta,
101 dcomplex **c, int ldc, int batch_count) {
102 auto alpha_cu = cucplx(alpha);
103 auto beta_cu = cucplx(beta);
104 CUBLAS_CHECK(cublasZgemmBatched, get_cublas_op(op_a), get_cublas_op(op_b), m, n, k, &alpha_cu, cucplx(a), lda, cucplx(b), ldb, &beta_cu,
105 cucplx(c), ldc, batch_count);
106 }
107
108#ifdef NDA_HAVE_MAGMA
109 void gemm_vbatch(char op_a, char op_b, int *m, int *n, int *k, double alpha, const double **a, int *lda, const double **b, int *ldb, double beta,
110 double **c, int *ldc, int batch_count) {
111 magmablas_dgemm_vbatched(get_magma_op(op_a), get_magma_op(op_b), m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, batch_count, get_magma_queue());
112 if (synchronize) magma_queue_sync(get_magma_queue());
113 if (synchronize) cudaDeviceSynchronize();
114 }
115 void gemm_vbatch(char op_a, char op_b, int *m, int *n, int *k, dcomplex alpha, const dcomplex **a, int *lda, const dcomplex **b, int *ldb,
116 dcomplex beta, dcomplex **c, int *ldc, int batch_count) {
117 auto alpha_cu = cucplx(alpha);
118 auto beta_cu = cucplx(beta);
119 magmablas_zgemm_vbatched(get_magma_op(op_a), get_magma_op(op_b), m, n, k, alpha_cu, cucplx(a), lda, cucplx(b), ldb, beta_cu, cucplx(c), ldc,
120 batch_count, get_magma_queue());
121 if (synchronize) magma_queue_sync(get_magma_queue());
122 if (synchronize) cudaDeviceSynchronize();
123 }
124#endif
125
126 void gemm_batch_strided(char op_a, char op_b, int m, int n, int k, double alpha, const double *a, int lda, int stride_a, const double *b, int ldb,
127 int stride_b, double beta, double *c, int ldc, int stride_c, int batch_count) {
128 CUBLAS_CHECK(cublasDgemmStridedBatched, get_cublas_op(op_a), get_cublas_op(op_b), m, n, k, &alpha, a, lda, stride_a, b, ldb, stride_b, &beta, c,
129 ldc, stride_c, batch_count);
130 }
131 void gemm_batch_strided(char op_a, char op_b, int m, int n, int k, dcomplex alpha, const dcomplex *a, int lda, int stride_a, const dcomplex *b,
132 int ldb, int stride_b, dcomplex beta, dcomplex *c, int ldc, int stride_c, int batch_count) {
133 auto alpha_cu = cucplx(alpha);
134 auto beta_cu = cucplx(beta);
135 CUBLAS_CHECK(cublasZgemmStridedBatched, get_cublas_op(op_a), get_cublas_op(op_b), m, n, k, &alpha_cu, cucplx(a), lda, stride_a, cucplx(b), ldb,
136 stride_b, &beta_cu, cucplx(c), ldc, stride_c, batch_count);
137 }
138
139 void axpy(int n, double alpha, const double *x, int incx, double *y, int incy) { cublasDaxpy(get_handle(), n, &alpha, x, incx, y, incy); }
140 void axpy(int n, dcomplex alpha, const dcomplex *x, int incx, dcomplex *y, int incy) {
141 auto alpha_cu = cucplx(alpha);
142 CUBLAS_CHECK(cublasZaxpy, n, &alpha_cu, cucplx(x), incx, cucplx(y), incy);
143 }
144
145 void copy(int n, const double *x, int incx, double *y, int incy) { cublasDcopy(get_handle(), n, x, incx, y, incy); }
146 void copy(int n, const dcomplex *x, int incx, dcomplex *y, int incy) { CUBLAS_CHECK(cublasZcopy, n, cucplx(x), incx, cucplx(y), incy); }
147
148 double dot(int m, const double *x, int incx, const double *y, int incy) {
149 double res{};
150 CUBLAS_CHECK(cublasDdot, m, x, incx, y, incy, &res);
151 return res;
152 }
153 dcomplex dot(int m, const dcomplex *x, int incx, const dcomplex *y, int incy) {
154 cuDoubleComplex res;
155 CUBLAS_CHECK(cublasZdotu, m, cucplx(x), incx, cucplx(y), incy, &res);
156 return {res.x, res.y};
157 }
158 dcomplex dotc(int m, const dcomplex *x, int incx, const dcomplex *y, int incy) {
159 cuDoubleComplex res;
160 CUBLAS_CHECK(cublasZdotc, m, cucplx(x), incx, cucplx(y), incy, &res);
161 return {res.x, res.y};
162 }
163
164 void gemv(char op, int m, int n, double alpha, const double *a, int lda, const double *x, int incx, double beta, double *y, int incy) {
165 CUBLAS_CHECK(cublasDgemv, get_cublas_op(op), m, n, &alpha, a, lda, x, incx, &beta, y, incy);
166 }
167 void gemv(char op, int m, int n, dcomplex alpha, const dcomplex *a, int lda, const dcomplex *x, int incx, dcomplex beta, dcomplex *y, int incy) {
168 auto alpha_cu = cucplx(alpha);
169 auto beta_cu = cucplx(beta);
170 CUBLAS_CHECK(cublasZgemv, get_cublas_op(op), m, n, &alpha_cu, cucplx(a), lda, cucplx(x), incx, &beta_cu, cucplx(y), incy);
171 }
172
173 void ger(int m, int n, double alpha, const double *x, int incx, const double *y, int incy, double *a, int lda) {
174 CUBLAS_CHECK(cublasDger, m, n, &alpha, x, incx, y, incy, a, lda);
175 }
176 void ger(int m, int n, dcomplex alpha, const dcomplex *x, int incx, const dcomplex *y, int incy, dcomplex *a, int lda) {
177 auto alpha_cu = cucplx(alpha);
178 CUBLAS_CHECK(cublasZgeru, m, n, &alpha_cu, cucplx(x), incx, cucplx(y), incy, cucplx(a), lda);
179 }
180 void gerc(int m, int n, dcomplex alpha, const dcomplex *x, int incx, const dcomplex *y, int incy, dcomplex *a, int lda) {
181 auto alpha_cu = cucplx(alpha);
182 CUBLAS_CHECK(cublasZgerc, m, n, &alpha_cu, cucplx(x), incx, cucplx(y), incy, cucplx(a), lda);
183 }
184
185 void scal(int m, double alpha, double *x, int incx) { CUBLAS_CHECK(cublasDscal, m, &alpha, x, incx); }
186 void scal(int m, dcomplex alpha, dcomplex *x, int incx) {
187 auto alpha_cu = cucplx(alpha);
188 CUBLAS_CHECK(cublasZscal, m, &alpha_cu, cucplx(x), incx);
189 }
190
191 void swap(int n, double *x, int incx, double *y, int incy) { CUBLAS_CHECK(cublasDswap, n, x, incx, y, incy); } // NOLINT (this is a BLAS swap)
192 void swap(int n, dcomplex *x, int incx, dcomplex *y, int incy) { // NOLINT (this is a BLAS swap)
193 CUBLAS_CHECK(cublasZswap, n, cucplx(x), incx, cucplx(y), incy);
194 }
195
196} // namespace nda::blas::device
Provides a C++ interface for the GPU versions of various BLAS routines.
Provides GPU and non-GPU specific functionality.
Provides a custom runtime error class and macros to assert conditions and throw exceptions.
std::complex< double > dcomplex
Alias for std::complex<double> type.
Definition tools.hpp:28
Provides various traits and utilities for the BLAS interface.