TRIQS/nda 1.3.0
Multi-dimensional array library for C++
Loading...
Searching...
No Matches
sym_grp.hpp
Go to the documentation of this file.
1// Copyright (c) 2023--present, The Simons Foundation
2// This file is part of TRIQS/nda and is licensed under the Apache License, Version 2.0.
3// SPDX-License-Identifier: Apache-2.0
4// See LICENSE in the root of this distribution for details.
5
10
11#pragma once
12
13#include "./nda.hpp"
14#include "./mpi.hpp"
15
16#include <itertools/omp_chunk.hpp>
17#include <mpi/mpi.hpp>
18
19#include <array>
20#include <concepts>
21#include <cstddef>
22#include <tuple>
23#include <type_traits>
24#include <vector>
25
26namespace nda {
27
32
36 struct operation {
38 bool sgn = false;
39
41 bool cc = false;
42
52 operation operator*(operation const &rhs) { return operation{bool(sgn xor rhs.sgn), bool(cc xor rhs.cc)}; }
53
61 template <typename T>
62 T operator()(T const &x) const {
63 if (sgn) return cc ? -conj(x) : -x;
64 return cc ? conj(x) : x;
65 }
66 };
67
76 template <Array A>
77 bool is_valid(A const &a, std::array<long, static_cast<std::size_t>(get_rank<A>)> const &idx) {
78 for (auto i = 0; i < get_rank<A>; ++i) {
79 if (not(0 <= idx[i] and idx[i] < a.shape()[i])) { return false; }
80 }
81 return true;
82 }
83
99 template <typename F, typename A, typename Idx = std::array<long, static_cast<std::size_t>(get_rank<A>)>>
100 concept NdaSymmetry = Array<A> and requires(F f, Idx const &idx) {
101 { f(idx) } -> std::same_as<std::tuple<Idx, operation>>;
102 };
103
114 template <typename F, typename A, typename Idx = std::array<long, static_cast<std::size_t>(get_rank<A>)>>
115 concept NdaInitFunc = Array<A> and requires(F f, Idx const &idx) {
116 { f(idx) } -> std::same_as<get_value_t<A>>;
117 };
118
132 template <typename F, typename A>
133 requires(Array<A> && NdaSymmetry<F, A>)
134 class sym_grp {
135 public:
137 static constexpr int ndims = get_rank<A>;
138
140 using sym_idx_t = std::pair<long, operation>;
141
143 using sym_class_t = std::span<sym_idx_t>;
144
146 using arr_idx_t = std::array<long, static_cast<std::size_t>(ndims)>;
147
148 private:
149 // std::vector containing the different symmetry classes.
150 std::vector<sym_class_t> sym_classes;
151
152 // std::vector of the size of the input array to store the elements of the symmetry classes.
153 std::vector<sym_idx_t> data;
154
155 public:
160 [[nodiscard]] std::vector<sym_class_t> const &get_sym_classes() const { return sym_classes; }
161
166 [[nodiscard]] long num_classes() const { return sym_classes.size(); }
167
179 template <typename H>
180 requires(NdaInitFunc<H, A>)
181 void init(A &a, H const &init_func, bool parallel = false) const {
182 if (parallel) {
183 // reset input array to allow for mpi reduction
184 a() = 0.0;
185
186#pragma omp parallel
187 for (auto const &sym_class : itertools::omp_chunk(mpi::chunk(sym_classes))) {
188 auto idx = a.indexmap().to_idx(sym_class[0].first);
189 auto ref_val = init_func(idx);
190 std::apply(a, idx) = ref_val;
191 for (auto const &[lin_idx, op] : sym_class) { std::apply(a, a.indexmap().to_idx(lin_idx)) = op(ref_val); }
192 }
193
194 // distribute data among all ranks
195 a = mpi::all_reduce(a);
196 } else {
197 for (auto const &sym_class : sym_classes) {
198 auto idx = a.indexmap().to_idx(sym_class[0].first);
199 auto ref_val = init_func(idx);
200 std::apply(a, idx) = ref_val;
201 for (auto const &[lin_idx, op] : sym_class) { std::apply(a, a.indexmap().to_idx(lin_idx)) = op(ref_val); }
202 }
203 }
204 }
205
215 std::pair<double, arr_idx_t> symmetrize(A &a) const {
216 // loop over all symmetry classes
217 double max_diff = 0.0;
218 auto max_idx = arr_idx_t{};
219 for (auto const &sym_class : sym_classes) {
220 // reference value for the symmetry class (arithmetic mean over all its elements)
221 get_value_t<A> ref_val = 0.0;
222 for (auto const &[lin_idx, op] : sym_class) { ref_val += op(std::apply(a, a.indexmap().to_idx(lin_idx))); }
223 ref_val /= sym_class.size();
224
225 // assign the reference value to all elements and calculate the violation
226 for (auto const &[lin_idx, op] : sym_class) {
227 auto mapped_val = op(ref_val);
228 auto mapped_idx = a.indexmap().to_idx(lin_idx);
229 auto current_val = std::apply(a, mapped_idx);
230 auto diff = std::abs(mapped_val - current_val);
231
232 if (diff > max_diff) {
233 max_diff = diff;
234 max_idx = mapped_idx;
235 };
236
237 std::apply(a, mapped_idx) = mapped_val;
238 }
239 }
240
241 return std::pair{max_diff, max_idx};
242 }
243
250 [[nodiscard]] std::vector<get_value_t<A>> get_representative_data(A const &a) const {
251 long len = sym_classes.size();
252 auto vec = std::vector<get_value_t<A>>(len);
253 for (long i : range(len)) vec[i] = std::apply(a, a.indexmap().to_idx(sym_classes[i][0].first));
254 return vec;
255 }
256
263 template <typename V>
264 void init_from_representative_data(A &a, V const &vec) const {
265 static_assert(std::is_same_v<const get_value_t<A> &, decltype(vec[0])>);
266 for (long i : range(vec.size())) {
267 for (auto const &[lin_idx, op] : sym_classes[i]) { std::apply(a, a.indexmap().to_idx(lin_idx)) = op(vec[i]); }
268 }
269 }
270
272 sym_grp() = default;
273
284 sym_grp(A const &a, std::vector<F> const &sym_list, long const max_length = 0) {
285 // array to keep track of the indices already in a symmetry class
286 array<bool, ndims> checked(a.shape());
287 checked() = false;
288
289 // initialize the data (we have as many elements as in the original array)
290 data.reserve(a.size());
291
292 // loop over all indices/elements
293 for_each(checked.shape(), [&checked, &sym_list, max_length, this](auto... is) {
294 if (not checked(is...)) {
295 // if the index has not been checked yet, we start a new symmetry class
296 auto class_start = data.end();
297
298 // mark the current index as checked
299 checked(is...) = true;
300
301 // add it to the symmetry class as its representative together with the identity operation
302 operation op;
303 data.emplace_back(checked.indexmap()(is...), op);
304
305 // apply all symmetries to the current index and generate the symmetry class
306 auto idx = std::array{is...};
307 auto class_size = iterate(idx, op, checked, sym_list, max_length) + 1;
308
309 // store the symmetry class
310 sym_classes.emplace_back(class_start, class_size);
311 }
312 });
313 }
314
315 private:
316 // Implementation of the actual symmetry reduction algorithm.
317 long long iterate(std::array<long, static_cast<std::size_t>(get_rank<A>)> const &idx, operation const &op, array<bool, ndims> &checked,
318 std::vector<F> const &sym_list, long const max_length, long excursion_length = 0) {
319 // count the number of new indices found by recursively applying the symmetries to the current index and to the
320 // newly found indices
321 long long segment_length = 0;
322
323 // loop over all symmetries
324 for (auto const &sym : sym_list) {
325 // apply the symmetry to the current index
326 auto [idxp, opp] = sym(idx);
327 opp = opp * op;
328
329 // check if the new index is valid
330 if (is_valid(checked, idxp)) {
331 // if the new index is valid, check if it has been used already
332 if (not std::apply(checked, idxp)) {
333 // if it has not been used, mark it as checked
334 std::apply(checked, idxp) = true;
335
336 // add it to the symmetry class
337 data.emplace_back(std::apply(checked.indexmap(), idxp), opp);
338
339 // increment the segment length for the current index and recursively call the function with the new index
340 // and the excursion length reset to zero
341 segment_length += iterate(idxp, opp, checked, sym_list, max_length) + 1;
342 }
343 } else if (excursion_length < max_length) {
344 // if the index is invalid, recursively call the function with the new index and the excursion length
345 // incremented by one (the segment length is not incremented)
346 segment_length += iterate(idxp, opp, checked, sym_list, max_length, ++excursion_length);
347 }
348 }
349
350 // return the final value of the local segment length, which will be added to the segments length higher up in the
351 // recursive call tree
352 return segment_length;
353 }
354 };
355
357
358} // namespace nda
auto const & shape() const noexcept
Get the shape of the view/array.
std::array< long, static_cast< std::size_t >(ndims)> arr_idx_t
Multi-dimensional index type.
Definition sym_grp.hpp:146
long num_classes() const
Get the number of symmetry classes.
Definition sym_grp.hpp:166
void init(A &a, H const &init_func, bool parallel=false) const
Initialize an nda::Array using an nda::NdaInitFunc.
Definition sym_grp.hpp:181
std::vector< sym_class_t > const & get_sym_classes() const
Get the symmetry classes.
Definition sym_grp.hpp:160
void init_from_representative_data(A &a, V const &vec) const
Initialize a multi-dimensional array from its representative data using symmetries.
Definition sym_grp.hpp:264
std::vector< get_value_t< A > > get_representative_data(A const &a) const
Reduce an nda::Array to its representative data using symmetries.
Definition sym_grp.hpp:250
std::pair< long, operation > sym_idx_t
Element type of a symmetry class.
Definition sym_grp.hpp:140
std::span< sym_idx_t > sym_class_t
Symmetry class type.
Definition sym_grp.hpp:143
std::pair< double, arr_idx_t > symmetrize(A &a) const
Symmetrize an array and return the maximum symmetry violation and its corresponding array index.
Definition sym_grp.hpp:215
sym_grp(A const &a, std::vector< F > const &sym_list, long const max_length=0)
Construct a symmetry group for a given array and a list of its symmetries.
Definition sym_grp.hpp:284
sym_grp()=default
Default constructor for a symmetry group.
static constexpr int ndims
Rank of the input array.
Definition sym_grp.hpp:137
Check if a given type satisfies the array concept.
Definition concepts.hpp:230
Concept defining an initializer function.
Definition sym_grp.hpp:115
Concept defining a symmetry in nda.
Definition sym_grp.hpp:100
decltype(auto) conj(A &&a)
Function conj for nda::ArrayOrScalar types (lazy and coefficient-wise for nda::Array types with a com...
bool is_valid(A const &a, std::array< long, static_cast< std::size_t >(get_rank< A >)> const &idx)
Check if a multi-dimensional index is valid, i.e. not out of bounds, w.r.t. to a given nda::Array obj...
Definition sym_grp.hpp:77
basic_array< ValueType, Rank, Layout, 'A', ContainerPolicy > array
Alias template of an nda::basic_array with an 'A' algebra.
constexpr int get_rank
Constexpr variable that specifies the rank of an nda::Array or of a contiguous 1-dimensional range.
Definition traits.hpp:125
std::decay_t< decltype(get_first_element(std::declval< A const >()))> get_value_t
Get the value type of an array/view or a scalar type.
Definition traits.hpp:181
__inline__ void for_each(std::array< Int, R > const &shape, F &&f)
Loop over all possible index values of a given shape and apply a function to them.
Definition for_each.hpp:116
Includes all MPI relevant headers.
Includes all relevant headers for the core nda library.
A structure to capture combinations of complex conjugation and sign flip operations.
Definition sym_grp.hpp:36
bool sgn
Boolean value indicating a sign flip operation.
Definition sym_grp.hpp:38
T operator()(T const &x) const
Function call operator to apply the operation to a value.
Definition sym_grp.hpp:62
operation operator*(operation const &rhs)
Multiplication operator for two operations.
Definition sym_grp.hpp:52
bool cc
Boolean value indicating a complex conjugation operation.
Definition sym_grp.hpp:41