TRIQS/nda 2.0.0
Multi-dimensional array library for C++
Loading...
Searching...
No Matches
cublas_interface.hpp
Go to the documentation of this file.
1// Copyright (c) 2022--present, The Simons Foundation
2// This file is part of TRIQS/nda and is licensed under the Apache License, Version 2.0.
3// SPDX-License-Identifier: Apache-2.0
4// See LICENSE in the root of this distribution for details.
5
10
11#pragma once
12
13#include "../tools.hpp"
14
15namespace nda::blas::device {
16
17 void set_synchronization(bool do_sync) noexcept;
18 bool get_synchronization() noexcept;
19
20 void axpy(int n, float alpha, const float *x, int incx, float *y, int incy);
21 void axpy(int n, std::complex<float> alpha, const std::complex<float> *x, int incx, std::complex<float> *y, int incy);
22 void axpy(int n, double alpha, const double *x, int incx, double *y, int incy);
23 void axpy(int n, std::complex<double> alpha, const std::complex<double> *x, int incx, std::complex<double> *y, int incy);
24
25 void copy(int n, const float *x, int incx, float *y, int incy);
26 void copy(int n, const std::complex<float> *x, int incx, std::complex<float> *y, int incy);
27 void copy(int n, const double *x, int incx, double *y, int incy);
28 void copy(int n, const std::complex<double> *x, int incx, std::complex<double> *y, int incy);
29
30 float dot(int m, const float *x, int incx, const float *y, int incy);
31 std::complex<float> dot(int m, const std::complex<float> *x, int incx, const std::complex<float> *y, int incy);
32 std::complex<float> dotc(int m, const std::complex<float> *x, int incx, const std::complex<float> *y, int incy);
33 double dot(int m, const double *x, int incx, const double *y, int incy);
34 std::complex<double> dot(int m, const std::complex<double> *x, int incx, const std::complex<double> *y, int incy);
35 std::complex<double> dotc(int m, const std::complex<double> *x, int incx, const std::complex<double> *y, int incy);
36
37 void gemm(char op_a, char op_b, int m, int n, int k, float alpha, const float *a, int lda, const float *b, int ldb, float beta, float *c, int ldc);
38 void gemm(char op_a, char op_b, int m, int n, int k, std::complex<float> alpha, const std::complex<float> *a, int lda, const std::complex<float> *b,
39 int ldb, std::complex<float> beta, std::complex<float> *c, int ldc);
40 void gemm(char op_a, char op_b, int m, int n, int k, double alpha, const double *a, int lda, const double *b, int ldb, double beta, double *c,
41 int ldc);
42 void gemm(char op_a, char op_b, int m, int n, int k, std::complex<double> alpha, const std::complex<double> *a, int lda,
43 const std::complex<double> *b, int ldb, std::complex<double> beta, std::complex<double> *c, int ldc);
44
45 void gemm_batch(char op_a, char op_b, int m, int n, int k, float alpha, const float **a, int lda, const float **b, int ldb, float beta, float **c,
46 int ldc, int batch_count);
47 void gemm_batch(char op_a, char op_b, int m, int n, int k, std::complex<float> alpha, const std::complex<float> **a, int lda,
48 const std::complex<float> **b, int ldb, std::complex<float> beta, std::complex<float> **c, int ldc, int batch_count);
49 void gemm_batch(char op_a, char op_b, int m, int n, int k, double alpha, const double **a, int lda, const double **b, int ldb, double beta,
50 double **c, int ldc, int batch_count);
51 void gemm_batch(char op_a, char op_b, int m, int n, int k, std::complex<double> alpha, const std::complex<double> **a, int lda,
52 const std::complex<double> **b, int ldb, std::complex<double> beta, std::complex<double> **c, int ldc, int batch_count);
53
54 void gemm_vbatch(char op_a, char op_b, int *m, int *n, int *k, float alpha, const float **a, int *lda, const float **b, int *ldb, float beta,
55 float **c, int *ldc, int batch_count);
56 void gemm_vbatch(char op_a, char op_b, int *m, int *n, int *k, std::complex<float> alpha, const std::complex<float> **a, int *lda,
57 const std::complex<float> **b, int *ldb, std::complex<float> beta, std::complex<float> **c, int *ldc, int batch_count);
58 void gemm_vbatch(char op_a, char op_b, int *m, int *n, int *k, double alpha, const double **a, int *lda, const double **b, int *ldb, double beta,
59 double **c, int *ldc, int batch_count);
60 void gemm_vbatch(char op_a, char op_b, int *m, int *n, int *k, std::complex<double> alpha, const std::complex<double> **a, int *lda,
61 const std::complex<double> **b, int *ldb, std::complex<double> beta, std::complex<double> **c, int *ldc, int batch_count);
62
63 void gemm_batch_strided(char op_a, char op_b, int m, int n, int k, float alpha, const float *a, int lda, int stride_a, const float *b, int ldb,
64 int stride_b, float beta, float *c, int ldc, int stride_c, int batch_count);
65 void gemm_batch_strided(char op_a, char op_b, int m, int n, int k, std::complex<float> alpha, const std::complex<float> *a, int lda, int stride_a,
66 const std::complex<float> *b, int ldb, int stride_b, std::complex<float> beta, std::complex<float> *c, int ldc,
67 int stride_c, int batch_count);
68 void gemm_batch_strided(char op_a, char op_b, int m, int n, int k, double alpha, const double *a, int lda, int stride_a, const double *b, int ldb,
69 int stride_b, double beta, double *c, int ldc, int stride_c, int batch_count);
70 void gemm_batch_strided(char op_a, char op_b, int m, int n, int k, std::complex<double> alpha, const std::complex<double> *a, int lda, int stride_a,
71 const std::complex<double> *b, int ldb, int stride_b, std::complex<double> beta, std::complex<double> *c, int ldc,
72 int stride_c, int batch_count);
73
74 void gemv(char op, int m, int n, float alpha, const float *a, int lda, const float *x, int incx, float beta, float *y, int incy);
75 void gemv(char op, int m, int n, std::complex<float> alpha, const std::complex<float> *a, int lda, const std::complex<float> *x, int incx,
76 std::complex<float> beta, std::complex<float> *y, int incy);
77 void gemv(char op, int m, int n, double alpha, const double *a, int lda, const double *x, int incx, double beta, double *y, int incy);
78 void gemv(char op, int m, int n, std::complex<double> alpha, const std::complex<double> *a, int lda, const std::complex<double> *x, int incx,
79 std::complex<double> beta, std::complex<double> *y, int incy);
80
81 void ger(int m, int n, float alpha, const float *x, int incx, const float *y, int incy, float *a, int lda);
82 void ger(int m, int n, std::complex<float> alpha, const std::complex<float> *x, int incx, const std::complex<float> *y, int incy,
83 std::complex<float> *a, int lda);
84 void gerc(int m, int n, std::complex<float> alpha, const std::complex<float> *x, int incx, const std::complex<float> *y, int incy,
85 std::complex<float> *a, int lda);
86 void ger(int m, int n, double alpha, const double *x, int incx, const double *y, int incy, double *a, int lda);
87 void ger(int m, int n, std::complex<double> alpha, const std::complex<double> *x, int incx, const std::complex<double> *y, int incy,
88 std::complex<double> *a, int lda);
89 void gerc(int m, int n, std::complex<double> alpha, const std::complex<double> *x, int incx, const std::complex<double> *y, int incy,
90 std::complex<double> *a, int lda);
91
92 void scal(int m, float alpha, float *x, int incx);
93 void scal(int m, std::complex<float> alpha, std::complex<float> *x, int incx);
94 void scal(int m, double alpha, double *x, int incx);
95 void scal(int m, std::complex<double> alpha, std::complex<double> *x, int incx);
96
97 void swap(int n, float *x, int incx, float *y, int incy); // NOLINT (this is a BLAS swap)
98 void swap(int n, std::complex<float> *x, int incx, std::complex<float> *y, int incy); // NOLINT (this is a BLAS swap)
99 void swap(int n, double *x, int incx, double *y, int incy); // NOLINT (this is a BLAS swap)
100 void swap(int n, std::complex<double> *x, int incx, std::complex<double> *y, int incy); // NOLINT (this is a BLAS swap)
101
102 void getrf_batch(int n, float **a_array, int lda, int *ipiv_array, int *info_array, int batch_size);
103 void getrf_batch(int n, std::complex<float> **a_array, int lda, int *ipiv_array, int *info_array, int batch_size);
104 void getrf_batch(int n, double **a_array, int lda, int *ipiv_array, int *info_array, int batch_size);
105 void getrf_batch(int n, std::complex<double> **a_array, int lda, int *ipiv_array, int *info_array, int batch_size);
106
107 void getri_batch(int n, float **a_array, int lda, int const *ipiv_array, float **c_array, int ldc, int *info_array, int batch_size);
108 void getri_batch(int n, std::complex<float> **a_array, int lda, int const *ipiv_array, std::complex<float> **c_array, int ldc, int *info_array,
109 int batch_size);
110 void getri_batch(int n, double **a_array, int lda, int const *ipiv_array, double **c_array, int ldc, int *info_array, int batch_size);
111 void getri_batch(int n, std::complex<double> **a_array, int lda, int const *ipiv_array, std::complex<double> **c_array, int ldc, int *info_array,
112 int batch_size);
113
114 void getrs_batch(char op, int n, int nrhs, const float **a_array, int lda, int const *ipiv_array, float **b_array, int ldb, int &info,
115 int batch_size);
116 void getrs_batch(char op, int n, int nrhs, const std::complex<float> **a_array, int lda, int const *ipiv_array, std::complex<float> **b_array,
117 int ldb, int &info, int batch_size);
118 void getrs_batch(char op, int n, int nrhs, const double **a_array, int lda, int const *ipiv_array, double **b_array, int ldb, int &info,
119 int batch_size);
120 void getrs_batch(char op, int n, int nrhs, const std::complex<double> **a_array, int lda, int const *ipiv_array, std::complex<double> **b_array,
121 int ldb, int &info, int batch_size);
122
123 void geqrf_batch(int n, int m, float **a_array, int lda, float **tau_array, int &info, int batch_size);
124 void geqrf_batch(int n, int m, std::complex<float> **a_array, int lda, std::complex<float> **tau_array, int &info, int batch_size);
125 void geqrf_batch(int n, int m, double **a_array, int lda, double **tau_array, int &info, int batch_size);
126 void geqrf_batch(int n, int m, std::complex<double> **a_array, int lda, std::complex<double> **tau_array, int &info, int batch_size);
127
128} // namespace nda::blas::device
Provides various traits and utilities for the BLAS interface.