TRIQS/nda 1.3.0
Multi-dimensional array library for C++
Loading...
Searching...
No Matches
cusolver_interface.cpp
Go to the documentation of this file.
1// Copyright (c) 2022--present, The Simons Foundation
2// This file is part of TRIQS/nda and is licensed under the Apache License, Version 2.0.
3// SPDX-License-Identifier: Apache-2.0
4// See LICENSE in the root of this distribution for details.
5
10
12#include "../../basic_array.hpp"
13#include "../../blas/tools.hpp"
15#include "../../device.hpp"
16#include "../../exceptions.hpp"
17#include "../../macros.hpp"
19#include "../../mem/handle.hpp"
20
21#include <cusolverDn.h>
22
23#include <string>
24
25namespace nda::lapack::device {
26
27 // Local function to get unique CuSolver handle.
28 inline cusolverDnHandle_t &get_handle() {
29 struct handle_storage_t { // RAII for handle
30 handle_storage_t() { cusolverDnCreate(&handle); }
31 ~handle_storage_t() { cusolverDnDestroy(handle); }
32 cusolverDnHandle_t handle = {};
33 };
34 static auto sto = handle_storage_t{};
35 return sto.handle;
36 }
37
38 // Get an integer pointer in unified memory to return info from lapack routines.
39 int *get_info_ptr() {
40 static auto info_u_handle = mem::handle_heap<int, mem::mallocator<mem::Unified>>(1);
41 return info_u_handle.data();
42 }
43
44 // Global option to turn on/off the cudaDeviceSynchronize after cusolver library calls.
45 static bool synchronize = true; // NOLINT (global option is on purpose)
46
47// Macro to check cusolver calls.
48#define CUSOLVER_CHECK(X, info, ...) \
49 auto err = X(get_handle(), __VA_ARGS__, get_info_ptr()); \
50 if (err != CUSOLVER_STATUS_SUCCESS) { NDA_RUNTIME_ERROR << AS_STRING(X) << " failed with error code " << std::to_string(err); } \
51 if (synchronize) { \
52 auto errsync = cudaDeviceSynchronize(); \
53 if (errsync != cudaSuccess) { \
54 NDA_RUNTIME_ERROR << " cudaDeviceSynchronize failed after call to: " << AS_STRING(X) " \n " \
55 << " cudaGetErrorName: " << std::string(cudaGetErrorName(errsync)) << "\n" \
56 << " cudaGetErrorString: " << std::string(cudaGetErrorString(errsync)) << "\n"; \
57 } \
58 } \
59 info = *get_info_ptr();
60
61 void gesvd(char JOBU, char JOBVT, int M, int N, double *A, int LDA, double *S, double *U, int LDU, double *VT, int LDVT, double *WORK, int LWORK,
62 double *RWORK, int &INFO) {
63 // Replicate behavior of Netlib gesvd
64 if (LWORK == -1) {
65 int bufferSize = 0;
66 cusolverDnDgesvd_bufferSize(get_handle(), M, N, &bufferSize);
67 *WORK = bufferSize;
68 } else {
69 CUSOLVER_CHECK(cusolverDnDgesvd, INFO, JOBU, JOBVT, M, N, A, LDA, S, U, LDU, VT, LDVT, WORK, LWORK, RWORK);
70 }
71 }
72 void gesvd(char JOBU, char JOBVT, int M, int N, dcomplex *A, int LDA, double *S, dcomplex *U, int LDU, dcomplex *VT, int LDVT, dcomplex *WORK,
73 int LWORK, double *RWORK, int &INFO) {
74 // Replicate behavior of Netlib gesvd
75 if (LWORK == -1) {
76 int bufferSize = 0;
77 cusolverDnZgesvd_bufferSize(get_handle(), M, N, &bufferSize);
78 *WORK = bufferSize;
79 } else {
80 CUSOLVER_CHECK(cusolverDnZgesvd, INFO, JOBU, JOBVT, M, N, cucplx(A), LDA, S, cucplx(U), LDU, cucplx(VT), LDVT, cucplx(WORK), LWORK,
81 RWORK); // NOLINT
82 }
83 }
84
85 void getrf(int M, int N, double *A, int LDA, int *ipiv, int &info) {
86 int bufferSize = 0;
87 cusolverDnDgetrf_bufferSize(get_handle(), M, N, A, LDA, &bufferSize);
88 auto Workspace = nda::cuvector<double>(bufferSize);
89 CUSOLVER_CHECK(cusolverDnDgetrf, info, M, N, A, LDA, Workspace.data(), ipiv);
90 }
91 void getrf(int M, int N, dcomplex *A, int LDA, int *ipiv, int &info) {
92 int bufferSize = 0;
93 cusolverDnZgetrf_bufferSize(get_handle(), M, N, cucplx(A), LDA, &bufferSize);
94 auto Workspace = nda::cuvector<dcomplex>(bufferSize);
95 CUSOLVER_CHECK(cusolverDnZgetrf, info, M, N, cucplx(A), LDA, cucplx(Workspace.data()), ipiv);
96 }
97
98 void getrs(char op, int N, int NRHS, double const *A, int LDA, int const *ipiv, double *B, int LDB, int &info) {
99 CUSOLVER_CHECK(cusolverDnDgetrs, info, get_cublas_op(op), N, NRHS, A, LDA, ipiv, B, LDB);
100 }
101 void getrs(char op, int N, int NRHS, dcomplex const *A, int LDA, int const *ipiv, dcomplex *B, int LDB, int &info) {
102 CUSOLVER_CHECK(cusolverDnZgetrs, info, get_cublas_op(op), N, NRHS, cucplx(A), LDA, ipiv, cucplx(B), LDB);
103 }
104
105} // namespace nda::lapack::device
Provides custom allocators for the nda library.
Provides the generic class for arrays.
Provides a C++ interface for the GPU versions of various LAPACK routines.
Provides various convenient aliases and helper functions for nda::basic_array and nda::basic_array_vi...
Provides GPU and non-GPU specific functionality.
Provides a custom runtime error class and macros to assert conditions and throw exceptions.
basic_array< ValueType, 1, C_layout, 'V', heap< mem::Device > > cuvector
Similar to nda::vector except the memory is stored on the device.
std::complex< double > dcomplex
Alias for std::complex<double> type.
Definition tools.hpp:28
Provides various handles to take care of memory management for nda::basic_array and nda::basic_array_...
Macros used in the nda library.
Provides various traits and utilities for the BLAS interface.