TRIQS/nda 1.3.0
Multi-dimensional array library for C++
Loading...
Searching...
No Matches
map.hpp
Go to the documentation of this file.
1// Copyright (c) 2019-2022 Simons Foundation
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7// http://www.apache.org/licenses/LICENSE-2.0.txt
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14//
15// Authors: Olivier Parcollet, Nils Wentzell
16
17/**
18 * @file
19 * @brief Provides lazy function calls on arrays/views.
20 */
21
22#pragma once
23
24#include "./concepts.hpp"
25#include "./layout/range.hpp"
26#include "./macros.hpp"
27#include "./traits.hpp"
28
29#include <cstddef>
30#include <utility>
31#include <tuple>
32
33namespace nda {
34
35 /// @cond
36 // Forward declarations.
37 template <typename F, Array... A>
38 struct expr_call;
39
40 template <class F>
41 struct mapped;
42 /// @endcond
43
44 namespace detail {
45
46 // Implementation of the nda::get_algebra trait for function call expressions.
47 template <typename... Char>
48 constexpr char _impl_find_common_algebra(char x0, Char... x) {
49 return (((x == x0) && ...) ? x0 : 'N');
50 }
51
52 } // namespace detail
53
54 /**
55 * @ingroup av_utils
56 * @brief Get the resulting algebra of a function call expression involving arrays/views.
57 *
58 * @details If one of the algebras of the arguments is different, the resulting algebra is 'N'.
59 *
60 * @tparam F Callable object of the expression.
61 * @tparam As nda::Array argument types.
62 */
63 template <typename F, Array... As>
64 constexpr char get_algebra<expr_call<F, As...>> = detail::_impl_find_common_algebra(get_algebra<As>...);
65
66 /**
67 * @addtogroup av_math
68 * @{
69 */
70
71 /**
72 * @brief A lazy function call expression on arrays/views.
73 *
74 * @details The lazy expression call fulfils the nda::Array concept and can therefore be assigned to other
75 * nda::basic_array or nda::basic_array_view objects. For example:
76 *
77 * @code{.cpp}
78 * nda::matrix<int> mat{{1, 2}, {3, 4}};
79 * nda::matrix<int> pmat = nda::pow(mat, 2);
80 * @endcode
81 *
82 * Here, `nda::pow(mat, 2)` returns a lazy expression call object which is then used in the constructor of `pmat`.
83 *
84 * The callable object should take the array/view elements as arguments.
85 *
86 * @tparam F Callable type.
87 * @tparam As nda::Array argument types.
88 */
89 template <typename F, Array... As>
90 struct expr_call {
91 /// Callable object of the expression.
92 F f;
93
94 /// Tuple containing the nda::Array arguments.
95 std::tuple<const As...> a;
96
97 private:
98 // Implementation of the function call operator.
99 template <size_t... Is, typename... Args>
100 [[gnu::always_inline]] [[nodiscard]] auto _call(std::index_sequence<Is...>, Args const &...args) const {
101 // if args contains a range, we need to return an expr_call on the resulting slice
102 if constexpr ((is_range_or_ellipsis<Args> or ... or false)) {
103 return mapped<F>{f}(std::get<Is>(a)(args...)...);
104 } else {
105 return f(std::get<Is>(a)(args...)...);
106 }
107 }
108
109 // Implementation of the subscript operator.
110 template <size_t... Is, typename Arg>
111 [[gnu::always_inline]] auto _call_bra(std::index_sequence<Is...>, Arg const &arg) const {
112 return f(std::get<Is>(a)[arg]...);
113 }
114
115 public:
116 /**
117 * @brief Function call operator.
118 *
119 * @details The arguments (usually multi-dimensional indices) are passed to all the nda::Array objects stored in the
120 * tuple and the results are then passed to the callable object.
121 *
122 * If the arguments contain a range, a new lazy function call expression is returned.
123 *
124 * @tparam Args Argument types.
125 * @param args Function call arguments.
126 * @return The result of the function call (depends on the callable and the arguments).
127 */
128 template <typename... Args>
129 auto operator()(Args const &...args) const {
130 return _call(std::make_index_sequence<sizeof...(As)>{}, args...);
131 }
132
133 /**
134 * @brief Subscript operator.
135 *
136 * @details The argument (usually a 1-dimensional index) is passed to all the nda::Array objects stored in the tuple
137 * and the results are then passed to the callable object.
138 *
139 * If the argument is a range, a new lazy function call expression is returned.
140 *
141 * @tparam Arg Argument types.
142 * @param arg Subscript argument.
143 * @return The result of the subscript operation (depends on the callable and the arguments).
144 */
145 template <typename Arg>
146 auto operator[](Arg const &arg) const {
147 return _call_bra(std::make_index_sequence<sizeof...(As)>{}, arg);
148 }
149
150 // FIXME copy needed for the && case only. Overload ?
151 /**
152 * @brief Get the shape of the nda::Array objects.
153 * @return `std::array<long, Rank>` object specifying the shape of each nda::Array object.
154 */
155 [[nodiscard]] auto shape() const { return std::get<0>(a).shape(); }
156
157 /**
158 * @brief Get the total size of the nda::Array objects.
159 * @return Number of elements contained in each nda::Array object.
160 */
161 [[nodiscard]] long size() const { return std::get<0>(a).size(); }
162 };
163
164 /**
165 * @brief Functor that is returned by the nda::map function.
166 * @tparam F Callable type.
167 */
168 template <class F>
169 struct mapped {
170 /// Callable object.
171 F f;
172
173 /**
174 * @brief Function call operator that returns a lazy function call expression.
175 *
176 * @tparam A0 First nda::Array argument type.
177 * @tparam As Rest of the nda::Array argument types.
178 * @param a0 First nda::Array argument.
179 * @param as Rest of the nda::Array arguments.
180 * @return A lazy nda::expr_call object.
181 */
182 template <Array A0, Array... As>
183 expr_call<F, A0, As...> operator()(A0 &&a0, As &&...as) const {
184 EXPECTS(((as.shape() == a0.shape()) && ...)); // same shape
185 return {f, {std::forward<A0>(a0), std::forward<As>(as)...}};
186 }
187 };
188
189 /**
190 * @brief Create a lazy function call expression on arrays/views.
191 *
192 * @details The callable should take the array/view elements as arguments.
193 *
194 * @tparam F Callable type.
195 * @param f Callable object.
196 * @return A lazy nda::mapped object.
197 */
198 template <class F>
199 mapped<F> map(F f) {
200 return {std::move(f)};
201 }
202
203 /** @} */
204
205} // namespace nda
#define CUBLAS_CHECK(X,...)
#define NDA_RUNTIME_ERROR
mapped< F > map(F f)
Create a lazy function call expression on arrays/views.
Definition map.hpp:199
constexpr bool is_regular_v
Constexpr variable that is true if type A is a regular array, i.e. an nda::basic_array.
Definition traits.hpp:145
constexpr char get_algebra
Constexpr variable that specifies the algebra of a type.
Definition traits.hpp:126
constexpr bool is_matrix_or_view_v
Constexpr variable that is true if type A is a regular matrix or a view of a matrix.
Definition traits.hpp:167
constexpr bool have_same_value_type_v
Constexpr variable that is true if all types in As have the same value type as A0.
Definition traits.hpp:196
constexpr int get_rank
Constexpr variable that specifies the rank of an nda::Array or of a contiguous 1-dimensional range.
Definition traits.hpp:136
constexpr bool have_same_rank_v
Constexpr variable that is true if all types in As have the same rank as A0.
Definition traits.hpp:200
constexpr bool is_view_v
Constexpr variable that is true if type A is a view, i.e. an nda::basic_array_view.
Definition traits.hpp:154
constexpr bool is_regular_or_view_v
Constexpr variable that is true if type A is either a regular array or a view.
Definition traits.hpp:163
constexpr char get_algebra< expr_call< F, As... > >
Get the resulting algebra of a function call expression involving arrays/views.
Definition map.hpp:64
decltype(auto) get_first_element(A const &a)
Get the first element of an array/view or simply return the scalar if a scalar is given.
Definition traits.hpp:177
constexpr bool layout_property_compatible(layout_prop_e from, layout_prop_e to)
Checks if two layout properties are compatible with each other.
Definition traits.hpp:237
constexpr bool has_contiguous(layout_prop_e lp)
Checks if a layout property has the contiguous property.
Definition traits.hpp:282
constexpr bool has_layout_smallest_stride_is_one
Constexpr variable that is true if type A has the smallest_stride_is_one nda::layout_prop_e guarantee...
Definition traits.hpp:338
constexpr bool has_strided_1d(layout_prop_e lp)
Checks if a layout property has the strided_1d property.
Definition traits.hpp:266
constexpr bool has_layout_strided_1d
Constexpr variable that is true if type A has the strided_1d nda::layout_prop_e guarantee.
Definition traits.hpp:334
constexpr layout_prop_e operator&(layout_prop_e lhs, layout_prop_e rhs)
Bitwise AND operator for two layout properties.
Definition traits.hpp:258
constexpr layout_info_t operator&(layout_info_t lhs, layout_info_t rhs)
Bitwise AND operator for layout infos.
Definition traits.hpp:312
constexpr layout_prop_e operator|(layout_prop_e lhs, layout_prop_e rhs)
Bitwise OR operator for two layout properties.
Definition traits.hpp:249
constexpr layout_info_t get_layout_info
Constexpr variable that specifies the nda::layout_info_t of type A.
Definition traits.hpp:321
constexpr bool has_smallest_stride_is_one(layout_prop_e lp)
Checks if a layout property has the smallest_stride_is_one property.
Definition traits.hpp:274
constexpr bool has_contiguous_layout
Constexpr variable that is true if type A has the contiguous nda::layout_prop_e guarantee.
Definition traits.hpp:330
layout_prop_e
Compile-time guarantees of the memory layout of an array/view.
Definition traits.hpp:222
int get_ld(A const &a)
Get the leading dimension in LAPACK jargon of an nda::MemoryMatrix.
Definition tools.hpp:109
static constexpr bool has_C_layout
Constexpr variable that is true if the given nda::Array type has a C memory layout.
Definition tools.hpp:76
static constexpr bool is_conj_array_expr
Constexpr variable that is true if the given type is a conjugate lazy expression.
Definition tools.hpp:52
int get_ncols(A const &a)
Get the number of columns in LAPACK jargon of an nda::MemoryMatrix.
Definition tools.hpp:121
static constexpr bool is_conj_array_expr< expr_call< conj_f, A > >
Specialization of nda::blas::is_conj_array_expr for the conjugate lazy expressions.
Definition tools.hpp:56
static constexpr bool has_F_layout
Constexpr variable that is true if the given nda::Array type has a Fortran memory layout.
Definition tools.hpp:66
const char get_op
Variable template that determines the BLAS matrix operation tag ('N','T','C') based on the given bool...
Definition tools.hpp:91
AddressSpace
Enum providing identifiers for the different memory address spaces.
constexpr bool is_instantiation_of_v
Constexpr variable that is true if type T is an instantiation of TMPLT (see nda::is_instantiation_of)...
Definition traits.hpp:59
constexpr bool is_complex_v
Constexpr variable that is true if type T is a std::complex type.
Definition traits.hpp:75
constexpr bool is_blas_lapack_v
Alias for nda::is_double_or_complex_v.
Definition traits.hpp:102
static constexpr bool always_false
Constexpr variable that is always false regardless of the types in Ts (used to trigger static_assert)...
Definition traits.hpp:71
constexpr bool is_scalar_for_v
Constexpr variable used to check requirements when initializing an nda::basic_array or nda::basic_arr...
Definition traits.hpp:93
static constexpr bool is_any_of
Constexpr variable that is true if type T is contained in the parameter pack Ts.
Definition traits.hpp:63
constexpr bool is_double_or_complex_v
Constexpr variable that is true if type T is a std::complex type or a double type.
Definition traits.hpp:98
static constexpr bool always_true
Constexpr variable that is always true regardless of the types in Ts.
Definition traits.hpp:67
constexpr bool is_scalar_v
Constexpr variable that is true if type S is a scalar type, i.e. arithmetic or complex.
Definition traits.hpp:79
constexpr bool is_scalar_or_convertible_v
Constexpr variable that is true if type S is a scalar type (see nda::is_scalar_v) or if a std::comple...
Definition traits.hpp:86
#define EXPECTS(X)
Definition macros.hpp:59
#define AS_STRING(...)
Definition macros.hpp:31
A small wrapper around a single long integer to be used as a linear index.
Definition traits.hpp:343
long value
Linear index.
Definition traits.hpp:345
A lazy function call expression on arrays/views.
Definition map.hpp:90
long size() const
Get the total size of the nda::Array objects.
Definition map.hpp:161
auto shape() const
Get the shape of the nda::Array objects.
Definition map.hpp:155
auto operator()(Args const &...args) const
Function call operator.
Definition map.hpp:129
std::tuple< const As... > a
Tuple containing the nda::Array arguments.
Definition map.hpp:95
F f
Callable object of the expression.
Definition map.hpp:92
auto operator[](Arg const &arg) const
Subscript operator.
Definition map.hpp:146
Check if type T is of type TMPLT<...>.
Definition traits.hpp:51
Stores information about the memory layout and the stride order of an array/view.
Definition traits.hpp:295
uint64_t stride_order
Stride order of the array/view.
Definition traits.hpp:297
layout_prop_e prop
Memory layout properties of the array/view.
Definition traits.hpp:300
Functor that is returned by the nda::map function.
Definition map.hpp:169
expr_call< F, A0, As... > operator()(A0 &&a0, As &&...as) const
Function call operator that returns a lazy function call expression.
Definition map.hpp:183
F f
Callable object.
Definition map.hpp:171
Memory block consisting of a pointer and its size.