23#include "./cblas_f77.h"
38 static int const mkl_interface_layer = mkl_set_interface_layer(MKL_INTERFACE_LP64 + MKL_INTERFACE_GNU);
40 inline auto *mklcplx(
nda::dcomplex *c) {
return reinterpret_cast<MKL_Complex16 *
>(c); }
41 inline auto *mklcplx(
nda::dcomplex const *c) {
return reinterpret_cast<const MKL_Complex16 *
>(c); }
42 inline auto *mklcplx(
nda::dcomplex **c) {
return reinterpret_cast<MKL_Complex16 **
>(c); }
43 inline auto *mklcplx(
nda::dcomplex const **c) {
return reinterpret_cast<const MKL_Complex16 **
>(c); }
51 struct nda_complex_double {
59#define F77_ddot F77_GLOBAL(ddot, DDOT)
60#define F77_zdotu F77_GLOBAL(zdotu, ZDOTU)
61#define F77_zdotc F77_GLOBAL(zdotc, ZDOTC)
63double F77_ddot(FINT,
const double *, FINT,
const double *, FINT);
64nda_complex_double F77_zdotu(FINT,
const double *, FINT,
const double *, FINT);
65nda_complex_double F77_zdotc(FINT,
const double *, FINT,
const double *, FINT);
68namespace nda::blas::f77 {
70 inline auto *blacplx(
dcomplex *c) {
return reinterpret_cast<double *
>(c); }
71 inline auto *blacplx(
dcomplex const *c) {
return reinterpret_cast<const double *
>(c); }
72 inline auto **blacplx(
dcomplex **c) {
return reinterpret_cast<double **
>(c); }
73 inline auto **blacplx(
dcomplex const **c) {
return reinterpret_cast<const double **
>(c); }
75 void axpy(
int N,
double alpha,
const double *x,
int incx,
double *Y,
int incy) { F77_daxpy(&N, &alpha, x, &incx, Y, &incy); }
77 F77_zaxpy(&N, blacplx(&alpha), blacplx(x), &incx, blacplx(Y), &incy);
81 void copy(
int N,
const double *x,
int incx,
double *Y,
int incy) { F77_dcopy(&N, x, &incx, Y, &incy); }
82 void copy(
int N,
const dcomplex *x,
int incx,
dcomplex *Y,
int incy) { F77_zcopy(&N, blacplx(x), &incx, blacplx(Y), &incy); }
84 double dot(
int M,
const double *x,
int incx,
const double *Y,
int incy) {
return F77_ddot(&M, x, &incx, Y, &incy); }
88 cblas_zdotu_sub(M, mklcplx(x), incx, mklcplx(Y), incy, &result);
90 auto result = F77_zdotu(&M, blacplx(x), &incx, blacplx(Y), &incy);
92 return dcomplex{result.real, result.imag};
97 cblas_zdotc_sub(M, mklcplx(x), incx, mklcplx(Y), incy, &result);
99 auto result = F77_zdotc(&M, blacplx(x), &incx, blacplx(Y), &incy);
101 return dcomplex{result.real, result.imag};
104 void gemm(
char op_a,
char op_b,
int M,
int N,
int K,
double alpha,
const double *A,
int LDA,
const double *B,
int LDB,
double beta,
double *C,
106 F77_dgemm(&op_a, &op_b, &M, &N, &K, &alpha, A, &LDA, B, &LDB, &beta, C, &LDC);
108 void gemm(
char op_a,
char op_b,
int M,
int N,
int K,
dcomplex alpha,
const dcomplex *A,
int LDA,
const dcomplex *B,
int LDB,
dcomplex beta,
110 F77_zgemm(&op_a, &op_b, &M, &N, &K, blacplx(&alpha), blacplx(A), &LDA, blacplx(B), &LDB, blacplx(&beta), blacplx(C), &LDC);
113 void gemm_batch(
char op_a,
char op_b,
int M,
int N,
int K,
double alpha,
const double **A,
int LDA,
const double **B,
int LDB,
double beta,
114 double **C,
int LDC,
int batch_count) {
116 const int group_count = 1;
117 dgemm_batch(&op_a, &op_b, &M, &N, &K, &alpha, A, &LDA, B, &LDB, &beta, C, &LDC, &group_count, &batch_count);
119 for (
int i = 0; i < batch_count; ++i)
gemm(op_a, op_b, M, N, K, alpha, A[i], LDA, B[i], LDB, beta, C[i], LDC);
122 void gemm_batch(
char op_a,
char op_b,
int M,
int N,
int K,
dcomplex alpha,
const dcomplex **A,
int LDA,
const dcomplex **B,
int LDB,
dcomplex beta,
123 dcomplex **C,
int LDC,
int batch_count) {
125 const int group_count = 1;
126 zgemm_batch(&op_a, &op_b, &M, &N, &K, mklcplx(&alpha), mklcplx(A), &LDA, mklcplx(B), &LDB, mklcplx(&beta), mklcplx(C), &LDC, &group_count,
129 for (
int i = 0; i < batch_count; ++i)
gemm(op_a, op_b, M, N, K, alpha, A[i], LDA, B[i], LDB, beta, C[i], LDC);
133 void gemm_vbatch(
char op_a,
char op_b,
int *M,
int *N,
int *K,
double alpha,
const double **A,
int *LDA,
const double **B,
int *LDB,
double beta,
134 double **C,
int *LDC,
int batch_count) {
139 dgemm_batch(ops_a.data(), ops_b.data(), M, N, K, alphas.data(), A, LDA, B, LDB, betas.data(), C, LDC, &batch_count, group_size.data());
141 for (
int i = 0; i < batch_count; ++i)
gemm(op_a, op_b, M[i], N[i], K[i], alpha, A[i], LDA[i], B[i], LDB[i], beta, C[i], LDC[i]);
144 void gemm_vbatch(
char op_a,
char op_b,
int *M,
int *N,
int *K,
dcomplex alpha,
const dcomplex **A,
int *LDA,
const dcomplex **B,
int *LDB,
150 zgemm_batch(ops_a.data(), ops_b.data(), M, N, K, mklcplx(alphas.data()), mklcplx(A), LDA, mklcplx(B), LDB, mklcplx(betas.data()), mklcplx(C), LDC,
151 &batch_count, group_size.data());
153 for (
int i = 0; i < batch_count; ++i)
gemm(op_a, op_b, M[i], N[i], K[i], alpha, A[i], LDA[i], B[i], LDB[i], beta, C[i], LDC[i]);
157 void gemm_batch_strided(
char op_a,
char op_b,
int M,
int N,
int K,
double alpha,
const double *A,
int LDA,
int strideA,
const double *B,
int LDB,
158 int strideB,
double beta,
double *C,
int LDC,
int strideC,
int batch_count) {
159#if defined(NDA_USE_MKL) && INTEL_MKL_VERSION >= 20200002
160 dgemm_batch_strided(&op_a, &op_b, &M, &N, &K, &alpha, A, &LDA, &strideA, B, &LDB, &strideB, &beta, C, &LDC, &strideC, &batch_count);
162 for (
int i = 0; i < batch_count; ++i)
163 gemm(op_a, op_b, M, N, K, alpha, A +
static_cast<ptrdiff_t
>(i * strideA), LDA, B +
static_cast<ptrdiff_t
>(i * strideB), LDB, beta,
164 C +
static_cast<ptrdiff_t
>(i * strideC), LDC);
167 void gemm_batch_strided(
char op_a,
char op_b,
int M,
int N,
int K,
dcomplex alpha,
const dcomplex *A,
int LDA,
int strideA,
const dcomplex *B,
168 int LDB,
int strideB,
dcomplex beta,
dcomplex *C,
int LDC,
int strideC,
int batch_count) {
169#if defined(NDA_USE_MKL) && INTEL_MKL_VERSION >= 20200002
170 zgemm_batch_strided(&op_a, &op_b, &M, &N, &K, mklcplx(&alpha), mklcplx(A), &LDA, &strideA, mklcplx(B), &LDB, &strideB, mklcplx(&beta), mklcplx(C),
171 &LDC, &strideC, &batch_count);
173 for (
int i = 0; i < batch_count; ++i)
174 gemm(op_a, op_b, M, N, K, alpha, A +
static_cast<ptrdiff_t
>(i * strideA), LDA, B +
static_cast<ptrdiff_t
>(i * strideB), LDB, beta,
175 C +
static_cast<ptrdiff_t
>(i * strideC), LDC);
179 void gemv(
char op,
int M,
int N,
double alpha,
const double *A,
int LDA,
const double *x,
int incx,
double beta,
double *Y,
int incy) {
180 F77_dgemv(&op, &M, &N, &alpha, A, &LDA, x, &incx, &beta, Y, &incy);
182 void gemv(
char op,
int M,
int N,
dcomplex alpha,
const dcomplex *A,
int LDA,
const dcomplex *x,
int incx,
dcomplex beta,
dcomplex *Y,
int incy) {
183 F77_zgemv(&op, &M, &N, blacplx(&alpha), blacplx(A), &LDA, blacplx(x), &incx, blacplx(&beta), blacplx(Y), &incy);
186 void ger(
int M,
int N,
double alpha,
const double *x,
int incx,
const double *Y,
int incy,
double *A,
int LDA) {
187 F77_dger(&M, &N, &alpha, x, &incx, Y, &incy, A, &LDA);
190 F77_zgeru(&M, &N, blacplx(&alpha), blacplx(x), &incx, blacplx(Y), &incy, blacplx(A), &LDA);
193 void scal(
int M,
double alpha,
double *x,
int incx) { F77_dscal(&M, &alpha, x, &incx); }
194 void scal(
int M,
dcomplex alpha,
dcomplex *x,
int incx) { F77_zscal(&M, blacplx(&alpha), blacplx(x), &incx); }
196 void swap(
int N,
double *x,
int incx,
double *Y,
int incy) { F77_dswap(&N, x, &incx, Y, &incy); }
198 F77_zswap(&N, blacplx(x), &incx, blacplx(Y), &incy);
Provides the generic class for arrays.
void swap(nda::basic_array_view< V1, R1, LP1, A1, AP1, OP1 > &a, nda::basic_array_view< V2, R2, LP2, A2, AP2, OP2 > &b)=delete
std::swap is deleted for nda::basic_array_view.
Provides a C++ interface for various BLAS routines.
A generic multi-dimensional array.
Provides various convenient aliases and helper functions for nda::basic_array and nda::basic_array_vi...
auto imag(A &&a)
Function imag for nda::ArrayOrScalar types (lazy and coefficient-wise for nda::Array types).
void gemm_batch_strided(get_value_t< A > alpha, A const &a, B const &b, get_value_t< A > beta, C &&c)
Implements a strided batched version of nda::blas::gemm taking 3-dimensional arrays as arguments.
std::complex< double > dcomplex
Alias for std::complex<double> type.
void gemm_vbatch(get_value_t< A > alpha, std::vector< A > const &va, std::vector< B > const &vb, get_value_t< A > beta, std::vector< C > &vc)
Wrapper of nda::blas::gemm_batch that allows variable sized matrices.
auto dotc(X const &x, Y const &y)
Interface to the BLAS dotc routine.
void gemv(get_value_t< A > alpha, A const &a, X const &x, get_value_t< A > beta, Y &&y)
Interface to the BLAS gemv routine.
auto dot(X const &x, Y const &y)
Interface to the BLAS dot routine.
void gemm(get_value_t< A > alpha, A const &a, B const &b, get_value_t< A > beta, C &&c)
Interface to the BLAS gemm routine.
void ger(get_value_t< X > alpha, X const &x, Y const &y, M &&m)
Interface to the BLAS ger routine.
void gemm_batch(get_value_t< A > alpha, std::vector< A > const &va, std::vector< B > const &vb, get_value_t< A > beta, std::vector< C > &vc)
Implements a batched version of nda::blas::gemm taking vectors of matrices as arguments.