TRIQS/nda 1.3.0
Multi-dimensional array library for C++
Loading...
Searching...
No Matches
cublas_interface.hpp
Go to the documentation of this file.
1// Copyright (c) 2022--present, The Simons Foundation
2// This file is part of TRIQS/nda and is licensed under the Apache License, Version 2.0.
3// SPDX-License-Identifier: Apache-2.0
4// See LICENSE in the root of this distribution for details.
5
10
11#pragma once
12
13#include "../tools.hpp"
14
15#ifndef NDA_HAVE_CUDA
16#error "CUDA support is not enabled in this build of nda. Please configure and install nda with -DCUDASupport=ON"
17#endif
18
19#ifndef NDA_HAVE_MAGMA
20#include "../../exceptions.hpp"
21#endif // NDA_HAVE_MAGMA
22
23namespace nda::blas::device {
24
25 void axpy(int N, double alpha, const double *x, int incx, double *Y, int incy);
26 void axpy(int N, dcomplex alpha, const dcomplex *x, int incx, dcomplex *Y, int incy);
27
28 void copy(int N, const double *x, int incx, double *Y, int incy);
29 void copy(int N, const dcomplex *x, int incx, dcomplex *Y, int incy);
30
31 double dot(int M, const double *x, int incx, const double *Y, int incy);
32 dcomplex dot(int M, const dcomplex *x, int incx, const dcomplex *Y, int incy);
33 dcomplex dotc(int M, const dcomplex *x, int incx, const dcomplex *Y, int incy);
34
35 void gemm(char op_a, char op_b, int M, int N, int K, double alpha, const double *A, int LDA, const double *B, int LDB, double beta, double *C,
36 int LDC);
37 void gemm(char op_a, char op_b, int M, int N, int K, dcomplex alpha, const dcomplex *A, int LDA, const dcomplex *B, int LDB, dcomplex beta,
38 dcomplex *C, int LDC);
39
40 void gemm_batch(char op_a, char op_b, int M, int N, int K, double alpha, const double **A, int LDA, const double **B, int LDB, double beta,
41 double **C, int LDC, int batch_count);
42 void gemm_batch(char op_a, char op_b, int M, int N, int K, dcomplex alpha, const dcomplex **A, int LDA, const dcomplex **B, int LDB, dcomplex beta,
43 dcomplex **C, int LDC, int batch_count);
44
45#ifdef NDA_HAVE_MAGMA
46 void gemm_vbatch(char op_a, char op_b, int *M, int *N, int *K, double alpha, const double **A, int *LDA, const double **B, int *LDB, double beta,
47 double **C, int *LDC, int batch_count);
48 void gemm_vbatch(char op_a, char op_b, int *M, int *N, int *K, dcomplex alpha, const dcomplex **A, int *LDA, const dcomplex **B, int *LDB,
49 dcomplex beta, dcomplex **C, int *LDC, int batch_count);
50#else
51 inline void gemm_vbatch(char, char, int *, int *, int *, double, const double **, int *, const double **, int *, double, double **, int *, int) {
52 NDA_RUNTIME_ERROR << "nda::blas::device::gemmv_batch requires Magma [https://icl.cs.utk.edu/magma/]. Configure nda with -DUse_Magma=ON";
53 }
54 inline void gemm_vbatch(char, char, int *, int *, int *, dcomplex, const dcomplex **, int *, const dcomplex **, int *, dcomplex, dcomplex **, int *,
55 int) {
56 NDA_RUNTIME_ERROR << "nda::blas::device::gemmv_batch requires Magma [https://icl.cs.utk.edu/magma/]. Configure nda with -DUse_Magma=ON";
57 }
58#endif
59
60 void gemm_batch_strided(char op_a, char op_b, int M, int N, int K, double alpha, const double *A, int LDA, int strideA, const double *B, int LDB,
61 int strideB, double beta, double *C, int LDC, int strideC, int batch_count);
62 void gemm_batch_strided(char op_a, char op_b, int M, int N, int K, dcomplex alpha, const dcomplex *A, int LDA, int strideA, const dcomplex *B,
63 int LDB, int srideB, dcomplex beta, dcomplex *C, int LDC, int strideC, int batch_count);
64
65 void gemv(char op, int M, int N, double alpha, const double *A, int LDA, const double *x, int incx, double beta, double *Y, int incy);
66 void gemv(char op, int M, int N, dcomplex alpha, const dcomplex *A, int LDA, const dcomplex *x, int incx, dcomplex beta, dcomplex *Y, int incy);
67
68 void ger(int M, int N, double alpha, const double *x, int incx, const double *Y, int incy, double *A, int LDA);
69 void ger(int M, int N, dcomplex alpha, const dcomplex *x, int incx, const dcomplex *Y, int incy, dcomplex *A, int LDA);
70
71 void scal(int M, double alpha, double *x, int incx);
72 void scal(int M, dcomplex alpha, dcomplex *x, int incx);
73
74 void swap(int N, double *x, int incx, double *Y, int incy); // NOLINT (this is a BLAS swap)
75 void swap(int N, dcomplex *x, int incx, dcomplex *Y, int incy); // NOLINT (this is a BLAS swap)
76
77} // namespace nda::blas::device
Provides a custom runtime error class and macros to assert conditions and throw exceptions.
std::complex< double > dcomplex
Alias for std::complex<double> type.
Definition tools.hpp:28
Provides various traits and utilities for the BLAS interface.