TRIQS/nda 1.3.0
Multi-dimensional array library for C++
Loading...
Searching...
No Matches
cxx_interface.cpp
Go to the documentation of this file.
1// Copyright (c) 2019--present, The Simons Foundation
2// This file is part of TRIQS/nda and is licensed under the Apache License, Version 2.0.
3// SPDX-License-Identifier: Apache-2.0
4// See LICENSE in the root of this distribution for details.
5
10
11// Extracted from Reference Lapack (https://github.com/Reference-LAPACK):
12#include "./cblas_f77.h"
13#include "./cxx_interface.hpp"
14#include "../tools.hpp"
15
16#include <cstddef>
17
18#ifdef NDA_USE_MKL
19#include "../../basic_array.hpp"
21
22#include <mkl.h>
23
24namespace nda::blas {
25
26#ifdef NDA_USE_MKL_RT
27 static int const mkl_interface_layer = mkl_set_interface_layer(MKL_INTERFACE_LP64 + MKL_INTERFACE_GNU);
28#endif
29 inline auto *mklcplx(nda::dcomplex *c) { return reinterpret_cast<MKL_Complex16 *>(c); } // NOLINT
30 inline auto *mklcplx(nda::dcomplex const *c) { return reinterpret_cast<const MKL_Complex16 *>(c); } // NOLINT
31 inline auto *mklcplx(nda::dcomplex **c) { return reinterpret_cast<MKL_Complex16 **>(c); } // NOLINT
32 inline auto *mklcplx(nda::dcomplex const **c) { return reinterpret_cast<const MKL_Complex16 **>(c); } // NOLINT
33
34} // namespace nda::blas
35#endif
36
37namespace {
38
39 // complex struct which is returned by BLAS functions
40 struct nda_complex_double {
41 double real;
42 double imag;
43 };
44
45} // namespace
46
47// manually define dot routines since cblas_f77.h uses "_sub" to wrap the Fortran routines
48#define F77_ddot F77_GLOBAL(ddot, DDOT)
49#define F77_zdotu F77_GLOBAL(zdotu, ZDOTU)
50#define F77_zdotc F77_GLOBAL(zdotc, ZDOTC)
51extern "C" {
52double F77_ddot(FINT, const double *, FINT, const double *, FINT);
53nda_complex_double F77_zdotu(FINT, const double *, FINT, const double *, FINT);
54nda_complex_double F77_zdotc(FINT, const double *, FINT, const double *, FINT);
55}
56
57namespace nda::blas::f77 {
58
59 inline auto *blacplx(dcomplex *c) { return reinterpret_cast<double *>(c); } // NOLINT
60 inline auto *blacplx(dcomplex const *c) { return reinterpret_cast<const double *>(c); } // NOLINT
61 inline auto **blacplx(dcomplex **c) { return reinterpret_cast<double **>(c); } // NOLINT
62 inline auto **blacplx(dcomplex const **c) { return reinterpret_cast<const double **>(c); } // NOLINT
63
64 void axpy(int n, double alpha, const double *x, int incx, double *y, int incy) { F77_daxpy(&n, &alpha, x, &incx, y, &incy); }
65 void axpy(int n, dcomplex alpha, const dcomplex *x, int incx, dcomplex *y, int incy) {
66 F77_zaxpy(&n, blacplx(&alpha), blacplx(x), &incx, blacplx(y), &incy);
67 }
68
69 // No Const In Wrapping!
70 void copy(int n, const double *x, int incx, double *y, int incy) { F77_dcopy(&n, x, &incx, y, &incy); }
71 void copy(int n, const dcomplex *x, int incx, dcomplex *y, int incy) { F77_zcopy(&n, blacplx(x), &incx, blacplx(y), &incy); }
72
73 double dot(int m, const double *x, int incx, const double *y, int incy) { return F77_ddot(&m, x, &incx, y, &incy); }
74 dcomplex dot(int m, const dcomplex *x, int incx, const dcomplex *y, int incy) {
75#ifdef NDA_USE_MKL
76 MKL_Complex16 result;
77 cblas_zdotu_sub(m, mklcplx(x), incx, mklcplx(y), incy, &result);
78#else
79 auto result = F77_zdotu(&m, blacplx(x), &incx, blacplx(y), &incy);
80#endif
81 return dcomplex{result.real, result.imag};
82 }
83 dcomplex dotc(int m, const dcomplex *x, int incx, const dcomplex *y, int incy) {
84#ifdef NDA_USE_MKL
85 MKL_Complex16 result;
86 cblas_zdotc_sub(m, mklcplx(x), incx, mklcplx(y), incy, &result);
87#else
88 auto result = F77_zdotc(&m, blacplx(x), &incx, blacplx(y), &incy);
89#endif
90 return dcomplex{result.real, result.imag};
91 }
92
93 void gemm(char op_a, char op_b, int m, int n, int k, double alpha, const double *a, int lda, const double *b, int ldb, double beta, double *c,
94 int ldc) {
95 F77_dgemm(&op_a, &op_b, &m, &n, &k, &alpha, a, &lda, b, &ldb, &beta, c, &ldc);
96 }
97 void gemm(char op_a, char op_b, int m, int n, int k, dcomplex alpha, const dcomplex *a, int lda, const dcomplex *b, int ldb, dcomplex beta,
98 dcomplex *c, int ldc) {
99 F77_zgemm(&op_a, &op_b, &m, &n, &k, blacplx(&alpha), blacplx(a), &lda, blacplx(b), &ldb, blacplx(&beta), blacplx(c), &ldc);
100 }
101
102 void gemm_batch(char op_a, char op_b, int m, int n, int k, double alpha, const double **a, int lda, const double **b, int ldb, double beta,
103 double **c, int ldc, int batch_count) {
104#ifdef NDA_USE_MKL
105 const int group_count = 1;
106 dgemm_batch(&op_a, &op_b, &m, &n, &k, &alpha, a, &lda, b, &ldb, &beta, c, &ldc, &group_count, &batch_count);
107#else // Fallback to loop
108 for (int i = 0; i < batch_count; ++i) gemm(op_a, op_b, m, n, k, alpha, a[i], lda, b[i], ldb, beta, c[i], ldc);
109#endif
110 }
111 void gemm_batch(char op_a, char op_b, int m, int n, int k, dcomplex alpha, const dcomplex **a, int lda, const dcomplex **b, int ldb, dcomplex beta,
112 dcomplex **c, int ldc, int batch_count) {
113#ifdef NDA_USE_MKL
114 const int group_count = 1;
115 zgemm_batch(&op_a, &op_b, &m, &n, &k, mklcplx(&alpha), mklcplx(a), &lda, mklcplx(b), &ldb, mklcplx(&beta), mklcplx(c), &ldc, &group_count,
116 &batch_count);
117#else
118 for (int i = 0; i < batch_count; ++i) gemm(op_a, op_b, m, n, k, alpha, a[i], lda, b[i], ldb, beta, c[i], ldc);
119#endif
120 }
121
122 void gemm_vbatch(char op_a, char op_b, int *m, int *n, int *k, double alpha, const double **a, int *lda, const double **b, int *ldb, double beta,
123 double **c, int *ldc, int batch_count) {
124#ifdef NDA_USE_MKL
125 nda::vector<int> group_size(batch_count, 1);
126 nda::vector<char> ops_a(batch_count, op_a), ops_b(batch_count, op_b);
127 nda::vector<double> alphas(batch_count, alpha), betas(batch_count, beta);
128 dgemm_batch(ops_a.data(), ops_b.data(), m, n, k, alphas.data(), a, lda, b, ldb, betas.data(), c, ldc, &batch_count, group_size.data());
129#else
130 for (int i = 0; i < batch_count; ++i) gemm(op_a, op_b, m[i], n[i], k[i], alpha, a[i], lda[i], b[i], ldb[i], beta, c[i], ldc[i]);
131#endif
132 }
133 void gemm_vbatch(char op_a, char op_b, int *m, int *n, int *k, dcomplex alpha, const dcomplex **a, int *lda, const dcomplex **b, int *ldb,
134 dcomplex beta, dcomplex **c, int *ldc, int batch_count) {
135#ifdef NDA_USE_MKL
136 nda::vector<int> group_size(batch_count, 1);
137 nda::vector<char> ops_a(batch_count, op_a), ops_b(batch_count, op_b);
138 nda::vector<dcomplex> alphas(batch_count, alpha), betas(batch_count, beta);
139 zgemm_batch(ops_a.data(), ops_b.data(), m, n, k, mklcplx(alphas.data()), mklcplx(a), lda, mklcplx(b), ldb, mklcplx(betas.data()), mklcplx(c), ldc,
140 &batch_count, group_size.data());
141#else
142 for (int i = 0; i < batch_count; ++i) gemm(op_a, op_b, m[i], n[i], k[i], alpha, a[i], lda[i], b[i], ldb[i], beta, c[i], ldc[i]);
143#endif
144 }
145
146 void gemm_batch_strided(char op_a, char op_b, int m, int n, int k, double alpha, const double *a, int lda, int stride_a, const double *b, int ldb,
147 int stride_b, double beta, double *c, int ldc, int stride_c, int batch_count) {
148#if defined(NDA_USE_MKL) && INTEL_MKL_VERSION >= 20200002
149 dgemm_batch_strided(&op_a, &op_b, &m, &n, &k, &alpha, a, &lda, &stride_a, b, &ldb, &stride_b, &beta, c, &ldc, &stride_c, &batch_count);
150#else
151 for (int i = 0; i < batch_count; ++i)
152 gemm(op_a, op_b, m, n, k, alpha, a + static_cast<ptrdiff_t>(i * stride_a), lda, b + static_cast<ptrdiff_t>(i * stride_b), ldb, beta,
153 c + static_cast<ptrdiff_t>(i * stride_c), ldc);
154#endif
155 }
156 void gemm_batch_strided(char op_a, char op_b, int m, int n, int k, dcomplex alpha, const dcomplex *a, int lda, int stride_a, const dcomplex *b,
157 int ldb, int stride_b, dcomplex beta, dcomplex *c, int ldc, int stride_c, int batch_count) {
158#if defined(NDA_USE_MKL) && INTEL_MKL_VERSION >= 20200002
159 zgemm_batch_strided(&op_a, &op_b, &m, &n, &k, mklcplx(&alpha), mklcplx(a), &lda, &stride_a, mklcplx(b), &ldb, &stride_b, mklcplx(&beta), mklcplx(c),
160 &ldc, &stride_c, &batch_count);
161#else
162 for (int i = 0; i < batch_count; ++i)
163 gemm(op_a, op_b, m, n, k, alpha, a + static_cast<ptrdiff_t>(i * stride_a), lda, b + static_cast<ptrdiff_t>(i * stride_b), ldb, beta,
164 c + static_cast<ptrdiff_t>(i * stride_c), ldc);
165#endif
166 }
167
168 void gemv(char op, int m, int n, double alpha, const double *a, int lda, const double *x, int incx, double beta, double *y, int incy) {
169 F77_dgemv(&op, &m, &n, &alpha, a, &lda, x, &incx, &beta, y, &incy);
170 }
171 void gemv(char op, int m, int n, dcomplex alpha, const dcomplex *a, int lda, const dcomplex *x, int incx, dcomplex beta, dcomplex *y, int incy) {
172 F77_zgemv(&op, &m, &n, blacplx(&alpha), blacplx(a), &lda, blacplx(x), &incx, blacplx(&beta), blacplx(y), &incy);
173 }
174
175 void ger(int m, int n, double alpha, const double *x, int incx, const double *y, int incy, double *a, int lda) {
176 F77_dger(&m, &n, &alpha, x, &incx, y, &incy, a, &lda);
177 }
178 void ger(int m, int n, dcomplex alpha, const dcomplex *x, int incx, const dcomplex *y, int incy, dcomplex *a, int lda) {
179 F77_zgeru(&m, &n, blacplx(&alpha), blacplx(x), &incx, blacplx(y), &incy, blacplx(a), &lda);
180 }
181 void gerc(int m, int n, dcomplex alpha, const dcomplex *x, int incx, const dcomplex *y, int incy, dcomplex *a, int lda) {
182 F77_zgerc(&m, &n, blacplx(&alpha), blacplx(x), &incx, blacplx(y), &incy, blacplx(a), &lda);
183 }
184
185 void scal(int m, double alpha, double *x, int incx) { F77_dscal(&m, &alpha, x, &incx); }
186 void scal(int m, dcomplex alpha, dcomplex *x, int incx) { F77_zscal(&m, blacplx(&alpha), blacplx(x), &incx); }
187
188 void swap(int n, double *x, int incx, double *y, int incy) { F77_dswap(&n, x, &incx, y, &incy); } // NOLINT (this is a BLAS swap)
189 void swap(int n, dcomplex *x, int incx, dcomplex *y, int incy) { // NOLINT (this is a BLAS swap)
190 F77_zswap(&n, blacplx(x), &incx, blacplx(y), &incy);
191 }
192
193} // namespace nda::blas::f77
Provides the generic class for arrays.
Provides a C++ interface for various BLAS routines.
Provides various convenient aliases and helper functions for nda::basic_array and nda::basic_array_vi...
basic_array< ValueType, 1, C_layout, 'V', ContainerPolicy > vector
Alias template of an nda::basic_array with rank 1 and a 'V' algebra.
std::complex< double > dcomplex
Alias for std::complex<double> type.
Definition tools.hpp:28
Provides various traits and utilities for the BLAS interface.