TRIQS/nda 1.3.0
Multi-dimensional array library for C++
Loading...
Searching...
No Matches
cxx_interface.cpp
Go to the documentation of this file.
1// Copyright (c) 2019--present, The Simons Foundation
2// This file is part of TRIQS/nda and is licensed under the Apache License, Version 2.0.
3// SPDX-License-Identifier: Apache-2.0
4// See LICENSE in the root of this distribution for details.
5
10
11// Extracted from Reference Lapack (https://github.com/Reference-LAPACK):
12#include "./cblas_f77.h"
13#include "./cxx_interface.hpp"
14#include "../tools.hpp"
15
16#include <cstddef>
17
18#ifdef NDA_USE_MKL
19#include "../../basic_array.hpp"
21
22#include <mkl.h>
23
24namespace nda::blas {
25
26#ifdef NDA_USE_MKL_RT
27 static int const mkl_interface_layer = mkl_set_interface_layer(MKL_INTERFACE_LP64 + MKL_INTERFACE_GNU);
28#endif
29 inline auto *mklcplx(nda::dcomplex *c) { return reinterpret_cast<MKL_Complex16 *>(c); } // NOLINT
30 inline auto *mklcplx(nda::dcomplex const *c) { return reinterpret_cast<const MKL_Complex16 *>(c); } // NOLINT
31 inline auto *mklcplx(nda::dcomplex **c) { return reinterpret_cast<MKL_Complex16 **>(c); } // NOLINT
32 inline auto *mklcplx(nda::dcomplex const **c) { return reinterpret_cast<const MKL_Complex16 **>(c); } // NOLINT
33
34} // namespace nda::blas
35#endif
36
37namespace {
38
39 // complex struct which is returned by BLAS functions
40 struct nda_complex_double {
41 double real;
42 double imag;
43 };
44
45} // namespace
46
47// manually define dot routines since cblas_f77.h uses "_sub" to wrap the Fortran routines
48#define F77_ddot F77_GLOBAL(ddot, DDOT)
49#define F77_zdotu F77_GLOBAL(zdotu, ZDOTU)
50#define F77_zdotc F77_GLOBAL(zdotc, ZDOTC)
51extern "C" {
52double F77_ddot(FINT, const double *, FINT, const double *, FINT);
53nda_complex_double F77_zdotu(FINT, const double *, FINT, const double *, FINT);
54nda_complex_double F77_zdotc(FINT, const double *, FINT, const double *, FINT);
55}
56
57namespace nda::blas::f77 {
58
59 inline auto *blacplx(dcomplex *c) { return reinterpret_cast<double *>(c); } // NOLINT
60 inline auto *blacplx(dcomplex const *c) { return reinterpret_cast<const double *>(c); } // NOLINT
61 inline auto **blacplx(dcomplex **c) { return reinterpret_cast<double **>(c); } // NOLINT
62 inline auto **blacplx(dcomplex const **c) { return reinterpret_cast<const double **>(c); } // NOLINT
63
64 void axpy(int N, double alpha, const double *x, int incx, double *Y, int incy) { F77_daxpy(&N, &alpha, x, &incx, Y, &incy); }
65 void axpy(int N, dcomplex alpha, const dcomplex *x, int incx, dcomplex *Y, int incy) {
66 F77_zaxpy(&N, blacplx(&alpha), blacplx(x), &incx, blacplx(Y), &incy);
67 }
68
69 // No Const In Wrapping!
70 void copy(int N, const double *x, int incx, double *Y, int incy) { F77_dcopy(&N, x, &incx, Y, &incy); }
71 void copy(int N, const dcomplex *x, int incx, dcomplex *Y, int incy) { F77_zcopy(&N, blacplx(x), &incx, blacplx(Y), &incy); }
72
73 double dot(int M, const double *x, int incx, const double *Y, int incy) { return F77_ddot(&M, x, &incx, Y, &incy); }
74 dcomplex dot(int M, const dcomplex *x, int incx, const dcomplex *Y, int incy) {
75#ifdef NDA_USE_MKL
76 MKL_Complex16 result;
77 cblas_zdotu_sub(M, mklcplx(x), incx, mklcplx(Y), incy, &result);
78#else
79 auto result = F77_zdotu(&M, blacplx(x), &incx, blacplx(Y), &incy);
80#endif
81 return dcomplex{result.real, result.imag};
82 }
83 dcomplex dotc(int M, const dcomplex *x, int incx, const dcomplex *Y, int incy) {
84#ifdef NDA_USE_MKL
85 MKL_Complex16 result;
86 cblas_zdotc_sub(M, mklcplx(x), incx, mklcplx(Y), incy, &result);
87#else
88 auto result = F77_zdotc(&M, blacplx(x), &incx, blacplx(Y), &incy);
89#endif
90 return dcomplex{result.real, result.imag};
91 }
92
93 void gemm(char op_a, char op_b, int M, int N, int K, double alpha, const double *A, int LDA, const double *B, int LDB, double beta, double *C,
94 int LDC) {
95 F77_dgemm(&op_a, &op_b, &M, &N, &K, &alpha, A, &LDA, B, &LDB, &beta, C, &LDC);
96 }
97 void gemm(char op_a, char op_b, int M, int N, int K, dcomplex alpha, const dcomplex *A, int LDA, const dcomplex *B, int LDB, dcomplex beta,
98 dcomplex *C, int LDC) {
99 F77_zgemm(&op_a, &op_b, &M, &N, &K, blacplx(&alpha), blacplx(A), &LDA, blacplx(B), &LDB, blacplx(&beta), blacplx(C), &LDC);
100 }
101
102 void gemm_batch(char op_a, char op_b, int M, int N, int K, double alpha, const double **A, int LDA, const double **B, int LDB, double beta,
103 double **C, int LDC, int batch_count) {
104#ifdef NDA_USE_MKL
105 const int group_count = 1;
106 dgemm_batch(&op_a, &op_b, &M, &N, &K, &alpha, A, &LDA, B, &LDB, &beta, C, &LDC, &group_count, &batch_count);
107#else // Fallback to loop
108 for (int i = 0; i < batch_count; ++i) gemm(op_a, op_b, M, N, K, alpha, A[i], LDA, B[i], LDB, beta, C[i], LDC);
109#endif
110 }
111 void gemm_batch(char op_a, char op_b, int M, int N, int K, dcomplex alpha, const dcomplex **A, int LDA, const dcomplex **B, int LDB, dcomplex beta,
112 dcomplex **C, int LDC, int batch_count) {
113#ifdef NDA_USE_MKL
114 const int group_count = 1;
115 zgemm_batch(&op_a, &op_b, &M, &N, &K, mklcplx(&alpha), mklcplx(A), &LDA, mklcplx(B), &LDB, mklcplx(&beta), mklcplx(C), &LDC, &group_count,
116 &batch_count);
117#else
118 for (int i = 0; i < batch_count; ++i) gemm(op_a, op_b, M, N, K, alpha, A[i], LDA, B[i], LDB, beta, C[i], LDC);
119#endif
120 }
121
122 void gemm_vbatch(char op_a, char op_b, int *M, int *N, int *K, double alpha, const double **A, int *LDA, const double **B, int *LDB, double beta,
123 double **C, int *LDC, int batch_count) {
124#ifdef NDA_USE_MKL
125 nda::vector<int> group_size(batch_count, 1);
126 nda::vector<char> ops_a(batch_count, op_a), ops_b(batch_count, op_b);
127 nda::vector<double> alphas(batch_count, alpha), betas(batch_count, beta);
128 dgemm_batch(ops_a.data(), ops_b.data(), M, N, K, alphas.data(), A, LDA, B, LDB, betas.data(), C, LDC, &batch_count, group_size.data());
129#else
130 for (int i = 0; i < batch_count; ++i) gemm(op_a, op_b, M[i], N[i], K[i], alpha, A[i], LDA[i], B[i], LDB[i], beta, C[i], LDC[i]);
131#endif
132 }
133 void gemm_vbatch(char op_a, char op_b, int *M, int *N, int *K, dcomplex alpha, const dcomplex **A, int *LDA, const dcomplex **B, int *LDB,
134 dcomplex beta, dcomplex **C, int *LDC, int batch_count) {
135#ifdef NDA_USE_MKL
136 nda::vector<int> group_size(batch_count, 1);
137 nda::vector<char> ops_a(batch_count, op_a), ops_b(batch_count, op_b);
138 nda::vector<dcomplex> alphas(batch_count, alpha), betas(batch_count, beta);
139 zgemm_batch(ops_a.data(), ops_b.data(), M, N, K, mklcplx(alphas.data()), mklcplx(A), LDA, mklcplx(B), LDB, mklcplx(betas.data()), mklcplx(C), LDC,
140 &batch_count, group_size.data());
141#else
142 for (int i = 0; i < batch_count; ++i) gemm(op_a, op_b, M[i], N[i], K[i], alpha, A[i], LDA[i], B[i], LDB[i], beta, C[i], LDC[i]);
143#endif
144 }
145
146 void gemm_batch_strided(char op_a, char op_b, int M, int N, int K, double alpha, const double *A, int LDA, int strideA, const double *B, int LDB,
147 int strideB, double beta, double *C, int LDC, int strideC, int batch_count) {
148#if defined(NDA_USE_MKL) && INTEL_MKL_VERSION >= 20200002
149 dgemm_batch_strided(&op_a, &op_b, &M, &N, &K, &alpha, A, &LDA, &strideA, B, &LDB, &strideB, &beta, C, &LDC, &strideC, &batch_count);
150#else
151 for (int i = 0; i < batch_count; ++i)
152 gemm(op_a, op_b, M, N, K, alpha, A + static_cast<ptrdiff_t>(i * strideA), LDA, B + static_cast<ptrdiff_t>(i * strideB), LDB, beta,
153 C + static_cast<ptrdiff_t>(i * strideC), LDC);
154#endif
155 }
156 void gemm_batch_strided(char op_a, char op_b, int M, int N, int K, dcomplex alpha, const dcomplex *A, int LDA, int strideA, const dcomplex *B,
157 int LDB, int strideB, dcomplex beta, dcomplex *C, int LDC, int strideC, int batch_count) {
158#if defined(NDA_USE_MKL) && INTEL_MKL_VERSION >= 20200002
159 zgemm_batch_strided(&op_a, &op_b, &M, &N, &K, mklcplx(&alpha), mklcplx(A), &LDA, &strideA, mklcplx(B), &LDB, &strideB, mklcplx(&beta), mklcplx(C),
160 &LDC, &strideC, &batch_count);
161#else
162 for (int i = 0; i < batch_count; ++i)
163 gemm(op_a, op_b, M, N, K, alpha, A + static_cast<ptrdiff_t>(i * strideA), LDA, B + static_cast<ptrdiff_t>(i * strideB), LDB, beta,
164 C + static_cast<ptrdiff_t>(i * strideC), LDC);
165#endif
166 }
167
168 void gemv(char op, int M, int N, double alpha, const double *A, int LDA, const double *x, int incx, double beta, double *Y, int incy) {
169 F77_dgemv(&op, &M, &N, &alpha, A, &LDA, x, &incx, &beta, Y, &incy);
170 }
171 void gemv(char op, int M, int N, dcomplex alpha, const dcomplex *A, int LDA, const dcomplex *x, int incx, dcomplex beta, dcomplex *Y, int incy) {
172 F77_zgemv(&op, &M, &N, blacplx(&alpha), blacplx(A), &LDA, blacplx(x), &incx, blacplx(&beta), blacplx(Y), &incy);
173 }
174
175 void ger(int M, int N, double alpha, const double *x, int incx, const double *Y, int incy, double *A, int LDA) {
176 F77_dger(&M, &N, &alpha, x, &incx, Y, &incy, A, &LDA);
177 }
178 void ger(int M, int N, dcomplex alpha, const dcomplex *x, int incx, const dcomplex *Y, int incy, dcomplex *A, int LDA) {
179 F77_zgeru(&M, &N, blacplx(&alpha), blacplx(x), &incx, blacplx(Y), &incy, blacplx(A), &LDA);
180 }
181
182 void scal(int M, double alpha, double *x, int incx) { F77_dscal(&M, &alpha, x, &incx); }
183 void scal(int M, dcomplex alpha, dcomplex *x, int incx) { F77_zscal(&M, blacplx(&alpha), blacplx(x), &incx); }
184
185 void swap(int N, double *x, int incx, double *Y, int incy) { F77_dswap(&N, x, &incx, Y, &incy); } // NOLINT (this is a BLAS swap)
186 void swap(int N, dcomplex *x, int incx, dcomplex *Y, int incy) { // NOLINT (this is a BLAS swap)
187 F77_zswap(&N, blacplx(x), &incx, blacplx(Y), &incy);
188 }
189
190} // namespace nda::blas::f77
Provides the generic class for arrays.
Provides a C++ interface for various BLAS routines.
Provides various convenient aliases and helper functions for nda::basic_array and nda::basic_array_vi...
basic_array< ValueType, 1, C_layout, 'V', ContainerPolicy > vector
Alias template of an nda::basic_array with rank 1 and a 'V' algebra.
std::complex< double > dcomplex
Alias for std::complex<double> type.
Definition tools.hpp:28
Provides various traits and utilities for the BLAS interface.