TRIQS/nda 1.3.0
Multi-dimensional array library for C++
Loading...
Searching...
No Matches
dot.hpp
Go to the documentation of this file.
1// Copyright (c) 2019--present, The Simons Foundation
2// This file is part of TRIQS/nda and is licensed under the Apache License, Version 2.0.
3// SPDX-License-Identifier: Apache-2.0
4// See LICENSE in the root of this distribution for details.
5
10
11#pragma once
12
14#include "../concepts.hpp"
15#include "../macros.hpp"
18#include "../traits.hpp"
19
20#ifndef NDA_HAVE_DEVICE
21#include "../device.hpp"
22#endif
23
24#include <complex>
25
26namespace nda::blas {
27
32
48 template <typename X, typename Y>
49 requires((Scalar<X> or MemoryVector<X>) and (Scalar<Y> or MemoryVector<X>))
50 auto dot(X const &x, Y const &y) {
51 if constexpr (Scalar<X> or Scalar<Y>) {
52 return x * y;
53 } else {
54 // compile-time checks
55 static_assert(have_same_value_type_v<X, Y>, "Error in nda::blas::dot: Incompatible value types");
56 static_assert(mem::have_compatible_addr_space<X, Y>, "Error in nda::blas::dot: Incompatible memory address spaces");
57 static_assert(is_blas_lapack_v<get_value_t<X>>, "Error in nda::blas::dot: Value types incompatible with blas");
58
59 // runtime check
60 EXPECTS(x.shape() == y.shape());
61
63#if defined(NDA_HAVE_DEVICE)
64 return device::dot(x.size(), x.data(), x.indexmap().strides()[0], y.data(), y.indexmap().strides()[0]);
65#else
67 return get_value_t<X>(0);
68#endif
69 } else {
70 return f77::dot(x.size(), x.data(), x.indexmap().strides()[0], y.data(), y.indexmap().strides()[0]);
71 }
72 }
73 }
74
90 template <typename X, typename Y>
92 auto dotc(X const &x, Y const &y) {
93 if constexpr (Scalar<X> or Scalar<Y>) {
94 return conj(x) * y;
95 } else {
96 // compile-time checks
97 static_assert(have_same_value_type_v<X, Y>, "Error in nda::blas::dotc: Incompatible value types");
98 static_assert(mem::have_compatible_addr_space<X, Y>, "Error in nda::blas::dotc: Incompatible memory address spaces");
99 static_assert(is_blas_lapack_v<get_value_t<X>>, "Error in nda::blas::dotc: Value types incompatible with blas");
100
101 // runtime check
102 EXPECTS(x.shape() == y.shape());
103
104 if constexpr (!is_complex_v<get_value_t<X>>) {
105 return dot(x, y);
107#if defined(NDA_HAVE_DEVICE)
108 return device::dotc(x.size(), x.data(), x.indexmap().strides()[0], y.data(), y.indexmap().strides()[0]);
109#else
111 return get_value_t<X>(0);
112#endif
113 } else {
114 return f77::dotc(x.size(), x.data(), x.indexmap().strides()[0], y.data(), y.indexmap().strides()[0]);
115 }
116 }
117 }
118
119 namespace detail {
120
121 // Implementation of the nda::dot_generic and nda::dotc_generic functions.
122 template <bool star, typename X, typename Y>
123 auto _dot_impl(X const &x, Y const &y) {
124 EXPECTS(x.shape() == y.shape());
125 long N = x.shape()[0];
126
127 auto _conj = [](auto z) __attribute__((always_inline)) {
128 if constexpr (star and is_complex_v<decltype(z)>) {
129 return std::conj(z);
130 } else
131 return z;
132 };
133
136 auto *__restrict px = x.data();
137 auto *__restrict py = y.data();
138 auto res = _conj(px[0]) * py[0];
139 for (size_t i = 1; i < N; ++i) { res += _conj(px[i]) * py[i]; }
140 return res;
141 } else {
142 auto res = _conj(x(_linear_index_t{0})) * y(_linear_index_t{0});
143 for (long i = 1; i < N; ++i) { res += _conj(x(_linear_index_t{i})) * y(_linear_index_t{i}); }
144 return res;
145 }
146 } else {
147 auto res = _conj(x(0)) * y(0);
148 for (long i = 1; i < N; ++i) { res += _conj(x(i)) * y(i); }
149 return res;
150 }
151 }
152
153 } // namespace detail
154
164 template <typename X, typename Y>
165 auto dot_generic(X const &x, Y const &y) {
166 if constexpr (Scalar<X> or Scalar<Y>) {
167 return x * y;
168 } else {
169 return detail::_dot_impl<false>(x, y);
170 }
171 }
172
182 template <typename X, typename Y>
183 auto dotc_generic(X const &x, Y const &y) {
184 if constexpr (Scalar<X> or Scalar<Y>) {
185 return conj(x) * y;
186 } else {
187 return detail::_dot_impl<true>(x, y);
188 }
189 }
190
192
193} // namespace nda::blas
Provides definitions and type traits involving the different memory address spaces supported by nda.
Provides a C++ interface for various BLAS routines.
Check if a given type is a memory vector, i.e. an nda::MemoryArrayOfRank<1>.
Definition concepts.hpp:314
Check if a given type is either an arithmetic or complex type.
Definition concepts.hpp:108
Provides concepts for the nda library.
Provides GPU and non-GPU specific functionality.
decltype(auto) conj(A &&a)
Function conj for nda::ArrayOrScalar types (lazy and coefficient-wise for nda::Array types with a com...
constexpr bool have_same_value_type_v
Constexpr variable that is true if all types in As have the same value type as A0.
Definition traits.hpp:186
std::decay_t< decltype(get_first_element(std::declval< A const >()))> get_value_t
Get the value type of an array/view or a scalar type.
Definition traits.hpp:182
constexpr bool is_regular_or_view_v
Constexpr variable that is true if type A is either a regular array or a view.
Definition traits.hpp:153
constexpr bool has_layout_smallest_stride_is_one
Constexpr variable that is true if type A has the smallest_stride_is_one nda::layout_prop_e guarantee...
Definition traits.hpp:328
auto dot_generic(X const &x, Y const &y)
Generic implementation of nda::blas::dot for types not supported by BLAS/LAPACK.
Definition dot.hpp:165
auto dotc(X const &x, Y const &y)
Interface to the BLAS dotc routine.
Definition dot.hpp:92
auto dotc_generic(X const &x, Y const &y)
Generic implementation of nda::blas::dotc for types not supported by BLAS/LAPACK.
Definition dot.hpp:183
auto dot(X const &x, Y const &y)
Interface to the BLAS dot routine.
Definition dot.hpp:50
static constexpr bool have_compatible_addr_space
Constexpr variable that is true if all given types have compatible address spaces.
static constexpr bool have_device_compatible_addr_space
Constexpr variable that is true if all given types have an address space compatible with Device.
void compile_error_no_gpu()
Trigger a compilation error in case GPU specific functionality is used without configuring the projec...
Definition device.hpp:36
constexpr bool is_complex_v
Constexpr variable that is true if type T is a std::complex type.
Definition traits.hpp:65
constexpr bool is_blas_lapack_v
Alias for nda::is_double_or_complex_v.
Definition traits.hpp:92
Macros used in the nda library.
Provides some custom implementations of standard mathematical functions used for lazy,...
Provides type traits for the nda library.