TRIQS/nda 1.3.0
Multi-dimensional array library for C++
Loading...
Searching...
No Matches
cublas_interface.hpp
Go to the documentation of this file.
1// Copyright (c) 2022-2023 Simons Foundation
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7// http://www.apache.org/licenses/LICENSE-2.0.txt
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14//
15// Authors: Miguel Morales, Nils Wentzell
16
17/**
18 * @file
19 * @brief Provides a C++ interface for the GPU versions of various BLAS routines.
20 */
21
22#pragma once
23
24#include "../tools.hpp"
25
26#ifndef NDA_HAVE_MAGMA
27#include "../../exceptions.hpp"
28#endif // NDA_HAVE_MAGMA
29
30namespace nda::blas::device {
31
32 void axpy(int N, double alpha, const double *x, int incx, double *Y, int incy);
33 void axpy(int N, dcomplex alpha, const dcomplex *x, int incx, dcomplex *Y, int incy);
34
35 void copy(int N, const double *x, int incx, double *Y, int incy);
36 void copy(int N, const dcomplex *x, int incx, dcomplex *Y, int incy);
37
38 double dot(int M, const double *x, int incx, const double *Y, int incy);
39 dcomplex dot(int M, const dcomplex *x, int incx, const dcomplex *Y, int incy);
40 dcomplex dotc(int M, const dcomplex *x, int incx, const dcomplex *Y, int incy);
41
42 void gemm(char op_a, char op_b, int M, int N, int K, double alpha, const double *A, int LDA, const double *B, int LDB, double beta, double *C,
43 int LDC);
44 void gemm(char op_a, char op_b, int M, int N, int K, dcomplex alpha, const dcomplex *A, int LDA, const dcomplex *B, int LDB, dcomplex beta,
45 dcomplex *C, int LDC);
46
47 void gemm_batch(char op_a, char op_b, int M, int N, int K, double alpha, const double **A, int LDA, const double **B, int LDB, double beta,
48 double **C, int LDC, int batch_count);
49 void gemm_batch(char op_a, char op_b, int M, int N, int K, dcomplex alpha, const dcomplex **A, int LDA, const dcomplex **B, int LDB, dcomplex beta,
50 dcomplex **C, int LDC, int batch_count);
51
52#ifdef NDA_HAVE_MAGMA
53 void gemm_vbatch(char op_a, char op_b, int *M, int *N, int *K, double alpha, const double **A, int *LDA, const double **B, int *LDB, double beta,
54 double **C, int *LDC, int batch_count);
55 void gemm_vbatch(char op_a, char op_b, int *M, int *N, int *K, dcomplex alpha, const dcomplex **A, int *LDA, const dcomplex **B, int *LDB,
56 dcomplex beta, dcomplex **C, int *LDC, int batch_count);
57#else
58 inline void gemm_vbatch(char, char, int, int, int, double, const double **, int *, const double **, int *, double, double **, int *, int) {
59 NDA_RUNTIME_ERROR << "nda::blas::device::gemmv_batch requires Magma [https://icl.cs.utk.edu/magma/]. Configure nda with -DUse_Magma=ON";
60 }
61 inline void gemm_vbatch(char, char, int *, int *, int *, dcomplex, const dcomplex **, int *, const dcomplex **, int *, dcomplex, dcomplex **, int *,
62 int) {
63 NDA_RUNTIME_ERROR << "nda::blas::device::gemmv_batch requires Magma [https://icl.cs.utk.edu/magma/]. Configure nda with -DUse_Magma=ON";
64 }
65#endif
66
67 void gemm_batch_strided(char op_a, char op_b, int M, int N, int K, double alpha, const double *A, int LDA, int strideA, const double *B, int LDB,
68 int strideB, double beta, double *C, int LDC, int strideC, int batch_count);
69 void gemm_batch_strided(char op_a, char op_b, int M, int N, int K, dcomplex alpha, const dcomplex *A, int LDA, int strideA, const dcomplex *B,
70 int LDB, int srideB, dcomplex beta, dcomplex *C, int LDC, int strideC, int batch_count);
71
72 void gemv(char op, int M, int N, double alpha, const double *A, int LDA, const double *x, int incx, double beta, double *Y, int incy);
73 void gemv(char op, int M, int N, dcomplex alpha, const dcomplex *A, int LDA, const dcomplex *x, int incx, dcomplex beta, dcomplex *Y, int incy);
74
75 void ger(int M, int N, double alpha, const double *x, int incx, const double *Y, int incy, double *A, int LDA);
76 void ger(int M, int N, dcomplex alpha, const dcomplex *x, int incx, const dcomplex *Y, int incy, dcomplex *A, int LDA);
77
78 void scal(int M, double alpha, double *x, int incx);
79 void scal(int M, dcomplex alpha, dcomplex *x, int incx);
80
81 void swap(int N, double *x, int incx, double *Y, int incy); // NOLINT (this is a BLAS swap)
82 void swap(int N, dcomplex *x, int incx, dcomplex *Y, int incy); // NOLINT (this is a BLAS swap)
83
84} // namespace nda::blas::device
#define CUBLAS_CHECK(X,...)
#define NDA_RUNTIME_ERROR
#define AS_STRING(...)
Definition macros.hpp:31