21#include <cusolverDn.h>
25namespace nda::lapack::device {
28 inline cusolverDnHandle_t &get_handle() {
29 struct handle_storage_t {
30 handle_storage_t() { cusolverDnCreate(&handle); }
31 ~handle_storage_t() { cusolverDnDestroy(handle); }
32 cusolverDnHandle_t handle = {};
34 static auto sto = handle_storage_t{};
40 static auto info_u_handle = mem::handle_heap<int, mem::mallocator<mem::Unified>>(1);
41 return info_u_handle.data();
45 static bool synchronize =
true;
48#define CUSOLVER_CHECK(X, info, ...) \
49 auto err = X(get_handle(), __VA_ARGS__, get_info_ptr()); \
50 if (err != CUSOLVER_STATUS_SUCCESS) { NDA_RUNTIME_ERROR << AS_STRING(X) << " failed with error code " << std::to_string(err); } \
52 auto errsync = cudaDeviceSynchronize(); \
53 if (errsync != cudaSuccess) { \
54 NDA_RUNTIME_ERROR << " cudaDeviceSynchronize failed after call to: " << AS_STRING(X) " \n " \
55 << " cudaGetErrorName: " << std::string(cudaGetErrorName(errsync)) << "\n" \
56 << " cudaGetErrorString: " << std::string(cudaGetErrorString(errsync)) << "\n"; \
59 info = *get_info_ptr();
61 void gesvd(
char jobu,
char jobvt,
int m,
int n,
double *a,
int lda,
double *s,
double *u,
int ldu,
double *vt,
int ldvt,
double *work,
int lwork,
62 double *rwork,
int &info) {
66 cusolverDnDgesvd_bufferSize(get_handle(), m, n, &bufferSize);
69 CUSOLVER_CHECK(cusolverDnDgesvd, info, jobu, jobvt, m, n, a, lda, s, u, ldu, vt, ldvt, work, lwork, rwork);
72 void gesvd(
char jobu,
char jobvt,
int m,
int n,
dcomplex *a,
int lda,
double *s,
dcomplex *u,
int ldu,
dcomplex *vt,
int ldvt,
dcomplex *work,
73 int lwork,
double *rwork,
int &info) {
77 cusolverDnZgesvd_bufferSize(get_handle(), m, n, &bufferSize);
80 CUSOLVER_CHECK(cusolverDnZgesvd, info, jobu, jobvt, m, n, cucplx(a), lda, s, cucplx(u), ldu, cucplx(vt), ldvt, cucplx(work), lwork,
85 void getrf(
int m,
int n,
double *a,
int lda,
int *ipiv,
int &info) {
87 cusolverDnDgetrf_bufferSize(get_handle(), m, n, a, lda, &bufferSize);
89 CUSOLVER_CHECK(cusolverDnDgetrf, info, m, n, a, lda, Workspace.data(), ipiv);
91 void getrf(
int m,
int n,
dcomplex *a,
int lda,
int *ipiv,
int &info) {
93 cusolverDnZgetrf_bufferSize(get_handle(), m, n, cucplx(a), lda, &bufferSize);
95 CUSOLVER_CHECK(cusolverDnZgetrf, info, m, n, cucplx(a), lda, cucplx(Workspace.data()), ipiv);
98 void getrs(
char op,
int n,
int nrhs,
double const *a,
int lda,
int const *ipiv,
double *b,
int ldb,
int &info) {
99 CUSOLVER_CHECK(cusolverDnDgetrs, info, get_cublas_op(op), n, nrhs, a, lda, ipiv, b, ldb);
101 void getrs(
char op,
int n,
int nrhs,
dcomplex const *a,
int lda,
int const *ipiv,
dcomplex *b,
int ldb,
int &info) {
102 CUSOLVER_CHECK(cusolverDnZgetrs, info, get_cublas_op(op), n, nrhs, cucplx(a), lda, ipiv, cucplx(b), ldb);
Provides custom allocators for the nda library.
Provides the generic class for arrays.
Provides a C++ interface for the GPU versions of various LAPACK routines.
Provides various convenient aliases and helper functions for nda::basic_array and nda::basic_array_vi...
Provides GPU and non-GPU specific functionality.
Provides a custom runtime error class and macros to assert conditions and throw exceptions.
basic_array< ValueType, 1, C_layout, 'V', heap< mem::Device > > cuvector
Similar to nda::vector except the memory is stored on the device.
std::complex< double > dcomplex
Alias for std::complex<double> type.
Provides various handles to take care of memory management for nda::basic_array and nda::basic_array_...
Macros used in the nda library.