21#include <cusolverDn.h>
25namespace nda::lapack::device {
28 inline cusolverDnHandle_t &get_handle() {
29 struct handle_storage_t {
30 handle_storage_t() { cusolverDnCreate(&handle); }
31 ~handle_storage_t() { cusolverDnDestroy(handle); }
32 cusolverDnHandle_t handle = {};
34 static auto sto = handle_storage_t{};
40 static auto info_u_handle = mem::handle_heap<int, mem::mallocator<mem::Unified>>(1);
41 return info_u_handle.data();
45 static bool synchronize =
true;
48#define CUSOLVER_CHECK(X, info, ...) \
49 auto err = X(get_handle(), __VA_ARGS__, get_info_ptr()); \
50 if (err != CUSOLVER_STATUS_SUCCESS) { NDA_RUNTIME_ERROR << AS_STRING(X) << " failed with error code " << std::to_string(err); } \
52 auto errsync = cudaDeviceSynchronize(); \
53 if (errsync != cudaSuccess) { \
54 NDA_RUNTIME_ERROR << " cudaDeviceSynchronize failed after call to: " << AS_STRING(X) " \n " \
55 << " cudaGetErrorName: " << std::string(cudaGetErrorName(errsync)) << "\n" \
56 << " cudaGetErrorString: " << std::string(cudaGetErrorString(errsync)) << "\n"; \
59 info = *get_info_ptr();
61 void gesvd(
char JOBU,
char JOBVT,
int M,
int N,
double *A,
int LDA,
double *S,
double *U,
int LDU,
double *VT,
int LDVT,
double *WORK,
int LWORK,
62 double *RWORK,
int &INFO) {
66 cusolverDnDgesvd_bufferSize(get_handle(), M, N, &bufferSize);
69 CUSOLVER_CHECK(cusolverDnDgesvd, INFO, JOBU, JOBVT, M, N, A, LDA, S, U, LDU, VT, LDVT, WORK, LWORK, RWORK);
72 void gesvd(
char JOBU,
char JOBVT,
int M,
int N,
dcomplex *A,
int LDA,
double *S,
dcomplex *U,
int LDU,
dcomplex *VT,
int LDVT,
dcomplex *WORK,
73 int LWORK,
double *RWORK,
int &INFO) {
77 cusolverDnZgesvd_bufferSize(get_handle(), M, N, &bufferSize);
80 CUSOLVER_CHECK(cusolverDnZgesvd, INFO, JOBU, JOBVT, M, N, cucplx(A), LDA, S, cucplx(U), LDU, cucplx(VT), LDVT, cucplx(WORK), LWORK,
85 void getrf(
int M,
int N,
double *A,
int LDA,
int *ipiv,
int &info) {
87 cusolverDnDgetrf_bufferSize(get_handle(), M, N, A, LDA, &bufferSize);
89 CUSOLVER_CHECK(cusolverDnDgetrf, info, M, N, A, LDA, Workspace.data(), ipiv);
91 void getrf(
int M,
int N,
dcomplex *A,
int LDA,
int *ipiv,
int &info) {
93 cusolverDnZgetrf_bufferSize(get_handle(), M, N, cucplx(A), LDA, &bufferSize);
95 CUSOLVER_CHECK(cusolverDnZgetrf, info, M, N, cucplx(A), LDA, cucplx(Workspace.data()), ipiv);
98 void getrs(
char op,
int N,
int NRHS,
double const *A,
int LDA,
int const *ipiv,
double *B,
int LDB,
int &info) {
99 CUSOLVER_CHECK(cusolverDnDgetrs, info, get_cublas_op(op), N, NRHS, A, LDA, ipiv, B, LDB);
101 void getrs(
char op,
int N,
int NRHS,
dcomplex const *A,
int LDA,
int const *ipiv,
dcomplex *B,
int LDB,
int &info) {
102 CUSOLVER_CHECK(cusolverDnZgetrs, info, get_cublas_op(op), N, NRHS, cucplx(A), LDA, ipiv, cucplx(B), LDB);
Provides custom allocators for the nda library.
Provides the generic class for arrays.
Provides a C++ interface for the GPU versions of various LAPACK routines.
Provides various convenient aliases and helper functions for nda::basic_array and nda::basic_array_vi...
Provides GPU and non-GPU specific functionality.
Provides a custom runtime error class and macros to assert conditions and throw exceptions.
basic_array< ValueType, 1, C_layout, 'V', heap< mem::Device > > cuvector
Similar to nda::vector except the memory is stored on the device.
std::complex< double > dcomplex
Alias for std::complex<double> type.
Provides various handles to take care of memory management for nda::basic_array and nda::basic_array_...
Macros used in the nda library.