32#include <cusolverDn.h> 
   36namespace nda::lapack::device {
 
   39  inline cusolverDnHandle_t &get_handle() {
 
   40    struct handle_storage_t { 
 
   41      handle_storage_t() { cusolverDnCreate(&handle); }
 
   42      ~handle_storage_t() { cusolverDnDestroy(handle); }
 
   43      cusolverDnHandle_t handle = {};
 
   45    static auto sto = handle_storage_t{};
 
   51    static auto info_u_handle = mem::handle_heap<int, mem::mallocator<mem::Unified>>(1);
 
   52    return info_u_handle.data();
 
   56  static bool synchronize = 
true; 
 
   59#define CUSOLVER_CHECK(X, info, ...)                                                                                                                 \ 
   60  auto err = X(get_handle(), __VA_ARGS__, get_info_ptr());                                                                                           \ 
   61  if (err != CUSOLVER_STATUS_SUCCESS) { NDA_RUNTIME_ERROR << AS_STRING(X) << " failed with error code " << std::to_string(err); }                    \ 
   63    auto errsync = cudaDeviceSynchronize();                                                                                                          \ 
   64    if (errsync != cudaSuccess) {                                                                                                                    \ 
   65      NDA_RUNTIME_ERROR << " cudaDeviceSynchronize failed after call to: " << AS_STRING(X) " \n "                                                    \ 
   66                        << " cudaGetErrorName: " << std::string(cudaGetErrorName(errsync)) << "\n"                                                   \ 
   67                        << " cudaGetErrorString: " << std::string(cudaGetErrorString(errsync)) << "\n";                                              \ 
   70  info = *get_info_ptr(); 
   72  void gesvd(
char JOBU, 
char JOBVT, 
int M, 
int N, 
double *A, 
int LDA, 
double *S, 
double *U, 
int LDU, 
double *VT, 
int LDVT, 
double *WORK, 
int LWORK,
 
   73             double *RWORK, 
int &INFO) {
 
   77      cusolverDnDgesvd_bufferSize(get_handle(), M, N, &bufferSize);
 
   80      CUSOLVER_CHECK(cusolverDnDgesvd, INFO, JOBU, JOBVT, M, N, A, LDA, S, U, LDU, VT, LDVT, WORK, LWORK, RWORK);
 
   83  void gesvd(
char JOBU, 
char JOBVT, 
int M, 
int N, 
dcomplex *A, 
int LDA, 
double *S, 
dcomplex *U, 
int LDU, 
dcomplex *VT, 
int LDVT, 
dcomplex *WORK,
 
   84             int LWORK, 
double *RWORK, 
int &INFO) {
 
   88      cusolverDnZgesvd_bufferSize(get_handle(), M, N, &bufferSize);
 
   91      CUSOLVER_CHECK(cusolverDnZgesvd, INFO, JOBU, JOBVT, M, N, cucplx(A), LDA, S, cucplx(U), LDU, cucplx(VT), LDVT, cucplx(WORK), LWORK,
 
   96  void getrf(
int M, 
int N, 
double *A, 
int LDA, 
int *ipiv, 
int &info) {
 
   98    cusolverDnDgetrf_bufferSize(get_handle(), M, N, A, LDA, &bufferSize);
 
  100    CUSOLVER_CHECK(cusolverDnDgetrf, info, M, N, A, LDA, Workspace.data(), ipiv);
 
  102  void getrf(
int M, 
int N, 
dcomplex *A, 
int LDA, 
int *ipiv, 
int &info) {
 
  104    cusolverDnZgetrf_bufferSize(get_handle(), M, N, cucplx(A), LDA, &bufferSize);
 
  106    CUSOLVER_CHECK(cusolverDnZgetrf, info, M, N, cucplx(A), LDA, cucplx(Workspace.data()), ipiv);
 
  109  void getrs(
char op, 
int N, 
int NRHS, 
double const *A, 
int LDA, 
int const *ipiv, 
double *B, 
int LDB, 
int &info) {
 
  110    CUSOLVER_CHECK(cusolverDnDgetrs, info, get_cublas_op(op), N, NRHS, A, LDA, ipiv, B, LDB);
 
  112  void getrs(
char op, 
int N, 
int NRHS, 
dcomplex const *A, 
int LDA, 
int const *ipiv, 
dcomplex *B, 
int LDB, 
int &info) {
 
  113    CUSOLVER_CHECK(cusolverDnZgetrs, info, get_cublas_op(op), N, NRHS, cucplx(A), LDA, ipiv, cucplx(B), LDB);
 
Provides custom allocators for the nda library.
 
Provides the generic class for arrays.
 
Provides a C++ interface for the GPU versions of various LAPACK routines.
 
Provides various convenient aliases and helper functions for nda::basic_array and nda::basic_array_vi...
 
Provides GPU and non-GPU specific functionality.
 
Provides a custom runtime error class and macros to assert conditions and throw exceptions.
 
basic_array< ValueType, 1, C_layout, 'V', heap< mem::Device > > cuvector
Similar to nda::vector except the memory is stored on the device.
 
std::complex< double > dcomplex
Alias for std::complex<double> type.
 
Provides various handles to take care of memory management for nda::basic_array and nda::basic_array_...
 
Macros used in the nda library.