TRIQS/nda 1.3.0
Multi-dimensional array library for C++
Loading...
Searching...
No Matches
cublas_interface.hpp
Go to the documentation of this file.
1// Copyright (c) 2022--present, The Simons Foundation
2// This file is part of TRIQS/nda and is licensed under the Apache License, Version 2.0.
3// SPDX-License-Identifier: Apache-2.0
4// See LICENSE in the root of this distribution for details.
5
10
11#pragma once
12
13#include "../tools.hpp"
14
15#ifndef NDA_HAVE_MAGMA
16#include "../../exceptions.hpp"
17#endif // NDA_HAVE_MAGMA
18
19namespace nda::blas::device {
20
21 void axpy(int N, double alpha, const double *x, int incx, double *Y, int incy);
22 void axpy(int N, dcomplex alpha, const dcomplex *x, int incx, dcomplex *Y, int incy);
23
24 void copy(int N, const double *x, int incx, double *Y, int incy);
25 void copy(int N, const dcomplex *x, int incx, dcomplex *Y, int incy);
26
27 double dot(int M, const double *x, int incx, const double *Y, int incy);
28 dcomplex dot(int M, const dcomplex *x, int incx, const dcomplex *Y, int incy);
29 dcomplex dotc(int M, const dcomplex *x, int incx, const dcomplex *Y, int incy);
30
31 void gemm(char op_a, char op_b, int M, int N, int K, double alpha, const double *A, int LDA, const double *B, int LDB, double beta, double *C,
32 int LDC);
33 void gemm(char op_a, char op_b, int M, int N, int K, dcomplex alpha, const dcomplex *A, int LDA, const dcomplex *B, int LDB, dcomplex beta,
34 dcomplex *C, int LDC);
35
36 void gemm_batch(char op_a, char op_b, int M, int N, int K, double alpha, const double **A, int LDA, const double **B, int LDB, double beta,
37 double **C, int LDC, int batch_count);
38 void gemm_batch(char op_a, char op_b, int M, int N, int K, dcomplex alpha, const dcomplex **A, int LDA, const dcomplex **B, int LDB, dcomplex beta,
39 dcomplex **C, int LDC, int batch_count);
40
41#ifdef NDA_HAVE_MAGMA
42 void gemm_vbatch(char op_a, char op_b, int *M, int *N, int *K, double alpha, const double **A, int *LDA, const double **B, int *LDB, double beta,
43 double **C, int *LDC, int batch_count);
44 void gemm_vbatch(char op_a, char op_b, int *M, int *N, int *K, dcomplex alpha, const dcomplex **A, int *LDA, const dcomplex **B, int *LDB,
45 dcomplex beta, dcomplex **C, int *LDC, int batch_count);
46#else
47 inline void gemm_vbatch(char, char, int *, int *, int *, double, const double **, int *, const double **, int *, double, double **, int *, int) {
48 NDA_RUNTIME_ERROR << "nda::blas::device::gemmv_batch requires Magma [https://icl.cs.utk.edu/magma/]. Configure nda with -DUse_Magma=ON";
49 }
50 inline void gemm_vbatch(char, char, int *, int *, int *, dcomplex, const dcomplex **, int *, const dcomplex **, int *, dcomplex, dcomplex **, int *,
51 int) {
52 NDA_RUNTIME_ERROR << "nda::blas::device::gemmv_batch requires Magma [https://icl.cs.utk.edu/magma/]. Configure nda with -DUse_Magma=ON";
53 }
54#endif
55
56 void gemm_batch_strided(char op_a, char op_b, int M, int N, int K, double alpha, const double *A, int LDA, int strideA, const double *B, int LDB,
57 int strideB, double beta, double *C, int LDC, int strideC, int batch_count);
58 void gemm_batch_strided(char op_a, char op_b, int M, int N, int K, dcomplex alpha, const dcomplex *A, int LDA, int strideA, const dcomplex *B,
59 int LDB, int srideB, dcomplex beta, dcomplex *C, int LDC, int strideC, int batch_count);
60
61 void gemv(char op, int M, int N, double alpha, const double *A, int LDA, const double *x, int incx, double beta, double *Y, int incy);
62 void gemv(char op, int M, int N, dcomplex alpha, const dcomplex *A, int LDA, const dcomplex *x, int incx, dcomplex beta, dcomplex *Y, int incy);
63
64 void ger(int M, int N, double alpha, const double *x, int incx, const double *Y, int incy, double *A, int LDA);
65 void ger(int M, int N, dcomplex alpha, const dcomplex *x, int incx, const dcomplex *Y, int incy, dcomplex *A, int LDA);
66
67 void scal(int M, double alpha, double *x, int incx);
68 void scal(int M, dcomplex alpha, dcomplex *x, int incx);
69
70 void swap(int N, double *x, int incx, double *Y, int incy); // NOLINT (this is a BLAS swap)
71 void swap(int N, dcomplex *x, int incx, dcomplex *Y, int incy); // NOLINT (this is a BLAS swap)
72
73} // namespace nda::blas::device
Provides a custom runtime error class and macros to assert conditions and throw exceptions.
std::complex< double > dcomplex
Alias for std::complex<double> type.
Definition tools.hpp:28
Provides various traits and utilities for the BLAS interface.