TRIQS/nda 1.3.0
Multi-dimensional array library for C++
Loading...
Searching...
No Matches
dot.hpp
Go to the documentation of this file.
1// Copyright (c) 2019-2023 Simons Foundation
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7// http://www.apache.org/licenses/LICENSE-2.0.txt
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14//
15// Authors: Miguel Morales, Olivier Parcollet, Nils Wentzell
16
17/**
18 * @file
19 * @brief Provides a generic interface to the BLAS `dot` routine.
20 */
21
22#pragma once
23
24#include "./interface/cxx_interface.hpp"
25#include "../concepts.hpp"
26#include "../macros.hpp"
27#include "../mapped_functions.hpp"
28#include "../mem/address_space.hpp"
29#include "../traits.hpp"
30
31#ifndef NDA_HAVE_DEVICE
32#include "../device.hpp"
33#endif
34
35#include <complex>
36
37namespace nda::blas {
38
39 /**
40 * @addtogroup linalg_blas
41 * @{
42 */
43
44 /**
45 * @brief Interface to the BLAS `dot` routine.
46 *
47 * @details This function forms the dot product of two vectors. It calculates
48 * - \f$ \mathbf{x}^T \mathbf{y} \f$ in case that both \f$ \mathbf{x} \f$ and \f$ \mathbf{y} \f$ are vectors,
49 * - \f$ x \mathbf{y} \f$ in case that \f$ x \f$ is a scalar and \f$ \mathbf{y} \f$ is a vector,
50 * - \f$ \mathbf{x} y \f$ in case that \f$ \mathbf{x} \f$ is a vector and \f$ y \f$ is a scalar or
51 * - \f$ x y \f$ in case that both \f$ x \f$ and \f$ y \f$ are scalars.
52 *
53 * @tparam X nda::MemoryVector or nda::Scalar type.
54 * @tparam Y nda::MemoryVector or nda::Scalar type.
55 * @param x Input vector/scalar.
56 * @param y Input vector/scalar.
57 * @return Vector/scalar result of the dot product.
58 */
59 template <typename X, typename Y>
60 requires((Scalar<X> or MemoryVector<X>) and (Scalar<Y> or MemoryVector<X>))
61 auto dot(X const &x, Y const &y) {
62 if constexpr (Scalar<X> or Scalar<Y>) {
63 return x * y;
64 } else {
65 // compile-time checks
66 static_assert(have_same_value_type_v<X, Y>, "Error in nda::blas::dot: Incompatible value types");
67 static_assert(mem::have_compatible_addr_space<X, Y>, "Error in nda::blas::dot: Incompatible memory address spaces");
68 static_assert(is_blas_lapack_v<get_value_t<X>>, "Error in nda::blas::dot: Value types incompatible with blas");
69
70 // runtime check
71 EXPECTS(x.shape() == y.shape());
72
73 if constexpr (mem::have_device_compatible_addr_space<X, Y>) {
74#if defined(NDA_HAVE_DEVICE)
75 return device::dot(x.size(), x.data(), x.indexmap().strides()[0], y.data(), y.indexmap().strides()[0]);
76#else
78 return get_value_t<X>(0);
79#endif
80 } else {
81 return f77::dot(x.size(), x.data(), x.indexmap().strides()[0], y.data(), y.indexmap().strides()[0]);
82 }
83 }
84 }
85
86 /**
87 * @brief Interface to the BLAS `dotc` routine.
88 *
89 * @details This function forms the dot product of two vectors. It calculates
90 * - \f$ \mathbf{x}^H \mathbf{y} \f$ in case that both \f$ \mathbf{x} \f$ and \f$ \mathbf{y} \f$ are vectors,
91 * - \f$ \bar{x} \mathbf{y} \f$ in case that \f$ x \f$ is a scalar and \f$ \mathbf{y} \f$ is a vector,
92 * - \f$ \mathbf{x}^H y \f$ in case that \f$ \mathbf{x} \f$ is a vector and \f$ y \f$ is a scalar or
93 * - \f$ \bar{x} y \f$ in case that both \f$ x \f$ and \f$ y \f$ are scalars.
94 *
95 * @tparam X nda::MemoryVector or nda::Scalar type.
96 * @tparam Y nda::MemoryVector or nda::Scalar type.
97 * @param x Input vector/scalar.
98 * @param y Input vector/scalar.
99 * @return Vector/scalar result of the dot product.
100 */
101 template <typename X, typename Y>
102 requires((Scalar<X> or MemoryVector<X>) and (Scalar<Y> or MemoryVector<X>))
103 auto dotc(X const &x, Y const &y) {
104 if constexpr (Scalar<X> or Scalar<Y>) {
105 return conj(x) * y;
106 } else {
107 // compile-time checks
108 static_assert(have_same_value_type_v<X, Y>, "Error in nda::blas::dotc: Incompatible value types");
109 static_assert(mem::have_compatible_addr_space<X, Y>, "Error in nda::blas::dotc: Incompatible memory address spaces");
110 static_assert(is_blas_lapack_v<get_value_t<X>>, "Error in nda::blas::dotc: Value types incompatible with blas");
111
112 // runtime check
113 EXPECTS(x.shape() == y.shape());
114
115 if constexpr (!is_complex_v<get_value_t<X>>) {
116 return dot(x, y);
117 } else if constexpr (mem::have_device_compatible_addr_space<X, Y>) {
118#if defined(NDA_HAVE_DEVICE)
119 return device::dotc(x.size(), x.data(), x.indexmap().strides()[0], y.data(), y.indexmap().strides()[0]);
120#else
122 return get_value_t<X>(0);
123#endif
124 } else {
125 return f77::dotc(x.size(), x.data(), x.indexmap().strides()[0], y.data(), y.indexmap().strides()[0]);
126 }
127 }
128 }
129
130 namespace detail {
131
132 // Implementation of the nda::dot_generic and nda::dotc_generic functions.
133 template <bool star, typename X, typename Y>
134 auto _dot_impl(X const &x, Y const &y) {
135 EXPECTS(x.shape() == y.shape());
136 long N = x.shape()[0];
137
138 auto _conj = [](auto z) __attribute__((always_inline)) {
139 if constexpr (star and is_complex_v<decltype(z)>) {
140 return std::conj(z);
141 } else
142 return z;
143 };
144
145 if constexpr (has_layout_smallest_stride_is_one<X> and has_layout_smallest_stride_is_one<Y>) {
146 if constexpr (is_regular_or_view_v<X> and is_regular_or_view_v<Y>) {
147 auto *__restrict px = x.data();
148 auto *__restrict py = y.data();
149 auto res = _conj(px[0]) * py[0];
150 for (size_t i = 1; i < N; ++i) { res += _conj(px[i]) * py[i]; }
151 return res;
152 } else {
153 auto res = _conj(x(_linear_index_t{0})) * y(_linear_index_t{0});
154 for (long i = 1; i < N; ++i) { res += _conj(x(_linear_index_t{i})) * y(_linear_index_t{i}); }
155 return res;
156 }
157 } else {
158 auto res = _conj(x(0)) * y(0);
159 for (long i = 1; i < N; ++i) { res += _conj(x(i)) * y(i); }
160 return res;
161 }
162 }
163
164 } // namespace detail
165
166 /**
167 * @brief Generic implementation of nda::blas::dot for types not supported by BLAS/LAPACK.
168 *
169 * @tparam X Vector/Scalar type.
170 * @tparam Y Vector/Scalar type.
171 * @param x Input vector/scalar.
172 * @param y Input vector/scalar.
173 * @return Vector/scalar result of the dot product.
174 */
175 template <typename X, typename Y>
176 auto dot_generic(X const &x, Y const &y) {
177 if constexpr (Scalar<X> or Scalar<Y>) {
178 return x * y;
179 } else {
180 return detail::_dot_impl<false>(x, y);
181 }
182 }
183
184 /**
185 * @brief Generic implementation of nda::blas::dotc for types not supported by BLAS/LAPACK.
186 *
187 * @tparam X Vector/Scalar type.
188 * @tparam Y Vector/Scalar type.
189 * @param x Input vector/scalar.
190 * @param y Input vector/scalar.
191 * @return Vector/scalar result of the dot product.
192 */
193 template <typename X, typename Y>
194 auto dotc_generic(X const &x, Y const &y) {
195 if constexpr (Scalar<X> or Scalar<Y>) {
196 return conj(x) * y;
197 } else {
198 return detail::_dot_impl<true>(x, y);
199 }
200 }
201
202 /** @} */
203
204} // namespace nda::blas
auto dot_generic(X const &x, Y const &y)
Generic implementation of nda::blas::dot for types not supported by BLAS/LAPACK.
Definition dot.hpp:176
auto dotc(X const &x, Y const &y)
Interface to the BLAS dotc routine.
Definition dot.hpp:103
auto dotc_generic(X const &x, Y const &y)
Generic implementation of nda::blas::dotc for types not supported by BLAS/LAPACK.
Definition dot.hpp:194
auto dot(X const &x, Y const &y)
Interface to the BLAS dot routine.
Definition dot.hpp:61
void compile_error_no_gpu()
Trigger a compilation error in case GPU specific functionality is used without configuring the projec...
Definition device.hpp:47
#define EXPECTS(X)
Definition macros.hpp:59
A small wrapper around a single long integer to be used as a linear index.
Definition traits.hpp:343