TRIQS/nda 1.3.0
Multi-dimensional array library for C++
Loading...
Searching...
No Matches
for_each.hpp
Go to the documentation of this file.
1// Copyright (c) 2018 Commissariat à l'énergie atomique et aux énergies alternatives (CEA)
2// Copyright (c) 2018 Centre national de la recherche scientifique (CNRS)
3// Copyright (c) 2018-2024 Simons Foundation
4//
5// Licensed under the Apache License, Version 2.0 (the "License");
6// you may not use this file except in compliance with the License.
7// You may obtain a copy of the License at
8//
9// http://www.apache.org/licenses/LICENSE-2.0.txt
10//
11// Unless required by applicable law or agreed to in writing, software
12// distributed under the License is distributed on an "AS IS" BASIS,
13// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14// See the License for the specific language governing permissions and
15// limitations under the License.
16//
17// Authors: Thomas Hahn, Olivier Parcollet, Nils Wentzell
18
24#pragma once
25
26#include "./permutation.hpp"
27#include "../stdutil/array.hpp"
28
29#include <array>
30#include <concepts>
31#include <cstdint>
32#include <utility>
33
34namespace nda {
35
41 namespace detail {
42
43 // Get the i-th slowest moving dimension from a given encoded stride order.
44 template <int R>
45 constexpr int index_from_stride_order(uint64_t stride_order, int i) {
46 if (stride_order == 0) return i; // default C-order
47 auto stride_order_arr = decode<R>(stride_order); // FIXME C++20
48 return stride_order_arr[i];
49 }
50
51 // Get the extent of an array along its i-th dimension.
52 template <int I, int R, uint64_t StaticExtents, std::integral Int = long>
53 constexpr long get_extent(std::array<Int, R> const &shape) {
54 if constexpr (StaticExtents == 0) {
55 // dynamic extents
56 return shape[I];
57 } else {
58 // full/partial static extents
59 constexpr auto static_extents = decode<R>(StaticExtents); // FIXME C++20
60 if constexpr (static_extents[I] == 0)
61 return shape[I];
62 else
63 return static_extents[I];
64 }
65 }
66
67 // Apply a callable object recursively to all possible index values of a given shape.
68 template <int I, uint64_t StaticExtents, uint64_t StrideOrder, typename F, size_t R, std::integral Int = long>
69 FORCEINLINE void for_each_static_impl(std::array<Int, R> const &shape, std::array<long, R> &idxs, F &f) {
70 if constexpr (I == R) {
71 // end of recursion
72 std::apply(f, idxs);
73 } else {
74 // get the dimension over which to iterate and its extent
75 static constexpr int J = index_from_stride_order<R>(StrideOrder, I);
76 const long imax = get_extent<J, R, StaticExtents>(shape);
77
78 // loop over all indices of the current dimension
79 for (long i = 0; i < imax; ++i) {
80 // recursive call for the next dimension
81 for_each_static_impl<I + 1, StaticExtents, StrideOrder>(shape, idxs, f);
82 ++idxs[J];
83 }
84 idxs[J] = 0;
85 }
86 }
87
88 } // namespace detail
89
107 template <uint64_t StaticExtents, uint64_t StrideOrder, typename F, auto R, std::integral Int = long>
108 FORCEINLINE void for_each_static(std::array<Int, R> const &shape, F &&f) { // NOLINT (we do not want to forward here)
110 detail::for_each_static_impl<0, StaticExtents, StrideOrder>(shape, idxs, f);
111 }
112
128 template <typename F, auto R, std::integral Int = long>
129 FORCEINLINE void for_each(std::array<Int, R> const &shape, F &&f) { // NOLINT (we do not want to forward here)
131 detail::for_each_static_impl<0, 0, 0>(shape, idxs, f);
132 }
133
136} // namespace nda
Provides utility functions for std::array.
constexpr uint64_t static_extents(int i0, Is... is)
Encode the given shape into a single integer using the nda::encode function.
__inline__ void for_each(std::array< Int, R > const &shape, F &&f)
Loop over all possible index values of a given shape and apply a function to them.
Definition for_each.hpp:129
__inline__ void for_each_static(std::array< Int, R > const &shape, F &&f)
Loop over all possible index values of a given shape and apply a function to them.
Definition for_each.hpp:108
constexpr std::array< int, N > decode(uint64_t binary_representation)
Decode a uint64_t into a std::array<int, N>.
constexpr std::array< T, R > make_initialized_array(T v)
Create a new std::array object initialized with a specific value.
Definition array.hpp:168
Provides utilities to work with permutations and to compactly encode/decode std::array objects.