TRIQS/nda 1.3.0
Multi-dimensional array library for C++
Loading...
Searching...
No Matches
cxx_interface.cpp
Go to the documentation of this file.
1// Copyright (c) 2019-2023 Simons Foundation
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7// http://www.apache.org/licenses/LICENSE-2.0.txt
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14//
15// Authors: Thomas Hahn, Henri Menke, Miguel Morales, Olivier Parcollet, Nils Wentzell
16
17/**
18 * @file
19 * @brief Implementation details for blas/interface/cxx_interface.hpp.
20 */
21
22// Extracted from Reference Lapack (https://github.com/Reference-LAPACK):
23#include "./cblas_f77.h"
24#include "./cxx_interface.hpp"
25#include "../tools.hpp"
26
27#include <cstddef>
28
29#ifdef NDA_USE_MKL
30#include "../../basic_array.hpp"
31#include "../../declarations.hpp"
32
33#include <mkl.h>
34
35namespace nda::blas {
36
37#ifdef NDA_USE_MKL_RT
38 static int const mkl_interface_layer = mkl_set_interface_layer(MKL_INTERFACE_LP64 + MKL_INTERFACE_GNU);
39#endif
40 inline auto *mklcplx(nda::dcomplex *c) { return reinterpret_cast<MKL_Complex16 *>(c); } // NOLINT
41 inline auto *mklcplx(nda::dcomplex const *c) { return reinterpret_cast<const MKL_Complex16 *>(c); } // NOLINT
42 inline auto *mklcplx(nda::dcomplex **c) { return reinterpret_cast<MKL_Complex16 **>(c); } // NOLINT
43 inline auto *mklcplx(nda::dcomplex const **c) { return reinterpret_cast<const MKL_Complex16 **>(c); } // NOLINT
44
45} // namespace nda::blas
46#endif
47
48namespace {
49
50 // complex struct which is returned by BLAS functions
51 struct nda_complex_double {
52 double real;
53 double imag;
54 };
55
56} // namespace
57
58// manually define dot routines since cblas_f77.h uses "_sub" to wrap the Fortran routines
59#define F77_ddot F77_GLOBAL(ddot, DDOT)
60#define F77_zdotu F77_GLOBAL(zdotu, DDOT)
61#define F77_zdotc F77_GLOBAL(zdotc, DDOT)
62extern "C" {
63double F77_ddot(FINT, const double *, FINT, const double *, FINT);
64nda_complex_double F77_zdotu(FINT, const double *, FINT, const double *, FINT);
65nda_complex_double F77_zdotc(FINT, const double *, FINT, const double *, FINT);
66}
67
68namespace nda::blas::f77 {
69
70 inline auto *blacplx(dcomplex *c) { return reinterpret_cast<double *>(c); } // NOLINT
71 inline auto *blacplx(dcomplex const *c) { return reinterpret_cast<const double *>(c); } // NOLINT
72 inline auto **blacplx(dcomplex **c) { return reinterpret_cast<double **>(c); } // NOLINT
73 inline auto **blacplx(dcomplex const **c) { return reinterpret_cast<const double **>(c); } // NOLINT
74
75 void axpy(int N, double alpha, const double *x, int incx, double *Y, int incy) { F77_daxpy(&N, &alpha, x, &incx, Y, &incy); }
76 void axpy(int N, dcomplex alpha, const dcomplex *x, int incx, dcomplex *Y, int incy) {
77 F77_zaxpy(&N, blacplx(&alpha), blacplx(x), &incx, blacplx(Y), &incy);
78 }
79
80 // No Const In Wrapping!
81 void copy(int N, const double *x, int incx, double *Y, int incy) { F77_dcopy(&N, x, &incx, Y, &incy); }
82 void copy(int N, const dcomplex *x, int incx, dcomplex *Y, int incy) { F77_zcopy(&N, blacplx(x), &incx, blacplx(Y), &incy); }
83
84 double dot(int M, const double *x, int incx, const double *Y, int incy) { return F77_ddot(&M, x, &incx, Y, &incy); }
85 dcomplex dot(int M, const dcomplex *x, int incx, const dcomplex *Y, int incy) {
86#ifdef NDA_USE_MKL
87 MKL_Complex16 result;
88 cblas_zdotu_sub(M, mklcplx(x), incx, mklcplx(Y), incy, &result);
89#else
90 auto result = F77_zdotu(&M, blacplx(x), &incx, blacplx(Y), &incy);
91#endif
92 return dcomplex{result.real, result.imag};
93 }
94 dcomplex dotc(int M, const dcomplex *x, int incx, const dcomplex *Y, int incy) {
95#ifdef NDA_USE_MKL
96 MKL_Complex16 result;
97 cblas_zdotc_sub(M, mklcplx(x), incx, mklcplx(Y), incy, &result);
98#else
99 auto result = F77_zdotc(&M, blacplx(x), &incx, blacplx(Y), &incy);
100#endif
101 return dcomplex{result.real, result.imag};
102 }
103
104 void gemm(char op_a, char op_b, int M, int N, int K, double alpha, const double *A, int LDA, const double *B, int LDB, double beta, double *C,
105 int LDC) {
106 F77_dgemm(&op_a, &op_b, &M, &N, &K, &alpha, A, &LDA, B, &LDB, &beta, C, &LDC);
107 }
108 void gemm(char op_a, char op_b, int M, int N, int K, dcomplex alpha, const dcomplex *A, int LDA, const dcomplex *B, int LDB, dcomplex beta,
109 dcomplex *C, int LDC) {
110 F77_zgemm(&op_a, &op_b, &M, &N, &K, blacplx(&alpha), blacplx(A), &LDA, blacplx(B), &LDB, blacplx(&beta), blacplx(C), &LDC);
111 }
112
113 void gemm_batch(char op_a, char op_b, int M, int N, int K, double alpha, const double **A, int LDA, const double **B, int LDB, double beta,
114 double **C, int LDC, int batch_count) {
115#ifdef NDA_USE_MKL
116 const int group_count = 1;
117 dgemm_batch(&op_a, &op_b, &M, &N, &K, &alpha, A, &LDA, B, &LDB, &beta, C, &LDC, &group_count, &batch_count);
118#else // Fallback to loop
119 for (int i = 0; i < batch_count; ++i) gemm(op_a, op_b, M, N, K, alpha, A[i], LDA, B[i], LDB, beta, C[i], LDC);
120#endif
121 }
122 void gemm_batch(char op_a, char op_b, int M, int N, int K, dcomplex alpha, const dcomplex **A, int LDA, const dcomplex **B, int LDB, dcomplex beta,
123 dcomplex **C, int LDC, int batch_count) {
124#ifdef NDA_USE_MKL
125 const int group_count = 1;
126 zgemm_batch(&op_a, &op_b, &M, &N, &K, mklcplx(&alpha), mklcplx(A), &LDA, mklcplx(B), &LDB, mklcplx(&beta), mklcplx(C), &LDC, &group_count,
127 &batch_count);
128#else
129 for (int i = 0; i < batch_count; ++i) gemm(op_a, op_b, M, N, K, alpha, A[i], LDA, B[i], LDB, beta, C[i], LDC);
130#endif
131 }
132
133 void gemm_vbatch(char op_a, char op_b, int *M, int *N, int *K, double alpha, const double **A, int *LDA, const double **B, int *LDB, double beta,
134 double **C, int *LDC, int batch_count) {
135#ifdef NDA_USE_MKL
136 nda::vector<int> group_size(batch_count, 1);
137 nda::vector<char> ops_a(batch_count, op_a), ops_b(batch_count, op_b);
138 nda::vector<double> alphas(batch_count, alpha), betas(batch_count, beta);
139 dgemm_batch(ops_a.data(), ops_b.data(), M, N, K, alphas.data(), A, LDA, B, LDB, betas.data(), C, LDC, &batch_count, group_size.data());
140#else
141 for (int i = 0; i < batch_count; ++i) gemm(op_a, op_b, M[i], N[i], K[i], alpha, A[i], LDA[i], B[i], LDB[i], beta, C[i], LDC[i]);
142#endif
143 }
144 void gemm_vbatch(char op_a, char op_b, int *M, int *N, int *K, dcomplex alpha, const dcomplex **A, int *LDA, const dcomplex **B, int *LDB,
145 dcomplex beta, dcomplex **C, int *LDC, int batch_count) {
146#ifdef NDA_USE_MKL
147 nda::vector<int> group_size(batch_count, 1);
148 nda::vector<char> ops_a(batch_count, op_a), ops_b(batch_count, op_b);
149 nda::vector<dcomplex> alphas(batch_count, alpha), betas(batch_count, beta);
150 zgemm_batch(ops_a.data(), ops_b.data(), M, N, K, mklcplx(alphas.data()), mklcplx(A), LDA, mklcplx(B), LDB, mklcplx(betas.data()), mklcplx(C), LDC,
151 &batch_count, group_size.data());
152#else
153 for (int i = 0; i < batch_count; ++i) gemm(op_a, op_b, M[i], N[i], K[i], alpha, A[i], LDA[i], B[i], LDB[i], beta, C[i], LDC[i]);
154#endif
155 }
156
157 void gemm_batch_strided(char op_a, char op_b, int M, int N, int K, double alpha, const double *A, int LDA, int strideA, const double *B, int LDB,
158 int strideB, double beta, double *C, int LDC, int strideC, int batch_count) {
159#if defined(NDA_USE_MKL) && INTEL_MKL_VERSION >= 20200002
160 dgemm_batch_strided(&op_a, &op_b, &M, &N, &K, &alpha, A, &LDA, &strideA, B, &LDB, &strideB, &beta, C, &LDC, &strideC, &batch_count);
161#else
162 for (int i = 0; i < batch_count; ++i)
163 gemm(op_a, op_b, M, N, K, alpha, A + static_cast<ptrdiff_t>(i * strideA), LDA, B + static_cast<ptrdiff_t>(i * strideB), LDB, beta,
164 C + static_cast<ptrdiff_t>(i * strideC), LDC);
165#endif
166 }
167 void gemm_batch_strided(char op_a, char op_b, int M, int N, int K, dcomplex alpha, const dcomplex *A, int LDA, int strideA, const dcomplex *B,
168 int LDB, int strideB, dcomplex beta, dcomplex *C, int LDC, int strideC, int batch_count) {
169#if defined(NDA_USE_MKL) && INTEL_MKL_VERSION >= 20200002
170 zgemm_batch_strided(&op_a, &op_b, &M, &N, &K, mklcplx(&alpha), mklcplx(A), &LDA, &strideA, mklcplx(B), &LDB, &strideB, mklcplx(&beta), mklcplx(C),
171 &LDC, &strideC, &batch_count);
172#else
173 for (int i = 0; i < batch_count; ++i)
174 gemm(op_a, op_b, M, N, K, alpha, A + static_cast<ptrdiff_t>(i * strideA), LDA, B + static_cast<ptrdiff_t>(i * strideB), LDB, beta,
175 C + static_cast<ptrdiff_t>(i * strideC), LDC);
176#endif
177 }
178
179 void gemv(char op, int M, int N, double alpha, const double *A, int LDA, const double *x, int incx, double beta, double *Y, int incy) {
180 F77_dgemv(&op, &M, &N, &alpha, A, &LDA, x, &incx, &beta, Y, &incy);
181 }
182 void gemv(char op, int M, int N, dcomplex alpha, const dcomplex *A, int LDA, const dcomplex *x, int incx, dcomplex beta, dcomplex *Y, int incy) {
183 F77_zgemv(&op, &M, &N, blacplx(&alpha), blacplx(A), &LDA, blacplx(x), &incx, blacplx(&beta), blacplx(Y), &incy);
184 }
185
186 void ger(int M, int N, double alpha, const double *x, int incx, const double *Y, int incy, double *A, int LDA) {
187 F77_dger(&M, &N, &alpha, x, &incx, Y, &incy, A, &LDA);
188 }
189 void ger(int M, int N, dcomplex alpha, const dcomplex *x, int incx, const dcomplex *Y, int incy, dcomplex *A, int LDA) {
190 F77_zgeru(&M, &N, blacplx(&alpha), blacplx(x), &incx, blacplx(Y), &incy, blacplx(A), &LDA);
191 }
192
193 void scal(int M, double alpha, double *x, int incx) { F77_dscal(&M, &alpha, x, &incx); }
194 void scal(int M, dcomplex alpha, dcomplex *x, int incx) { F77_zscal(&M, blacplx(&alpha), blacplx(x), &incx); }
195
196 void swap(int N, double *x, int incx, double *Y, int incy) { F77_dswap(&N, x, &incx, Y, &incy); } // NOLINT (this is a BLAS swap)
197 void swap(int N, dcomplex *x, int incx, dcomplex *Y, int incy) { // NOLINT (this is a BLAS swap)
198 F77_zswap(&N, blacplx(x), &incx, blacplx(Y), &incy);
199 }
200
201} // namespace nda::blas::f77
#define F77_ddot
#define F77_zdotu
#define F77_zdotc
#define F77_dgemv
Definition cblas_f77.h:186
#define F77_zcopy
Definition cblas_f77.h:116
#define F77_zgemm
Definition cblas_f77.h:237
#define F77_dcopy
Definition cblas_f77.h:108
#define F77_zswap
Definition cblas_f77.h:115
#define F77_zaxpy
Definition cblas_f77.h:117
#define F77_zscal
Definition cblas_f77.h:125
#define F77_dgemm
Definition cblas_f77.h:225
#define F77_dger
Definition cblas_f77.h:155
#define FINT
Definition cblas_f77.h:86
#define F77_dscal
Definition cblas_f77.h:123
#define F77_zgemv
Definition cblas_f77.h:202
#define F77_dswap
Definition cblas_f77.h:107
#define F77_daxpy
Definition cblas_f77.h:109
#define F77_zgeru
Definition cblas_f77.h:172
#define F77_GLOBAL(lcname, UCNAME)