TRIQS/nda 1.3.0
Multi-dimensional array library for C++
Loading...
Searching...
No Matches
dot.hpp
Go to the documentation of this file.
1// Copyright (c) 2019-2023 Simons Foundation
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7// http://www.apache.org/licenses/LICENSE-2.0.txt
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14//
15// Authors: Miguel Morales, Olivier Parcollet, Nils Wentzell
16
22#pragma once
23
25#include "../concepts.hpp"
26#include "../macros.hpp"
29#include "../traits.hpp"
30
31#ifndef NDA_HAVE_DEVICE
32#include "../device.hpp"
33#endif
34
35#include <complex>
36
37namespace nda::blas {
38
59 template <typename X, typename Y>
60 requires((Scalar<X> or MemoryVector<X>) and (Scalar<Y> or MemoryVector<X>))
61 auto dot(X const &x, Y const &y) {
62 if constexpr (Scalar<X> or Scalar<Y>) {
63 return x * y;
64 } else {
65 // compile-time checks
66 static_assert(have_same_value_type_v<X, Y>, "Error in nda::blas::dot: Incompatible value types");
67 static_assert(mem::have_compatible_addr_space<X, Y>, "Error in nda::blas::dot: Incompatible memory address spaces");
68 static_assert(is_blas_lapack_v<get_value_t<X>>, "Error in nda::blas::dot: Value types incompatible with blas");
69
70 // runtime check
71 EXPECTS(x.shape() == y.shape());
72
74#if defined(NDA_HAVE_DEVICE)
75 return device::dot(x.size(), x.data(), x.indexmap().strides()[0], y.data(), y.indexmap().strides()[0]);
76#else
78 return get_value_t<X>(0);
79#endif
80 } else {
81 return f77::dot(x.size(), x.data(), x.indexmap().strides()[0], y.data(), y.indexmap().strides()[0]);
82 }
83 }
84 }
85
101 template <typename X, typename Y>
102 requires((Scalar<X> or MemoryVector<X>) and (Scalar<Y> or MemoryVector<X>))
103 auto dotc(X const &x, Y const &y) {
104 if constexpr (Scalar<X> or Scalar<Y>) {
105 return conj(x) * y;
106 } else {
107 // compile-time checks
108 static_assert(have_same_value_type_v<X, Y>, "Error in nda::blas::dotc: Incompatible value types");
109 static_assert(mem::have_compatible_addr_space<X, Y>, "Error in nda::blas::dotc: Incompatible memory address spaces");
110 static_assert(is_blas_lapack_v<get_value_t<X>>, "Error in nda::blas::dotc: Value types incompatible with blas");
111
112 // runtime check
113 EXPECTS(x.shape() == y.shape());
114
115 if constexpr (!is_complex_v<get_value_t<X>>) {
116 return dot(x, y);
118#if defined(NDA_HAVE_DEVICE)
119 return device::dotc(x.size(), x.data(), x.indexmap().strides()[0], y.data(), y.indexmap().strides()[0]);
120#else
122 return get_value_t<X>(0);
123#endif
124 } else {
125 return f77::dotc(x.size(), x.data(), x.indexmap().strides()[0], y.data(), y.indexmap().strides()[0]);
126 }
127 }
128 }
129
130 namespace detail {
131
132 // Implementation of the nda::dot_generic and nda::dotc_generic functions.
133 template <bool star, typename X, typename Y>
134 auto _dot_impl(X const &x, Y const &y) {
135 EXPECTS(x.shape() == y.shape());
136 long N = x.shape()[0];
137
138 auto _conj = [](auto z) __attribute__((always_inline)) {
139 if constexpr (star and is_complex_v<decltype(z)>) {
140 return std::conj(z);
141 } else
142 return z;
143 };
144
147 auto *__restrict px = x.data();
148 auto *__restrict py = y.data();
149 auto res = _conj(px[0]) * py[0];
150 for (size_t i = 1; i < N; ++i) { res += _conj(px[i]) * py[i]; }
151 return res;
152 } else {
153 auto res = _conj(x(_linear_index_t{0})) * y(_linear_index_t{0});
154 for (long i = 1; i < N; ++i) { res += _conj(x(_linear_index_t{i})) * y(_linear_index_t{i}); }
155 return res;
156 }
157 } else {
158 auto res = _conj(x(0)) * y(0);
159 for (long i = 1; i < N; ++i) { res += _conj(x(i)) * y(i); }
160 return res;
161 }
162 }
163
164 } // namespace detail
165
175 template <typename X, typename Y>
176 auto dot_generic(X const &x, Y const &y) {
177 if constexpr (Scalar<X> or Scalar<Y>) {
178 return x * y;
179 } else {
180 return detail::_dot_impl<false>(x, y);
181 }
182 }
183
193 template <typename X, typename Y>
194 auto dotc_generic(X const &x, Y const &y) {
195 if constexpr (Scalar<X> or Scalar<Y>) {
196 return conj(x) * y;
197 } else {
198 return detail::_dot_impl<true>(x, y);
199 }
200 }
201
204} // namespace nda::blas
Provides definitions and type traits involving the different memory address spaces supported by nda.
Provides a C++ interface for various BLAS routines.
Check if a given type is a memory vector, i.e. an nda::MemoryArrayOfRank<1>.
Definition concepts.hpp:314
Check if a given type is either an arithmetic or complex type.
Definition concepts.hpp:119
Provides concepts for the nda library.
Provides GPU and non-GPU specific functionality.
auto conj(T t)
Get the complex conjugate of a scalar.
constexpr bool have_same_value_type_v
Constexpr variable that is true if all types in As have the same value type as A0.
Definition traits.hpp:196
std::decay_t< decltype(get_first_element(std::declval< A const >()))> get_value_t
Get the value type of an array/view or a scalar type.
Definition traits.hpp:192
constexpr bool is_regular_or_view_v
Constexpr variable that is true if type A is either a regular array or a view.
Definition traits.hpp:163
constexpr bool has_layout_smallest_stride_is_one
Constexpr variable that is true if type A has the smallest_stride_is_one nda::layout_prop_e guarantee...
Definition traits.hpp:338
auto dot_generic(X const &x, Y const &y)
Generic implementation of nda::blas::dot for types not supported by BLAS/LAPACK.
Definition dot.hpp:176
auto dotc(X const &x, Y const &y)
Interface to the BLAS dotc routine.
Definition dot.hpp:103
auto dotc_generic(X const &x, Y const &y)
Generic implementation of nda::blas::dotc for types not supported by BLAS/LAPACK.
Definition dot.hpp:194
auto dot(X const &x, Y const &y)
Interface to the BLAS dot routine.
Definition dot.hpp:61
static constexpr bool have_compatible_addr_space
Constexpr variable that is true if all given types have compatible address spaces.
static constexpr bool have_device_compatible_addr_space
Constexpr variable that is true if all given types have an address space compatible with Device.
void compile_error_no_gpu()
Trigger a compilation error in case GPU specific functionality is used without configuring the projec...
Definition device.hpp:47
constexpr bool is_complex_v
Constexpr variable that is true if type T is a std::complex type.
Definition traits.hpp:75
constexpr bool is_blas_lapack_v
Alias for nda::is_double_or_complex_v.
Definition traits.hpp:102
Macros used in the nda library.
Provides some custom implementations of standard mathematical functions used for lazy,...
Provides type traits for the nda library.