TRIQS/nda 1.3.0
Multi-dimensional array library for C++
Loading...
Searching...
No Matches
cusolver_interface.cpp
Go to the documentation of this file.
1// Copyright (c) 2022-2024 Simons Foundation
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7// http://www.apache.org/licenses/LICENSE-2.0.txt
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14//
15// Authors: Thomas Hahn, Miguel Morales, Nils Wentzell
16
23#include "../../basic_array.hpp"
24#include "../../blas/tools.hpp"
26#include "../../device.hpp"
27#include "../../exceptions.hpp"
28#include "../../macros.hpp"
30#include "../../mem/handle.hpp"
31
32#include <cusolverDn.h>
33
34#include <string>
35
36namespace nda::lapack::device {
37
38 // Local function to get unique CuSolver handle.
39 inline cusolverDnHandle_t &get_handle() {
40 struct handle_storage_t { // RAII for handle
41 handle_storage_t() { cusolverDnCreate(&handle); }
42 ~handle_storage_t() { cusolverDnDestroy(handle); }
43 cusolverDnHandle_t handle = {};
44 };
45 static auto sto = handle_storage_t{};
46 return sto.handle;
47 }
48
49 // Get an integer pointer in unified memory to return info from lapack routines.
50 int *get_info_ptr() {
51 static auto info_u_handle = mem::handle_heap<int, mem::mallocator<mem::Unified>>(1);
52 return info_u_handle.data();
53 }
54
55 // Global option to turn on/off the cudaDeviceSynchronize after cusolver library calls.
56 static bool synchronize = true; // NOLINT (global option is on purpose)
57
58// Macro to check cusolver calls.
59#define CUSOLVER_CHECK(X, info, ...) \
60 auto err = X(get_handle(), __VA_ARGS__, get_info_ptr()); \
61 if (err != CUSOLVER_STATUS_SUCCESS) { NDA_RUNTIME_ERROR << AS_STRING(X) << " failed with error code " << std::to_string(err); } \
62 if (synchronize) { \
63 auto errsync = cudaDeviceSynchronize(); \
64 if (errsync != cudaSuccess) { \
65 NDA_RUNTIME_ERROR << " cudaDeviceSynchronize failed after call to: " << AS_STRING(X) " \n " \
66 << " cudaGetErrorName: " << std::string(cudaGetErrorName(errsync)) << "\n" \
67 << " cudaGetErrorString: " << std::string(cudaGetErrorString(errsync)) << "\n"; \
68 } \
69 } \
70 info = *get_info_ptr();
71
72 void gesvd(char JOBU, char JOBVT, int M, int N, double *A, int LDA, double *S, double *U, int LDU, double *VT, int LDVT, double *WORK, int LWORK,
73 double *RWORK, int &INFO) {
74 // Replicate behavior of Netlib gesvd
75 if (LWORK == -1) {
76 int bufferSize = 0;
77 cusolverDnDgesvd_bufferSize(get_handle(), M, N, &bufferSize);
78 *WORK = bufferSize;
79 } else {
80 CUSOLVER_CHECK(cusolverDnDgesvd, INFO, JOBU, JOBVT, M, N, A, LDA, S, U, LDU, VT, LDVT, WORK, LWORK, RWORK);
81 }
82 }
83 void gesvd(char JOBU, char JOBVT, int M, int N, dcomplex *A, int LDA, double *S, dcomplex *U, int LDU, dcomplex *VT, int LDVT, dcomplex *WORK,
84 int LWORK, double *RWORK, int &INFO) {
85 // Replicate behavior of Netlib gesvd
86 if (LWORK == -1) {
87 int bufferSize = 0;
88 cusolverDnZgesvd_bufferSize(get_handle(), M, N, &bufferSize);
89 *WORK = bufferSize;
90 } else {
91 CUSOLVER_CHECK(cusolverDnZgesvd, INFO, JOBU, JOBVT, M, N, cucplx(A), LDA, S, cucplx(U), LDU, cucplx(VT), LDVT, cucplx(WORK), LWORK,
92 RWORK); // NOLINT
93 }
94 }
95
96 void getrf(int M, int N, double *A, int LDA, int *ipiv, int &info) {
97 int bufferSize = 0;
98 cusolverDnDgetrf_bufferSize(get_handle(), M, N, A, LDA, &bufferSize);
99 auto Workspace = nda::cuvector<double>(bufferSize);
100 CUSOLVER_CHECK(cusolverDnDgetrf, info, M, N, A, LDA, Workspace.data(), ipiv);
101 }
102 void getrf(int M, int N, dcomplex *A, int LDA, int *ipiv, int &info) {
103 int bufferSize = 0;
104 cusolverDnZgetrf_bufferSize(get_handle(), M, N, cucplx(A), LDA, &bufferSize);
105 auto Workspace = nda::cuvector<dcomplex>(bufferSize);
106 CUSOLVER_CHECK(cusolverDnZgetrf, info, M, N, cucplx(A), LDA, cucplx(Workspace.data()), ipiv);
107 }
108
109 void getrs(char op, int N, int NRHS, double const *A, int LDA, int const *ipiv, double *B, int LDB, int &info) {
110 CUSOLVER_CHECK(cusolverDnDgetrs, info, get_cublas_op(op), N, NRHS, A, LDA, ipiv, B, LDB);
111 }
112 void getrs(char op, int N, int NRHS, dcomplex const *A, int LDA, int const *ipiv, dcomplex *B, int LDB, int &info) {
113 CUSOLVER_CHECK(cusolverDnZgetrs, info, get_cublas_op(op), N, NRHS, cucplx(A), LDA, ipiv, cucplx(B), LDB);
114 }
115
116} // namespace nda::lapack::device
Provides custom allocators for the nda library.
Provides the generic class for arrays.
Provides a C++ interface for the GPU versions of various LAPACK routines.
Provides various convenient aliases and helper functions for nda::basic_array and nda::basic_array_vi...
Provides GPU and non-GPU specific functionality.
Provides a custom runtime error class and macros to assert conditions and throw exceptions.
basic_array< ValueType, 1, C_layout, 'V', heap< mem::Device > > cuvector
Similar to nda::vector except the memory is stored on the device.
std::complex< double > dcomplex
Alias for std::complex<double> type.
Definition tools.hpp:39
int gesvd(A &&a, S &&s, U &&u, VT &&vt)
Interface to the LAPACK gesvd routine.
Definition gesvd.hpp:75
int getrs(A const &a, B &&b, IPIV const &ipiv)
Interface to the LAPACK getrs routine.
Definition getrs.hpp:64
int getrf(A &&a, IPIV &&ipiv)
Interface to the LAPACK getrf routine.
Definition getrf.hpp:62
Provides various handles to take care of memory management for nda::basic_array and nda::basic_array_...
Macros used in the nda library.
Provides various traits and utilities for the BLAS interface.