TRIQS/nda 1.3.0
Multi-dimensional array library for C++
Loading...
Searching...
No Matches
idx_map.hpp
Go to the documentation of this file.
1// Copyright (c) 2018 Commissariat à l'énergie atomique et aux énergies alternatives (CEA)
2// Copyright (c) 2018 Centre national de la recherche scientifique (CNRS)
3// Copyright (c) 2018-2024 Simons Foundation
4//
5// Licensed under the Apache License, Version 2.0 (the "License");
6// you may not use this file except in compliance with the License.
7// You may obtain a copy of the License at
8//
9// http://www.apache.org/licenses/LICENSE-2.0.txt
10//
11// Unless required by applicable law or agreed to in writing, software
12// distributed under the License is distributed on an "AS IS" BASIS,
13// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14// See the License for the specific language governing permissions and
15// limitations under the License.
16//
17// Authors: Thomas Hahn, Dominik Kiese, Henri Menke, Olivier Parcollet, Nils Wentzell
18
24#pragma once
25
26#include "./permutation.hpp"
27#include "./range.hpp"
28#include "./slice_static.hpp"
29#include "../macros.hpp"
30#include "../traits.hpp"
31
32#include <algorithm>
33#include <array>
34#include <concepts>
35#include <cstdint>
36#include <cstdlib>
37#include <functional>
38#include <numeric>
39#include <stdexcept>
40#include <type_traits>
41#include <utility>
42#include <vector>
43
44namespace nda {
45
56 template <int Rank>
58
64 template <int Rank>
66
102 template <int Rank, uint64_t StaticExtents, uint64_t StrideOrder, layout_prop_e LayoutProp>
103 class idx_map {
104 static_assert(Rank < 16, "Error in nda::idx_map: Rank must be < 16");
105 static_assert((StrideOrder != 0) or (Rank == 1), "Error in nda::idx_map: StrideOrder can only be zero for 1D arrays");
106
107 // Extents of all dimensions (the shape of the map).
108 std::array<long, Rank> len{};
109
110 // Strides of all dimensions.
111 std::array<long, Rank> str{};
112
113 public:
115 static constexpr uint64_t static_extents_encoded = StaticExtents;
116
118 static constexpr std::array<int, Rank> static_extents = decode<Rank>(StaticExtents);
119
121 static constexpr std::array<int, Rank> stride_order = (StrideOrder == 0 ? permutations::identity<Rank>() : decode<Rank>(StrideOrder));
122
124 static constexpr uint64_t stride_order_encoded = encode(stride_order);
125
127 static constexpr layout_prop_e layout_prop = LayoutProp;
128
131
133 template <typename T>
134 static constexpr int argument_is_allowed_for_call = std::is_constructible_v<long, T>;
135
137 template <typename T>
139 std::is_same_v<range, T> or std::is_same_v<range::all_t, T> or std::is_same_v<ellipsis, T> or std::is_constructible_v<long, T>;
140
141 protected:
143 static constexpr int n_dynamic_extents = []() {
144 int r = 0;
145 for (int u = 0; u < Rank; ++u) r += (static_extents[u] == 0 ? 1 : 0);
146 return r;
147 }();
148
149 public:
154 static constexpr int rank() noexcept { return Rank; }
155
160 [[nodiscard]] long size() const noexcept { return std::accumulate(len.cbegin(), len.cend(), 1L, std::multiplies<>{}); }
161
166 static constexpr long ce_size() noexcept {
167 if constexpr (n_dynamic_extents != 0) {
168 return 0;
169 } else {
170 return std::accumulate(static_extents.cbegin(), static_extents.cend(), 1L, std::multiplies<>{});
171 }
172 }
173
178 [[nodiscard]] std::array<long, Rank> const &lengths() const noexcept { return len; }
179
184 [[nodiscard]] std::array<long, Rank> const &strides() const noexcept { return str; }
185
190 [[nodiscard]] long min_stride() const noexcept { return str[stride_order[Rank - 1]]; }
191
201 [[nodiscard]] bool is_contiguous() const noexcept {
202 auto s = size();
203 if (s == 0) return true;
204 return (std::abs(str[stride_order[0]] * len[stride_order[0]]) == s);
205 }
206
211 [[nodiscard]] bool has_positive_strides() const noexcept { return (*std::min_element(str.cbegin(), str.cend()) >= 0); }
212
223 [[nodiscard]] bool is_strided_1d() const noexcept {
224 auto s = size();
225 if (s == 0) return true;
226 int i = Rank - 1;
227 while (len[stride_order[i]] == 1 and i > 0) --i;
228 return (std::abs(str[stride_order[0]] * len[stride_order[0]]) == s * std::abs(str[stride_order[i]]));
229 }
230
235 static constexpr bool is_stride_order_C() { return (stride_order == permutations::identity<Rank>()); }
236
242
251 template <std::integral Int>
252 [[nodiscard]] static bool is_stride_order_valid(Int *lenptr, Int *strptr) {
253 auto dims_to_check = std::vector<int>{};
254 dims_to_check.reserve(Rank);
255 for (auto dim : stride_order)
256 if (lenptr[dim] > 1) dims_to_check.push_back(dim);
257 for (int n = 1; n < dims_to_check.size(); ++n)
258 if (std::abs(strptr[dims_to_check[n - 1]]) < std::abs(strptr[dims_to_check[n]])) return false;
259 return true;
260 }
261
267 [[nodiscard]] bool is_stride_order_valid() const { return is_stride_order_valid(len.data(), str.data()); }
268
269 private:
270 // Compute contiguous strides from the shape.
271 void compute_strides_contiguous() {
272 long s = 1;
273 for (int v = rank() - 1; v >= 0; --v) {
274 int u = stride_order[v];
275 str[u] = s;
276 s *= len[u];
277 }
278 ENSURES(s == size());
279 }
280
281 // Check that the static extents and the shape are compatible.
282 void assert_static_extents_and_len_are_compatible() const {
283#ifdef NDA_ENFORCE_BOUNDCHECK
284 if constexpr (n_dynamic_extents != Rank) {
285#ifndef NDEBUG
286 for (int u = 0; u < Rank; ++u)
287 if (static_extents[u] != 0) EXPECTS(static_extents[u] == len[u]);
288#endif
289 }
290#endif
291 }
292
293 // Should we check the stride order when constructing an idx_map from a shape and strides?
294#ifdef NDA_DEBUG
295 static constexpr bool check_stride_order = true;
296#else
297 static constexpr bool check_stride_order = false;
298#endif
299
300 // Combine static and dynamic extents.
301 static std::array<long, Rank> merge_static_and_dynamic_extents(std::array<long, n_dynamic_extents> const &dynamic_extents) {
302 std::array<long, Rank> extents;
303 for (int u = 0, v = 0; u < Rank; ++u) extents[u] = (static_extents[u] == 0 ? dynamic_extents[v++] : static_extents[u]);
304 return extents;
305 }
306
307 // FIXME ADD A CHECK layout_prop_e ... compare to stride and
308
309 public:
319 if constexpr (n_dynamic_extents == 0) {
320 for (int u = 0; u < Rank; ++u) len[u] = static_extents[u];
321 compute_strides_contiguous();
322 }
323 }
324
331 template <layout_prop_e LP>
332 idx_map(idx_map<Rank, StaticExtents, StrideOrder, LP> const &idxm) noexcept : len(idxm.lengths()), str(idxm.strides()) {
333 // check strides and stride order of the constructed map
334 EXPECTS(is_stride_order_valid());
335
336 // check that the layout properties are compatible
337 if constexpr (not layout_property_compatible(LP, layout_prop)) {
338 if constexpr (has_contiguous(layout_prop)) {
339 EXPECTS_WITH_MESSAGE(idxm.is_contiguous(), "Error in nda::idx_map: Constructing a contiguous from a non-contiguous layout");
340 }
341 if constexpr (has_strided_1d(layout_prop)) {
342 EXPECTS_WITH_MESSAGE(idxm.is_strided_1d(), "Error in nda::idx_map: Constructing a strided_1d from a non-strided_1d layout");
343 }
344 }
345 }
346
354 template <uint64_t SE, layout_prop_e LP>
355 idx_map(idx_map<Rank, SE, StrideOrder, LP> const &idxm) noexcept(false) : len(idxm.lengths()), str(idxm.strides()) {
356 // check strides and stride order
357 EXPECTS(is_stride_order_valid());
358
359 // check that the layout properties are compatible
360 if constexpr (not layout_property_compatible(LP, LayoutProp)) {
361 if constexpr (has_contiguous(LayoutProp)) {
362 EXPECTS_WITH_MESSAGE(idxm.is_contiguous(), "Error in nda::idx_map: Constructing a contiguous from a non-contiguous layout");
363 }
364 if constexpr (has_strided_1d(LayoutProp)) {
365 EXPECTS_WITH_MESSAGE(idxm.is_strided_1d(), "Error in nda::idx_map: Constructing a strided_1d from a non-strided_1d layout");
366 }
367 }
368
369 // check that the static extents and the shape are compatible
370 assert_static_extents_and_len_are_compatible();
371 }
372
379 idx_map(std::array<long, Rank> const &shape, // NOLINT (only throws if check_stride_order is true)
380 std::array<long, Rank> const &strides) noexcept(!check_stride_order)
381 : len(shape), str(strides) {
382 EXPECTS(std::all_of(shape.cbegin(), shape.cend(), [](auto const &i) { return i >= 0; }));
383 if constexpr (check_stride_order) {
384 if (not is_stride_order_valid()) throw std::runtime_error("Error in nda::idx_map: Incompatible strides, shape and stride order");
385 }
386 }
387
394 template <std::integral Int = long>
395 idx_map(std::array<Int, Rank> const &shape) noexcept : len(stdutil::make_std_array<long>(shape)) {
396 EXPECTS(std::all_of(shape.cbegin(), shape.cend(), [](auto const &i) { return i >= 0; }));
397 assert_static_extents_and_len_are_compatible();
398 compute_strides_contiguous();
399 }
400
409 idx_map(std::array<long, n_dynamic_extents> const &shape) noexcept
410 requires((n_dynamic_extents != Rank) and (n_dynamic_extents != 0))
411 : idx_map(merge_static_and_dynamic_extents(shape)) {}
412
422 template <uint64_t SE, uint64_t SO, layout_prop_e LP>
424 requires(stride_order_encoded != SO)
425 {
426 static_assert((stride_order_encoded == SO), "Error in nda::idx_map: Incompatible stride orders");
427 }
428
434 template <int R>
435 idx_map(std::array<long, R> const &)
436 requires(R != Rank)
437 {
438 static_assert(R == Rank, "Error in nda::idx_map: Incompatible ranks");
439 }
440
442 idx_map(idx_map const &) = default;
443
445 idx_map(idx_map &&) = default;
446
448 idx_map &operator=(idx_map const &) = default;
449
451 idx_map &operator=(idx_map &&) = default;
452
453 private:
454 // Get the contribution to the linear index in case of an nda::ellipsis argument.
455 template <bool skip_stride, auto I>
456 [[nodiscard]] FORCEINLINE long myget(ellipsis) const noexcept {
457 // nda::ellipsis are skipped and do not contribute to the linear index
458 return 0;
459 }
460
461 // Get the contribution to the linear index in case of a long argument.
462 template <bool skip_stride, auto I>
463 [[nodiscard]] FORCEINLINE long myget(long arg) const noexcept {
464 if constexpr (skip_stride and (I == stride_order[Rank - 1])) {
465 // optimize for the case when the fastest varying dimension is contiguous in memory
466 return arg;
467 } else {
468 // otherwise multiply the argument by the stride of the current dimension
469 return arg * std::get<I>(str);
470 }
471 }
472
473 // Is the smallest stride equal to one, i.e. is the fastest varying dimension contiguous in memory?
474 static constexpr bool smallest_stride_is_one = has_smallest_stride_is_one(layout_prop);
475
476 // Implementation of the function call operator that takes a multi-dimensional index and maps it to a linear index.
477 template <typename... Args, size_t... Is>
478 [[nodiscard]] FORCEINLINE long call_impl(std::index_sequence<Is...>, Args... args) const noexcept {
479 // set e_pos to the position of the ellipsis, otherwise to -1
480 static constexpr int e_pos = ((std::is_same_v<Args, ellipsis> ? int(Is) + 1 : 0) + ...) - 1;
481
482 if constexpr (e_pos == -1) {
483 // no ellipsis present
484 if constexpr (smallest_stride_is_one) {
485 // optimize for the case that the fastest varying dimension is contiguous in memory
486 return (myget<true, Is>(static_cast<long>(args)) + ...);
487 } else {
488 // arbitrary layouts
489 return ((args * std::get<Is>(str)) + ...);
490 }
491 } else {
492 // empty ellipsis is present and needs to be skipped
493 return (myget<smallest_stride_is_one, (Is < e_pos ? Is : Is - 1)>(args) + ...);
494 }
495 }
496
497 public:
516 template <typename... Args>
517 FORCEINLINE long operator()(Args const &...args) const
518#ifdef NDA_ENFORCE_BOUNDCHECK
519 noexcept(false) {
520 assert_in_bounds(rank(), len.data(), args...);
521#else
522 noexcept(true) {
523#endif
524 return call_impl(std::make_index_sequence<sizeof...(Args)>{}, args...);
525 }
526
548 std::array<long, Rank> to_idx(long lin_idx) const {
549 // compute residues starting from slowest index
550 std::array<long, Rank> residues;
551 residues[0] = lin_idx;
552 for (auto i : range(1, Rank)) { residues[i] = residues[i - 1] % str[stride_order[i - 1]]; }
553
554 // convert residues to indices, ordered from slowest to fastest
555 std::array<long, Rank> idx;
556 idx[Rank - 1] = residues[Rank - 1] / str[stride_order[Rank - 1]];
557 for (auto i : range(Rank - 2, -1, -1)) { idx[i] = (residues[i] - residues[i + 1]) / str[stride_order[i]]; }
558
559 // reorder indices according to stride order
561 }
562
574 template <typename... Args>
575 auto slice(Args const &...args) const {
576 return slice_static::slice_idx_map(*this, args...);
577 }
578
589 template <int R, uint64_t SE, uint64_t SO, layout_prop_e LP>
590 bool operator==(idx_map<R, SE, SO, LP> const &rhs) const {
591 return (Rank == R and len == rhs.lengths() and str == rhs.strides());
592 }
593
608 template <uint64_t Permutation>
609 auto transpose() const {
610 // Makes a new transposed idx_map with permutation P such that
611 // denoting here A = this, A' = P A = returned_value
612 // A'(i_k) = A(i_{P[k]})
613 //
614 // Note that this convention is the correct one to have a (left) action of the symmetric group on
615 // a array and it may not be completely obvious.
616 // Proof
617 // let's operate with P then Q, and denote A'' = Q A'. We want to show that A'' = (QP) A
618 // A'(i_k) = A(i_{P[k]})
619 // A''(j_k) = A'(j_{Q[k]})
620 // then i_k = j_{Q[k]} and A''(j_k) = A(i_{P[k]}) = A(j_{Q[P[k]]}) = A(j_{(QP)[k]}), q.e.d
621 //
622 // NB test will test this composition
623 // Denoting this as A, an indexmap, calling it returns the linear index given by
624 //
625 // A(i_k) = sum_k i_k * S[k] (1)
626 //
627 // where S[k] denotes the strides.
628 //
629 // 1- S' : strides of A'
630 // A'(i_k) = sum_k i_{P[k]} * S[k] = sum_k i_k * S[P{^-1}[k]]
631 // so
632 // S'[k] = S[P{^-1}[k]] (2)
633 // i.e. apply (inverse(P), S) or apply_inverse directly.
634 //
635 // 2- L' : lengths of A'
636 // if L[k] is the k-th length, then because of the definition of A', i.e. A'(i_k) = A(i_{P[k]})
637 // i_q in the lhs A is at position q' such that P[q'] = q (A'(i0 i1 i2...) = A( i_P0 i_P1 i_P2....)
638 // hence L'[q] = L[q'] = L[P^{-1}[q]]
639 // same for static length
640 //
641 // 3- stride_order: denoted in this paragraph as Q (and Q' for A').
642 // by definition Q is a permutation such that Q[0] is the slowest index, Q[Rank -1] the fastest
643 // hence S[Q[k]] is a strictly decreasing sequence (as checked by strides_compatible_to_stride_order)
644 // we want therefore Q' the permutation that will sort the S', i.e.
645 // S'[Q'[k]] = S[Q[k]]
646 // using (2), we have S[P{^-1}[Q'[k]]] = S[Q[k]]
647 // so the permutation Q' is such that P{^-1}Q' = Q or Q' = PQ (as permutation product/composition).
648 // NB : Q and P are permutations, so the operation must be a composition, not an apply (apply applies
649 // a P to any set, like L, S, not only a permutation) even though they are all std::array in the code ...
650 //
651 static constexpr std::array<int, Rank> permu = decode<Rank>(Permutation);
652 static constexpr std::array<int, Rank> new_stride_order = permutations::compose(permu, stride_order);
653 static constexpr std::array<int, Rank> new_static_extents = permutations::apply_inverse(permu, static_extents);
654
655 return idx_map<Rank, encode(new_static_extents), encode(new_stride_order), LayoutProp>{permutations::apply_inverse(permu, lengths()),
657 }
658 };
659
662} // namespace nda
Layout that specifies how to map multi-dimensional indices to a linear/flat index.
Definition idx_map.hpp:103
__inline__ long operator()(Args const &...args) const noexcept(true)
Function call operator to map a given multi-dimensional index to a linear index.
Definition idx_map.hpp:517
static constexpr std::array< int, Rank > stride_order
Decoded stride order.
Definition idx_map.hpp:121
idx_map(idx_map< Rank, StaticExtents, StrideOrder, LP > const &idxm) noexcept
Construct a new map from an existing map with different layout properties.
Definition idx_map.hpp:332
static bool is_stride_order_valid(Int *lenptr, Int *strptr)
Check if a given shape and strides are compatible with the stride order.
Definition idx_map.hpp:252
static constexpr layout_info_t layout_info
Compile-time information about the layout (stride order and layout properties).
Definition idx_map.hpp:130
idx_map(idx_map &&)=default
Default move constructor.
long size() const noexcept
Get the total number of elements.
Definition idx_map.hpp:160
long min_stride() const noexcept
Get the value of the smallest stride (positive or negative).
Definition idx_map.hpp:190
idx_map(std::array< long, n_dynamic_extents > const &shape) noexcept
Construct a new map from an array with its dynamic extents.
Definition idx_map.hpp:409
auto transpose() const
Create a new map by permuting the indices/dimensions of the current map with a given permutation.
Definition idx_map.hpp:609
std::array< long, Rank > to_idx(long lin_idx) const
Calculate the multi-dimensional index from a given linear index.
Definition idx_map.hpp:548
bool is_stride_order_valid() const
Check if the shape and strides of the current map are compatible with its stride order.
Definition idx_map.hpp:267
idx_map(idx_map< Rank, SE, SO, LP > const &)
Construct a new map from an existing map with a different stride order.
Definition idx_map.hpp:423
bool has_positive_strides() const noexcept
Are all strides positive?
Definition idx_map.hpp:211
bool operator==(idx_map< R, SE, SO, LP > const &rhs) const
Equal-to operator for two nda::idx_map objects.
Definition idx_map.hpp:590
static constexpr bool is_stride_order_C()
Is the stride order equal to C-order?
Definition idx_map.hpp:235
static constexpr int n_dynamic_extents
Number of dynamic dimensions/extents.
Definition idx_map.hpp:143
static constexpr uint64_t stride_order_encoded
Encoded stride order.
Definition idx_map.hpp:124
static constexpr int rank() noexcept
Get the rank of the map.
Definition idx_map.hpp:154
static constexpr std::array< int, Rank > static_extents
Decoded static extents.
Definition idx_map.hpp:118
bool is_strided_1d() const noexcept
Is the data strided in memory with a constant stride?
Definition idx_map.hpp:223
static constexpr int argument_is_allowed_for_call_or_slice
Alias template to check if type T can be used to either access a single element or a slice of element...
Definition idx_map.hpp:138
idx_map(std::array< Int, Rank > const &shape) noexcept
Construct a new map from a given shape and with contiguous strides.
Definition idx_map.hpp:395
std::array< long, Rank > const & lengths() const noexcept
Get the extents of all dimensions.
Definition idx_map.hpp:178
auto slice(Args const &...args) const
Get a new nda::idx_map by taking a slice of the current one.
Definition idx_map.hpp:575
static constexpr bool is_stride_order_Fortran()
Is the stride order equal to Fortran-order?
Definition idx_map.hpp:241
idx_map(idx_map const &)=default
Default copy constructor.
static constexpr int argument_is_allowed_for_call
Alias template to check if type T can be used to access a single element.
Definition idx_map.hpp:134
static constexpr long ce_size() noexcept
Get the size known at compile-time.
Definition idx_map.hpp:166
bool is_contiguous() const noexcept
Is the data contiguous in memory?
Definition idx_map.hpp:201
idx_map(idx_map< Rank, SE, StrideOrder, LP > const &idxm) noexcept(false)
Construct a new map from an existing map with different layout properties and different static extent...
Definition idx_map.hpp:355
idx_map(std::array< long, Rank > const &shape, std::array< long, Rank > const &strides) noexcept(!check_stride_order)
Construct a new map from a given shape and strides.
Definition idx_map.hpp:379
idx_map()
Default constructor.
Definition idx_map.hpp:318
std::array< long, Rank > const & strides() const noexcept
Get the strides of all dimensions.
Definition idx_map.hpp:184
idx_map & operator=(idx_map const &)=default
Default copy assignment operator.
static constexpr layout_prop_e layout_prop
Compile-time memory layout properties.
Definition idx_map.hpp:127
static constexpr uint64_t static_extents_encoded
Encoded static extents.
Definition idx_map.hpp:115
idx_map & operator=(idx_map &&)=default
Default move assignment operator.
idx_map(std::array< long, R > const &)
Construct a new map with a shape of a different rank.
Definition idx_map.hpp:435
constexpr uint64_t C_stride_order
C/Row-major stride order.
Definition idx_map.hpp:65
constexpr uint64_t Fortran_stride_order
Fortran/Column-major stride order.
Definition idx_map.hpp:57
__inline__ decltype(auto) slice_idx_map(idx_map< R, SE, SO, LP > const &idxm, Args const &...args)
Determine the resulting nda::idx_map when taking a slice of a given nda::idx_map.
constexpr bool layout_property_compatible(layout_prop_e from, layout_prop_e to)
Checks if two layout properties are compatible with each other.
Definition traits.hpp:237
constexpr bool has_contiguous(layout_prop_e lp)
Checks if a layout property has the contiguous property.
Definition traits.hpp:282
constexpr bool has_strided_1d(layout_prop_e lp)
Checks if a layout property has the strided_1d property.
Definition traits.hpp:266
void assert_in_bounds(int rank, long const *lengths, Args const &...args)
Check if the given indices/arguments are within the bounds of an array/view.
constexpr bool has_smallest_stride_is_one(layout_prop_e lp)
Checks if a layout property has the smallest_stride_is_one property.
Definition traits.hpp:274
layout_prop_e
Compile-time guarantees of the memory layout of an array/view.
Definition traits.hpp:222
constexpr std::array< int, N > decode(uint64_t binary_representation)
Decode a uint64_t into a std::array<int, N>.
constexpr std::array< T, N > apply_inverse(std::array< Int, N > const &p, std::array< T, N > const &a)
Apply the inverse of a permutation to a std::array.
constexpr std::array< int, N > identity()
Get the identity permutation.
constexpr uint64_t encode(std::array< int, N > const &a)
Encode a std::array<int, N> in a uint64_t.
constexpr std::array< int, N > reverse_identity()
Get the reverse identity permutation.
constexpr std::array< Int, N > compose(std::array< Int, N > const &p1, std::array< Int, N > const &p2)
Composition of two permutations.
constexpr std::array< T, R > make_std_array(std::array< U, R > const &a)
Convert a std::array with value type U to a std::array with value type T.
Definition array.hpp:184
Macros used in the nda library.
Provides utilities to work with permutations and to compactly encode/decode std::array objects.
Includes the itertools header and provides some additional utilities.
Provides utilities that determine the resulting nda::idx_map when taking a slice of an nda::idx_map.
Mimics Python's ... syntax.
Definition range.hpp:49
Stores information about the memory layout and the stride order of an array/view.
Definition traits.hpp:295
Provides type traits for the nda library.