18
19
20
24#include "./interface/cxx_interface.hpp"
25#include "../concepts.hpp"
26#include "../macros.hpp"
27#include "../mapped_functions.hpp"
28#include "../mem/address_space.hpp"
29#include "../traits.hpp"
31#ifndef NDA_HAVE_DEVICE
32#include "../device.hpp"
40
41
42
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59 template <
typename X,
typename Y>
60 requires((Scalar<X>
or MemoryVector<X>)
and (Scalar<Y>
or MemoryVector<X>))
61 auto dot(X
const &x, Y
const &y) {
62 if constexpr (Scalar<X>
or Scalar<Y>) {
66 static_assert(have_same_value_type_v<X, Y>,
"Error in nda::blas::dot: Incompatible value types");
67 static_assert(mem::have_compatible_addr_space<X, Y>,
"Error in nda::blas::dot: Incompatible memory address spaces");
68 static_assert(is_blas_lapack_v<get_value_t<X>>,
"Error in nda::blas::dot: Value types incompatible with blas");
71 EXPECTS(x.shape() == y.shape());
73 if constexpr (mem::have_device_compatible_addr_space<X, Y>) {
74#if defined(NDA_HAVE_DEVICE)
75 return device::dot(x.size(), x.data(), x.indexmap().strides()[0], y.data(), y.indexmap().strides()[0]);
78 return get_value_t<X>(0);
81 return f77::dot(x.size(), x.data(), x.indexmap().strides()[0], y.data(), y.indexmap().strides()[0]);
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101 template <
typename X,
typename Y>
102 requires((Scalar<X>
or MemoryVector<X>)
and (Scalar<Y>
or MemoryVector<X>))
103 auto dotc(X
const &x, Y
const &y) {
104 if constexpr (Scalar<X>
or Scalar<Y>) {
108 static_assert(have_same_value_type_v<X, Y>,
"Error in nda::blas::dotc: Incompatible value types");
109 static_assert(mem::have_compatible_addr_space<X, Y>,
"Error in nda::blas::dotc: Incompatible memory address spaces");
110 static_assert(is_blas_lapack_v<get_value_t<X>>,
"Error in nda::blas::dotc: Value types incompatible with blas");
113 EXPECTS(x.shape() == y.shape());
115 if constexpr (!is_complex_v<get_value_t<X>>) {
117 }
else if constexpr (mem::have_device_compatible_addr_space<X, Y>) {
118#if defined(NDA_HAVE_DEVICE)
119 return device::dotc(x.size(), x.data(), x.indexmap().strides()[0], y.data(), y.indexmap().strides()[0]);
122 return get_value_t<X>(0);
125 return f77::dotc(x.size(), x.data(), x.indexmap().strides()[0], y.data(), y.indexmap().strides()[0]);
133 template <
bool star,
typename X,
typename Y>
134 auto _dot_impl(X
const &x, Y
const &y) {
135 EXPECTS(x.shape() == y.shape());
136 long N = x.shape()[0];
138 auto _conj = [](
auto z)
__attribute__((always_inline)) {
139 if constexpr (star
and is_complex_v<
decltype(z)>) {
145 if constexpr (has_layout_smallest_stride_is_one<X>
and has_layout_smallest_stride_is_one<Y>) {
146 if constexpr (is_regular_or_view_v<X>
and is_regular_or_view_v<Y>) {
147 auto *
__restrict px = x.data();
148 auto *
__restrict py = y.data();
149 auto res = _conj(px[0]) * py[0];
150 for (size_t i = 1; i < N; ++i) { res += _conj(px[i]) * py[i]; }
158 auto res = _conj(x(0)) * y(0);
159 for (
long i = 1; i < N; ++i) { res += _conj(x(i)) * y(i); }
167
168
169
170
171
172
173
174
175 template <
typename X,
typename Y>
177 if constexpr (Scalar<X>
or Scalar<Y>) {
180 return detail::_dot_impl<
false>(x, y);
185
186
187
188
189
190
191
192
193 template <
typename X,
typename Y>
195 if constexpr (Scalar<X>
or Scalar<Y>) {
198 return detail::_dot_impl<
true>(x, y);
auto dot_generic(X const &x, Y const &y)
Generic implementation of nda::blas::dot for types not supported by BLAS/LAPACK.
auto dotc(X const &x, Y const &y)
Interface to the BLAS dotc routine.
auto dotc_generic(X const &x, Y const &y)
Generic implementation of nda::blas::dotc for types not supported by BLAS/LAPACK.
auto dot(X const &x, Y const &y)
Interface to the BLAS dot routine.
void compile_error_no_gpu()
Trigger a compilation error in case GPU specific functionality is used without configuring the projec...
A small wrapper around a single long integer to be used as a linear index.