TRIQS/nda 1.3.0
Multi-dimensional array library for C++
Loading...
Searching...
No Matches
sym_grp.hpp
Go to the documentation of this file.
1// Copyright (c) 2023--present, The Simons Foundation
2// This file is part of TRIQS/nda and is licensed under the Apache License, Version 2.0.
3// SPDX-License-Identifier: Apache-2.0
4// See LICENSE in the root of this distribution for details.
5
10
11#pragma once
12
13#include "./nda.hpp"
14#ifdef NDA_HAVE_MPI
15#include "./mpi.hpp"
16#endif
17
18#include <itertools/omp_chunk.hpp>
19
20#include <array>
21#include <concepts>
22#include <cstddef>
23#include <tuple>
24#include <type_traits>
25#include <vector>
26
27namespace nda {
28
33
37 struct operation {
39 bool sgn = false;
40
42 bool cc = false;
43
53 operation operator*(operation const &rhs) { return operation{bool(sgn xor rhs.sgn), bool(cc xor rhs.cc)}; }
54
62 template <typename T>
63 T operator()(T const &x) const {
64 if (sgn) return cc ? -conj(x) : -x;
65 return cc ? conj(x) : x;
66 }
67 };
68
77 template <Array A>
78 bool is_valid(A const &a, std::array<long, static_cast<std::size_t>(get_rank<A>)> const &idx) {
79 for (auto i = 0; i < get_rank<A>; ++i) {
80 if (not(0 <= idx[i] and idx[i] < a.shape()[i])) { return false; }
81 }
82 return true;
83 }
84
100 template <typename F, typename A, typename Idx = std::array<long, static_cast<std::size_t>(get_rank<A>)>>
101 concept NdaSymmetry = Array<A> and requires(F f, Idx const &idx) {
102 { f(idx) } -> std::same_as<std::tuple<Idx, operation>>;
103 };
104
115 template <typename F, typename A, typename Idx = std::array<long, static_cast<std::size_t>(get_rank<A>)>>
116 concept NdaInitFunc = Array<A> and requires(F f, Idx const &idx) {
117 { f(idx) } -> std::same_as<get_value_t<A>>;
118 };
119
133 template <typename F, typename A>
134 requires(Array<A> && NdaSymmetry<F, A>)
135 class sym_grp {
136 public:
138 static constexpr int ndims = get_rank<A>;
139
141 using sym_idx_t = std::pair<long, operation>;
142
144 using sym_class_t = std::span<sym_idx_t>;
145
147 using arr_idx_t = std::array<long, static_cast<std::size_t>(ndims)>;
148
149 private:
150 // std::vector containing the different symmetry classes.
151 std::vector<sym_class_t> sym_classes;
152
153 // std::vector of the size of the input array to store the elements of the symmetry classes.
154 std::vector<sym_idx_t> data;
155
156 public:
161 [[nodiscard]] std::vector<sym_class_t> const &get_sym_classes() const { return sym_classes; }
162
167 [[nodiscard]] long num_classes() const { return sym_classes.size(); }
168
180 template <typename H>
181 requires(NdaInitFunc<H, A>)
182 void init(A &a, H const &init_func, bool parallel = false) const {
183 auto init_with_sym = [&](sym_class_t const &sym_class) {
184 auto idx = a.indexmap().to_idx(sym_class[0].first);
185 auto ref_val = init_func(idx);
186 std::apply(a, idx) = ref_val;
187 for (auto const &[lin_idx, op] : sym_class) { std::apply(a, a.indexmap().to_idx(lin_idx)) = op(ref_val); }
188 };
189
190 if (parallel) {
191#ifdef NDA_HAVE_MPI
192 a() = 0.0; // Required for MPI reduce below
193#endif
194#ifdef NDA_HAVE_OPENMP
195#pragma omp parallel for
196#endif // NDA_HAVE_OPENMP
197#ifdef NDA_HAVE_MPI
198 for (auto const &sym_class : mpi::chunk(sym_classes)) init_with_sym(sym_class);
199 mpi::all_reduce_in_place(a);
200#else
201 for (auto const &sym_class : sym_classes) init_with_sym(sym_class);
202#endif // NDA_HAVE_MPI
203 } else {
204 for (auto const &sym_class : sym_classes) init_with_sym(sym_class);
205 }
206 }
207
217 std::pair<double, arr_idx_t> symmetrize(A &a) const {
218 // loop over all symmetry classes
219 double max_diff = 0.0;
220 auto max_idx = arr_idx_t{};
221 for (auto const &sym_class : sym_classes) {
222 // reference value for the symmetry class (arithmetic mean over all its elements)
223 get_value_t<A> ref_val = 0.0;
224 for (auto const &[lin_idx, op] : sym_class) { ref_val += op(std::apply(a, a.indexmap().to_idx(lin_idx))); }
225 ref_val /= sym_class.size();
226
227 // assign the reference value to all elements and calculate the violation
228 for (auto const &[lin_idx, op] : sym_class) {
229 auto mapped_val = op(ref_val);
230 auto mapped_idx = a.indexmap().to_idx(lin_idx);
231 auto current_val = std::apply(a, mapped_idx);
232 auto diff = std::abs(mapped_val - current_val);
233
234 if (diff > max_diff) {
235 max_diff = diff;
236 max_idx = mapped_idx;
237 };
238
239 std::apply(a, mapped_idx) = mapped_val;
240 }
241 }
242
243 return std::pair{max_diff, max_idx};
244 }
245
252 [[nodiscard]] std::vector<get_value_t<A>> get_representative_data(A const &a) const {
253 long len = sym_classes.size();
254 auto vec = std::vector<get_value_t<A>>(len);
255 for (long i : range(len)) vec[i] = std::apply(a, a.indexmap().to_idx(sym_classes[i][0].first));
256 return vec;
257 }
258
265 template <typename V>
266 void init_from_representative_data(A &a, V const &vec) const {
267 static_assert(std::is_same_v<const get_value_t<A> &, decltype(vec[0])>);
268 for (long i : range(vec.size())) {
269 for (auto const &[lin_idx, op] : sym_classes[i]) { std::apply(a, a.indexmap().to_idx(lin_idx)) = op(vec[i]); }
270 }
271 }
272
274 sym_grp() = default;
275
286 sym_grp(A const &a, std::vector<F> const &sym_list, long const max_length = 0) {
287 // array to keep track of the indices already in a symmetry class
288 array<bool, ndims> checked(a.shape());
289 checked() = false;
290
291 // initialize the data (we have as many elements as in the original array)
292 data.reserve(a.size());
293
294 // loop over all indices/elements
295 for_each(checked.shape(), [&checked, &sym_list, max_length, this](auto... is) {
296 if (not checked(is...)) {
297 // if the index has not been checked yet, we start a new symmetry class
298 auto class_start = data.end();
299
300 // mark the current index as checked
301 checked(is...) = true;
302
303 // add it to the symmetry class as its representative together with the identity operation
304 operation op;
305 data.emplace_back(checked.indexmap()(is...), op);
306
307 // apply all symmetries to the current index and generate the symmetry class
308 auto idx = std::array{is...};
309 auto class_size = iterate(idx, op, checked, sym_list, max_length) + 1;
310
311 // store the symmetry class
312 sym_classes.emplace_back(class_start, class_size);
313 }
314 });
315 }
316
317 private:
318 // Implementation of the actual symmetry reduction algorithm.
319 long long iterate(std::array<long, static_cast<std::size_t>(get_rank<A>)> const &idx, operation const &op, array<bool, ndims> &checked,
320 std::vector<F> const &sym_list, long const max_length, long excursion_length = 0) {
321 // count the number of new indices found by recursively applying the symmetries to the current index and to the
322 // newly found indices
323 long long segment_length = 0;
324
325 // loop over all symmetries
326 for (auto const &sym : sym_list) {
327 // apply the symmetry to the current index
328 auto [idxp, opp] = sym(idx);
329 opp = opp * op;
330
331 // check if the new index is valid
332 if (is_valid(checked, idxp)) {
333 // if the new index is valid, check if it has been used already
334 if (not std::apply(checked, idxp)) {
335 // if it has not been used, mark it as checked
336 std::apply(checked, idxp) = true;
337
338 // add it to the symmetry class
339 data.emplace_back(std::apply(checked.indexmap(), idxp), opp);
340
341 // increment the segment length for the current index and recursively call the function with the new index
342 // and the excursion length reset to zero
343 segment_length += iterate(idxp, opp, checked, sym_list, max_length) + 1;
344 }
345 } else if (excursion_length < max_length) {
346 // if the index is invalid, recursively call the function with the new index and the excursion length
347 // incremented by one (the segment length is not incremented)
348 segment_length += iterate(idxp, opp, checked, sym_list, max_length, ++excursion_length);
349 }
350 }
351
352 // return the final value of the local segment length, which will be added to the segments length higher up in the
353 // recursive call tree
354 return segment_length;
355 }
356 };
357
359
360} // namespace nda
auto const & shape() const noexcept
Get the shape of the view/array.
std::array< long, static_cast< std::size_t >(ndims)> arr_idx_t
Multi-dimensional index type.
Definition sym_grp.hpp:147
long num_classes() const
Get the number of symmetry classes.
Definition sym_grp.hpp:167
void init(A &a, H const &init_func, bool parallel=false) const
Initialize an nda::Array using an nda::NdaInitFunc.
Definition sym_grp.hpp:182
std::vector< sym_class_t > const & get_sym_classes() const
Get the symmetry classes.
Definition sym_grp.hpp:161
void init_from_representative_data(A &a, V const &vec) const
Initialize a multi-dimensional array from its representative data using symmetries.
Definition sym_grp.hpp:266
std::vector< get_value_t< A > > get_representative_data(A const &a) const
Reduce an nda::Array to its representative data using symmetries.
Definition sym_grp.hpp:252
std::pair< long, operation > sym_idx_t
Element type of a symmetry class.
Definition sym_grp.hpp:141
std::span< sym_idx_t > sym_class_t
Symmetry class type.
Definition sym_grp.hpp:144
std::pair< double, arr_idx_t > symmetrize(A &a) const
Symmetrize an array and return the maximum symmetry violation and its corresponding array index.
Definition sym_grp.hpp:217
sym_grp(A const &a, std::vector< F > const &sym_list, long const max_length=0)
Construct a symmetry group for a given array and a list of its symmetries.
Definition sym_grp.hpp:286
sym_grp()=default
Default constructor for a symmetry group.
static constexpr int ndims
Rank of the input array.
Definition sym_grp.hpp:138
Check if a given type satisfies the array concept.
Definition concepts.hpp:230
Concept defining an initializer function.
Definition sym_grp.hpp:116
Concept defining a symmetry in nda.
Definition sym_grp.hpp:101
decltype(auto) conj(A &&a)
Function conj for nda::ArrayOrScalar types (lazy and coefficient-wise for nda::Array types with a com...
bool is_valid(A const &a, std::array< long, static_cast< std::size_t >(get_rank< A >)> const &idx)
Check if a multi-dimensional index is valid, i.e. not out of bounds, w.r.t. to a given nda::Array obj...
Definition sym_grp.hpp:78
basic_array< ValueType, Rank, Layout, 'A', ContainerPolicy > array
Alias template of an nda::basic_array with an 'A' algebra.
constexpr int get_rank
Constexpr variable that specifies the rank of an nda::Array or of a contiguous 1-dimensional range.
Definition traits.hpp:126
std::decay_t< decltype(get_first_element(std::declval< A const >()))> get_value_t
Get the value type of an array/view or a scalar type.
Definition traits.hpp:182
__inline__ void for_each(std::array< Int, R > const &shape, F &&f)
Loop over all possible index values of a given shape and apply a function to them.
Definition for_each.hpp:116
Includes all MPI relevant headers.
Includes all relevant headers for the core nda library.
A structure to capture combinations of complex conjugation and sign flip operations.
Definition sym_grp.hpp:37
bool sgn
Boolean value indicating a sign flip operation.
Definition sym_grp.hpp:39
T operator()(T const &x) const
Function call operator to apply the operation to a value.
Definition sym_grp.hpp:63
operation operator*(operation const &rhs)
Multiplication operator for two operations.
Definition sym_grp.hpp:53
bool cc
Boolean value indicating a complex conjugation operation.
Definition sym_grp.hpp:42