TRIQS/nda 1.3.0
Multi-dimensional array library for C++
Loading...
Searching...
No Matches
slice_static.hpp
Go to the documentation of this file.
1// Copyright (c) 2019-2024 Simons Foundation
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7// http://www.apache.org/licenses/LICENSE-2.0.txt
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14//
15// Authors: Thomas Hahn, Olivier Parcollet, Nils Wentzell
16
22#pragma once
23
24#include "./range.hpp"
25#include "../macros.hpp"
26#include "../stdutil/array.hpp"
27#include "../traits.hpp"
28
29#include <array>
30#include <cstdint>
31#include <cstddef>
32#include <tuple>
33#include <type_traits>
34#include <utility>
35
36#ifdef NDA_ENFORCE_BOUNDCHECK
38#endif
39
40namespace nda {
41
43 // Forward declarations.
44 template <int Rank, uint64_t StaticExtents, uint64_t StrideOrder, layout_prop_e LayoutProp>
45 class idx_map;
47
48} // namespace nda
49
50namespace nda::slice_static {
51
57 // Notations for this file
58 //
59 // N: rank of the original idx_map
60 // P: rank of the resulting idx_map
61 // Q: number of arguments given when creating the slice
62 //
63 // n: 0, ..., N - 1: indexes the dimensions of the original idx_map
64 // p: 0, ..., P - 1: indexes the dimensions of the resulting idx_map
65 // q: 0, ..., Q - 1: indexes the arguments
66 //
67 // p are the indices of the non-long arguments (after ellipsis expansion)
68 //
69 // N - Q + e is the length of the ellipsis, i.e. the number of the dimensions covered by the
70 // ellipsis, where e = +1 if there is an ellipsis and e = 0 otherwise.
71 //
72 // Let's assume the original idx_map is of rank N = 6 and the following slice arguments are given:
73 //
74 // Given args = long, long , ellipsis, long, range
75 // q 0 1 2 3 4
76 //
77 // Here the position of the ellipsis is 2 and its length is N - Q + 1 = 6 - 5 + 1 = 2.
78 // After expanding the ellipsis, the arguments are:
79 //
80 // Expanded args = long, long, range::all, range::all, long, range
81 // q 0 1 2 2 3 4
82 // n 0 1 2 3 4 5
83 // p - - 0 1 - 2
84 //
85 // Now we can define the following (compile-time) maps:
86 //
87 // - q(n): n -> q
88 // - n(p): p -> n
89 // - q(p): p -> q
90
91 namespace detail {
92
93 // Helper function to get the position of an nda::ellipsis object in a given parameter pack.
94 template <typename... Args, size_t... Is>
95 constexpr int ellipsis_position_impl(std::index_sequence<Is...>) {
96 // we know that there is at most one ellipsis
97 int r = ((std::is_same_v<Args, ellipsis> ? int(Is) + 1 : 0) + ...);
98 return (r == 0 ? 128 : r - 1);
99 }
100
101 // Get the position of an nda::ellipsis object in a given parameter pack (returns 128 if there is no nda::ellipsis).
102 template <typename... Args>
103 constexpr int ellipsis_position() {
104 return detail::ellipsis_position_impl<Args...>(std::make_index_sequence<sizeof...(Args)>{});
105 }
106
107 // Map the original dimension n to the argument index q after ellipsis expansion.
108 constexpr int q_of_n(int n, int e_pos, int e_len) {
109 // n appears before the ellipsis or no ellipsis is present
110 if (n < e_pos) return n;
111
112 // n is covered by the ellipsis
113 if (n < (e_pos + e_len)) return e_pos;
114
115 // n appears after the ellipsis
116 return n - (e_len - 1);
117 }
118
119 // Determine how the dimensions of the sliced index map are mapped to the dimensions of the original index map.
120 template <int N, int P, size_t Q>
121 constexpr std::array<int, P> n_of_p_map(std::array<bool, Q> const &args_is_range, int e_pos, int e_len) {
122 auto result = stdutil::make_initialized_array<P>(0);
123 // loop over all dimensions of the original index map
124 for (int n = 0, p = 0; n < N; ++n) {
125 // for each dimension n, determine the corresponding argument index q after ellipsis expansion
126 int q = q_of_n(n, e_pos, e_len);
127 // if q is a range, map the current p to the current n and increment p
128 if (args_is_range[q]) result[p++] = n;
129 }
130 return result;
131 }
132
133 // Determine how the dimensions of the sliced index map are mapped to the arguments after ellipsis expansion.
134 template <int N, int P, size_t Q>
135 constexpr std::array<int, P> q_of_p_map(std::array<bool, Q> const &args_is_range, int e_pos, int e_len) {
136 auto result = stdutil::make_initialized_array<P>(0);
137 // loop over all dimensions of the original index map
138 int p = 0;
139 for (int n = 0; n < N; ++n) {
140 // for each dimension n, determine the corresponding argument index q after ellipsis expansion
141 int q = q_of_n(n, e_pos, e_len);
142 // if q is a range, map the current p to the current q and increment p
143 if (args_is_range[q]) result[p++] = q;
144 }
145 return result;
146 }
147
148 // Determine the pseudo inverse map p(n): n -> p. If an n has no corresponding p, the value is -1.
149 template <size_t N, size_t P>
150 constexpr std::array<int, N> p_of_n_map(std::array<int, P> const &n_of_p) {
151 auto result = stdutil::make_initialized_array<N>(-1);
152 for (size_t p = 0; p < P; ++p) result[n_of_p[p]] = p;
153 return result;
154 }
155
156 // Determine the stride order of the sliced index map.
157 template <size_t P, size_t N>
158 constexpr std::array<int, P> slice_stride_order(std::array<int, N> const &orig_stride_order, std::array<int, P> const &n_of_p) {
159 auto result = stdutil::make_initialized_array<P>(0);
160 auto p_of_n = p_of_n_map<N>(n_of_p);
161 for (int i = 0, ip = 0; i < N; ++i) {
162 // n traverses the original stride order, slowest first
163 int n = orig_stride_order[i];
164 // get the corresponding dimension in the sliced index map
165 int p = p_of_n[n];
166 // if n maps to a p, add p to the stride order of the sliced index map
167 if (p != -1) result[ip++] = p;
168 }
169 return result;
170 }
171
172 // Determine the compile-time layout properties of the sliced index map.
173 template <size_t Q, size_t N>
174 constexpr layout_prop_e slice_layout_prop(int P, bool has_only_rangeall_and_long, std::array<bool, Q> const &args_is_rangeall,
175 std::array<int, N> const &orig_stride_order, layout_prop_e orig_layout_prop, int e_pos, int e_len) {
176 // if there are ranges in the arguments
177 if (not has_only_rangeall_and_long) {
178 if (P == 1)
179 // rank one is always at least strided_1d
180 return layout_prop_e::strided_1d;
181 else
182 // otherwise we don't know
183 return layout_prop_e::none;
184 }
185
186 // count the number of nda::range::all_t blocks in the argument list, e.g. long, range::all, range::all, long,
187 // range::all -> 2 blocks
188 int n_rangeall_blocks = 0;
189 bool previous_arg_is_rangeall = false;
190 for (int i = 0; i < N; ++i) {
191 int q = q_of_n(orig_stride_order[i], e_pos, e_len);
192 bool arg_is_rangeall = args_is_rangeall[q];
193 if (arg_is_rangeall and (not previous_arg_is_rangeall)) ++n_rangeall_blocks;
194 previous_arg_is_rangeall = arg_is_rangeall;
195 }
196 bool rangeall_are_grouped_in_memory = (n_rangeall_blocks <= 1);
197 bool last_is_rangeall = previous_arg_is_rangeall;
198
199 // return the proper layout_prop_e
200 if (has_contiguous(orig_layout_prop) and rangeall_are_grouped_in_memory and last_is_rangeall) return layout_prop_e::contiguous;
201 if (has_strided_1d(orig_layout_prop) and rangeall_are_grouped_in_memory) return layout_prop_e::strided_1d;
202 if (has_smallest_stride_is_one(orig_layout_prop) and last_is_rangeall) return layout_prop_e::smallest_stride_is_one;
203
204 return layout_prop_e::none;
205 }
206
207 // Get the contribution to the flat index of the first element of the slice from a single dimension if the argument
208 // is a long.
209 FORCEINLINE long get_offset(long idx, long stride) { return idx * stride; }
210
211 // Get the contribution to the flat index of the first element of the slice from a single dimension if the argument
212 // is a range.
213 FORCEINLINE long get_offset(range const &rg, long stride) { return rg.first() * stride; }
214
215 // Get the contribution to the flat index of the first element of the slice from a single dimension if the argument
216 // is a range::all_t or covered by an nda::ellipsis.
217 FORCEINLINE long get_offset(range::all_t, long) { return 0; }
218
219 // Get the length of the slice for a single dimension if the argument is a range.
220 FORCEINLINE long get_length(range const &rg, long original_len) {
221 auto last = (rg.last() == -1 and rg.step() > 0) ? original_len : rg.last();
222 return range(rg.first(), last, rg.step()).size();
223 }
224
225 // Get the length of the slice for a single dimension if the argument is a range::all_t or covered by an
226 // nda::ellipsis.
227 FORCEINLINE long get_length(range::all_t, long original_len) { return original_len; }
228
229 // Get the stride of the slice for a single dimension if the argument is a range.
230 FORCEINLINE long get_stride(range const &rg, long original_str) { return original_str * rg.step(); }
231
232 // Get the stride of the slice for a single dimension if the argument is a range::all_t or covered by an
233 // nda::ellipsis..
234 FORCEINLINE long get_stride(range::all_t, long original_str) { return original_str; }
235
236 // Helper function to determine the resulting index map when taking a slice of a given index map.
237 template <size_t... Ps, size_t... Ns, typename IdxMap, typename... Args>
238 FORCEINLINE auto slice_idx_map_impl(std::index_sequence<Ps...>, std::index_sequence<Ns...>, IdxMap const &idxm, Args const &...args) {
239 // optional bounds check
240#ifdef NDA_ENFORCE_BOUNDCHECK
241 nda::assert_in_bounds(idxm.rank(), idxm.lengths().data(), args...);
242#endif
243 // compile time check
244 static_assert(IdxMap::rank() == sizeof...(Ns), "Internal error in slice_idx_map_impl: Rank and length of index sequence do not match");
245
246 // rank of original and resulting idx_map, number of arguments, length of ellipsis, position of ellipsis in the
247 // argument list
248 static constexpr int N = sizeof...(Ns);
249 static constexpr int P = sizeof...(Ps);
250 static constexpr int Q = sizeof...(Args);
251 static constexpr int e_len = N - Q + 1;
252 static constexpr int e_pos = ellipsis_position<Args...>();
253
254 // is i-th argument a range/range::all_t/ellipsis?
255 static constexpr std::array<bool, Q> args_is_range{(std::is_same_v<Args, range> or std::is_base_of_v<range::all_t, Args>)...};
256
257 // is i-th argument a range::all_t/ellipsis?
258 static constexpr std::array<bool, Q> args_is_rangeall{(std::is_base_of_v<range::all_t, Args>)...};
259
260 // mapping between the dimensions of the resulting index map and the dimensions of the original index map
261 static constexpr std::array<int, P> n_of_p = n_of_p_map<N, P>(args_is_range, e_pos, e_len);
262
263 // mapping between the dimensions of the resulting index map and the given arguments
264 static constexpr std::array<int, P> q_of_p = q_of_p_map<N, P>(args_is_range, e_pos, e_len);
265
266 // more compile time checks
267 static_assert(n_of_p.size() == P, "Internal error in slice_idx_map_impl: Size of the mapping n_of_p and P do not match");
268 static_assert(q_of_p.size() == P, "Internal error in slice_idx_map_impl: Size of the mapping q_of_p and P do not match");
269
270 // create tuple of arguments
271 auto argstie = std::tie(args...);
272
273 // flat index of the first element of the slice
274 long offset = (get_offset(std::get<q_of_n(Ns, e_pos, e_len)>(argstie), std::get<Ns>(idxm.strides())) + ... + 0);
275
276 // shape and strides of the resulting index map
277 std::array<long, P> len{get_length(std::get<q_of_p[Ps]>(argstie), std::get<n_of_p[Ps]>(idxm.lengths()))...};
278 std::array<long, P> str{get_stride(std::get<q_of_p[Ps]>(argstie), std::get<n_of_p[Ps]>(idxm.strides()))...};
279
280 // static extents of the resulting index map: 0 (= dynamic extent) if the corresponding argument is not a
281 // range/range::all_t/ellipsis
282 static constexpr std::array<int, P> new_static_extents{(args_is_rangeall[q_of_p[Ps]] ? IdxMap::static_extents[n_of_p[Ps]] : 0)...};
283
284 // stride order of the resulting index map
285 static constexpr std::array<int, P> new_stride_order = slice_stride_order(IdxMap::stride_order, n_of_p);
286
287 // compile-time layout properties of the resulting index map
288 static constexpr bool has_only_rangeall_and_long = ((std::is_constructible_v<long, Args> or std::is_base_of_v<range::all_t, Args>)and...);
289 static constexpr layout_prop_e li =
290 slice_layout_prop(P, has_only_rangeall_and_long, args_is_rangeall, IdxMap::stride_order, IdxMap::layout_prop, e_pos, e_len);
291
292 // return the resulting index map
293 static constexpr uint64_t new_static_extents_encoded = encode(new_static_extents);
294 static constexpr uint64_t new_stride_order_encoded = encode(new_stride_order);
295 return std::make_pair(offset, idx_map<P, new_static_extents_encoded, new_stride_order_encoded, li>{len, str});
296 }
297
298 } // namespace detail
299
324 template <int R, uint64_t SE, uint64_t SO, layout_prop_e LP, typename... Args>
325 FORCEINLINE decltype(auto) slice_idx_map(idx_map<R, SE, SO, LP> const &idxm, Args const &...args) {
326 // number of ellipsis and long arguments
327 static constexpr int n_args_ellipsis = ((std::is_same_v<Args, ellipsis>)+...);
328 static constexpr int n_args_long = (std::is_constructible_v<long, Args> + ...);
329
330 // compile time checks
331 static_assert(n_args_ellipsis <= 1, "Error in nda::slice_static::slice_idx_map: At most one ellipsis argument is allowed");
332 static_assert((sizeof...(Args) <= R + 1), "Error in nda::slice_static::slice_idx_map: Incorrect number of arguments");
333 static_assert((n_args_ellipsis == 1) or (sizeof...(Args) == R), "Error in nda::slice_static::slice_idx_map: Incorrect number of arguments");
334
335 return detail::slice_idx_map_impl(std::make_index_sequence<R - n_args_long>{}, std::make_index_sequence<R>{}, idxm, args...);
336 }
337
340} // namespace nda::slice_static
Provides utility functions for std::array.
Provides a way to check the bounds when accessing elements/slices of an array or a view.
Layout that specifies how to map multi-dimensional indices to a linear/flat index.
Definition idx_map.hpp:103
__inline__ decltype(auto) slice_idx_map(idx_map< R, SE, SO, LP > const &idxm, Args const &...args)
Determine the resulting nda::idx_map when taking a slice of a given nda::idx_map.
constexpr bool has_contiguous(layout_prop_e lp)
Checks if a layout property has the contiguous property.
Definition traits.hpp:282
constexpr bool has_strided_1d(layout_prop_e lp)
Checks if a layout property has the strided_1d property.
Definition traits.hpp:266
void assert_in_bounds(int rank, long const *lengths, Args const &...args)
Check if the given indices/arguments are within the bounds of an array/view.
constexpr bool has_smallest_stride_is_one(layout_prop_e lp)
Checks if a layout property has the smallest_stride_is_one property.
Definition traits.hpp:274
layout_prop_e
Compile-time guarantees of the memory layout of an array/view.
Definition traits.hpp:222
constexpr uint64_t encode(std::array< int, N > const &a)
Encode a std::array<int, N> in a uint64_t.
constexpr std::array< T, R > make_initialized_array(T v)
Create a new std::array object initialized with a specific value.
Definition array.hpp:168
Macros used in the nda library.
Includes the itertools header and provides some additional utilities.
Provides type traits for the nda library.