TRIQS/nda 1.3.0
Multi-dimensional array library for C++
Loading...
Searching...
No Matches
dot.hpp
Go to the documentation of this file.
1// Copyright (c) 2022--present, The Simons Foundation
2// This file is part of TRIQS/nda and is licensed under the Apache License, Version 2.0.
3// SPDX-License-Identifier: Apache-2.0
4// See LICENSE in the root of this distribution for details.
5
10
11#pragma once
12
13#include "../blas/dot.hpp"
14#include "../concepts.hpp"
15#include "../macros.hpp"
17#include "../traits.hpp"
18
19#include <complex>
20#include <cstddef>
21
22namespace nda::linalg {
23
28
43 template <bool star = false, Vector X, Vector Y>
44 requires(Scalar<get_value_t<X>> and Scalar<get_value_t<Y>> and mem::have_host_compatible_addr_space<X, Y>)
45 auto dot_generic(X const &x, Y const &y) {
46 // check the dimensions of the input arrays/views
47 EXPECTS(x.size() == y.size());
48
49 // conditional conjugation
50 auto cond_conj = [](auto z) __attribute__((always_inline)) {
51 if constexpr (star and is_complex_v<decltype(z)>) {
52 return std::conj(z);
53 } else {
54 return z;
55 }
56 };
57
58 // early return for zero-sized vectors
59 long const N = x.size();
60 if (N == 0) return decltype(cond_conj(x(0)) * y(0)){0};
61
62 // loop over vectors and sum up element-wise products
65 auto *__restrict px = x.data();
66 auto *__restrict py = y.data();
67 auto res = cond_conj(px[0]) * py[0];
68 for (size_t i = 1; i < N; ++i) res += cond_conj(px[i]) * py[i];
69 return res;
70 } else {
71 auto res = cond_conj(x(_linear_index_t{0})) * y(_linear_index_t{0});
72 for (long i = 1; i < N; ++i) res += cond_conj(x(_linear_index_t{i})) * y(_linear_index_t{i});
73 return res;
74 }
75 } else {
76 auto res = cond_conj(x(0)) * y(0);
77 for (long i = 1; i < N; ++i) res += cond_conj(x(i)) * y(i);
78 return res;
79 }
80 }
81
102 template <typename X, typename Y>
103 requires((Scalar<X> and Scalar<Y>) or (Vector<X> and Vector<Y>))
104 auto dot(X const &x, Y const &y) {
105 if constexpr (Scalar<X>) {
106 return x * y;
107 } else if constexpr (requires { nda::blas::dot(x, y); }) {
108 return nda::blas::dot(x, y);
109 } else {
110 return dot_generic<false>(x, y);
111 }
112 }
113
133 template <typename X, typename Y>
134 requires((Scalar<X> and Scalar<Y>) or (Vector<X> and Vector<Y>))
135 auto dotc(X const &x, Y const &y) {
136 if constexpr (Scalar<X>) {
137 if constexpr (is_complex_v<X>) return std::conj(x) * y;
138 return x * y;
139 } else if constexpr (requires { nda::blas::dotc(x, y); }) {
140 return nda::blas::dotc(x, y);
141 } else {
142 return dot_generic<true>(x, y);
143 }
144 }
145
147
148} // namespace nda::linalg
Provides definitions and type traits involving the different memory address spaces supported by nda.
Provides a generic interface to the BLAS dot, dotu and dotc routine.
Check if a given type is either an arithmetic or complex type.
Definition concepts.hpp:108
Check if a given type is a vector, i.e. an nda::ArrayOfRank<1>.
Definition concepts.hpp:298
Provides concepts for the nda library.
constexpr bool is_regular_or_view_v
Constexpr variable that is true if type A is either a regular array or a view.
Definition traits.hpp:153
constexpr bool has_layout_smallest_stride_is_one
Constexpr variable that is true if type A has the smallest_stride_is_one nda::layout_prop_e guarantee...
Definition traits.hpp:328
auto dot(X const &x, Y const &y)
Interface to the BLAS dot and dotu routine.
Definition dot.hpp:48
auto dotc(X const &x, Y const &y)
Interface to the BLAS dotc routine.
Definition dot.hpp:83
auto dotc(X const &x, Y const &y)
Compute the dotc (LHS operand is conjugated) product of two nda::vector objects or the product of two...
Definition dot.hpp:135
auto dot(X const &x, Y const &y)
Compute the dot product of two nda::vector objects or the product of two scalars.
Definition dot.hpp:104
auto dot_generic(X const &x, Y const &y)
Generic loop-based dot product implementation for vectors.
Definition dot.hpp:45
static constexpr bool have_host_compatible_addr_space
Constexpr variable that is true if all given types have an address space compatible with Host.
constexpr bool is_complex_v
Constexpr variable that is true if type T is a std::complex type.
Definition traits.hpp:65
Macros used in the nda library.
A small wrapper around a single long integer to be used as a linear index.
Definition traits.hpp:333
Provides type traits for the nda library.