19#include "../../mem/fill.hpp"
24#include <cusolverDn.h>
29namespace nda::lapack::device {
32 inline cusolverDnHandle_t &get_handle() {
33 struct handle_storage_t {
34 handle_storage_t() { cusolverDnCreate(&handle); }
35 ~handle_storage_t() { cusolverDnDestroy(handle); }
36 cusolverDnHandle_t handle = {};
38 static auto sto = handle_storage_t{};
44 static auto info_u_handle = mem::handle_heap<int, mem::mallocator<mem::Unified>>(1);
45 return info_u_handle.data();
49 thread_local bool synchronize =
true;
50 void set_synchronization(
bool do_sync)
noexcept { synchronize = do_sync; }
51 bool get_synchronization() noexcept {
return synchronize; }
54#define CUSOLVER_CHECK(X, info, ...) \
55 auto err = X(get_handle(), __VA_ARGS__, get_info_ptr()); \
56 if (err != CUSOLVER_STATUS_SUCCESS) { NDA_RUNTIME_ERROR << AS_STRING(X) << " failed with error code " << std::to_string(err); } \
57 cuda_device_sync(synchronize, AS_STRING(X)); \
58 info = *get_info_ptr();
65 int gesvd_buffer_size_impl(
int m,
int n) {
67 if constexpr (std::is_same_v<T, float>) {
68 cusolverDnSgesvd_bufferSize(get_handle(), m, n, &bufferSize);
69 }
else if constexpr (std::is_same_v<T, double>) {
70 cusolverDnDgesvd_bufferSize(get_handle(), m, n, &bufferSize);
71 }
else if constexpr (std::is_same_v<T, std::complex<float>>) {
72 cusolverDnCgesvd_bufferSize(get_handle(), m, n, &bufferSize);
73 }
else if constexpr (std::is_same_v<T, std::complex<double>>) {
74 cusolverDnZgesvd_bufferSize(get_handle(), m, n, &bufferSize);
81 int getrf_buffer_size_impl(
int m,
int n, T *a,
int lda) {
83 if constexpr (std::is_same_v<T, float>) {
84 cusolverDnSgetrf_bufferSize(get_handle(), m, n, a, lda, &bufferSize);
85 }
else if constexpr (std::is_same_v<T, double>) {
86 cusolverDnDgetrf_bufferSize(get_handle(), m, n, a, lda, &bufferSize);
87 }
else if constexpr (std::is_same_v<T, std::complex<float>>) {
88 cusolverDnCgetrf_bufferSize(get_handle(), m, n, cucplx(a), lda, &bufferSize);
89 }
else if constexpr (std::is_same_v<T, std::complex<double>>) {
90 cusolverDnZgetrf_bufferSize(get_handle(), m, n, cucplx(a), lda, &bufferSize);
97 int geqrf_buffer_size_impl(
int m,
int n, T *a,
int lda) {
99 if constexpr (std::is_same_v<T, float>) {
100 cusolverDnSgeqrf_bufferSize(get_handle(), m, n, a, lda, &bufferSize);
101 }
else if constexpr (std::is_same_v<T, double>) {
102 cusolverDnDgeqrf_bufferSize(get_handle(), m, n, a, lda, &bufferSize);
103 }
else if constexpr (std::is_same_v<T, std::complex<float>>) {
104 cusolverDnCgeqrf_bufferSize(get_handle(), m, n, cucplx(a), lda, &bufferSize);
105 }
else if constexpr (std::is_same_v<T, std::complex<double>>) {
106 cusolverDnZgeqrf_bufferSize(get_handle(), m, n, cucplx(a), lda, &bufferSize);
112 template <
typename T>
113 int xxgqr_buffer_size_impl(
int m,
int n,
int k, T
const *a,
int lda, T
const *tau) {
115 if constexpr (std::is_same_v<T, float>) {
116 cusolverDnSorgqr_bufferSize(get_handle(), m, n, k, a, lda, tau, &bufferSize);
117 }
else if constexpr (std::is_same_v<T, double>) {
118 cusolverDnDorgqr_bufferSize(get_handle(), m, n, k, a, lda, tau, &bufferSize);
119 }
else if constexpr (std::is_same_v<T, std::complex<float>>) {
120 cusolverDnCungqr_bufferSize(get_handle(), m, n, k, cucplx(a), lda, cucplx(tau), &bufferSize);
121 }
else if constexpr (std::is_same_v<T, std::complex<double>>) {
122 cusolverDnZungqr_bufferSize(get_handle(), m, n, k, cucplx(a), lda, cucplx(tau), &bufferSize);
129 template <
typename T>
130 void getri_impl(
int n, T *a,
int lda,
int const *ipiv, T *work,
int lwork,
int &info) {
131 auto solve_and_copy_back = [&](T *b_ptr) {
134 mem::fill2D_n<mem::Device>(b_ptr,
static_cast<size_t>(n) + 1, 1, n, T(1));
135 getrs(
'N', n, n, a, lda, ipiv, b_ptr, n, info);
138 if (lwork >= n * n) {
139 solve_and_copy_back(work);
142 solve_and_copy_back(tmp.data());
149 int gesvd_buffer_size(
int m,
int n,
float *) {
return gesvd_buffer_size_impl<float>(m, n); }
150 int gesvd_buffer_size(
int m,
int n, std::complex<float> *) {
return gesvd_buffer_size_impl<std::complex<float>>(m, n); }
151 int gesvd_buffer_size(
int m,
int n,
double *) {
return gesvd_buffer_size_impl<double>(m, n); }
152 int gesvd_buffer_size(
int m,
int n, std::complex<double> *) {
return gesvd_buffer_size_impl<std::complex<double>>(m, n); }
155 void gesvd(
char jobu,
char jobvt,
int m,
int n,
float *a,
int lda,
float *s,
float *u,
int ldu,
float *vt,
int ldvt,
float *work,
int lwork,
156 float *rwork,
int &info) {
157 CUSOLVER_CHECK(cusolverDnSgesvd, info, jobu, jobvt, m, n, a, lda, s, u, ldu, vt, ldvt, work, lwork, rwork);
159 void gesvd(
char jobu,
char jobvt,
int m,
int n, std::complex<float> *a,
int lda,
float *s, std::complex<float> *u,
int ldu, std::complex<float> *vt,
160 int ldvt, std::complex<float> *work,
int lwork,
float *rwork,
int &info) {
161 CUSOLVER_CHECK(cusolverDnCgesvd, info, jobu, jobvt, m, n, cucplx(a), lda, s, cucplx(u), ldu, cucplx(vt), ldvt, cucplx(work), lwork, rwork);
163 void gesvd(
char jobu,
char jobvt,
int m,
int n,
double *a,
int lda,
double *s,
double *u,
int ldu,
double *vt,
int ldvt,
double *work,
int lwork,
164 double *rwork,
int &info) {
165 CUSOLVER_CHECK(cusolverDnDgesvd, info, jobu, jobvt, m, n, a, lda, s, u, ldu, vt, ldvt, work, lwork, rwork);
167 void gesvd(
char jobu,
char jobvt,
int m,
int n, std::complex<double> *a,
int lda,
double *s, std::complex<double> *u,
int ldu,
168 std::complex<double> *vt,
int ldvt, std::complex<double> *work,
int lwork,
double *rwork,
int &info) {
169 CUSOLVER_CHECK(cusolverDnZgesvd, info, jobu, jobvt, m, n, cucplx(a), lda, s, cucplx(u), ldu, cucplx(vt), ldvt, cucplx(work), lwork, rwork);
173 int getrf_buffer_size(
int m,
int n,
float *a,
int lda) {
return getrf_buffer_size_impl<float>(m, n, a, lda); }
174 int getrf_buffer_size(
int m,
int n, std::complex<float> *a,
int lda) {
return getrf_buffer_size_impl<std::complex<float>>(m, n, a, lda); }
175 int getrf_buffer_size(
int m,
int n,
double *a,
int lda) {
return getrf_buffer_size_impl<double>(m, n, a, lda); }
176 int getrf_buffer_size(
int m,
int n, std::complex<double> *a,
int lda) {
return getrf_buffer_size_impl<std::complex<double>>(m, n, a, lda); }
179 void getrf(
int m,
int n,
float *a,
int lda,
float *work,
int *ipiv,
int &info) { CUSOLVER_CHECK(cusolverDnSgetrf, info, m, n, a, lda, work, ipiv); }
180 void getrf(
int m,
int n, std::complex<float> *a,
int lda, std::complex<float> *work,
int *ipiv,
int &info) {
181 CUSOLVER_CHECK(cusolverDnCgetrf, info, m, n, cucplx(a), lda, cucplx(work), ipiv);
183 void getrf(
int m,
int n,
double *a,
int lda,
double *work,
int *ipiv,
int &info) {
184 CUSOLVER_CHECK(cusolverDnDgetrf, info, m, n, a, lda, work, ipiv);
186 void getrf(
int m,
int n, std::complex<double> *a,
int lda, std::complex<double> *work,
int *ipiv,
int &info) {
187 CUSOLVER_CHECK(cusolverDnZgetrf, info, m, n, cucplx(a), lda, cucplx(work), ipiv);
191 void getrs(
char op,
int n,
int nrhs,
float const *a,
int lda,
int const *ipiv,
float *b,
int ldb,
int &info) {
192 CUSOLVER_CHECK(cusolverDnSgetrs, info, get_cublas_op(op), n, nrhs, a, lda, ipiv, b, ldb);
194 void getrs(
char op,
int n,
int nrhs, std::complex<float>
const *a,
int lda,
int const *ipiv, std::complex<float> *b,
int ldb,
int &info) {
195 CUSOLVER_CHECK(cusolverDnCgetrs, info, get_cublas_op(op), n, nrhs, cucplx(a), lda, ipiv, cucplx(b), ldb);
197 void getrs(
char op,
int n,
int nrhs,
double const *a,
int lda,
int const *ipiv,
double *b,
int ldb,
int &info) {
198 CUSOLVER_CHECK(cusolverDnDgetrs, info, get_cublas_op(op), n, nrhs, a, lda, ipiv, b, ldb);
200 void getrs(
char op,
int n,
int nrhs, std::complex<double>
const *a,
int lda,
int const *ipiv, std::complex<double> *b,
int ldb,
int &info) {
201 CUSOLVER_CHECK(cusolverDnZgetrs, info, get_cublas_op(op), n, nrhs, cucplx(a), lda, ipiv, cucplx(b), ldb);
205 void getri(
int n,
float *a,
int lda,
int const *ipiv,
float *work,
int lwork,
int &info) { getri_impl<float>(n, a, lda, ipiv, work, lwork, info); }
206 void getri(
int n, std::complex<float> *a,
int lda,
int const *ipiv, std::complex<float> *work,
int lwork,
int &info) {
207 getri_impl<std::complex<float>>(n, a, lda, ipiv, work, lwork, info);
209 void getri(
int n,
double *a,
int lda,
int const *ipiv,
double *work,
int lwork,
int &info) {
210 getri_impl<double>(n, a, lda, ipiv, work, lwork, info);
212 void getri(
int n, std::complex<double> *a,
int lda,
int const *ipiv, std::complex<double> *work,
int lwork,
int &info) {
213 getri_impl<std::complex<double>>(n, a, lda, ipiv, work, lwork, info);
217 int geqrf_buffer_size(
int m,
int n,
float *a,
int lda) {
return geqrf_buffer_size_impl<float>(m, n, a, lda); }
218 int geqrf_buffer_size(
int m,
int n, std::complex<float> *a,
int lda) {
return geqrf_buffer_size_impl<std::complex<float>>(m, n, a, lda); }
219 int geqrf_buffer_size(
int m,
int n,
double *a,
int lda) {
return geqrf_buffer_size_impl<double>(m, n, a, lda); }
220 int geqrf_buffer_size(
int m,
int n, std::complex<double> *a,
int lda) {
return geqrf_buffer_size_impl<std::complex<double>>(m, n, a, lda); }
223 void geqrf(
int m,
int n,
float *a,
int lda,
float *tau,
float *work,
int lwork,
int &info) {
224 CUSOLVER_CHECK(cusolverDnSgeqrf, info, m, n, a, lda, tau, work, lwork);
226 void geqrf(
int m,
int n, std::complex<float> *a,
int lda, std::complex<float> *tau, std::complex<float> *work,
int lwork,
int &info) {
227 CUSOLVER_CHECK(cusolverDnCgeqrf, info, m, n, cucplx(a), lda, cucplx(tau), cucplx(work), lwork);
229 void geqrf(
int m,
int n,
double *a,
int lda,
double *tau,
double *work,
int lwork,
int &info) {
230 CUSOLVER_CHECK(cusolverDnDgeqrf, info, m, n, a, lda, tau, work, lwork);
232 void geqrf(
int m,
int n, std::complex<double> *a,
int lda, std::complex<double> *tau, std::complex<double> *work,
int lwork,
int &info) {
233 CUSOLVER_CHECK(cusolverDnZgeqrf, info, m, n, cucplx(a), lda, cucplx(tau), cucplx(work), lwork);
237 int orgqr_buffer_size(
int m,
int n,
int k,
float const *a,
int lda,
float const *tau) {
return xxgqr_buffer_size_impl(m, n, k, a, lda, tau); }
238 int orgqr_buffer_size(
int m,
int n,
int k,
double const *a,
int lda,
double const *tau) {
return xxgqr_buffer_size_impl(m, n, k, a, lda, tau); }
241 void orgqr(
int m,
int n,
int k,
float *a,
int lda,
float const *tau,
float *work,
int lwork,
int &info) {
242 CUSOLVER_CHECK(cusolverDnSorgqr, info, m, n, k, a, lda, tau, work, lwork);
244 void orgqr(
int m,
int n,
int k,
double *a,
int lda,
double const *tau,
double *work,
int lwork,
int &info) {
245 CUSOLVER_CHECK(cusolverDnDorgqr, info, m, n, k, a, lda, tau, work, lwork);
249 int ungqr_buffer_size(
int m,
int n,
int k, std::complex<float>
const *a,
int lda, std::complex<float>
const *tau) {
250 return xxgqr_buffer_size_impl(m, n, k, a, lda, tau);
252 int ungqr_buffer_size(
int m,
int n,
int k, std::complex<double>
const *a,
int lda, std::complex<double>
const *tau) {
253 return xxgqr_buffer_size_impl(m, n, k, a, lda, tau);
257 void ungqr(
int m,
int n,
int k, std::complex<float> *a,
int lda, std::complex<float>
const *tau, std::complex<float> *work,
int lwork,
int &info) {
258 CUSOLVER_CHECK(cusolverDnCungqr, info, m, n, k, cucplx(a), lda, cucplx(tau), cucplx(work), lwork);
260 void ungqr(
int m,
int n,
int k, std::complex<double> *a,
int lda, std::complex<double>
const *tau, std::complex<double> *work,
int lwork,
262 CUSOLVER_CHECK(cusolverDnZungqr, info, m, n, k, cucplx(a), lda, cucplx(tau), cucplx(work), lwork);
Provides custom allocators for the nda library.
Provides the generic class for arrays.
Provides a C++ interface for the GPU versions of various LAPACK routines.
Provides various convenient aliases and helper functions for nda::basic_array and nda::basic_array_vi...
Provides GPU and non-GPU specific functionality.
Provides a custom runtime error class and macros to assert conditions and throw exceptions.
basic_array< ValueType, 1, C_layout, 'V', heap< mem::Device > > cuvector
Similar to nda::vector except the memory is stored on the device.
basic_array_view< ValueType, Rank, Layout, 'A', default_accessor, borrowed< mem::Device > > cuarray_view
Similar to nda::array_view except the memory is stored on the device.
void memcpy2D(void *dest, size_t dpitch, const void *src, size_t spitch, size_t width, size_t height)
Call CUDA's cudaMemcpy2D function or simulate its behavior on the Host based on the given address spa...
Provides various handles to take care of memory management for nda::basic_array and nda::basic_array_...
Macros used in the nda library.
Provides a generic memcpy and memcpy2D function for different address spaces.
Provides type traits for the nda library.