TRIQS/nda 2.0.0
Multi-dimensional array library for C++
Loading...
Searching...
No Matches
cusolver_interface.hpp
Go to the documentation of this file.
1// Copyright (c) 2022--present, The Simons Foundation
2// This file is part of TRIQS/nda and is licensed under the Apache License, Version 2.0.
3// SPDX-License-Identifier: Apache-2.0
4// See LICENSE in the root of this distribution for details.
5
10
11#pragma once
12
13#include "../../blas/tools.hpp"
14
15namespace nda::lapack::device {
16
17 void set_synchronization(bool do_sync) noexcept;
18 bool get_synchronization() noexcept;
19
20 int gesvd_buffer_size(int m, int n, float *);
21 int gesvd_buffer_size(int m, int n, std::complex<float> *);
22 int gesvd_buffer_size(int m, int n, double *);
23 int gesvd_buffer_size(int m, int n, std::complex<double> *);
24
25 void gesvd(char jobu, char jobvt, int m, int n, float *a, int lda, float *s, float *u, int ldu, float *vt, int ldvt, float *work, int lwork,
26 float *rwork, int &info);
27 void gesvd(char jobu, char jobvt, int m, int n, std::complex<float> *a, int lda, float *s, std::complex<float> *u, int ldu, std::complex<float> *vt,
28 int ldvt, std::complex<float> *work, int lwork, float *rwork, int &info);
29 void gesvd(char jobu, char jobvt, int m, int n, double *a, int lda, double *s, double *u, int ldu, double *vt, int ldvt, double *work, int lwork,
30 double *rwork, int &info);
31 void gesvd(char jobu, char jobvt, int m, int n, std::complex<double> *a, int lda, double *s, std::complex<double> *u, int ldu,
32 std::complex<double> *vt, int ldvt, std::complex<double> *work, int lwork, double *rwork, int &info);
33
34 int getrf_buffer_size(int m, int n, float *a, int lda);
35 int getrf_buffer_size(int m, int n, std::complex<float> *a, int lda);
36 int getrf_buffer_size(int m, int n, double *a, int lda);
37 int getrf_buffer_size(int m, int n, std::complex<double> *a, int lda);
38
39 void getrf(int m, int n, float *a, int lda, float *work, int *ipiv, int &info);
40 void getrf(int m, int n, std::complex<float> *a, int lda, std::complex<float> *work, int *ipiv, int &info);
41 void getrf(int m, int n, double *a, int lda, double *work, int *ipiv, int &info);
42 void getrf(int m, int n, std::complex<double> *a, int lda, std::complex<double> *work, int *ipiv, int &info);
43
44 void getrs(char op, int n, int nrhs, float const *a, int lda, int const *ipiv, float *b, int ldb, int &info);
45 void getrs(char op, int n, int nrhs, std::complex<float> const *a, int lda, int const *ipiv, std::complex<float> *b, int ldb, int &info);
46 void getrs(char op, int n, int nrhs, double const *a, int lda, int const *ipiv, double *b, int ldb, int &info);
47 void getrs(char op, int n, int nrhs, std::complex<double> const *a, int lda, int const *ipiv, std::complex<double> *b, int ldb, int &info);
48
49 template <typename T>
50 int getri_buffer_size(int n, T * /*a*/, int /*lda*/) {
51 return n * n;
52 }
53
54 void getri(int n, float *a, int lda, int const *ipiv, float *work, int lwork, int &info);
55 void getri(int n, std::complex<float> *a, int lda, int const *ipiv, std::complex<float> *work, int lwork, int &info);
56 void getri(int n, double *a, int lda, int const *ipiv, double *work, int lwork, int &info);
57 void getri(int n, std::complex<double> *a, int lda, int const *ipiv, std::complex<double> *work, int lwork, int &info);
58
59 int geqrf_buffer_size(int m, int n, float *a, int lda);
60 int geqrf_buffer_size(int m, int n, std::complex<float> *a, int lda);
61 int geqrf_buffer_size(int m, int n, double *a, int lda);
62 int geqrf_buffer_size(int m, int n, std::complex<double> *a, int lda);
63
64 void geqrf(int m, int n, float *a, int lda, float *tau, float *work, int lwork, int &info);
65 void geqrf(int m, int n, std::complex<float> *a, int lda, std::complex<float> *tau, std::complex<float> *work, int lwork, int &info);
66 void geqrf(int m, int n, double *a, int lda, double *tau, double *work, int lwork, int &info);
67 void geqrf(int m, int n, std::complex<double> *a, int lda, std::complex<double> *tau, std::complex<double> *work, int lwork, int &info);
68
69 int orgqr_buffer_size(int m, int n, int k, float const *a, int lda, float const *tau);
70 int orgqr_buffer_size(int m, int n, int k, double const *a, int lda, double const *tau);
71
72 void orgqr(int m, int n, int k, float *a, int lda, float const *tau, float *work, int lwork, int &info);
73 void orgqr(int m, int n, int k, double *a, int lda, double const *tau, double *work, int lwork, int &info);
74
75 int ungqr_buffer_size(int m, int n, int k, std::complex<float> const *a, int lda, std::complex<float> const *tau);
76 int ungqr_buffer_size(int m, int n, int k, std::complex<double> const *a, int lda, std::complex<double> const *tau);
77
78 void ungqr(int m, int n, int k, std::complex<float> *a, int lda, std::complex<float> const *tau, std::complex<float> *work, int lwork, int &info);
79 void ungqr(int m, int n, int k, std::complex<double> *a, int lda, std::complex<double> const *tau, std::complex<double> *work, int lwork,
80 int &info);
81
82} // namespace nda::lapack::device
Provides various traits and utilities for the BLAS interface.