TRIQS/nda 1.3.0
Multi-dimensional array library for C++
Loading...
Searching...
No Matches
sym_grp.hpp
Go to the documentation of this file.
1// Copyright (c) 2023 Simons Foundation
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7// http://www.apache.org/licenses/LICENSE-2.0.txt
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14//
15// Authors: Dominik Kiese
16
17/**
18 * @file
19 * @brief Provides tools to use symmetries with nda objects.
20 */
21
22#pragma once
23
24#include "./nda.hpp"
25#include "./mpi.hpp"
26
27#include <itertools/omp_chunk.hpp>
28#include <mpi/mpi.hpp>
29
30#include <array>
31#include <concepts>
32#include <cstddef>
33#include <tuple>
34#include <type_traits>
35#include <vector>
36
37namespace nda {
38
39 /**
40 * @addtogroup av_sym
41 * @{
42 */
43
44 /**
45 * @brief A structure to capture combinations of complex conjugation and sign flip operations.
46 */
47 struct operation {
48 /// Boolean value indicating a sign flip operation.
49 bool sgn = false;
50
51 /// Boolean value indicating a complex conjugation operation.
52 bool cc = false;
53
54 /**
55 * @brief Multiplication operator for two operations.
56 *
57 * @details The sign flip (complex conjugation) operation is set to true in the resulting product iff one of the two
58 * (exclusive or!) input operations has the sign flip (complex conjugation) operation set to true.
59 *
60 * @param rhs Right hand side operation.
61 * @return The resulting operation.
62 */
63 operation operator*(operation const &rhs) { return operation{bool(sgn xor rhs.sgn), bool(cc xor rhs.cc)}; }
64
65 /**
66 * @brief Function call operator to apply the operation to a value.
67 *
68 * @tparam T Value type.
69 * @param x Value to which the operation is applied.
70 * @return The value after the operation has been applied.
71 */
72 template <typename T>
73 T operator()(T const &x) const {
74 if (sgn) return cc ? -conj(x) : -x;
75 return cc ? conj(x) : x;
76 }
77 };
78
79 /**
80 * @brief Check if a multi-dimensional index is valid, i.e. not out of bounds, w.r.t. to a given nda::Array object.
81 *
82 * @tparam A nda::Array type.
83 * @param a nda::Array object.
84 * @param idx Multi-dimensional index.
85 * @return True if the index is not out of bounds, false otherwise.
86 */
87 template <Array A>
88 bool is_valid(A const &a, std::array<long, static_cast<std::size_t>(get_rank<A>)> const &idx) {
89 for (auto i = 0; i < get_rank<A>; ++i) {
90 if (not(0 <= idx[i] and idx[i] < a.shape()[i])) { return false; }
91 }
92 return true;
93 }
94
95 /**
96 * @brief Concept defining a symmetry in nda.
97 *
98 * @details A symmetry consists of a callable type that can be called with a multi-dimensional index with the same
99 * rank as a given array type and returns a tuple with a new multi-dimensional index and an nda::operation.
100 *
101 * The returned index corresponds to an element which is related to the element at the input index by the symmetry.
102 *
103 * The returned operation describes how the values of the elements are related, i.e. either via a sign flip, a complex
104 * conjugation, or both.
105 *
106 * @tparam F Callable type.
107 * @tparam A nda::Array type.
108 * @tparam Idx Multi-dimensional index type.
109 */
110 template <typename F, typename A, typename Idx = std::array<long, static_cast<std::size_t>(get_rank<A>)>>
111 concept NdaSymmetry = Array<A> and requires(F f, Idx const &idx) {
112 { f(idx) } -> std::same_as<std::tuple<Idx, operation>>;
113 };
114
115 /**
116 * @brief Concept defining an initializer function.
117 *
118 * @details An initializer function consists of a callable type that can be called with a multi-dimensional index with
119 * the same rank as a given array type and returns a an object of the same type as the value type of the array.
120 *
121 * @tparam F Callable type.
122 * @tparam A nda::Array type.
123 * @tparam Idx Multi-dimensional index type.
124 */
125 template <typename F, typename A, typename Idx = std::array<long, static_cast<std::size_t>(get_rank<A>)>>
126 concept NdaInitFunc = Array<A> and requires(F f, Idx const &idx) {
127 { f(idx) } -> std::same_as<get_value_t<A>>;
128 };
129
130 /**
131 * @brief Class representing a symmetry group.
132 *
133 * @details A symmetry group detects and stores all different symmetry classes associated with a given nda::Array
134 * object.
135 *
136 * A symmetry class is simply a set of all the elements of the array that are related to each other by some symmetry.
137 * The elements in a symmetry class have all the same values except for a possible sign flip or complex conjugation,
138 * i.e. an nda::operation. The symmetry classes form a partition of all the elements of the array.
139 *
140 * @tparam F nda::NdaSymmetry type.
141 * @tparam A nda::Array type.
142 */
143 template <typename F, typename A>
144 requires(Array<A> && NdaSymmetry<F, A>)
145 class sym_grp {
146 public:
147 /// Rank of the input array.
148 static constexpr int ndims = get_rank<A>;
149
150 /// Element type of a symmetry class.
151 using sym_idx_t = std::pair<long, operation>;
152
153 /// Symmetry class type.
154 using sym_class_t = std::span<sym_idx_t>;
155
156 /// Multi-dimensional index type.
157 using arr_idx_t = std::array<long, static_cast<std::size_t>(ndims)>;
158
159 private:
160 // std::vector containing the different symmetry classes.
161 std::vector<sym_class_t> sym_classes;
162
163 // std::vector of the size of the input array to store the elements of the symmetry classes.
164 std::vector<sym_idx_t> data;
165
166 public:
167 /**
168 * @brief Get the symmetry classes.
169 * @return std::vector containing the individual classes.
170 */
171 [[nodiscard]] std::vector<sym_class_t> const &get_sym_classes() const { return sym_classes; }
172
173 /**
174 * @brief Get the number of symmetry classes.
175 * @return Number of detected symmetry classes.
176 */
177 [[nodiscard]] long num_classes() const { return sym_classes.size(); }
178
179 /**
180 * @brief Initialize an nda::Array using an nda::NdaInitFunc.
181 *
182 * @details The nda::NdaInitFunc object is evaluated only once per symmetry class. The result is then assigned to
183 * all elements in the symmetry class after applying the nda::operation.
184 *
185 * @tparam H Callable type of nda::NdaInitFunc.
186 * @param a nda::Array object to be initialized.
187 * @param init_func Callable that is used to initialize the array.
188 * @param parallel Parallelize using openmp and mpi.
189 */
190 template <typename H>
191 requires(NdaInitFunc<H, A>)
192 void init(A &a, H const &init_func, bool parallel = false) const {
193 if (parallel) {
194 // reset input array to allow for mpi reduction
195 a() = 0.0;
196
197#pragma omp parallel
198 for (auto const &sym_class : itertools::omp_chunk(mpi::chunk(sym_classes))) {
199 auto idx = a.indexmap().to_idx(sym_class[0].first);
200 auto ref_val = init_func(idx);
201 std::apply(a, idx) = ref_val;
202 for (auto const &[lin_idx, op] : sym_class) { std::apply(a, a.indexmap().to_idx(lin_idx)) = op(ref_val); }
203 }
204
205 // distribute data among all ranks
206 a = mpi::all_reduce(a);
207 } else {
208 for (auto const &sym_class : sym_classes) {
209 auto idx = a.indexmap().to_idx(sym_class[0].first);
210 auto ref_val = init_func(idx);
211 std::apply(a, idx) = ref_val;
212 for (auto const &[lin_idx, op] : sym_class) { std::apply(a, a.indexmap().to_idx(lin_idx)) = op(ref_val); }
213 }
214 }
215 }
216
217 /**
218 * @brief Symmetrize an array and return the maximum symmetry violation and its corresponding array index.
219 *
220 * @note This actually requires the definition of an inverse operation but with the current implementation, all
221 * operations are self-inverse (sign flip and complex conjugation).
222 *
223 * @param a nda::Array object to be symmetrized.
224 * @return Maximum symmetry violation and corresponding array index.
225 */
226 std::pair<double, arr_idx_t> symmetrize(A &a) const {
227 // loop over all symmetry classes
228 double max_diff = 0.0;
229 auto max_idx = arr_idx_t{};
230 for (auto const &sym_class : sym_classes) {
231 // reference value for the symmetry class (arithmetic mean over all its elements)
232 get_value_t<A> ref_val = 0.0;
233 for (auto const &[lin_idx, op] : sym_class) { ref_val += op(std::apply(a, a.indexmap().to_idx(lin_idx))); }
234 ref_val /= sym_class.size();
235
236 // assign the reference value to all elements and calculate the violation
237 for (auto const &[lin_idx, op] : sym_class) {
238 auto mapped_val = op(ref_val);
239 auto mapped_idx = a.indexmap().to_idx(lin_idx);
240 auto current_val = std::apply(a, mapped_idx);
241 auto diff = std::abs(mapped_val - current_val);
242
243 if (diff > max_diff) {
244 max_diff = diff;
245 max_idx = mapped_idx;
246 };
247
248 std::apply(a, mapped_idx) = mapped_val;
249 }
250 }
251
252 return std::pair{max_diff, max_idx};
253 }
254
255 /**
256 * @brief Reduce an nda::Array to its representative data using symmetries.
257 *
258 * @param a nda::Array object.
259 * @return std::vector of data values for the representative elements of each symmetry class.
260 */
261 [[nodiscard]] std::vector<get_value_t<A>> get_representative_data(A const &a) const {
262 long len = sym_classes.size();
263 auto vec = std::vector<get_value_t<A>>(len);
264 for (long i : range(len)) vec[i] = std::apply(a, a.indexmap().to_idx(sym_classes[i][0].first));
265 return vec;
266 }
267
268 /**
269 * @brief Initialize a multi-dimensional array from its representative data using symmetries.
270 *
271 * @param a nda::Array object to be initialized.
272 * @param vec std::vector of data values for the representative elements of each symmetry class.
273 */
274 template <typename V>
275 void init_from_representative_data(A &a, V const &vec) const {
276 static_assert(std::is_same_v<const get_value_t<A> &, decltype(vec[0])>);
277 for (long i : range(vec.size())) {
278 for (auto const &[lin_idx, op] : sym_classes[i]) { std::apply(a, a.indexmap().to_idx(lin_idx)) = op(vec[i]); }
279 }
280 }
281
282 /// Default constructor for a symmetry group.
283 sym_grp() = default;
284
285 /**
286 * @brief Construct a symmetry group for a given array and a list of its symmetries.
287 *
288 * @details It uses nda::for_each to loop over all possible multi-dimensional indices of the given array and applies
289 * the symmetries to each index to generate the different symmetry classes.
290 *
291 * @param a nda::Array object.
292 * @param sym_list List of symmetries containing nda::NdaSymmetry objects.
293 * @param max_length Maximum recursion depth for out-of-bounds projection. Default is 0.
294 */
295 sym_grp(A const &a, std::vector<F> const &sym_list, long const max_length = 0) {
296 // array to keep track of the indices already in a symmetry class
297 array<bool, ndims> checked(a.shape());
298 checked() = false;
299
300 // initialize the data (we have as many elements as in the original array)
301 data.reserve(a.size());
302
303 // loop over all indices/elements
304 for_each(checked.shape(), [&checked, &sym_list, max_length, this](auto... is) {
305 if (not checked(is...)) {
306 // if the index has not been checked yet, we start a new symmetry class
307 auto class_start = data.end();
308
309 // mark the current index as checked
310 checked(is...) = true;
311
312 // add it to the symmetry class as its representative together with the identity operation
313 operation op;
314 data.emplace_back(checked.indexmap()(is...), op);
315
316 // apply all symmetries to the current index and generate the symmetry class
317 auto idx = std::array{is...};
318 auto class_size = iterate(idx, op, checked, sym_list, max_length) + 1;
319
320 // store the symmetry class
321 sym_classes.emplace_back(class_start, class_size);
322 }
323 });
324 }
325
326 private:
327 // Implementation of the actual symmetry reduction algorithm.
328 long long iterate(std::array<long, static_cast<std::size_t>(get_rank<A>)> const &idx, operation const &op, array<bool, ndims> &checked,
329 std::vector<F> const &sym_list, long const max_length, long excursion_length = 0) {
330 // count the number of new indices found by recursively applying the symmetries to the current index and to the
331 // newly found indices
332 long long segment_length = 0;
333
334 // loop over all symmetries
335 for (auto const &sym : sym_list) {
336 // apply the symmetry to the current index
337 auto [idxp, opp] = sym(idx);
338 opp = opp * op;
339
340 // check if the new index is valid
341 if (is_valid(checked, idxp)) {
342 // if the new index is valid, check if it has been used already
343 if (not std::apply(checked, idxp)) {
344 // if it has not been used, mark it as checked
345 std::apply(checked, idxp) = true;
346
347 // add it to the symmetry class
348 data.emplace_back(std::apply(checked.indexmap(), idxp), opp);
349
350 // increment the segment length for the current index and recursively call the function with the new index
351 // and the excursion length reset to zero
352 segment_length += iterate(idxp, opp, checked, sym_list, max_length) + 1;
353 }
354 } else if (excursion_length < max_length) {
355 // if the index is invalid, recursively call the function with the new index and the excursion length
356 // incremented by one (the segment length is not incremented)
357 segment_length += iterate(idxp, opp, checked, sym_list, max_length, ++excursion_length);
358 }
359 }
360
361 // return the final value of the local segment length, which will be added to the segments length higher up in the
362 // recursive call tree
363 return segment_length;
364 }
365 };
366
367 /** @} */
368
369} // namespace nda
Class representing a symmetry group.
Definition sym_grp.hpp:145
long num_classes() const
Get the number of symmetry classes.
Definition sym_grp.hpp:177
void init(A &a, H const &init_func, bool parallel=false) const
Initialize an nda::Array using an nda::NdaInitFunc.
Definition sym_grp.hpp:192
std::vector< sym_class_t > const & get_sym_classes() const
Get the symmetry classes.
Definition sym_grp.hpp:171
void init_from_representative_data(A &a, V const &vec) const
Initialize a multi-dimensional array from its representative data using symmetries.
Definition sym_grp.hpp:275
std::vector< get_value_t< A > > get_representative_data(A const &a) const
Reduce an nda::Array to its representative data using symmetries.
Definition sym_grp.hpp:261
std::pair< double, arr_idx_t > symmetrize(A &a) const
Symmetrize an array and return the maximum symmetry violation and its corresponding array index.
Definition sym_grp.hpp:226
sym_grp(A const &a, std::vector< F > const &sym_list, long const max_length=0)
Construct a symmetry group for a given array and a list of its symmetries.
Definition sym_grp.hpp:295
sym_grp()=default
Default constructor for a symmetry group.
static constexpr int ndims
Rank of the input array.
Definition sym_grp.hpp:148
bool is_valid(A const &a, std::array< long, static_cast< std::size_t >(get_rank< A >)> const &idx)
Check if a multi-dimensional index is valid, i.e. not out of bounds, w.r.t. to a given nda::Array obj...
Definition sym_grp.hpp:88
A structure to capture combinations of complex conjugation and sign flip operations.
Definition sym_grp.hpp:47
bool sgn
Boolean value indicating a sign flip operation.
Definition sym_grp.hpp:49
T operator()(T const &x) const
Function call operator to apply the operation to a value.
Definition sym_grp.hpp:73
operation operator*(operation const &rhs)
Multiplication operator for two operations.
Definition sym_grp.hpp:63
bool cc
Boolean value indicating a complex conjugation operation.
Definition sym_grp.hpp:52