12#include "./cblas_f77.h"
27 static int const mkl_interface_layer = mkl_set_interface_layer(MKL_INTERFACE_LP64 + MKL_INTERFACE_GNU);
29 inline auto *mklcplx(
nda::dcomplex *c) {
return reinterpret_cast<MKL_Complex16 *
>(c); }
30 inline auto *mklcplx(
nda::dcomplex const *c) {
return reinterpret_cast<const MKL_Complex16 *
>(c); }
31 inline auto *mklcplx(
nda::dcomplex **c) {
return reinterpret_cast<MKL_Complex16 **
>(c); }
32 inline auto *mklcplx(
nda::dcomplex const **c) {
return reinterpret_cast<const MKL_Complex16 **
>(c); }
40 struct nda_complex_double {
48#define F77_ddot F77_GLOBAL(ddot, DDOT)
49#define F77_zdotu F77_GLOBAL(zdotu, ZDOTU)
50#define F77_zdotc F77_GLOBAL(zdotc, ZDOTC)
52double F77_ddot(FINT,
const double *, FINT,
const double *, FINT);
53nda_complex_double F77_zdotu(FINT,
const double *, FINT,
const double *, FINT);
54nda_complex_double F77_zdotc(FINT,
const double *, FINT,
const double *, FINT);
57namespace nda::blas::f77 {
59 inline auto *blacplx(
dcomplex *c) {
return reinterpret_cast<double *
>(c); }
60 inline auto *blacplx(
dcomplex const *c) {
return reinterpret_cast<const double *
>(c); }
61 inline auto **blacplx(
dcomplex **c) {
return reinterpret_cast<double **
>(c); }
62 inline auto **blacplx(
dcomplex const **c) {
return reinterpret_cast<const double **
>(c); }
64 void axpy(
int N,
double alpha,
const double *x,
int incx,
double *Y,
int incy) { F77_daxpy(&N, &alpha, x, &incx, Y, &incy); }
66 F77_zaxpy(&N, blacplx(&alpha), blacplx(x), &incx, blacplx(Y), &incy);
70 void copy(
int N,
const double *x,
int incx,
double *Y,
int incy) { F77_dcopy(&N, x, &incx, Y, &incy); }
71 void copy(
int N,
const dcomplex *x,
int incx,
dcomplex *Y,
int incy) { F77_zcopy(&N, blacplx(x), &incx, blacplx(Y), &incy); }
73 double dot(
int M,
const double *x,
int incx,
const double *Y,
int incy) {
return F77_ddot(&M, x, &incx, Y, &incy); }
77 cblas_zdotu_sub(M, mklcplx(x), incx, mklcplx(Y), incy, &result);
79 auto result = F77_zdotu(&M, blacplx(x), &incx, blacplx(Y), &incy);
81 return dcomplex{result.real, result.imag};
86 cblas_zdotc_sub(M, mklcplx(x), incx, mklcplx(Y), incy, &result);
88 auto result = F77_zdotc(&M, blacplx(x), &incx, blacplx(Y), &incy);
90 return dcomplex{result.real, result.imag};
93 void gemm(
char op_a,
char op_b,
int M,
int N,
int K,
double alpha,
const double *A,
int LDA,
const double *B,
int LDB,
double beta,
double *C,
95 F77_dgemm(&op_a, &op_b, &M, &N, &K, &alpha, A, &LDA, B, &LDB, &beta, C, &LDC);
97 void gemm(
char op_a,
char op_b,
int M,
int N,
int K,
dcomplex alpha,
const dcomplex *A,
int LDA,
const dcomplex *B,
int LDB,
dcomplex beta,
99 F77_zgemm(&op_a, &op_b, &M, &N, &K, blacplx(&alpha), blacplx(A), &LDA, blacplx(B), &LDB, blacplx(&beta), blacplx(C), &LDC);
102 void gemm_batch(
char op_a,
char op_b,
int M,
int N,
int K,
double alpha,
const double **A,
int LDA,
const double **B,
int LDB,
double beta,
103 double **C,
int LDC,
int batch_count) {
105 const int group_count = 1;
106 dgemm_batch(&op_a, &op_b, &M, &N, &K, &alpha, A, &LDA, B, &LDB, &beta, C, &LDC, &group_count, &batch_count);
108 for (
int i = 0; i < batch_count; ++i) gemm(op_a, op_b, M, N, K, alpha, A[i], LDA, B[i], LDB, beta, C[i], LDC);
111 void gemm_batch(
char op_a,
char op_b,
int M,
int N,
int K,
dcomplex alpha,
const dcomplex **A,
int LDA,
const dcomplex **B,
int LDB,
dcomplex beta,
112 dcomplex **C,
int LDC,
int batch_count) {
114 const int group_count = 1;
115 zgemm_batch(&op_a, &op_b, &M, &N, &K, mklcplx(&alpha), mklcplx(A), &LDA, mklcplx(B), &LDB, mklcplx(&beta), mklcplx(C), &LDC, &group_count,
118 for (
int i = 0; i < batch_count; ++i) gemm(op_a, op_b, M, N, K, alpha, A[i], LDA, B[i], LDB, beta, C[i], LDC);
122 void gemm_vbatch(
char op_a,
char op_b,
int *M,
int *N,
int *K,
double alpha,
const double **A,
int *LDA,
const double **B,
int *LDB,
double beta,
123 double **C,
int *LDC,
int batch_count) {
128 dgemm_batch(ops_a.data(), ops_b.data(), M, N, K, alphas.data(), A, LDA, B, LDB, betas.data(), C, LDC, &batch_count, group_size.data());
130 for (
int i = 0; i < batch_count; ++i) gemm(op_a, op_b, M[i], N[i], K[i], alpha, A[i], LDA[i], B[i], LDB[i], beta, C[i], LDC[i]);
133 void gemm_vbatch(
char op_a,
char op_b,
int *M,
int *N,
int *K,
dcomplex alpha,
const dcomplex **A,
int *LDA,
const dcomplex **B,
int *LDB,
139 zgemm_batch(ops_a.data(), ops_b.data(), M, N, K, mklcplx(alphas.data()), mklcplx(A), LDA, mklcplx(B), LDB, mklcplx(betas.data()), mklcplx(C), LDC,
140 &batch_count, group_size.data());
142 for (
int i = 0; i < batch_count; ++i) gemm(op_a, op_b, M[i], N[i], K[i], alpha, A[i], LDA[i], B[i], LDB[i], beta, C[i], LDC[i]);
146 void gemm_batch_strided(
char op_a,
char op_b,
int M,
int N,
int K,
double alpha,
const double *A,
int LDA,
int strideA,
const double *B,
int LDB,
147 int strideB,
double beta,
double *C,
int LDC,
int strideC,
int batch_count) {
148#if defined(NDA_USE_MKL) && INTEL_MKL_VERSION >= 20200002
149 dgemm_batch_strided(&op_a, &op_b, &M, &N, &K, &alpha, A, &LDA, &strideA, B, &LDB, &strideB, &beta, C, &LDC, &strideC, &batch_count);
151 for (
int i = 0; i < batch_count; ++i)
152 gemm(op_a, op_b, M, N, K, alpha, A +
static_cast<ptrdiff_t
>(i * strideA), LDA, B +
static_cast<ptrdiff_t
>(i * strideB), LDB, beta,
153 C +
static_cast<ptrdiff_t
>(i * strideC), LDC);
156 void gemm_batch_strided(
char op_a,
char op_b,
int M,
int N,
int K,
dcomplex alpha,
const dcomplex *A,
int LDA,
int strideA,
const dcomplex *B,
157 int LDB,
int strideB,
dcomplex beta,
dcomplex *C,
int LDC,
int strideC,
int batch_count) {
158#if defined(NDA_USE_MKL) && INTEL_MKL_VERSION >= 20200002
159 zgemm_batch_strided(&op_a, &op_b, &M, &N, &K, mklcplx(&alpha), mklcplx(A), &LDA, &strideA, mklcplx(B), &LDB, &strideB, mklcplx(&beta), mklcplx(C),
160 &LDC, &strideC, &batch_count);
162 for (
int i = 0; i < batch_count; ++i)
163 gemm(op_a, op_b, M, N, K, alpha, A +
static_cast<ptrdiff_t
>(i * strideA), LDA, B +
static_cast<ptrdiff_t
>(i * strideB), LDB, beta,
164 C +
static_cast<ptrdiff_t
>(i * strideC), LDC);
168 void gemv(
char op,
int M,
int N,
double alpha,
const double *A,
int LDA,
const double *x,
int incx,
double beta,
double *Y,
int incy) {
169 F77_dgemv(&op, &M, &N, &alpha, A, &LDA, x, &incx, &beta, Y, &incy);
171 void gemv(
char op,
int M,
int N,
dcomplex alpha,
const dcomplex *A,
int LDA,
const dcomplex *x,
int incx,
dcomplex beta,
dcomplex *Y,
int incy) {
172 F77_zgemv(&op, &M, &N, blacplx(&alpha), blacplx(A), &LDA, blacplx(x), &incx, blacplx(&beta), blacplx(Y), &incy);
175 void ger(
int M,
int N,
double alpha,
const double *x,
int incx,
const double *Y,
int incy,
double *A,
int LDA) {
176 F77_dger(&M, &N, &alpha, x, &incx, Y, &incy, A, &LDA);
179 F77_zgeru(&M, &N, blacplx(&alpha), blacplx(x), &incx, blacplx(Y), &incy, blacplx(A), &LDA);
182 void scal(
int M,
double alpha,
double *x,
int incx) { F77_dscal(&M, &alpha, x, &incx); }
183 void scal(
int M,
dcomplex alpha,
dcomplex *x,
int incx) { F77_zscal(&M, blacplx(&alpha), blacplx(x), &incx); }
185 void swap(
int N,
double *x,
int incx,
double *Y,
int incy) { F77_dswap(&N, x, &incx, Y, &incy); }
187 F77_zswap(&N, blacplx(x), &incx, blacplx(Y), &incy);
Provides the generic class for arrays.
Provides a C++ interface for various BLAS routines.
Provides various convenient aliases and helper functions for nda::basic_array and nda::basic_array_vi...
basic_array< ValueType, 1, C_layout, 'V', ContainerPolicy > vector
Alias template of an nda::basic_array with rank 1 and a 'V' algebra.
std::complex< double > dcomplex
Alias for std::complex<double> type.