TRIQS/nda
2.0.0
Multi-dimensional array library for C++
Toggle main menu visibility
Loading...
Searching...
No Matches
cublas_interface.hpp
Go to the documentation of this file.
1
// Copyright (c) 2022--present, The Simons Foundation
2
// This file is part of TRIQS/nda and is licensed under the Apache License, Version 2.0.
3
// SPDX-License-Identifier: Apache-2.0
4
// See LICENSE in the root of this distribution for details.
5
10
11
#pragma once
12
13
#include "
../tools.hpp
"
14
15
#ifndef NDA_HAVE_CUDA
16
#error "CUDA support is not enabled in this build of nda. Please configure and install nda with -DCUDASupport=ON"
17
#endif
18
19
#ifndef NDA_HAVE_MAGMA
20
#include "
../../exceptions.hpp
"
21
#endif
// NDA_HAVE_MAGMA
22
23
namespace
nda::blas::device {
24
25
void
axpy(
int
n,
double
alpha,
const
double
*x,
int
incx,
double
*y,
int
incy);
26
void
axpy(
int
n,
dcomplex
alpha,
const
dcomplex
*x,
int
incx,
dcomplex
*y,
int
incy);
27
28
void
copy(
int
n,
const
double
*x,
int
incx,
double
*y,
int
incy);
29
void
copy(
int
n,
const
dcomplex
*x,
int
incx,
dcomplex
*y,
int
incy);
30
31
double
dot(
int
m,
const
double
*x,
int
incx,
const
double
*y,
int
incy);
32
dcomplex
dot(
int
m,
const
dcomplex
*x,
int
incx,
const
dcomplex
*y,
int
incy);
33
dcomplex
dotc(
int
m,
const
dcomplex
*x,
int
incx,
const
dcomplex
*y,
int
incy);
34
35
void
gemm(
char
op_a,
char
op_b,
int
m,
int
n,
int
k,
double
alpha,
const
double
*a,
int
lda,
const
double
*b,
int
ldb,
double
beta,
double
*c,
36
int
ldc);
37
void
gemm(
char
op_a,
char
op_b,
int
m,
int
n,
int
k,
dcomplex
alpha,
const
dcomplex
*a,
int
lda,
const
dcomplex
*b,
int
ldb,
dcomplex
beta,
38
dcomplex
*c,
int
ldc);
39
40
void
gemm_batch(
char
op_a,
char
op_b,
int
m,
int
n,
int
k,
double
alpha,
const
double
**a,
int
lda,
const
double
**b,
int
ldb,
double
beta,
41
double
**c,
int
ldc,
int
batch_count);
42
void
gemm_batch(
char
op_a,
char
op_b,
int
m,
int
n,
int
k,
dcomplex
alpha,
const
dcomplex
**a,
int
lda,
const
dcomplex
**b,
int
ldb,
dcomplex
beta,
43
dcomplex
**c,
int
ldc,
int
batch_count);
44
45
#ifdef NDA_HAVE_MAGMA
46
void
gemm_vbatch(
char
op_a,
char
op_b,
int
*m,
int
*n,
int
*k,
double
alpha,
const
double
**a,
int
*lda,
const
double
**b,
int
*ldb,
double
beta,
47
double
**c,
int
*ldc,
int
batch_count);
48
void
gemm_vbatch(
char
op_a,
char
op_b,
int
*m,
int
*n,
int
*k,
dcomplex
alpha,
const
dcomplex
**a,
int
*lda,
const
dcomplex
**b,
int
*ldb,
49
dcomplex
beta,
dcomplex
**c,
int
*ldc,
int
batch_count);
50
#else
51
inline
void
gemm_vbatch(
char
,
char
,
int
*,
int
*,
int
*,
double
,
const
double
**,
int
*,
const
double
**,
int
*,
double
,
double
**,
int
*,
int
) {
52
NDA_RUNTIME_ERROR <<
"nda::blas::device::gemmv_batch requires Magma [https://icl.cs.utk.edu/magma/]. Configure nda with -DUse_Magma=ON"
;
53
}
54
inline
void
gemm_vbatch(
char
,
char
,
int
*,
int
*,
int
*,
dcomplex
,
const
dcomplex
**,
int
*,
const
dcomplex
**,
int
*,
dcomplex
,
dcomplex
**,
int
*,
55
int
) {
56
NDA_RUNTIME_ERROR <<
"nda::blas::device::gemmv_batch requires Magma [https://icl.cs.utk.edu/magma/]. Configure nda with -DUse_Magma=ON"
;
57
}
58
#endif
59
60
void
gemm_batch_strided(
char
op_a,
char
op_b,
int
m,
int
n,
int
k,
double
alpha,
const
double
*a,
int
lda,
int
stride_a,
const
double
*b,
int
ldb,
61
int
stride_b,
double
beta,
double
*c,
int
ldc,
int
stride_c,
int
batch_count);
62
void
gemm_batch_strided(
char
op_a,
char
op_b,
int
m,
int
n,
int
k,
dcomplex
alpha,
const
dcomplex
*a,
int
lda,
int
stride_a,
const
dcomplex
*b,
63
int
ldb,
int
stride_b,
dcomplex
beta,
dcomplex
*c,
int
ldc,
int
stride_c,
int
batch_count);
64
65
void
gemv(
char
op,
int
m,
int
n,
double
alpha,
const
double
*a,
int
lda,
const
double
*x,
int
incx,
double
beta,
double
*y,
int
incy);
66
void
gemv(
char
op,
int
m,
int
n,
dcomplex
alpha,
const
dcomplex
*a,
int
lda,
const
dcomplex
*x,
int
incx,
dcomplex
beta,
dcomplex
*y,
int
incy);
67
68
void
ger(
int
m,
int
n,
double
alpha,
const
double
*x,
int
incx,
const
double
*y,
int
incy,
double
*a,
int
lda);
69
void
ger(
int
m,
int
n,
dcomplex
alpha,
const
dcomplex
*x,
int
incx,
const
dcomplex
*y,
int
incy,
dcomplex
*a,
int
lda);
70
void
gerc(
int
m,
int
n,
dcomplex
alpha,
const
dcomplex
*x,
int
incx,
const
dcomplex
*y,
int
incy,
dcomplex
*a,
int
lda);
71
72
void
scal(
int
m,
double
alpha,
double
*x,
int
incx);
73
void
scal(
int
m,
dcomplex
alpha,
dcomplex
*x,
int
incx);
74
75
void
swap(
int
n,
double
*x,
int
incx,
double
*y,
int
incy);
// NOLINT (this is a BLAS swap)
76
void
swap(
int
n,
dcomplex
*x,
int
incx,
dcomplex
*y,
int
incy);
// NOLINT (this is a BLAS swap)
77
78
}
// namespace nda::blas::device
exceptions.hpp
Provides a custom runtime error class and macros to assert conditions and throw exceptions.
nda::dcomplex
std::complex< double > dcomplex
Alias for std::complex<double> type.
Definition
tools.hpp:28
tools.hpp
Provides various traits and utilities for the BLAS interface.
nda
blas
interface
cublas_interface.hpp
Generated by
1.17.0