TRIQS/nda 1.3.0
Multi-dimensional array library for C++
Loading...
Searching...
No Matches
layout_transforms.hpp
Go to the documentation of this file.
1// Copyright (c) 2019-2023 Simons Foundation
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7// http://www.apache.org/licenses/LICENSE-2.0.txt
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14//
15// Authors: Olivier Parcollet, Nils Wentzell
16
22#pragma once
23
24#include "./concepts.hpp"
25#include "./declarations.hpp"
26#include "./group_indices.hpp"
27#include "./layout/idx_map.hpp"
29#include "./layout/policies.hpp"
30#include "./map.hpp"
31#include "./stdutil/array.hpp"
32#include "./traits.hpp"
33
34#include <array>
35#include <concepts>
36#include <type_traits>
37#include <utility>
38
39#ifndef NDEBUG
40#include <numeric>
41#endif // NDEBUG
42
43namespace nda {
44
69 template <MemoryArray A, typename NewLayoutType>
70 auto map_layout_transform([[maybe_unused]] A &&a, [[maybe_unused]] NewLayoutType const &new_layout) {
71 using A_t = std::remove_reference_t<A>;
72
73 // layout policy of transformed array/view
74 using layout_policy = typename detail::layout_to_policy<NewLayoutType>::type;
75
76 // algebra of transformed array/view
77 static constexpr auto algebra = (NewLayoutType::rank() == get_rank<A> ? get_algebra<A> : 'A');
78
79 if constexpr (is_regular_v<A> and !std::is_reference_v<A>) {
80 // return a transformed array if the given type is an rvalue array
81 using array_t = basic_array<typename A_t::value_type, NewLayoutType::rank(), layout_policy, algebra, typename A_t::container_policy_t>;
82 return array_t{new_layout, std::forward<A>(a).storage()};
83 } else {
84 // otherwise return a transformed view
85 using value_t = std::conditional_t<std::is_const_v<A_t>, const typename A_t::value_type, typename A_t::value_type>;
86 using accessor_policy = typename get_view_t<A>::accessor_policy_t;
87 using owning_policy = typename get_view_t<A>::owning_policy_t;
88 return basic_array_view<value_t, NewLayoutType::rank(), layout_policy, algebra, accessor_policy, owning_policy>{new_layout, a.storage()};
89 }
90 }
91
107 template <MemoryArray A, std::integral Int, auto R>
108 auto reshape(A &&a, std::array<Int, R> const &new_shape) {
109 // check size and contiguity of new shape
110 EXPECTS_WITH_MESSAGE(a.size() == (std::accumulate(new_shape.cbegin(), new_shape.cend(), Int{1}, std::multiplies<>{})),
111 "Error in nda::reshape: New shape has an incorrect number of elements");
112 EXPECTS_WITH_MESSAGE(a.indexmap().is_contiguous(), "Error in nda::reshape: Only contiguous arrays/views are supported");
113
114 // restrict supported layouts (why?)
115 using A_t = std::remove_cvref_t<A>;
116 static_assert(A_t::is_stride_order_C() or A_t::is_stride_order_Fortran() or R == 1,
117 "Error in nda::reshape: Only C or Fortran layouts are supported");
118
119 // prepare new idx_map
120 using layout_t = typename std::decay_t<A>::layout_policy_t::template mapping<R>;
121 return map_layout_transform(std::forward<A>(a), layout_t{stdutil::make_std_array<long>(new_shape)});
122 }
123
135 template <MemoryArray A, std::integral... Ints>
136 auto reshape(A &&a, Ints... is) {
137 return reshape(std::forward<A>(a), std::array<long, sizeof...(Ints)>{static_cast<long>(is)...});
138 }
139
141 template <MemoryArray A, std::integral Int, auto newRank>
142 [[deprecated("Please use reshape(arr, shape) instead")]] auto reshaped_view(A &&a, std::array<Int, newRank> const &new_shape) {
143 return reshape(std::forward<A>(a), new_shape);
144 }
145
155 template <MemoryArray A>
156 auto flatten(A &&a) {
157 return reshape(std::forward<A>(a), std::array{a.size()});
158 }
159
171 template <uint64_t Permutation, MemoryArray A>
173 return map_layout_transform(std::forward<A>(a), a.indexmap().template transpose<Permutation>());
174 }
175
187 template <typename A>
188 auto transpose(A &&a)
190 {
191 if constexpr (MemoryArray<A>) {
193 } else { // expr_call
194 static_assert(std::tuple_size_v<decltype(a.a)> == 1, "Error in nda::transpose: Cannot transpose expr_call with more than one array argument");
195 return map(a.f)(transpose(std::get<0>(std::forward<A>(a).a)));
196 }
197 }
198
211 template <int I, int J, MemoryArray A>
212 auto transposed_view(A &&a)
214 {
215 return permuted_indices_view<encode(permutations::transposition<get_rank<A>>(I, J))>(std::forward<A>(a));
216 }
217
218 // FIXME : use "magnetic" placeholder
242 template <MemoryArray A, typename... IdxGrps>
243 auto group_indices_view(A &&a, IdxGrps...) {
244 return map_layout_transform(std::forward<A>(a), group_indices_layout(a.indexmap(), IdxGrps{}...));
245 }
246
247 namespace detail {
248
249 // Append N fast dimensions to a given stride order.s
250 template <int N, auto R>
251 constexpr std::array<int, R + N> complete_stride_order_with_fast(std::array<int, R> const &order) {
253 for (int i = 0; i < R; ++i) r[i] = order[i];
254 for (int i = 0; i < N; ++i) r[R + i] = R + i;
255 return r;
256 }
257
258 } // namespace detail
259
268 template <int N, typename A>
271 {
272 auto const &lay = a.indexmap();
273 using lay_t = std::decay_t<decltype(lay)>;
274
275 static constexpr uint64_t new_stride_order_encoded = encode(detail::complete_stride_order_with_fast<N>(lay_t::stride_order));
276 static constexpr uint64_t new_static_extents_encoded = encode(stdutil::join(lay_t::static_extents, stdutil::make_initialized_array<N>(0)));
277 using new_lay_t = idx_map<get_rank<A> + N, new_static_extents_encoded, new_stride_order_encoded, lay_t::layout_prop>;
278
279 auto ones_n = stdutil::make_initialized_array<N>(1l);
280 return map_layout_transform(std::forward<A>(a), new_lay_t{stdutil::join(lay.lengths(), ones_n), stdutil::join(lay.strides(), ones_n)});
281 }
282
285} // namespace nda
Provides utility functions for std::array.
A generic view of a multi-dimensional array.
A generic multi-dimensional array.
Layout that specifies how to map multi-dimensional indices to a linear/flat index.
Definition idx_map.hpp:103
Check if a given type satisfies the memory array concept.
Definition concepts.hpp:248
Provides concepts for the nda library.
Provides various convenient aliases and helper functions for nda::basic_array and nda::basic_array_vi...
auto transposed_view(A &&a)
Transpose two indices/dimensions of an nda::basic_array or nda::basic_array_view.
auto reinterpret_add_fast_dims_of_size_one(A &&a)
Add N fast varying dimensions of size 1 to a given nda::basic_array or nda::basic_array_view.
auto permuted_indices_view(A &&a)
Permute the indices/dimensions of an nda::basic_array or nda::basic_array_view.
auto map_layout_transform(A &&a, NewLayoutType const &new_layout)
Transform the memory layout of an nda::basic_array or nda::basic_array_view.
auto flatten(A &&a)
Flatten an nda::basic_array or nda::basic_array_view to a 1-dimensional array/view by reshaping it.
auto transpose(A &&a)
Transpose the memory layout of an nda::MemoryArray or an nda::expr_call.
auto reshape(A &&a, std::array< Int, R > const &new_shape)
Reshape an nda::basic_array or nda::basic_array_view.
auto group_indices_view(A &&a, IdxGrps...)
Create a new nda::basic_array or nda::basic_array_view by grouping indices together of a given array/...
auto reshaped_view(A &&a, std::array< Int, newRank > const &new_shape)
mapped< F > map(F f)
Create a lazy function call expression on arrays/views.
Definition map.hpp:199
constexpr int get_rank
Constexpr variable that specifies the rank of an nda::Array or of a contiguous 1-dimensional range.
Definition traits.hpp:136
std::remove_reference_t< decltype(basic_array_view{std::declval< T >()})> get_view_t
Get the type of the nda::basic_array_view that would be obtained by constructing a view from a given ...
constexpr bool is_regular_or_view_v
Constexpr variable that is true if type A is either a regular array or a view.
Definition traits.hpp:163
Provides functions used in nda::group_indices_view.
auto group_indices_layout(idx_map< Rank, StaticExtents, StrideOrder, LayoutProp > const &idxm, IdxGrps...)
Given an nda::idx_map and a partition of its indices, return a new nda::idx_map with the grouped indi...
constexpr std::array< int, N > transposition(int i, int j)
Get the permutation representing a single given transposition.
constexpr uint64_t encode(std::array< int, N > const &a)
Encode a std::array<int, N> in a uint64_t.
constexpr std::array< int, N > reverse_identity()
Get the reverse identity permutation.
constexpr std::array< T, R1+R2 > join(std::array< T, R1 > const &a1, std::array< T, R2 > const &a2)
Make a new std::array by joining two existing std::array objects.
Definition array.hpp:309
constexpr std::array< T, R > make_initialized_array(T v)
Create a new std::array object initialized with a specific value.
Definition array.hpp:165
constexpr std::array< T, R > make_std_array(std::array< U, R > const &a)
Convert a std::array with value type U to a std::array with value type T.
Definition array.hpp:181
constexpr bool is_instantiation_of_v
Constexpr variable that is true if type T is an instantiation of TMPLT (see nda::is_instantiation_of)...
Definition traits.hpp:59
Provides a class that maps multi-dimensional indices to a linear index and vice versa.
Provides definitions of various layout policies.
Provides lazy function calls on arrays/views.
Provides utilities to work with permutations and to compactly encode/decode std::array objects.
Provides type traits for the nda library.