TRIQS/nda 1.3.0
Multi-dimensional array library for C++
Loading...
Searching...
No Matches
cusolver_interface.cpp
Go to the documentation of this file.
1// Copyright (c) 2022-2023 Simons Foundation
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7// http://www.apache.org/licenses/LICENSE-2.0.txt
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14//
15// Authors: Miguel Morales, Nils Wentzell
16
17/**
18 * @file
19 * @brief Implementation details for lapack/interface/cusolver_interface.hpp.
20 */
21
22#include "./cusolver_interface.hpp"
23#include "../../basic_array.hpp"
24#include "../../blas/tools.hpp"
25#include "../../declarations.hpp"
26#include "../../device.hpp"
27#include "../../exceptions.hpp"
28#include "../../macros.hpp"
29#include "../../mem/allocators.hpp"
30#include "../../mem/handle.hpp"
31
32#include <cusolverDn.h>
33
34#include <string>
35
36namespace nda::lapack::device {
37
38 // Local function to get unique CuSolver handle.
39 inline cusolverDnHandle_t &get_handle() {
40 struct handle_storage_t { // RAII for handle
41 handle_storage_t() { cusolverDnCreate(&handle); }
42 ~handle_storage_t() { cusolverDnDestroy(handle); }
43 cusolverDnHandle_t handle = {};
44 };
45 static auto sto = handle_storage_t{};
46 return sto.handle;
47 }
48
49 // Get an integer pointer in unified memory to return info from lapack routines.
50 int *get_info_ptr() {
51 static auto info_u_handle = mem::handle_heap<int, mem::mallocator<mem::Unified>>(1);
52 return info_u_handle.data();
53 }
54
55 // Global option to turn on/off the cudaDeviceSynchronize after cusolver library calls.
56 static bool synchronize = true; // NOLINT (global option is on purpose)
57
58// Macro to check cusolver calls.
59#define CUSOLVER_CHECK(X, info, ...)
60 auto err = X(get_handle(), __VA_ARGS__, get_info_ptr());
61 if (err != CUSOLVER_STATUS_SUCCESS) { NDA_RUNTIME_ERROR << AS_STRING(X) << " failed with error code " << std::to_string(err); }
62 if (synchronize) {
63 auto errsync = cudaDeviceSynchronize();
64 if (errsync != cudaSuccess) {
65 NDA_RUNTIME_ERROR << " cudaDeviceSynchronize failed after call to: " << AS_STRING(X) " \n "
66 << " cudaGetErrorName: " << std::string(cudaGetErrorName(errsync)) << "\n"
67 << " cudaGetErrorString: " << std::string(cudaGetErrorString(errsync)) << "\n";
68 }
69 }
70 info = *get_info_ptr();
71
72 void gesvd(char JOBU, char JOBVT, int M, int N, double *A, int LDA, double *S, double *U, int LDU, double *VT, int LDVT, double *WORK, int LWORK,
73 double *RWORK, int &INFO) {
74 // Replicate behavior of Netlib gesvd
75 if (LWORK == -1) {
76 int bufferSize = 0;
77 cusolverDnDgesvd_bufferSize(get_handle(), M, N, &bufferSize);
78 *WORK = bufferSize;
79 } else {
80 CUSOLVER_CHECK(cusolverDnDgesvd, INFO, JOBU, JOBVT, M, N, A, LDA, S, U, LDU, VT, LDVT, WORK, LWORK, RWORK);
81 }
82 }
83 void gesvd(char JOBU, char JOBVT, int M, int N, dcomplex *A, int LDA, double *S, dcomplex *U, int LDU, dcomplex *VT, int LDVT, dcomplex *WORK,
84 int LWORK, double *RWORK, int &INFO) {
85 // Replicate behavior of Netlib gesvd
86 if (LWORK == -1) {
87 int bufferSize = 0;
88 cusolverDnZgesvd_bufferSize(get_handle(), M, N, &bufferSize);
89 *WORK = bufferSize;
90 } else {
91 CUSOLVER_CHECK(cusolverDnZgesvd, INFO, JOBU, JOBVT, M, N, cucplx(A), LDA, S, cucplx(U), LDU, cucplx(VT), LDVT, cucplx(WORK), LWORK,
92 RWORK); // NOLINT
93 }
94 }
95
96 void getrf(int M, int N, double *A, int LDA, int *ipiv, int &info) {
97 int bufferSize = 0;
98 cusolverDnDgetrf_bufferSize(get_handle(), M, N, A, LDA, &bufferSize);
99 auto Workspace = nda::cuvector<double>(bufferSize);
100 CUSOLVER_CHECK(cusolverDnDgetrf, info, M, N, A, LDA, Workspace.data(), ipiv);
101 }
102 void getrf(int M, int N, dcomplex *A, int LDA, int *ipiv, int &info) {
103 int bufferSize = 0;
104 cusolverDnZgetrf_bufferSize(get_handle(), M, N, cucplx(A), LDA, &bufferSize);
105 auto Workspace = nda::cuvector<dcomplex>(bufferSize);
106 CUSOLVER_CHECK(cusolverDnZgetrf, info, M, N, cucplx(A), LDA, cucplx(Workspace.data()), ipiv);
107 }
108
109 void getrs(char op, int N, int NRHS, double const *A, int LDA, int const *ipiv, double *B, int LDB, int &info) {
110 CUSOLVER_CHECK(cusolverDnDgetrs, info, get_cublas_op(op), N, NRHS, A, LDA, ipiv, B, LDB);
111 }
112 void getrs(char op, int N, int NRHS, dcomplex const *A, int LDA, int const *ipiv, dcomplex *B, int LDB, int &info) {
113 CUSOLVER_CHECK(cusolverDnZgetrs, info, get_cublas_op(op), N, NRHS, cucplx(A), LDA, ipiv, cucplx(B), LDB);
114 }
115
116} // namespace nda::lapack::device
#define CUSOLVER_CHECK(X, info,...)
#define NDA_RUNTIME_ERROR
#define AS_STRING(...)
Definition macros.hpp:31