TRIQS/nda 2.0.0
Multi-dimensional array library for C++
Loading...
Searching...
No Matches
dot.hpp
Go to the documentation of this file.
1// Copyright (c) 2022--present, The Simons Foundation
2// This file is part of TRIQS/nda and is licensed under the Apache License, Version 2.0.
3// SPDX-License-Identifier: Apache-2.0
4// See LICENSE in the root of this distribution for details.
5
10
11#pragma once
12
13#include "../blas/dot.hpp"
14#include "../concepts.hpp"
15#include "../macros.hpp"
17#include "../traits.hpp"
18
19#include <complex>
20#include <cstddef>
21
22namespace nda::linalg {
23
28
47 template <bool star = false, Vector X, Vector Y>
48 requires(Scalar<get_value_t<X>> and Scalar<get_value_t<Y>> and mem::have_host_compatible_addr_space<X, Y>)
49 auto dot_generic(X const &x, Y const &y) {
50 // check the dimensions of the input arrays/views
51 EXPECTS(x.size() == y.size());
52
53 // conditional conjugation
54 auto cond_conj = [](auto z) __attribute__((always_inline)) {
55 if constexpr (star and is_complex_v<decltype(z)>) {
56 return std::conj(z);
57 } else {
58 return z;
59 }
60 };
61
62 // early return for zero-sized vectors
63 long const N = x.size();
64 if (N == 0) return decltype(cond_conj(x(0)) * y(0)){0};
65
66 // loop over vectors and sum up element-wise products
69 auto *__restrict px = x.data();
70 auto *__restrict py = y.data();
71 auto res = cond_conj(px[0]) * py[0];
72 for (size_t i = 1; i < N; ++i) res += cond_conj(px[i]) * py[i];
73 return res;
74 } else {
75 auto res = cond_conj(x(_linear_index_t{0})) * y(_linear_index_t{0});
76 for (long i = 1; i < N; ++i) res += cond_conj(x(_linear_index_t{i})) * y(_linear_index_t{i});
77 return res;
78 }
79 } else {
80 auto res = cond_conj(x(0)) * y(0);
81 for (long i = 1; i < N; ++i) res += cond_conj(x(i)) * y(i);
82 return res;
83 }
84 }
85
106 template <typename X, typename Y>
107 requires((Scalar<X> and Scalar<Y>) or (Vector<X> and Vector<Y>))
108 auto dot(X const &x, Y const &y) {
109 if constexpr (Scalar<X>) {
110 return x * y;
111 } else if constexpr (requires { nda::blas::dot(x, y); }) {
112 return nda::blas::dot(x, y);
113 } else {
114 return dot_generic<false>(x, y);
115 }
116 }
117
137 template <typename X, typename Y>
138 requires((Scalar<X> and Scalar<Y>) or (Vector<X> and Vector<Y>))
139 auto dotc(X const &x, Y const &y) {
140 if constexpr (Scalar<X>) {
141 if constexpr (is_complex_v<X>) return std::conj(x) * y;
142 return x * y;
143 } else if constexpr (requires { nda::blas::dotc(x, y); }) {
144 return nda::blas::dotc(x, y);
145 } else {
146 return dot_generic<true>(x, y);
147 }
148 }
149
151
152} // namespace nda::linalg
Provides definitions and type traits involving the different memory address spaces supported by nda.
Provides a generic interface to the BLAS/cuBLAS dot, dotu and dotc routines.
Check if a given type is either an arithmetic or complex type.
Definition concepts.hpp:83
Check if a given type is a vector, i.e. an nda::ArrayOfRank<1>.
Definition concepts.hpp:280
Provides concepts for the nda library.
constexpr bool is_regular_or_view_v
Constexpr variable that is true if type A is either a regular array or a view.
Definition traits.hpp:183
constexpr bool has_layout_smallest_stride_is_one
Constexpr variable that is true if type A has the smallest_stride_is_one nda::layout_prop_e guarantee...
Definition traits.hpp:367
auto dotc(X const &x, Y const &y)
Interface to the BLAS/cuBLAS dotc routine.
Definition dot.hpp:72
auto dot(X const &x, Y const &y)
Interface to the BLAS/cuBLAS dot and dotu routines.
Definition dot.hpp:44
auto dotc(X const &x, Y const &y)
Compute the dotc (LHS operand is conjugated) product of two nda::vector objects or the product of two...
Definition dot.hpp:139
auto dot(X const &x, Y const &y)
Compute the dot product of two nda::vector objects or the product of two scalars.
Definition dot.hpp:108
auto dot_generic(X const &x, Y const &y)
Generic loop-based dot product implementation for vectors.
Definition dot.hpp:49
static constexpr bool have_host_compatible_addr_space
Constexpr variable that is true if all given types have an address space compatible with Host.
constexpr bool is_complex_v
Constexpr variable that is true if type T is a std::complex type.
Definition traits.hpp:65
Macros used in the nda library.
A small wrapper around a single long integer to be used as a linear index.
Definition traits.hpp:372
Provides type traits for the nda library.