12#include "./cblas_f77.h"
27 static int const mkl_interface_layer = mkl_set_interface_layer(MKL_INTERFACE_LP64 + MKL_INTERFACE_GNU);
29 inline auto *mklcplx(
nda::dcomplex *c) {
return reinterpret_cast<MKL_Complex16 *
>(c); }
30 inline auto *mklcplx(
nda::dcomplex const *c) {
return reinterpret_cast<const MKL_Complex16 *
>(c); }
31 inline auto *mklcplx(
nda::dcomplex **c) {
return reinterpret_cast<MKL_Complex16 **
>(c); }
32 inline auto *mklcplx(
nda::dcomplex const **c) {
return reinterpret_cast<const MKL_Complex16 **
>(c); }
40 struct nda_complex_double {
48#define F77_ddot F77_GLOBAL(ddot, DDOT)
49#define F77_zdotu F77_GLOBAL(zdotu, ZDOTU)
50#define F77_zdotc F77_GLOBAL(zdotc, ZDOTC)
52double F77_ddot(FINT,
const double *, FINT,
const double *, FINT);
53nda_complex_double F77_zdotu(FINT,
const double *, FINT,
const double *, FINT);
54nda_complex_double F77_zdotc(FINT,
const double *, FINT,
const double *, FINT);
57namespace nda::blas::f77 {
59 inline auto *blacplx(
dcomplex *c) {
return reinterpret_cast<double *
>(c); }
60 inline auto *blacplx(
dcomplex const *c) {
return reinterpret_cast<const double *
>(c); }
61 inline auto **blacplx(
dcomplex **c) {
return reinterpret_cast<double **
>(c); }
62 inline auto **blacplx(
dcomplex const **c) {
return reinterpret_cast<const double **
>(c); }
64 void axpy(
int n,
double alpha,
const double *x,
int incx,
double *y,
int incy) { F77_daxpy(&n, &alpha, x, &incx, y, &incy); }
66 F77_zaxpy(&n, blacplx(&alpha), blacplx(x), &incx, blacplx(y), &incy);
70 void copy(
int n,
const double *x,
int incx,
double *y,
int incy) { F77_dcopy(&n, x, &incx, y, &incy); }
71 void copy(
int n,
const dcomplex *x,
int incx,
dcomplex *y,
int incy) { F77_zcopy(&n, blacplx(x), &incx, blacplx(y), &incy); }
73 double dot(
int m,
const double *x,
int incx,
const double *y,
int incy) {
return F77_ddot(&m, x, &incx, y, &incy); }
77 cblas_zdotu_sub(m, mklcplx(x), incx, mklcplx(y), incy, &result);
79 auto result = F77_zdotu(&m, blacplx(x), &incx, blacplx(y), &incy);
81 return dcomplex{result.real, result.imag};
86 cblas_zdotc_sub(m, mklcplx(x), incx, mklcplx(y), incy, &result);
88 auto result = F77_zdotc(&m, blacplx(x), &incx, blacplx(y), &incy);
90 return dcomplex{result.real, result.imag};
93 void gemm(
char op_a,
char op_b,
int m,
int n,
int k,
double alpha,
const double *a,
int lda,
const double *b,
int ldb,
double beta,
double *c,
95 F77_dgemm(&op_a, &op_b, &m, &n, &k, &alpha, a, &lda, b, &ldb, &beta, c, &ldc);
97 void gemm(
char op_a,
char op_b,
int m,
int n,
int k,
dcomplex alpha,
const dcomplex *a,
int lda,
const dcomplex *b,
int ldb,
dcomplex beta,
99 F77_zgemm(&op_a, &op_b, &m, &n, &k, blacplx(&alpha), blacplx(a), &lda, blacplx(b), &ldb, blacplx(&beta), blacplx(c), &ldc);
102 void gemm_batch(
char op_a,
char op_b,
int m,
int n,
int k,
double alpha,
const double **a,
int lda,
const double **b,
int ldb,
double beta,
103 double **c,
int ldc,
int batch_count) {
105 const int group_count = 1;
106 dgemm_batch(&op_a, &op_b, &m, &n, &k, &alpha, a, &lda, b, &ldb, &beta, c, &ldc, &group_count, &batch_count);
108 for (
int i = 0; i < batch_count; ++i) gemm(op_a, op_b, m, n, k, alpha, a[i], lda, b[i], ldb, beta, c[i], ldc);
111 void gemm_batch(
char op_a,
char op_b,
int m,
int n,
int k,
dcomplex alpha,
const dcomplex **a,
int lda,
const dcomplex **b,
int ldb,
dcomplex beta,
112 dcomplex **c,
int ldc,
int batch_count) {
114 const int group_count = 1;
115 zgemm_batch(&op_a, &op_b, &m, &n, &k, mklcplx(&alpha), mklcplx(a), &lda, mklcplx(b), &ldb, mklcplx(&beta), mklcplx(c), &ldc, &group_count,
118 for (
int i = 0; i < batch_count; ++i) gemm(op_a, op_b, m, n, k, alpha, a[i], lda, b[i], ldb, beta, c[i], ldc);
122 void gemm_vbatch(
char op_a,
char op_b,
int *m,
int *n,
int *k,
double alpha,
const double **a,
int *lda,
const double **b,
int *ldb,
double beta,
123 double **c,
int *ldc,
int batch_count) {
128 dgemm_batch(ops_a.data(), ops_b.data(), m, n, k, alphas.data(), a, lda, b, ldb, betas.data(), c, ldc, &batch_count, group_size.data());
130 for (
int i = 0; i < batch_count; ++i) gemm(op_a, op_b, m[i], n[i], k[i], alpha, a[i], lda[i], b[i], ldb[i], beta, c[i], ldc[i]);
133 void gemm_vbatch(
char op_a,
char op_b,
int *m,
int *n,
int *k,
dcomplex alpha,
const dcomplex **a,
int *lda,
const dcomplex **b,
int *ldb,
139 zgemm_batch(ops_a.data(), ops_b.data(), m, n, k, mklcplx(alphas.data()), mklcplx(a), lda, mklcplx(b), ldb, mklcplx(betas.data()), mklcplx(c), ldc,
140 &batch_count, group_size.data());
142 for (
int i = 0; i < batch_count; ++i) gemm(op_a, op_b, m[i], n[i], k[i], alpha, a[i], lda[i], b[i], ldb[i], beta, c[i], ldc[i]);
146 void gemm_batch_strided(
char op_a,
char op_b,
int m,
int n,
int k,
double alpha,
const double *a,
int lda,
int stride_a,
const double *b,
int ldb,
147 int stride_b,
double beta,
double *c,
int ldc,
int stride_c,
int batch_count) {
148#if defined(NDA_USE_MKL) && INTEL_MKL_VERSION >= 20200002
149 dgemm_batch_strided(&op_a, &op_b, &m, &n, &k, &alpha, a, &lda, &stride_a, b, &ldb, &stride_b, &beta, c, &ldc, &stride_c, &batch_count);
151 for (
int i = 0; i < batch_count; ++i)
152 gemm(op_a, op_b, m, n, k, alpha, a +
static_cast<ptrdiff_t
>(i * stride_a), lda, b +
static_cast<ptrdiff_t
>(i * stride_b), ldb, beta,
153 c +
static_cast<ptrdiff_t
>(i * stride_c), ldc);
156 void gemm_batch_strided(
char op_a,
char op_b,
int m,
int n,
int k,
dcomplex alpha,
const dcomplex *a,
int lda,
int stride_a,
const dcomplex *b,
157 int ldb,
int stride_b,
dcomplex beta,
dcomplex *c,
int ldc,
int stride_c,
int batch_count) {
158#if defined(NDA_USE_MKL) && INTEL_MKL_VERSION >= 20200002
159 zgemm_batch_strided(&op_a, &op_b, &m, &n, &k, mklcplx(&alpha), mklcplx(a), &lda, &stride_a, mklcplx(b), &ldb, &stride_b, mklcplx(&beta), mklcplx(c),
160 &ldc, &stride_c, &batch_count);
162 for (
int i = 0; i < batch_count; ++i)
163 gemm(op_a, op_b, m, n, k, alpha, a +
static_cast<ptrdiff_t
>(i * stride_a), lda, b +
static_cast<ptrdiff_t
>(i * stride_b), ldb, beta,
164 c +
static_cast<ptrdiff_t
>(i * stride_c), ldc);
168 void gemv(
char op,
int m,
int n,
double alpha,
const double *a,
int lda,
const double *x,
int incx,
double beta,
double *y,
int incy) {
169 F77_dgemv(&op, &m, &n, &alpha, a, &lda, x, &incx, &beta, y, &incy);
171 void gemv(
char op,
int m,
int n,
dcomplex alpha,
const dcomplex *a,
int lda,
const dcomplex *x,
int incx,
dcomplex beta,
dcomplex *y,
int incy) {
172 F77_zgemv(&op, &m, &n, blacplx(&alpha), blacplx(a), &lda, blacplx(x), &incx, blacplx(&beta), blacplx(y), &incy);
175 void ger(
int m,
int n,
double alpha,
const double *x,
int incx,
const double *y,
int incy,
double *a,
int lda) {
176 F77_dger(&m, &n, &alpha, x, &incx, y, &incy, a, &lda);
179 F77_zgeru(&m, &n, blacplx(&alpha), blacplx(x), &incx, blacplx(y), &incy, blacplx(a), &lda);
182 F77_zgerc(&m, &n, blacplx(&alpha), blacplx(x), &incx, blacplx(y), &incy, blacplx(a), &lda);
185 void scal(
int m,
double alpha,
double *x,
int incx) { F77_dscal(&m, &alpha, x, &incx); }
186 void scal(
int m,
dcomplex alpha,
dcomplex *x,
int incx) { F77_zscal(&m, blacplx(&alpha), blacplx(x), &incx); }
188 void swap(
int n,
double *x,
int incx,
double *y,
int incy) { F77_dswap(&n, x, &incx, y, &incy); }
190 F77_zswap(&n, blacplx(x), &incx, blacplx(y), &incy);
Provides the generic class for arrays.
Provides a C++ interface for various BLAS routines.
Provides various convenient aliases and helper functions for nda::basic_array and nda::basic_array_vi...
basic_array< ValueType, 1, C_layout, 'V', ContainerPolicy > vector
Alias template of an nda::basic_array with rank 1 and a 'V' algebra.
std::complex< double > dcomplex
Alias for std::complex<double> type.