32#include <cusolverDn.h>
36namespace nda::lapack::device {
39 inline cusolverDnHandle_t &get_handle() {
40 struct handle_storage_t {
41 handle_storage_t() { cusolverDnCreate(&handle); }
42 ~handle_storage_t() { cusolverDnDestroy(handle); }
43 cusolverDnHandle_t handle = {};
45 static auto sto = handle_storage_t{};
51 static auto info_u_handle = mem::handle_heap<int, mem::mallocator<mem::Unified>>(1);
52 return info_u_handle.data();
56 static bool synchronize =
true;
59#define CUSOLVER_CHECK(X, info, ...) \
60 auto err = X(get_handle(), __VA_ARGS__, get_info_ptr()); \
61 if (err != CUSOLVER_STATUS_SUCCESS) { NDA_RUNTIME_ERROR << AS_STRING(X) << " failed with error code " << std::to_string(err); } \
63 auto errsync = cudaDeviceSynchronize(); \
64 if (errsync != cudaSuccess) { \
65 NDA_RUNTIME_ERROR << " cudaDeviceSynchronize failed after call to: " << AS_STRING(X) " \n " \
66 << " cudaGetErrorName: " << std::string(cudaGetErrorName(errsync)) << "\n" \
67 << " cudaGetErrorString: " << std::string(cudaGetErrorString(errsync)) << "\n"; \
70 info = *get_info_ptr();
72 void gesvd(
char JOBU,
char JOBVT,
int M,
int N,
double *A,
int LDA,
double *S,
double *U,
int LDU,
double *VT,
int LDVT,
double *WORK,
int LWORK,
73 double *RWORK,
int &INFO) {
77 cusolverDnDgesvd_bufferSize(get_handle(), M, N, &bufferSize);
80 CUSOLVER_CHECK(cusolverDnDgesvd, INFO, JOBU, JOBVT, M, N, A, LDA, S, U, LDU, VT, LDVT, WORK, LWORK, RWORK);
83 void gesvd(
char JOBU,
char JOBVT,
int M,
int N,
dcomplex *A,
int LDA,
double *S,
dcomplex *U,
int LDU,
dcomplex *VT,
int LDVT,
dcomplex *WORK,
84 int LWORK,
double *RWORK,
int &INFO) {
88 cusolverDnZgesvd_bufferSize(get_handle(), M, N, &bufferSize);
91 CUSOLVER_CHECK(cusolverDnZgesvd, INFO, JOBU, JOBVT, M, N, cucplx(A), LDA, S, cucplx(U), LDU, cucplx(VT), LDVT, cucplx(WORK), LWORK,
96 void getrf(
int M,
int N,
double *A,
int LDA,
int *ipiv,
int &info) {
98 cusolverDnDgetrf_bufferSize(get_handle(), M, N, A, LDA, &bufferSize);
100 CUSOLVER_CHECK(cusolverDnDgetrf, info, M, N, A, LDA, Workspace.data(), ipiv);
102 void getrf(
int M,
int N,
dcomplex *A,
int LDA,
int *ipiv,
int &info) {
104 cusolverDnZgetrf_bufferSize(get_handle(), M, N, cucplx(A), LDA, &bufferSize);
106 CUSOLVER_CHECK(cusolverDnZgetrf, info, M, N, cucplx(A), LDA, cucplx(Workspace.data()), ipiv);
109 void getrs(
char op,
int N,
int NRHS,
double const *A,
int LDA,
int const *ipiv,
double *B,
int LDB,
int &info) {
110 CUSOLVER_CHECK(cusolverDnDgetrs, info, get_cublas_op(op), N, NRHS, A, LDA, ipiv, B, LDB);
112 void getrs(
char op,
int N,
int NRHS,
dcomplex const *A,
int LDA,
int const *ipiv,
dcomplex *B,
int LDB,
int &info) {
113 CUSOLVER_CHECK(cusolverDnZgetrs, info, get_cublas_op(op), N, NRHS, cucplx(A), LDA, ipiv, cucplx(B), LDB);
Provides custom allocators for the nda library.
Provides the generic class for arrays.
Provides a C++ interface for the GPU versions of various LAPACK routines.
Provides various convenient aliases and helper functions for nda::basic_array and nda::basic_array_vi...
Provides GPU and non-GPU specific functionality.
Provides a custom runtime error class and macros to assert conditions and throw exceptions.
basic_array< ValueType, 1, C_layout, 'V', heap< mem::Device > > cuvector
Similar to nda::vector except the memory is stored on the device.
std::complex< double > dcomplex
Alias for std::complex<double> type.
int gesvd(A &&a, S &&s, U &&u, VT &&vt)
Interface to the LAPACK gesvd routine.
int getrs(A const &a, B &&b, IPIV const &ipiv)
Interface to the LAPACK getrs routine.
int getrf(A &&a, IPIV &&ipiv)
Interface to the LAPACK getrf routine.
Provides various handles to take care of memory management for nda::basic_array and nda::basic_array_...
Macros used in the nda library.