TRIQS/nda 1.3.0
Multi-dimensional array library for C++
Loading...
Searching...
No Matches
sym_grp.hpp
Go to the documentation of this file.
1// Copyright (c) 2023 Simons Foundation
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7// http://www.apache.org/licenses/LICENSE-2.0.txt
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14//
15// Authors: Dominik Kiese
16
22#pragma once
23
24#include "./nda.hpp"
25#include "./mpi.hpp"
26
27#include <itertools/omp_chunk.hpp>
28#include <mpi/mpi.hpp>
29
30#include <array>
31#include <concepts>
32#include <cstddef>
33#include <tuple>
34#include <type_traits>
35#include <vector>
36
37namespace nda {
38
47 struct operation {
49 bool sgn = false;
50
52 bool cc = false;
53
63 operation operator*(operation const &rhs) { return operation{bool(sgn xor rhs.sgn), bool(cc xor rhs.cc)}; }
64
72 template <typename T>
73 T operator()(T const &x) const {
74 if (sgn) return cc ? -conj(x) : -x;
75 return cc ? conj(x) : x;
76 }
77 };
78
87 template <Array A>
88 bool is_valid(A const &a, std::array<long, static_cast<std::size_t>(get_rank<A>)> const &idx) {
89 for (auto i = 0; i < get_rank<A>; ++i) {
90 if (not(0 <= idx[i] and idx[i] < a.shape()[i])) { return false; }
91 }
92 return true;
93 }
94
110 template <typename F, typename A, typename Idx = std::array<long, static_cast<std::size_t>(get_rank<A>)>>
111 concept NdaSymmetry = Array<A> and requires(F f, Idx const &idx) {
112 { f(idx) } -> std::same_as<std::tuple<Idx, operation>>;
113 };
114
125 template <typename F, typename A, typename Idx = std::array<long, static_cast<std::size_t>(get_rank<A>)>>
126 concept NdaInitFunc = Array<A> and requires(F f, Idx const &idx) {
127 { f(idx) } -> std::same_as<get_value_t<A>>;
128 };
129
143 template <typename F, typename A>
144 requires(Array<A> && NdaSymmetry<F, A>)
145 class sym_grp {
146 public:
148 static constexpr int ndims = get_rank<A>;
149
151 using sym_idx_t = std::pair<long, operation>;
152
154 using sym_class_t = std::span<sym_idx_t>;
155
157 using arr_idx_t = std::array<long, static_cast<std::size_t>(ndims)>;
158
159 private:
160 // std::vector containing the different symmetry classes.
161 std::vector<sym_class_t> sym_classes;
162
163 // std::vector of the size of the input array to store the elements of the symmetry classes.
164 std::vector<sym_idx_t> data;
165
166 public:
171 [[nodiscard]] std::vector<sym_class_t> const &get_sym_classes() const { return sym_classes; }
172
177 [[nodiscard]] long num_classes() const { return sym_classes.size(); }
178
190 template <typename H>
191 requires(NdaInitFunc<H, A>)
192 void init(A &a, H const &init_func, bool parallel = false) const {
193 if (parallel) {
194 // reset input array to allow for mpi reduction
195 a() = 0.0;
196
197#pragma omp parallel
198 for (auto const &sym_class : itertools::omp_chunk(mpi::chunk(sym_classes))) {
199 auto idx = a.indexmap().to_idx(sym_class[0].first);
200 auto ref_val = init_func(idx);
201 std::apply(a, idx) = ref_val;
202 for (auto const &[lin_idx, op] : sym_class) { std::apply(a, a.indexmap().to_idx(lin_idx)) = op(ref_val); }
203 }
204
205 // distribute data among all ranks
206 a = mpi::all_reduce(a);
207 } else {
208 for (auto const &sym_class : sym_classes) {
209 auto idx = a.indexmap().to_idx(sym_class[0].first);
210 auto ref_val = init_func(idx);
211 std::apply(a, idx) = ref_val;
212 for (auto const &[lin_idx, op] : sym_class) { std::apply(a, a.indexmap().to_idx(lin_idx)) = op(ref_val); }
213 }
214 }
215 }
216
226 std::pair<double, arr_idx_t> symmetrize(A &a) const {
227 // loop over all symmetry classes
228 double max_diff = 0.0;
229 auto max_idx = arr_idx_t{};
230 for (auto const &sym_class : sym_classes) {
231 // reference value for the symmetry class (arithmetic mean over all its elements)
232 get_value_t<A> ref_val = 0.0;
233 for (auto const &[lin_idx, op] : sym_class) { ref_val += op(std::apply(a, a.indexmap().to_idx(lin_idx))); }
234 ref_val /= sym_class.size();
235
236 // assign the reference value to all elements and calculate the violation
237 for (auto const &[lin_idx, op] : sym_class) {
238 auto mapped_val = op(ref_val);
239 auto mapped_idx = a.indexmap().to_idx(lin_idx);
240 auto current_val = std::apply(a, mapped_idx);
241 auto diff = std::abs(mapped_val - current_val);
242
243 if (diff > max_diff) {
244 max_diff = diff;
245 max_idx = mapped_idx;
246 };
247
248 std::apply(a, mapped_idx) = mapped_val;
249 }
250 }
251
252 return std::pair{max_diff, max_idx};
253 }
254
261 [[nodiscard]] std::vector<get_value_t<A>> get_representative_data(A const &a) const {
262 long len = sym_classes.size();
263 auto vec = std::vector<get_value_t<A>>(len);
264 for (long i : range(len)) vec[i] = std::apply(a, a.indexmap().to_idx(sym_classes[i][0].first));
265 return vec;
266 }
267
274 template <typename V>
275 void init_from_representative_data(A &a, V const &vec) const {
276 static_assert(std::is_same_v<const get_value_t<A> &, decltype(vec[0])>);
277 for (long i : range(vec.size())) {
278 for (auto const &[lin_idx, op] : sym_classes[i]) { std::apply(a, a.indexmap().to_idx(lin_idx)) = op(vec[i]); }
279 }
280 }
281
283 sym_grp() = default;
284
295 sym_grp(A const &a, std::vector<F> const &sym_list, long const max_length = 0) {
296 // array to keep track of the indices already in a symmetry class
297 array<bool, ndims> checked(a.shape());
298 checked() = false;
299
300 // initialize the data (we have as many elements as in the original array)
301 data.reserve(a.size());
302
303 // loop over all indices/elements
304 for_each(checked.shape(), [&checked, &sym_list, max_length, this](auto... is) {
305 if (not checked(is...)) {
306 // if the index has not been checked yet, we start a new symmetry class
307 auto class_start = data.end();
308
309 // mark the current index as checked
310 checked(is...) = true;
311
312 // add it to the symmetry class as its representative together with the identity operation
313 operation op;
314 data.emplace_back(checked.indexmap()(is...), op);
315
316 // apply all symmetries to the current index and generate the symmetry class
317 auto idx = std::array{is...};
318 auto class_size = iterate(idx, op, checked, sym_list, max_length) + 1;
319
320 // store the symmetry class
321 sym_classes.emplace_back(class_start, class_size);
322 }
323 });
324 }
325
326 private:
327 // Implementation of the actual symmetry reduction algorithm.
328 long long iterate(std::array<long, static_cast<std::size_t>(get_rank<A>)> const &idx, operation const &op, array<bool, ndims> &checked,
329 std::vector<F> const &sym_list, long const max_length, long excursion_length = 0) {
330 // count the number of new indices found by recursively applying the symmetries to the current index and to the
331 // newly found indices
332 long long segment_length = 0;
333
334 // loop over all symmetries
335 for (auto const &sym : sym_list) {
336 // apply the symmetry to the current index
337 auto [idxp, opp] = sym(idx);
338 opp = opp * op;
339
340 // check if the new index is valid
341 if (is_valid(checked, idxp)) {
342 // if the new index is valid, check if it has been used already
343 if (not std::apply(checked, idxp)) {
344 // if it has not been used, mark it as checked
345 std::apply(checked, idxp) = true;
346
347 // add it to the symmetry class
348 data.emplace_back(std::apply(checked.indexmap(), idxp), opp);
349
350 // increment the segment length for the current index and recursively call the function with the new index
351 // and the excursion length reset to zero
352 segment_length += iterate(idxp, opp, checked, sym_list, max_length) + 1;
353 }
354 } else if (excursion_length < max_length) {
355 // if the index is invalid, recursively call the function with the new index and the excursion length
356 // incremented by one (the segment length is not incremented)
357 segment_length += iterate(idxp, opp, checked, sym_list, max_length, ++excursion_length);
358 }
359 }
360
361 // return the final value of the local segment length, which will be added to the segments length higher up in the
362 // recursive call tree
363 return segment_length;
364 }
365 };
366
369} // namespace nda
A generic multi-dimensional array.
auto const & shape() const noexcept
Get the shape of the view/array.
Class representing a symmetry group.
Definition sym_grp.hpp:145
std::array< long, static_cast< std::size_t >(ndims)> arr_idx_t
Multi-dimensional index type.
Definition sym_grp.hpp:157
long num_classes() const
Get the number of symmetry classes.
Definition sym_grp.hpp:177
void init(A &a, H const &init_func, bool parallel=false) const
Initialize an nda::Array using an nda::NdaInitFunc.
Definition sym_grp.hpp:192
std::vector< sym_class_t > const & get_sym_classes() const
Get the symmetry classes.
Definition sym_grp.hpp:171
void init_from_representative_data(A &a, V const &vec) const
Initialize a multi-dimensional array from its representative data using symmetries.
Definition sym_grp.hpp:275
std::vector< get_value_t< A > > get_representative_data(A const &a) const
Reduce an nda::Array to its representative data using symmetries.
Definition sym_grp.hpp:261
std::pair< long, operation > sym_idx_t
Element type of a symmetry class.
Definition sym_grp.hpp:151
std::span< sym_idx_t > sym_class_t
Symmetry class type.
Definition sym_grp.hpp:154
std::pair< double, arr_idx_t > symmetrize(A &a) const
Symmetrize an array and return the maximum symmetry violation and its corresponding array index.
Definition sym_grp.hpp:226
sym_grp(A const &a, std::vector< F > const &sym_list, long const max_length=0)
Construct a symmetry group for a given array and a list of its symmetries.
Definition sym_grp.hpp:295
sym_grp()=default
Default constructor for a symmetry group.
Check if a given type satisfies the array concept.
Definition concepts.hpp:230
Concept defining an initializer function.
Definition sym_grp.hpp:126
Concept defining a symmetry in nda.
Definition sym_grp.hpp:111
auto conj(T t)
Get the complex conjugate of a scalar.
bool is_valid(A const &a, std::array< long, static_cast< std::size_t >(get_rank< A >)> const &idx)
Check if a multi-dimensional index is valid, i.e. not out of bounds, w.r.t. to a given nda::Array obj...
Definition sym_grp.hpp:88
basic_array< ValueType, Rank, Layout, 'A', ContainerPolicy > array
Alias template of an nda::basic_array with an 'A' algebra.
constexpr int get_rank
Constexpr variable that specifies the rank of an nda::Array or of a contiguous 1-dimensional range.
Definition traits.hpp:136
std::decay_t< decltype(get_first_element(std::declval< A const >()))> get_value_t
Get the value type of an array/view or a scalar type.
Definition traits.hpp:192
__inline__ void for_each(std::array< Int, R > const &shape, F &&f)
Loop over all possible index values of a given shape and apply a function to them.
Definition for_each.hpp:129
Includes all MPI relevant headers.
Includes all relevant headers for the core nda library.
A structure to capture combinations of complex conjugation and sign flip operations.
Definition sym_grp.hpp:47
bool sgn
Boolean value indicating a sign flip operation.
Definition sym_grp.hpp:49
T operator()(T const &x) const
Function call operator to apply the operation to a value.
Definition sym_grp.hpp:73
operation operator*(operation const &rhs)
Multiplication operator for two operations.
Definition sym_grp.hpp:63
bool cc
Boolean value indicating a complex conjugation operation.
Definition sym_grp.hpp:52