12#include "./cblas_f77.h"
27 static int const mkl_interface_layer = mkl_set_interface_layer(MKL_INTERFACE_LP64 + MKL_INTERFACE_GNU);
29 inline auto *mklcplx(std::complex<float> *c) {
return reinterpret_cast<MKL_Complex8 *
>(c); }
30 inline auto *mklcplx(std::complex<float>
const *c) {
return reinterpret_cast<const MKL_Complex8 *
>(c); }
31 inline auto *mklcplx(std::complex<float> **c) {
return reinterpret_cast<MKL_Complex8 **
>(c); }
32 inline auto *mklcplx(std::complex<float>
const **c) {
return reinterpret_cast<const MKL_Complex8 **
>(c); }
34 inline auto *mklcplx(std::complex<double> *c) {
return reinterpret_cast<MKL_Complex16 *
>(c); }
35 inline auto *mklcplx(std::complex<double>
const *c) {
return reinterpret_cast<const MKL_Complex16 *
>(c); }
36 inline auto *mklcplx(std::complex<double> **c) {
return reinterpret_cast<MKL_Complex16 **
>(c); }
37 inline auto *mklcplx(std::complex<double>
const **c) {
return reinterpret_cast<const MKL_Complex16 **
>(c); }
45 struct nda_complex_float {
51 struct nda_complex_double {
59#define F77_sdot F77_GLOBAL(sdot, SDOT)
60#define F77_cdotu F77_GLOBAL(cdotu, CDOTU)
61#define F77_cdotc F77_GLOBAL(cdotc, CDOTC)
62#define F77_ddot F77_GLOBAL(ddot, DDOT)
63#define F77_zdotu F77_GLOBAL(zdotu, ZDOTU)
64#define F77_zdotc F77_GLOBAL(zdotc, ZDOTC)
67float F77_sdot(FINT,
const float *, FINT,
const float *, FINT);
68nda_complex_float F77_cdotu(FINT,
const float *, FINT,
const float *, FINT);
69nda_complex_float F77_cdotc(FINT,
const float *, FINT,
const float *, FINT);
71double F77_ddot(FINT,
const double *, FINT,
const double *, FINT);
72nda_complex_double F77_zdotu(FINT,
const double *, FINT,
const double *, FINT);
73nda_complex_double F77_zdotc(FINT,
const double *, FINT,
const double *, FINT);
76namespace nda::blas::f77 {
78 inline auto *blacplx(std::complex<float> *c) {
return reinterpret_cast<float *
>(c); }
79 inline auto *blacplx(std::complex<float>
const *c) {
return reinterpret_cast<const float *
>(c); }
80 inline auto **blacplx(std::complex<float> **c) {
return reinterpret_cast<float **
>(c); }
81 inline auto **blacplx(std::complex<float>
const **c) {
return reinterpret_cast<const float **
>(c); }
83 inline auto *blacplx(std::complex<double> *c) {
return reinterpret_cast<double *
>(c); }
84 inline auto *blacplx(std::complex<double>
const *c) {
return reinterpret_cast<const double *
>(c); }
85 inline auto **blacplx(std::complex<double> **c) {
return reinterpret_cast<double **
>(c); }
86 inline auto **blacplx(std::complex<double>
const **c) {
return reinterpret_cast<const double **
>(c); }
92 void gemm_batch_impl(
char op_a,
char op_b,
int m,
int n,
int k, T alpha,
const T **a,
int lda,
const T **b,
int ldb, T beta, T **c,
int ldc,
95 const int group_count = 1;
96 if constexpr (std::is_same_v<T, float>) {
97 sgemm_batch(&op_a, &op_b, &m, &n, &k, &alpha, a, &lda, b, &ldb, &beta, c, &ldc, &group_count, &batch_count);
98 }
else if constexpr (std::is_same_v<T, double>) {
99 dgemm_batch(&op_a, &op_b, &m, &n, &k, &alpha, a, &lda, b, &ldb, &beta, c, &ldc, &group_count, &batch_count);
100 }
else if constexpr (std::is_same_v<T, std::complex<float>>) {
101 cgemm_batch(&op_a, &op_b, &m, &n, &k, mklcplx(&alpha), mklcplx(a), &lda, mklcplx(b), &ldb, mklcplx(&beta), mklcplx(c), &ldc, &group_count,
103 }
else if constexpr (std::is_same_v<T, std::complex<double>>) {
104 zgemm_batch(&op_a, &op_b, &m, &n, &k, mklcplx(&alpha), mklcplx(a), &lda, mklcplx(b), &ldb, mklcplx(&beta), mklcplx(c), &ldc, &group_count,
108 for (
int i = 0; i < batch_count; ++i) gemm(op_a, op_b, m, n, k, alpha, a[i], lda, b[i], ldb, beta, c[i], ldc);
113 template <
typename T>
114 void gemm_vbatch_impl(
char op_a,
char op_b,
int *m,
int *n,
int *k, T alpha,
const T **a,
int *lda,
const T **b,
int *ldb, T beta, T **c,
115 int *ldc,
int batch_count) {
119 nda::vector<T> alphas(batch_count, alpha), betas(batch_count, beta);
120 if constexpr (std::is_same_v<T, float>) {
121 sgemm_batch(ops_a.data(), ops_b.data(), m, n, k, alphas.data(), a, lda, b, ldb, betas.data(), c, ldc, &batch_count, group_size.data());
122 }
else if constexpr (std::is_same_v<T, double>) {
123 dgemm_batch(ops_a.data(), ops_b.data(), m, n, k, alphas.data(), a, lda, b, ldb, betas.data(), c, ldc, &batch_count, group_size.data());
124 }
else if constexpr (std::is_same_v<T, std::complex<float>>) {
125 cgemm_batch(ops_a.data(), ops_b.data(), m, n, k, mklcplx(alphas.data()), mklcplx(a), lda, mklcplx(b), ldb, mklcplx(betas.data()), mklcplx(c),
126 ldc, &batch_count, group_size.data());
127 }
else if constexpr (std::is_same_v<T, std::complex<double>>) {
128 zgemm_batch(ops_a.data(), ops_b.data(), m, n, k, mklcplx(alphas.data()), mklcplx(a), lda, mklcplx(b), ldb, mklcplx(betas.data()), mklcplx(c),
129 ldc, &batch_count, group_size.data());
132 for (
int i = 0; i < batch_count; ++i) gemm(op_a, op_b, m[i], n[i], k[i], alpha, a[i], lda[i], b[i], ldb[i], beta, c[i], ldc[i]);
137 template <
typename T>
138 void gemm_batch_strided_impl(
char op_a,
char op_b,
int m,
int n,
int k, T alpha,
const T *a,
int lda,
int stride_a,
const T *b,
int ldb,
139 int stride_b, T beta, T *c,
int ldc,
int stride_c,
int batch_count) {
140#if defined(NDA_USE_MKL) && INTEL_MKL_VERSION >= 20200002
141 if constexpr (std::is_same_v<T, float>) {
142 sgemm_batch_strided(&op_a, &op_b, &m, &n, &k, &alpha, a, &lda, &stride_a, b, &ldb, &stride_b, &beta, c, &ldc, &stride_c, &batch_count);
143 }
else if constexpr (std::is_same_v<T, double>) {
144 dgemm_batch_strided(&op_a, &op_b, &m, &n, &k, &alpha, a, &lda, &stride_a, b, &ldb, &stride_b, &beta, c, &ldc, &stride_c, &batch_count);
145 }
else if constexpr (std::is_same_v<T, std::complex<float>>) {
146 cgemm_batch_strided(&op_a, &op_b, &m, &n, &k, mklcplx(&alpha), mklcplx(a), &lda, &stride_a, mklcplx(b), &ldb, &stride_b, mklcplx(&beta),
147 mklcplx(c), &ldc, &stride_c, &batch_count);
148 }
else if constexpr (std::is_same_v<T, std::complex<double>>) {
149 zgemm_batch_strided(&op_a, &op_b, &m, &n, &k, mklcplx(&alpha), mklcplx(a), &lda, &stride_a, mklcplx(b), &ldb, &stride_b, mklcplx(&beta),
150 mklcplx(c), &ldc, &stride_c, &batch_count);
153 for (
int i = 0; i < batch_count; ++i)
154 gemm(op_a, op_b, m, n, k, alpha, a + i * stride_a, lda, b + i * stride_b, ldb, beta, c + i * stride_c, ldc);
161 void axpy(
int n,
float alpha,
const float *x,
int incx,
float *y,
int incy) { F77_saxpy(&n, &alpha, x, &incx, y, &incy); }
162 void axpy(
int n, std::complex<float> alpha,
const std::complex<float> *x,
int incx, std::complex<float> *y,
int incy) {
163 F77_caxpy(&n, blacplx(&alpha), blacplx(x), &incx, blacplx(y), &incy);
165 void axpy(
int n,
double alpha,
const double *x,
int incx,
double *y,
int incy) { F77_daxpy(&n, &alpha, x, &incx, y, &incy); }
166 void axpy(
int n, std::complex<double> alpha,
const std::complex<double> *x,
int incx, std::complex<double> *y,
int incy) {
167 F77_zaxpy(&n, blacplx(&alpha), blacplx(x), &incx, blacplx(y), &incy);
171 void copy(
int n,
const float *x,
int incx,
float *y,
int incy) { F77_scopy(&n, x, &incx, y, &incy); }
172 void copy(
int n,
const std::complex<float> *x,
int incx, std::complex<float> *y,
int incy) { F77_ccopy(&n, blacplx(x), &incx, blacplx(y), &incy); }
173 void copy(
int n,
const double *x,
int incx,
double *y,
int incy) { F77_dcopy(&n, x, &incx, y, &incy); }
174 void copy(
int n,
const std::complex<double> *x,
int incx, std::complex<double> *y,
int incy) {
175 F77_zcopy(&n, blacplx(x), &incx, blacplx(y), &incy);
179 float dot(
int m,
const float *x,
int incx,
const float *y,
int incy) {
return F77_sdot(&m, x, &incx, y, &incy); }
180 std::complex<float> dot(
int m,
const std::complex<float> *x,
int incx,
const std::complex<float> *y,
int incy) {
183 cblas_cdotu_sub(m, mklcplx(x), incx, mklcplx(y), incy, &result);
185 auto result = F77_cdotu(&m, blacplx(x), &incx, blacplx(y), &incy);
187 return std::complex<float>{result.real, result.imag};
189 std::complex<float> dotc(
int m,
const std::complex<float> *x,
int incx,
const std::complex<float> *y,
int incy) {
192 cblas_cdotc_sub(m, mklcplx(x), incx, mklcplx(y), incy, &result);
194 auto result = F77_cdotc(&m, blacplx(x), &incx, blacplx(y), &incy);
196 return std::complex<float>{result.real, result.imag};
198 double dot(
int m,
const double *x,
int incx,
const double *y,
int incy) {
return F77_ddot(&m, x, &incx, y, &incy); }
199 std::complex<double> dot(
int m,
const std::complex<double> *x,
int incx,
const std::complex<double> *y,
int incy) {
201 MKL_Complex16 result;
202 cblas_zdotu_sub(m, mklcplx(x), incx, mklcplx(y), incy, &result);
204 auto result = F77_zdotu(&m, blacplx(x), &incx, blacplx(y), &incy);
206 return std::complex<double>{result.real, result.imag};
208 std::complex<double> dotc(
int m,
const std::complex<double> *x,
int incx,
const std::complex<double> *y,
int incy) {
210 MKL_Complex16 result;
211 cblas_zdotc_sub(m, mklcplx(x), incx, mklcplx(y), incy, &result);
213 auto result = F77_zdotc(&m, blacplx(x), &incx, blacplx(y), &incy);
215 return std::complex<double>{result.real, result.imag};
219 void gemm(
char op_a,
char op_b,
int m,
int n,
int k,
float alpha,
const float *a,
int lda,
const float *b,
int ldb,
float beta,
float *c,
int ldc) {
220 F77_sgemm(&op_a, &op_b, &m, &n, &k, &alpha, a, &lda, b, &ldb, &beta, c, &ldc);
222 void gemm(
char op_a,
char op_b,
int m,
int n,
int k, std::complex<float> alpha,
const std::complex<float> *a,
int lda,
const std::complex<float> *b,
223 int ldb, std::complex<float> beta, std::complex<float> *c,
int ldc) {
224 F77_cgemm(&op_a, &op_b, &m, &n, &k, blacplx(&alpha), blacplx(a), &lda, blacplx(b), &ldb, blacplx(&beta), blacplx(c), &ldc);
226 void gemm(
char op_a,
char op_b,
int m,
int n,
int k,
double alpha,
const double *a,
int lda,
const double *b,
int ldb,
double beta,
double *c,
228 F77_dgemm(&op_a, &op_b, &m, &n, &k, &alpha, a, &lda, b, &ldb, &beta, c, &ldc);
230 void gemm(
char op_a,
char op_b,
int m,
int n,
int k, std::complex<double> alpha,
const std::complex<double> *a,
int lda,
231 const std::complex<double> *b,
int ldb, std::complex<double> beta, std::complex<double> *c,
int ldc) {
232 F77_zgemm(&op_a, &op_b, &m, &n, &k, blacplx(&alpha), blacplx(a), &lda, blacplx(b), &ldb, blacplx(&beta), blacplx(c), &ldc);
236 void gemm_batch(
char op_a,
char op_b,
int m,
int n,
int k,
float alpha,
const float **a,
int lda,
const float **b,
int ldb,
float beta,
float **c,
237 int ldc,
int batch_count) {
238 gemm_batch_impl(op_a, op_b, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, batch_count);
240 void gemm_batch(
char op_a,
char op_b,
int m,
int n,
int k, std::complex<float> alpha,
const std::complex<float> **a,
int lda,
241 const std::complex<float> **b,
int ldb, std::complex<float> beta, std::complex<float> **c,
int ldc,
int batch_count) {
242 gemm_batch_impl(op_a, op_b, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, batch_count);
244 void gemm_batch(
char op_a,
char op_b,
int m,
int n,
int k,
double alpha,
const double **a,
int lda,
const double **b,
int ldb,
double beta,
245 double **c,
int ldc,
int batch_count) {
246 gemm_batch_impl(op_a, op_b, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, batch_count);
248 void gemm_batch(
char op_a,
char op_b,
int m,
int n,
int k, std::complex<double> alpha,
const std::complex<double> **a,
int lda,
249 const std::complex<double> **b,
int ldb, std::complex<double> beta, std::complex<double> **c,
int ldc,
int batch_count) {
250 gemm_batch_impl(op_a, op_b, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, batch_count);
254 void gemm_vbatch(
char op_a,
char op_b,
int *m,
int *n,
int *k,
float alpha,
const float **a,
int *lda,
const float **b,
int *ldb,
float beta,
255 float **c,
int *ldc,
int batch_count) {
256 gemm_vbatch_impl(op_a, op_b, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, batch_count);
258 void gemm_vbatch(
char op_a,
char op_b,
int *m,
int *n,
int *k, std::complex<float> alpha,
const std::complex<float> **a,
int *lda,
259 const std::complex<float> **b,
int *ldb, std::complex<float> beta, std::complex<float> **c,
int *ldc,
int batch_count) {
260 gemm_vbatch_impl(op_a, op_b, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, batch_count);
262 void gemm_vbatch(
char op_a,
char op_b,
int *m,
int *n,
int *k,
double alpha,
const double **a,
int *lda,
const double **b,
int *ldb,
double beta,
263 double **c,
int *ldc,
int batch_count) {
264 gemm_vbatch_impl(op_a, op_b, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, batch_count);
266 void gemm_vbatch(
char op_a,
char op_b,
int *m,
int *n,
int *k, std::complex<double> alpha,
const std::complex<double> **a,
int *lda,
267 const std::complex<double> **b,
int *ldb, std::complex<double> beta, std::complex<double> **c,
int *ldc,
int batch_count) {
268 gemm_vbatch_impl(op_a, op_b, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, batch_count);
272 void gemm_batch_strided(
char op_a,
char op_b,
int m,
int n,
int k,
float alpha,
const float *a,
int lda,
int stride_a,
const float *b,
int ldb,
273 int stride_b,
float beta,
float *c,
int ldc,
int stride_c,
int batch_count) {
274 gemm_batch_strided_impl(op_a, op_b, m, n, k, alpha, a, lda, stride_a, b, ldb, stride_b, beta, c, ldc, stride_c, batch_count);
276 void gemm_batch_strided(
char op_a,
char op_b,
int m,
int n,
int k, std::complex<float> alpha,
const std::complex<float> *a,
int lda,
int stride_a,
277 const std::complex<float> *b,
int ldb,
int stride_b, std::complex<float> beta, std::complex<float> *c,
int ldc,
278 int stride_c,
int batch_count) {
279 gemm_batch_strided_impl(op_a, op_b, m, n, k, alpha, a, lda, stride_a, b, ldb, stride_b, beta, c, ldc, stride_c, batch_count);
281 void gemm_batch_strided(
char op_a,
char op_b,
int m,
int n,
int k,
double alpha,
const double *a,
int lda,
int stride_a,
const double *b,
int ldb,
282 int stride_b,
double beta,
double *c,
int ldc,
int stride_c,
int batch_count) {
283 gemm_batch_strided_impl(op_a, op_b, m, n, k, alpha, a, lda, stride_a, b, ldb, stride_b, beta, c, ldc, stride_c, batch_count);
285 void gemm_batch_strided(
char op_a,
char op_b,
int m,
int n,
int k, std::complex<double> alpha,
const std::complex<double> *a,
int lda,
int stride_a,
286 const std::complex<double> *b,
int ldb,
int stride_b, std::complex<double> beta, std::complex<double> *c,
int ldc,
287 int stride_c,
int batch_count) {
288 gemm_batch_strided_impl(op_a, op_b, m, n, k, alpha, a, lda, stride_a, b, ldb, stride_b, beta, c, ldc, stride_c, batch_count);
292 void gemv(
char op,
int m,
int n,
float alpha,
const float *a,
int lda,
const float *x,
int incx,
float beta,
float *y,
int incy) {
293 F77_sgemv(&op, &m, &n, &alpha, a, &lda, x, &incx, &beta, y, &incy);
295 void gemv(
char op,
int m,
int n, std::complex<float> alpha,
const std::complex<float> *a,
int lda,
const std::complex<float> *x,
int incx,
296 std::complex<float> beta, std::complex<float> *y,
int incy) {
297 F77_cgemv(&op, &m, &n, blacplx(&alpha), blacplx(a), &lda, blacplx(x), &incx, blacplx(&beta), blacplx(y), &incy);
299 void gemv(
char op,
int m,
int n,
double alpha,
const double *a,
int lda,
const double *x,
int incx,
double beta,
double *y,
int incy) {
300 F77_dgemv(&op, &m, &n, &alpha, a, &lda, x, &incx, &beta, y, &incy);
302 void gemv(
char op,
int m,
int n, std::complex<double> alpha,
const std::complex<double> *a,
int lda,
const std::complex<double> *x,
int incx,
303 std::complex<double> beta, std::complex<double> *y,
int incy) {
304 F77_zgemv(&op, &m, &n, blacplx(&alpha), blacplx(a), &lda, blacplx(x), &incx, blacplx(&beta), blacplx(y), &incy);
308 void ger(
int m,
int n,
float alpha,
const float *x,
int incx,
const float *y,
int incy,
float *a,
int lda) {
309 F77_sger(&m, &n, &alpha, x, &incx, y, &incy, a, &lda);
311 void ger(
int m,
int n, std::complex<float> alpha,
const std::complex<float> *x,
int incx,
const std::complex<float> *y,
int incy,
312 std::complex<float> *a,
int lda) {
313 F77_cgeru(&m, &n, blacplx(&alpha), blacplx(x), &incx, blacplx(y), &incy, blacplx(a), &lda);
315 void gerc(
int m,
int n, std::complex<float> alpha,
const std::complex<float> *x,
int incx,
const std::complex<float> *y,
int incy,
316 std::complex<float> *a,
int lda) {
317 F77_cgerc(&m, &n, blacplx(&alpha), blacplx(x), &incx, blacplx(y), &incy, blacplx(a), &lda);
319 void ger(
int m,
int n,
double alpha,
const double *x,
int incx,
const double *y,
int incy,
double *a,
int lda) {
320 F77_dger(&m, &n, &alpha, x, &incx, y, &incy, a, &lda);
322 void ger(
int m,
int n, std::complex<double> alpha,
const std::complex<double> *x,
int incx,
const std::complex<double> *y,
int incy,
323 std::complex<double> *a,
int lda) {
324 F77_zgeru(&m, &n, blacplx(&alpha), blacplx(x), &incx, blacplx(y), &incy, blacplx(a), &lda);
326 void gerc(
int m,
int n, std::complex<double> alpha,
const std::complex<double> *x,
int incx,
const std::complex<double> *y,
int incy,
327 std::complex<double> *a,
int lda) {
328 F77_zgerc(&m, &n, blacplx(&alpha), blacplx(x), &incx, blacplx(y), &incy, blacplx(a), &lda);
332 void scal(
int m,
float alpha,
float *x,
int incx) { F77_sscal(&m, &alpha, x, &incx); }
333 void scal(
int m, std::complex<float> alpha, std::complex<float> *x,
int incx) { F77_cscal(&m, blacplx(&alpha), blacplx(x), &incx); }
334 void scal(
int m,
double alpha,
double *x,
int incx) { F77_dscal(&m, &alpha, x, &incx); }
335 void scal(
int m, std::complex<double> alpha, std::complex<double> *x,
int incx) { F77_zscal(&m, blacplx(&alpha), blacplx(x), &incx); }
338 void swap(
int n,
float *x,
int incx,
float *y,
int incy) { F77_sswap(&n, x, &incx, y, &incy); }
339 void swap(
int n, std::complex<float> *x,
int incx, std::complex<float> *y,
int incy) {
340 F77_cswap(&n, blacplx(x), &incx, blacplx(y), &incy);
342 void swap(
int n,
double *x,
int incx,
double *y,
int incy) { F77_dswap(&n, x, &incx, y, &incy); }
343 void swap(
int n, std::complex<double> *x,
int incx, std::complex<double> *y,
int incy) {
344 F77_zswap(&n, blacplx(x), &incx, blacplx(y), &incy);
Provides the generic class for arrays.
Provides a C++ interface for various BLAS routines.
Provides various convenient aliases and helper functions for nda::basic_array and nda::basic_array_vi...
basic_array< ValueType, 1, C_layout, 'V', ContainerPolicy > vector
Alias template of an nda::basic_array with rank 1 and a 'V' algebra.