18
19
20
22#include "./cusolver_interface.hpp"
23#include "../../basic_array.hpp"
24#include "../../blas/tools.hpp"
25#include "../../declarations.hpp"
26#include "../../device.hpp"
27#include "../../exceptions.hpp"
28#include "../../macros.hpp"
29#include "../../mem/allocators.hpp"
30#include "../../mem/handle.hpp"
32#include <cusolverDn.h>
36namespace nda::lapack::device {
39 inline cusolverDnHandle_t &get_handle() {
40 struct handle_storage_t {
41 handle_storage_t() { cusolverDnCreate(&handle); }
42 ~handle_storage_t() { cusolverDnDestroy(handle); }
43 cusolverDnHandle_t handle = {};
45 static auto sto = handle_storage_t{};
51 static auto info_u_handle = mem::handle_heap<
int, mem::mallocator<mem::Unified>>(1);
52 return info_u_handle.data();
56 static bool synchronize =
true;
59#define CUSOLVER_CHECK(X, info, ...)
60 auto err = X(get_handle(), __VA_ARGS__, get_info_ptr());
61 if (err != CUSOLVER_STATUS_SUCCESS) { NDA_RUNTIME_ERROR << AS_STRING(X) << " failed with error code " << std::to_string(err); }
63 auto errsync = cudaDeviceSynchronize();
64 if (errsync != cudaSuccess) {
66 << " cudaGetErrorName: " << std::string(cudaGetErrorName(errsync)) << "\n"
67 << " cudaGetErrorString: " << std::string(cudaGetErrorString(errsync)) << "\n";
70 info = *get_info_ptr();
72 void gesvd(
char JOBU,
char JOBVT,
int M,
int N,
double *A,
int LDA,
double *S,
double *U,
int LDU,
double *VT,
int LDVT,
double *WORK,
int LWORK,
73 double *RWORK,
int &INFO) {
77 cusolverDnDgesvd_bufferSize(get_handle(), M, N, &bufferSize);
80 CUSOLVER_CHECK(cusolverDnDgesvd, INFO, JOBU, JOBVT, M, N, A, LDA, S, U, LDU, VT, LDVT, WORK, LWORK, RWORK);
83 void gesvd(
char JOBU,
char JOBVT,
int M,
int N, dcomplex *A,
int LDA,
double *S, dcomplex *U,
int LDU, dcomplex *VT,
int LDVT, dcomplex *WORK,
84 int LWORK,
double *RWORK,
int &INFO) {
88 cusolverDnZgesvd_bufferSize(get_handle(), M, N, &bufferSize);
91 CUSOLVER_CHECK(cusolverDnZgesvd, INFO, JOBU, JOBVT, M, N, cucplx(A), LDA, S, cucplx(U), LDU, cucplx(VT), LDVT, cucplx(WORK), LWORK,
96 void getrf(
int M,
int N,
double *A,
int LDA,
int *ipiv,
int &info) {
98 cusolverDnDgetrf_bufferSize(get_handle(), M, N, A, LDA, &bufferSize);
99 auto Workspace = nda::cuvector<
double>(bufferSize);
100 CUSOLVER_CHECK(cusolverDnDgetrf, info, M, N, A, LDA, Workspace.data(), ipiv);
102 void getrf(
int M,
int N, dcomplex *A,
int LDA,
int *ipiv,
int &info) {
104 cusolverDnZgetrf_bufferSize(get_handle(), M, N, cucplx(A), LDA, &bufferSize);
105 auto Workspace = nda::cuvector<dcomplex>(bufferSize);
106 CUSOLVER_CHECK(cusolverDnZgetrf, info, M, N, cucplx(A), LDA, cucplx(Workspace.data()), ipiv);
109 void getrs(
char op,
int N,
int NRHS,
double const *A,
int LDA,
int const *ipiv,
double *B,
int LDB,
int &info) {
110 CUSOLVER_CHECK(cusolverDnDgetrs, info, get_cublas_op(op), N, NRHS, A, LDA, ipiv, B, LDB);
112 void getrs(
char op,
int N,
int NRHS, dcomplex
const *A,
int LDA,
int const *ipiv, dcomplex *B,
int LDB,
int &info) {
113 CUSOLVER_CHECK(cusolverDnZgetrs, info, get_cublas_op(op), N, NRHS, cucplx(A), LDA, ipiv, cucplx(B), LDB);
#define CUSOLVER_CHECK(X, info,...)
#define NDA_RUNTIME_ERROR