TRIQS/nda 2.0.0
Multi-dimensional array library for C++
Loading...
Searching...
No Matches
cusolver_interface.cpp
Go to the documentation of this file.
1// Copyright (c) 2022--present, The Simons Foundation
2// This file is part of TRIQS/nda and is licensed under the Apache License, Version 2.0.
3// SPDX-License-Identifier: Apache-2.0
4// See LICENSE in the root of this distribution for details.
5
10
12#include "../../basic_array.hpp"
13#include "../../blas/tools.hpp"
15#include "../../device.hpp"
16#include "../../exceptions.hpp"
17#include "../../macros.hpp"
19#include "../../mem/fill.hpp"
20#include "../../mem/handle.hpp"
21#include "../../mem/memcpy.hpp"
22#include "../../traits.hpp"
23
24#include <cusolverDn.h>
25
26#include <string>
27#include <type_traits>
28
29namespace nda::lapack::device {
30
31 // Local function to get unique CuSolver handle.
32 inline cusolverDnHandle_t &get_handle() {
33 struct handle_storage_t { // RAII for handle
34 handle_storage_t() { cusolverDnCreate(&handle); }
35 ~handle_storage_t() { cusolverDnDestroy(handle); }
36 cusolverDnHandle_t handle = {};
37 };
38 static auto sto = handle_storage_t{};
39 return sto.handle;
40 }
41
42 // Get an integer pointer in unified memory to return info from lapack routines.
43 int *get_info_ptr() {
44 static auto info_u_handle = mem::handle_heap<int, mem::mallocator<mem::Unified>>(1);
45 return info_u_handle.data();
46 }
47
48 // Per-thread option to turn on/off the cudaDeviceSynchronize after cusolver library calls.
49 thread_local bool synchronize = true; // NOLINT (per-thread option is on purpose)
50 void set_synchronization(bool do_sync) noexcept { synchronize = do_sync; }
51 bool get_synchronization() noexcept { return synchronize; }
52
53// Macro to check cusolver calls.
54#define CUSOLVER_CHECK(X, info, ...) \
55 auto err = X(get_handle(), __VA_ARGS__, get_info_ptr()); \
56 if (err != CUSOLVER_STATUS_SUCCESS) { NDA_RUNTIME_ERROR << AS_STRING(X) << " failed with error code " << std::to_string(err); } \
57 cuda_device_sync(synchronize, AS_STRING(X)); \
58 info = *get_info_ptr();
59
60 // Anonymous namespace for some file local helper functions.
61 namespace {
62
63 // Get the buffer size for gesvd.
64 template <typename T>
65 int gesvd_buffer_size_impl(int m, int n) {
66 int bufferSize = 0;
67 if constexpr (std::is_same_v<T, float>) {
68 cusolverDnSgesvd_bufferSize(get_handle(), m, n, &bufferSize);
69 } else if constexpr (std::is_same_v<T, double>) {
70 cusolverDnDgesvd_bufferSize(get_handle(), m, n, &bufferSize);
71 } else if constexpr (std::is_same_v<T, std::complex<float>>) {
72 cusolverDnCgesvd_bufferSize(get_handle(), m, n, &bufferSize);
73 } else if constexpr (std::is_same_v<T, std::complex<double>>) {
74 cusolverDnZgesvd_bufferSize(get_handle(), m, n, &bufferSize);
75 }
76 return bufferSize;
77 }
78
79 // Get the buffer size for getrf.
80 template <typename T>
81 int getrf_buffer_size_impl(int m, int n, T *a, int lda) {
82 int bufferSize = 0;
83 if constexpr (std::is_same_v<T, float>) {
84 cusolverDnSgetrf_bufferSize(get_handle(), m, n, a, lda, &bufferSize);
85 } else if constexpr (std::is_same_v<T, double>) {
86 cusolverDnDgetrf_bufferSize(get_handle(), m, n, a, lda, &bufferSize);
87 } else if constexpr (std::is_same_v<T, std::complex<float>>) {
88 cusolverDnCgetrf_bufferSize(get_handle(), m, n, cucplx(a), lda, &bufferSize);
89 } else if constexpr (std::is_same_v<T, std::complex<double>>) {
90 cusolverDnZgetrf_bufferSize(get_handle(), m, n, cucplx(a), lda, &bufferSize);
91 }
92 return bufferSize;
93 }
94
95 // Get the buffer size for geqrf.
96 template <typename T>
97 int geqrf_buffer_size_impl(int m, int n, T *a, int lda) {
98 int bufferSize = 0;
99 if constexpr (std::is_same_v<T, float>) {
100 cusolverDnSgeqrf_bufferSize(get_handle(), m, n, a, lda, &bufferSize);
101 } else if constexpr (std::is_same_v<T, double>) {
102 cusolverDnDgeqrf_bufferSize(get_handle(), m, n, a, lda, &bufferSize);
103 } else if constexpr (std::is_same_v<T, std::complex<float>>) {
104 cusolverDnCgeqrf_bufferSize(get_handle(), m, n, cucplx(a), lda, &bufferSize);
105 } else if constexpr (std::is_same_v<T, std::complex<double>>) {
106 cusolverDnZgeqrf_bufferSize(get_handle(), m, n, cucplx(a), lda, &bufferSize);
107 }
108 return bufferSize;
109 }
110
111 // Get the buffer size for orgqr/ungqr.
112 template <typename T>
113 int xxgqr_buffer_size_impl(int m, int n, int k, T const *a, int lda, T const *tau) {
114 int bufferSize = 0;
115 if constexpr (std::is_same_v<T, float>) {
116 cusolverDnSorgqr_bufferSize(get_handle(), m, n, k, a, lda, tau, &bufferSize);
117 } else if constexpr (std::is_same_v<T, double>) {
118 cusolverDnDorgqr_bufferSize(get_handle(), m, n, k, a, lda, tau, &bufferSize);
119 } else if constexpr (std::is_same_v<T, std::complex<float>>) {
120 cusolverDnCungqr_bufferSize(get_handle(), m, n, k, cucplx(a), lda, cucplx(tau), &bufferSize);
121 } else if constexpr (std::is_same_v<T, std::complex<double>>) {
122 cusolverDnZungqr_bufferSize(get_handle(), m, n, k, cucplx(a), lda, cucplx(tau), &bufferSize);
123 }
124 return bufferSize;
125 }
126
127 // Custom getri implementation: build I in workspace (or a fallback buffer), call getrs to solve A * X = I, then
128 // copy the resulting inverse back over A.
129 template <typename T>
130 void getri_impl(int n, T *a, int lda, int const *ipiv, T *work, int lwork, int &info) {
131 auto solve_and_copy_back = [&](T *b_ptr) {
132 auto B = nda::cuarray_view<T, 2>{std::array<long, 2>{n, n}, b_ptr};
133 B() = T(0);
134 mem::fill2D_n<mem::Device>(b_ptr, static_cast<size_t>(n) + 1, 1, n, T(1));
135 getrs('N', n, n, a, lda, ipiv, b_ptr, n, info);
136 mem::memcpy2D<mem::Device, mem::Device>(a, lda * sizeof(T), b_ptr, n * sizeof(T), n * sizeof(T), n);
137 };
138 if (lwork >= n * n) {
139 solve_and_copy_back(work);
140 } else {
141 auto tmp = nda::cuvector<T>(n * n);
142 solve_and_copy_back(tmp.data());
143 }
144 }
145
146 } // namespace
147
148 // gesvd buffer size
149 int gesvd_buffer_size(int m, int n, float *) { return gesvd_buffer_size_impl<float>(m, n); }
150 int gesvd_buffer_size(int m, int n, std::complex<float> *) { return gesvd_buffer_size_impl<std::complex<float>>(m, n); }
151 int gesvd_buffer_size(int m, int n, double *) { return gesvd_buffer_size_impl<double>(m, n); }
152 int gesvd_buffer_size(int m, int n, std::complex<double> *) { return gesvd_buffer_size_impl<std::complex<double>>(m, n); }
153
154 // gesvd
155 void gesvd(char jobu, char jobvt, int m, int n, float *a, int lda, float *s, float *u, int ldu, float *vt, int ldvt, float *work, int lwork,
156 float *rwork, int &info) {
157 CUSOLVER_CHECK(cusolverDnSgesvd, info, jobu, jobvt, m, n, a, lda, s, u, ldu, vt, ldvt, work, lwork, rwork);
158 }
159 void gesvd(char jobu, char jobvt, int m, int n, std::complex<float> *a, int lda, float *s, std::complex<float> *u, int ldu, std::complex<float> *vt,
160 int ldvt, std::complex<float> *work, int lwork, float *rwork, int &info) {
161 CUSOLVER_CHECK(cusolverDnCgesvd, info, jobu, jobvt, m, n, cucplx(a), lda, s, cucplx(u), ldu, cucplx(vt), ldvt, cucplx(work), lwork, rwork);
162 }
163 void gesvd(char jobu, char jobvt, int m, int n, double *a, int lda, double *s, double *u, int ldu, double *vt, int ldvt, double *work, int lwork,
164 double *rwork, int &info) {
165 CUSOLVER_CHECK(cusolverDnDgesvd, info, jobu, jobvt, m, n, a, lda, s, u, ldu, vt, ldvt, work, lwork, rwork);
166 }
167 void gesvd(char jobu, char jobvt, int m, int n, std::complex<double> *a, int lda, double *s, std::complex<double> *u, int ldu,
168 std::complex<double> *vt, int ldvt, std::complex<double> *work, int lwork, double *rwork, int &info) {
169 CUSOLVER_CHECK(cusolverDnZgesvd, info, jobu, jobvt, m, n, cucplx(a), lda, s, cucplx(u), ldu, cucplx(vt), ldvt, cucplx(work), lwork, rwork);
170 }
171
172 // getrf buffer size
173 int getrf_buffer_size(int m, int n, float *a, int lda) { return getrf_buffer_size_impl<float>(m, n, a, lda); }
174 int getrf_buffer_size(int m, int n, std::complex<float> *a, int lda) { return getrf_buffer_size_impl<std::complex<float>>(m, n, a, lda); }
175 int getrf_buffer_size(int m, int n, double *a, int lda) { return getrf_buffer_size_impl<double>(m, n, a, lda); }
176 int getrf_buffer_size(int m, int n, std::complex<double> *a, int lda) { return getrf_buffer_size_impl<std::complex<double>>(m, n, a, lda); }
177
178 // getrf
179 void getrf(int m, int n, float *a, int lda, float *work, int *ipiv, int &info) { CUSOLVER_CHECK(cusolverDnSgetrf, info, m, n, a, lda, work, ipiv); }
180 void getrf(int m, int n, std::complex<float> *a, int lda, std::complex<float> *work, int *ipiv, int &info) {
181 CUSOLVER_CHECK(cusolverDnCgetrf, info, m, n, cucplx(a), lda, cucplx(work), ipiv);
182 }
183 void getrf(int m, int n, double *a, int lda, double *work, int *ipiv, int &info) {
184 CUSOLVER_CHECK(cusolverDnDgetrf, info, m, n, a, lda, work, ipiv);
185 }
186 void getrf(int m, int n, std::complex<double> *a, int lda, std::complex<double> *work, int *ipiv, int &info) {
187 CUSOLVER_CHECK(cusolverDnZgetrf, info, m, n, cucplx(a), lda, cucplx(work), ipiv);
188 }
189
190 // getrs
191 void getrs(char op, int n, int nrhs, float const *a, int lda, int const *ipiv, float *b, int ldb, int &info) {
192 CUSOLVER_CHECK(cusolverDnSgetrs, info, get_cublas_op(op), n, nrhs, a, lda, ipiv, b, ldb);
193 }
194 void getrs(char op, int n, int nrhs, std::complex<float> const *a, int lda, int const *ipiv, std::complex<float> *b, int ldb, int &info) {
195 CUSOLVER_CHECK(cusolverDnCgetrs, info, get_cublas_op(op), n, nrhs, cucplx(a), lda, ipiv, cucplx(b), ldb);
196 }
197 void getrs(char op, int n, int nrhs, double const *a, int lda, int const *ipiv, double *b, int ldb, int &info) {
198 CUSOLVER_CHECK(cusolverDnDgetrs, info, get_cublas_op(op), n, nrhs, a, lda, ipiv, b, ldb);
199 }
200 void getrs(char op, int n, int nrhs, std::complex<double> const *a, int lda, int const *ipiv, std::complex<double> *b, int ldb, int &info) {
201 CUSOLVER_CHECK(cusolverDnZgetrs, info, get_cublas_op(op), n, nrhs, cucplx(a), lda, ipiv, cucplx(b), ldb);
202 }
203
204 // getri
205 void getri(int n, float *a, int lda, int const *ipiv, float *work, int lwork, int &info) { getri_impl<float>(n, a, lda, ipiv, work, lwork, info); }
206 void getri(int n, std::complex<float> *a, int lda, int const *ipiv, std::complex<float> *work, int lwork, int &info) {
207 getri_impl<std::complex<float>>(n, a, lda, ipiv, work, lwork, info);
208 }
209 void getri(int n, double *a, int lda, int const *ipiv, double *work, int lwork, int &info) {
210 getri_impl<double>(n, a, lda, ipiv, work, lwork, info);
211 }
212 void getri(int n, std::complex<double> *a, int lda, int const *ipiv, std::complex<double> *work, int lwork, int &info) {
213 getri_impl<std::complex<double>>(n, a, lda, ipiv, work, lwork, info);
214 }
215
216 // geqrf buffer size
217 int geqrf_buffer_size(int m, int n, float *a, int lda) { return geqrf_buffer_size_impl<float>(m, n, a, lda); }
218 int geqrf_buffer_size(int m, int n, std::complex<float> *a, int lda) { return geqrf_buffer_size_impl<std::complex<float>>(m, n, a, lda); }
219 int geqrf_buffer_size(int m, int n, double *a, int lda) { return geqrf_buffer_size_impl<double>(m, n, a, lda); }
220 int geqrf_buffer_size(int m, int n, std::complex<double> *a, int lda) { return geqrf_buffer_size_impl<std::complex<double>>(m, n, a, lda); }
221
222 // geqrf
223 void geqrf(int m, int n, float *a, int lda, float *tau, float *work, int lwork, int &info) {
224 CUSOLVER_CHECK(cusolverDnSgeqrf, info, m, n, a, lda, tau, work, lwork);
225 }
226 void geqrf(int m, int n, std::complex<float> *a, int lda, std::complex<float> *tau, std::complex<float> *work, int lwork, int &info) {
227 CUSOLVER_CHECK(cusolverDnCgeqrf, info, m, n, cucplx(a), lda, cucplx(tau), cucplx(work), lwork);
228 }
229 void geqrf(int m, int n, double *a, int lda, double *tau, double *work, int lwork, int &info) {
230 CUSOLVER_CHECK(cusolverDnDgeqrf, info, m, n, a, lda, tau, work, lwork);
231 }
232 void geqrf(int m, int n, std::complex<double> *a, int lda, std::complex<double> *tau, std::complex<double> *work, int lwork, int &info) {
233 CUSOLVER_CHECK(cusolverDnZgeqrf, info, m, n, cucplx(a), lda, cucplx(tau), cucplx(work), lwork);
234 }
235
236 // orgqr buffer size
237 int orgqr_buffer_size(int m, int n, int k, float const *a, int lda, float const *tau) { return xxgqr_buffer_size_impl(m, n, k, a, lda, tau); }
238 int orgqr_buffer_size(int m, int n, int k, double const *a, int lda, double const *tau) { return xxgqr_buffer_size_impl(m, n, k, a, lda, tau); }
239
240 // orgqr
241 void orgqr(int m, int n, int k, float *a, int lda, float const *tau, float *work, int lwork, int &info) {
242 CUSOLVER_CHECK(cusolverDnSorgqr, info, m, n, k, a, lda, tau, work, lwork);
243 }
244 void orgqr(int m, int n, int k, double *a, int lda, double const *tau, double *work, int lwork, int &info) {
245 CUSOLVER_CHECK(cusolverDnDorgqr, info, m, n, k, a, lda, tau, work, lwork);
246 }
247
248 // ungqr buffer size
249 int ungqr_buffer_size(int m, int n, int k, std::complex<float> const *a, int lda, std::complex<float> const *tau) {
250 return xxgqr_buffer_size_impl(m, n, k, a, lda, tau);
251 }
252 int ungqr_buffer_size(int m, int n, int k, std::complex<double> const *a, int lda, std::complex<double> const *tau) {
253 return xxgqr_buffer_size_impl(m, n, k, a, lda, tau);
254 }
255
256 // ungqr
257 void ungqr(int m, int n, int k, std::complex<float> *a, int lda, std::complex<float> const *tau, std::complex<float> *work, int lwork, int &info) {
258 CUSOLVER_CHECK(cusolverDnCungqr, info, m, n, k, cucplx(a), lda, cucplx(tau), cucplx(work), lwork);
259 }
260 void ungqr(int m, int n, int k, std::complex<double> *a, int lda, std::complex<double> const *tau, std::complex<double> *work, int lwork,
261 int &info) {
262 CUSOLVER_CHECK(cusolverDnZungqr, info, m, n, k, cucplx(a), lda, cucplx(tau), cucplx(work), lwork);
263 }
264
265} // namespace nda::lapack::device
Provides custom allocators for the nda library.
Provides the generic class for arrays.
Provides various traits and utilities for the BLAS interface.
Provides a C++ interface for the GPU versions of various LAPACK routines.
Provides various convenient aliases and helper functions for nda::basic_array and nda::basic_array_vi...
Provides GPU and non-GPU specific functionality.
Provides a custom runtime error class and macros to assert conditions and throw exceptions.
basic_array< ValueType, 1, C_layout, 'V', heap< mem::Device > > cuvector
Similar to nda::vector except the memory is stored on the device.
basic_array_view< ValueType, Rank, Layout, 'A', default_accessor, borrowed< mem::Device > > cuarray_view
Similar to nda::array_view except the memory is stored on the device.
void memcpy2D(void *dest, size_t dpitch, const void *src, size_t spitch, size_t width, size_t height)
Call CUDA's cudaMemcpy2D function or simulate its behavior on the Host based on the given address spa...
Definition memcpy.hpp:73
Provides various handles to take care of memory management for nda::basic_array and nda::basic_array_...
Macros used in the nda library.
Provides a generic memcpy and memcpy2D function for different address spaces.
Provides type traits for the nda library.