TRIQS/nda 1.3.0
Multi-dimensional array library for C++
Loading...
Searching...
No Matches
cublas_interface.hpp
Go to the documentation of this file.
1// Copyright (c) 2022--present, The Simons Foundation
2// This file is part of TRIQS/nda and is licensed under the Apache License, Version 2.0.
3// SPDX-License-Identifier: Apache-2.0
4// See LICENSE in the root of this distribution for details.
5
10
11#pragma once
12
13#include "../tools.hpp"
14
15#ifndef NDA_HAVE_CUDA
16#error "CUDA support is not enabled in this build of nda. Please configure and install nda with -DCUDASupport=ON"
17#endif
18
19#ifndef NDA_HAVE_MAGMA
20#include "../../exceptions.hpp"
21#endif // NDA_HAVE_MAGMA
22
23namespace nda::blas::device {
24
25 void axpy(int n, double alpha, const double *x, int incx, double *y, int incy);
26 void axpy(int n, dcomplex alpha, const dcomplex *x, int incx, dcomplex *y, int incy);
27
28 void copy(int n, const double *x, int incx, double *y, int incy);
29 void copy(int n, const dcomplex *x, int incx, dcomplex *y, int incy);
30
31 double dot(int m, const double *x, int incx, const double *y, int incy);
32 dcomplex dot(int m, const dcomplex *x, int incx, const dcomplex *y, int incy);
33 dcomplex dotc(int m, const dcomplex *x, int incx, const dcomplex *y, int incy);
34
35 void gemm(char op_a, char op_b, int m, int n, int k, double alpha, const double *a, int lda, const double *b, int ldb, double beta, double *c,
36 int ldc);
37 void gemm(char op_a, char op_b, int m, int n, int k, dcomplex alpha, const dcomplex *a, int lda, const dcomplex *b, int ldb, dcomplex beta,
38 dcomplex *c, int ldc);
39
40 void gemm_batch(char op_a, char op_b, int m, int n, int k, double alpha, const double **a, int lda, const double **b, int ldb, double beta,
41 double **c, int ldc, int batch_count);
42 void gemm_batch(char op_a, char op_b, int m, int n, int k, dcomplex alpha, const dcomplex **a, int lda, const dcomplex **b, int ldb, dcomplex beta,
43 dcomplex **c, int ldc, int batch_count);
44
45#ifdef NDA_HAVE_MAGMA
46 void gemm_vbatch(char op_a, char op_b, int *m, int *n, int *k, double alpha, const double **a, int *lda, const double **b, int *ldb, double beta,
47 double **c, int *ldc, int batch_count);
48 void gemm_vbatch(char op_a, char op_b, int *m, int *n, int *k, dcomplex alpha, const dcomplex **a, int *lda, const dcomplex **b, int *ldb,
49 dcomplex beta, dcomplex **c, int *ldc, int batch_count);
50#else
51 inline void gemm_vbatch(char, char, int *, int *, int *, double, const double **, int *, const double **, int *, double, double **, int *, int) {
52 NDA_RUNTIME_ERROR << "nda::blas::device::gemmv_batch requires Magma [https://icl.cs.utk.edu/magma/]. Configure nda with -DUse_Magma=ON";
53 }
54 inline void gemm_vbatch(char, char, int *, int *, int *, dcomplex, const dcomplex **, int *, const dcomplex **, int *, dcomplex, dcomplex **, int *,
55 int) {
56 NDA_RUNTIME_ERROR << "nda::blas::device::gemmv_batch requires Magma [https://icl.cs.utk.edu/magma/]. Configure nda with -DUse_Magma=ON";
57 }
58#endif
59
60 void gemm_batch_strided(char op_a, char op_b, int m, int n, int k, double alpha, const double *a, int lda, int stride_a, const double *b, int ldb,
61 int stride_b, double beta, double *c, int ldc, int stride_c, int batch_count);
62 void gemm_batch_strided(char op_a, char op_b, int m, int n, int k, dcomplex alpha, const dcomplex *a, int lda, int stride_a, const dcomplex *b,
63 int ldb, int stride_b, dcomplex beta, dcomplex *c, int ldc, int stride_c, int batch_count);
64
65 void gemv(char op, int m, int n, double alpha, const double *a, int lda, const double *x, int incx, double beta, double *y, int incy);
66 void gemv(char op, int m, int n, dcomplex alpha, const dcomplex *a, int lda, const dcomplex *x, int incx, dcomplex beta, dcomplex *y, int incy);
67
68 void ger(int m, int n, double alpha, const double *x, int incx, const double *y, int incy, double *a, int lda);
69 void ger(int m, int n, dcomplex alpha, const dcomplex *x, int incx, const dcomplex *y, int incy, dcomplex *a, int lda);
70 void gerc(int m, int n, dcomplex alpha, const dcomplex *x, int incx, const dcomplex *y, int incy, dcomplex *a, int lda);
71
72 void scal(int m, double alpha, double *x, int incx);
73 void scal(int m, dcomplex alpha, dcomplex *x, int incx);
74
75 void swap(int n, double *x, int incx, double *y, int incy); // NOLINT (this is a BLAS swap)
76 void swap(int n, dcomplex *x, int incx, dcomplex *y, int incy); // NOLINT (this is a BLAS swap)
77
78} // namespace nda::blas::device
Provides a custom runtime error class and macros to assert conditions and throw exceptions.
std::complex< double > dcomplex
Alias for std::complex<double> type.
Definition tools.hpp:28
Provides various traits and utilities for the BLAS interface.