TRIQS/nda 1.3.0
Multi-dimensional array library for C++
Loading...
Searching...
No Matches
cublas_interface.hpp
Go to the documentation of this file.
1// Copyright (c) 2022-2023 Simons Foundation
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7// http://www.apache.org/licenses/LICENSE-2.0.txt
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14//
15// Authors: Miguel Morales, Nils Wentzell
16
22#pragma once
23
24#include "../tools.hpp"
25
26#ifndef NDA_HAVE_MAGMA
27#include "../../exceptions.hpp"
28#endif // NDA_HAVE_MAGMA
29
30namespace nda::blas::device {
31
32 void axpy(int N, double alpha, const double *x, int incx, double *Y, int incy);
33 void axpy(int N, dcomplex alpha, const dcomplex *x, int incx, dcomplex *Y, int incy);
34
35 void copy(int N, const double *x, int incx, double *Y, int incy);
36 void copy(int N, const dcomplex *x, int incx, dcomplex *Y, int incy);
37
38 double dot(int M, const double *x, int incx, const double *Y, int incy);
39 dcomplex dot(int M, const dcomplex *x, int incx, const dcomplex *Y, int incy);
40 dcomplex dotc(int M, const dcomplex *x, int incx, const dcomplex *Y, int incy);
41
42 void gemm(char op_a, char op_b, int M, int N, int K, double alpha, const double *A, int LDA, const double *B, int LDB, double beta, double *C,
43 int LDC);
44 void gemm(char op_a, char op_b, int M, int N, int K, dcomplex alpha, const dcomplex *A, int LDA, const dcomplex *B, int LDB, dcomplex beta,
45 dcomplex *C, int LDC);
46
47 void gemm_batch(char op_a, char op_b, int M, int N, int K, double alpha, const double **A, int LDA, const double **B, int LDB, double beta,
48 double **C, int LDC, int batch_count);
49 void gemm_batch(char op_a, char op_b, int M, int N, int K, dcomplex alpha, const dcomplex **A, int LDA, const dcomplex **B, int LDB, dcomplex beta,
50 dcomplex **C, int LDC, int batch_count);
51
52#ifdef NDA_HAVE_MAGMA
53 void gemm_vbatch(char op_a, char op_b, int *M, int *N, int *K, double alpha, const double **A, int *LDA, const double **B, int *LDB, double beta,
54 double **C, int *LDC, int batch_count);
55 void gemm_vbatch(char op_a, char op_b, int *M, int *N, int *K, dcomplex alpha, const dcomplex **A, int *LDA, const dcomplex **B, int *LDB,
56 dcomplex beta, dcomplex **C, int *LDC, int batch_count);
57#else
58 inline void gemm_vbatch(char, char, int, int, int, double, const double **, int *, const double **, int *, double, double **, int *, int) {
59 NDA_RUNTIME_ERROR << "nda::blas::device::gemmv_batch requires Magma [https://icl.cs.utk.edu/magma/]. Configure nda with -DUse_Magma=ON";
60 }
61 inline void gemm_vbatch(char, char, int *, int *, int *, dcomplex, const dcomplex **, int *, const dcomplex **, int *, dcomplex, dcomplex **, int *,
62 int) {
63 NDA_RUNTIME_ERROR << "nda::blas::device::gemmv_batch requires Magma [https://icl.cs.utk.edu/magma/]. Configure nda with -DUse_Magma=ON";
64 }
65#endif
66
67 void gemm_batch_strided(char op_a, char op_b, int M, int N, int K, double alpha, const double *A, int LDA, int strideA, const double *B, int LDB,
68 int strideB, double beta, double *C, int LDC, int strideC, int batch_count);
69 void gemm_batch_strided(char op_a, char op_b, int M, int N, int K, dcomplex alpha, const dcomplex *A, int LDA, int strideA, const dcomplex *B,
70 int LDB, int srideB, dcomplex beta, dcomplex *C, int LDC, int strideC, int batch_count);
71
72 void gemv(char op, int M, int N, double alpha, const double *A, int LDA, const double *x, int incx, double beta, double *Y, int incy);
73 void gemv(char op, int M, int N, dcomplex alpha, const dcomplex *A, int LDA, const dcomplex *x, int incx, dcomplex beta, dcomplex *Y, int incy);
74
75 void ger(int M, int N, double alpha, const double *x, int incx, const double *Y, int incy, double *A, int LDA);
76 void ger(int M, int N, dcomplex alpha, const dcomplex *x, int incx, const dcomplex *Y, int incy, dcomplex *A, int LDA);
77
78 void scal(int M, double alpha, double *x, int incx);
79 void scal(int M, dcomplex alpha, dcomplex *x, int incx);
80
81 void swap(int N, double *x, int incx, double *Y, int incy); // NOLINT (this is a BLAS swap)
82 void swap(int N, dcomplex *x, int incx, dcomplex *Y, int incy); // NOLINT (this is a BLAS swap)
83
84} // namespace nda::blas::device
void swap(nda::basic_array_view< V1, R1, LP1, A1, AP1, OP1 > &a, nda::basic_array_view< V2, R2, LP2, A2, AP2, OP2 > &b)=delete
std::swap is deleted for nda::basic_array_view.
Provides a custom runtime error class and macros to assert conditions and throw exceptions.
void gemm_batch_strided(get_value_t< A > alpha, A const &a, B const &b, get_value_t< A > beta, C &&c)
Implements a strided batched version of nda::blas::gemm taking 3-dimensional arrays as arguments.
std::complex< double > dcomplex
Alias for std::complex<double> type.
Definition tools.hpp:39
void gemm_vbatch(get_value_t< A > alpha, std::vector< A > const &va, std::vector< B > const &vb, get_value_t< A > beta, std::vector< C > &vc)
Wrapper of nda::blas::gemm_batch that allows variable sized matrices.
auto dotc(X const &x, Y const &y)
Interface to the BLAS dotc routine.
Definition dot.hpp:103
void gemv(get_value_t< A > alpha, A const &a, X const &x, get_value_t< A > beta, Y &&y)
Interface to the BLAS gemv routine.
Definition gemv.hpp:89
auto dot(X const &x, Y const &y)
Interface to the BLAS dot routine.
Definition dot.hpp:61
void gemm(get_value_t< A > alpha, A const &a, B const &b, get_value_t< A > beta, C &&c)
Interface to the BLAS gemm routine.
Definition gemm.hpp:101
void ger(get_value_t< X > alpha, X const &x, Y const &y, M &&m)
Interface to the BLAS ger routine.
Definition ger.hpp:68
void gemm_batch(get_value_t< A > alpha, std::vector< A > const &va, std::vector< B > const &vb, get_value_t< A > beta, std::vector< C > &vc)
Implements a batched version of nda::blas::gemm taking vectors of matrices as arguments.
Provides various traits and utilities for the BLAS interface.