18
19
20
27#include <itertools/omp_chunk.hpp>
40
41
42
45
46
55
56
57
58
59
60
61
62
66
67
68
69
70
71
74 if (
sgn)
return cc ? -conj(x) : -x;
75 return cc ? conj(x) : x;
80
81
82
83
84
85
86
88 bool is_valid(A
const &a, std::array<
long,
static_cast<std::size_t>(get_rank<A>)>
const &idx) {
89 for (
auto i = 0; i < get_rank<A>; ++i) {
90 if (
not(0 <= idx[i]
and idx[i] < a.shape()[i])) {
return false; }
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110 template <
typename F,
typename A,
typename Idx = std::array<
long,
static_cast<std::size_t>(get_rank<A>)>>
111 concept NdaSymmetry = Array<A>
and requires(F f, Idx
const &idx) {
112 { f(idx) } -> std::same_as<std::tuple<Idx,
operation>>;
116
117
118
119
120
121
122
123
124
125 template <
typename F,
typename A,
typename Idx = std::array<
long,
static_cast<std::size_t>(get_rank<A>)>>
126 concept NdaInitFunc = Array<A>
and requires(F f, Idx
const &idx) {
127 { f(idx) } -> std::same_as<get_value_t<A>>;
131
132
133
134
135
136
137
138
139
140
141
142
143 template <
typename F,
typename A>
144 requires(Array<A> && NdaSymmetry<F, A>)
148 static constexpr int ndims = get_rank<A>;
154 using sym_class_t = std::span<sym_idx_t>;
157 using arr_idx_t = std::array<
long,
static_cast<std::size_t>(
ndims)>;
161 std::vector<sym_class_t> sym_classes;
164 std::vector<sym_idx_t> data;
168
169
170
174
175
176
180
181
182
183
184
185
186
187
188
189
190 template <
typename H>
191 requires(NdaInitFunc<H, A>)
192 void init(A &a, H
const &init_func,
bool parallel =
false)
const {
198 for (
auto const &sym_class : itertools::omp_chunk(mpi::chunk(sym_classes))) {
199 auto idx = a.indexmap().to_idx(sym_class[0].first);
200 auto ref_val = init_func(idx);
201 std::apply(a, idx) = ref_val;
202 for (
auto const &[lin_idx, op] : sym_class) { std::apply(a, a.indexmap().to_idx(lin_idx)) = op(ref_val); }
206 a = mpi::all_reduce(a);
208 for (
auto const &sym_class : sym_classes) {
209 auto idx = a.indexmap().to_idx(sym_class[0].first);
210 auto ref_val = init_func(idx);
211 std::apply(a, idx) = ref_val;
212 for (
auto const &[lin_idx, op] : sym_class) { std::apply(a, a.indexmap().to_idx(lin_idx)) = op(ref_val); }
218
219
220
221
222
223
224
225
228 double max_diff = 0.0;
229 auto max_idx = arr_idx_t{};
230 for (
auto const &sym_class : sym_classes) {
232 get_value_t<A> ref_val = 0.0;
233 for (
auto const &[lin_idx, op] : sym_class) { ref_val += op(std::apply(a, a.indexmap().to_idx(lin_idx))); }
234 ref_val /= sym_class.size();
237 for (
auto const &[lin_idx, op] : sym_class) {
238 auto mapped_val = op(ref_val);
239 auto mapped_idx = a.indexmap().to_idx(lin_idx);
240 auto current_val = std::apply(a, mapped_idx);
241 auto diff = std::abs(mapped_val - current_val);
243 if (diff > max_diff) {
245 max_idx = mapped_idx;
248 std::apply(a, mapped_idx) = mapped_val;
252 return std::pair{max_diff, max_idx};
256
257
258
259
260
262 long len = sym_classes.size();
263 auto vec = std::vector<get_value_t<A>>(len);
264 for (
long i : range(len)) vec[i] = std::apply(a, a.indexmap().to_idx(sym_classes[i][0].first));
269
270
271
272
273
274 template <
typename V>
276 static_assert(std::is_same_v<
const get_value_t<A> &,
decltype(vec[0])>);
277 for (
long i : range(vec.size())) {
278 for (
auto const &[lin_idx, op] : sym_classes[i]) { std::apply(a, a.indexmap().to_idx(lin_idx)) = op(vec[i]); }
286
287
288
289
290
291
292
293
294
295 sym_grp(A
const &a, std::vector<F>
const &sym_list,
long const max_length = 0) {
297 array<
bool,
ndims> checked(a.shape());
301 data.reserve(a.size());
304 for_each(checked.shape(), [&checked, &sym_list, max_length,
this](
auto... is) {
305 if (
not checked(is...)) {
307 auto class_start = data.end();
310 checked(is...) =
true;
314 data.emplace_back(checked.indexmap()(is...), op);
317 auto idx = std::array{is...};
318 auto class_size = iterate(idx, op, checked, sym_list, max_length) + 1;
321 sym_classes.emplace_back(class_start, class_size);
328 long long iterate(std::array<
long,
static_cast<std::size_t>(get_rank<A>)>
const &idx,
operation const &op, array<
bool,
ndims> &checked,
329 std::vector<F>
const &sym_list,
long const max_length,
long excursion_length = 0) {
332 long long segment_length = 0;
335 for (
auto const &sym : sym_list) {
337 auto [idxp, opp] = sym(idx);
341 if (is_valid(checked, idxp)) {
343 if (
not std::apply(checked, idxp)) {
345 std::apply(checked, idxp) =
true;
348 data.emplace_back(std::apply(checked.indexmap(), idxp), opp);
352 segment_length += iterate(idxp, opp, checked, sym_list, max_length) + 1;
354 }
else if (excursion_length < max_length) {
357 segment_length += iterate(idxp, opp, checked, sym_list, max_length, ++excursion_length);
363 return segment_length;
Class representing a symmetry group.
long num_classes() const
Get the number of symmetry classes.
void init(A &a, H const &init_func, bool parallel=false) const
Initialize an nda::Array using an nda::NdaInitFunc.
std::vector< sym_class_t > const & get_sym_classes() const
Get the symmetry classes.
void init_from_representative_data(A &a, V const &vec) const
Initialize a multi-dimensional array from its representative data using symmetries.
std::vector< get_value_t< A > > get_representative_data(A const &a) const
Reduce an nda::Array to its representative data using symmetries.
std::pair< double, arr_idx_t > symmetrize(A &a) const
Symmetrize an array and return the maximum symmetry violation and its corresponding array index.
sym_grp(A const &a, std::vector< F > const &sym_list, long const max_length=0)
Construct a symmetry group for a given array and a list of its symmetries.
sym_grp()=default
Default constructor for a symmetry group.
static constexpr int ndims
Rank of the input array.
bool is_valid(A const &a, std::array< long, static_cast< std::size_t >(get_rank< A >)> const &idx)
Check if a multi-dimensional index is valid, i.e. not out of bounds, w.r.t. to a given nda::Array obj...
A structure to capture combinations of complex conjugation and sign flip operations.
bool sgn
Boolean value indicating a sign flip operation.
T operator()(T const &x) const
Function call operator to apply the operation to a value.
operation operator*(operation const &rhs)
Multiplication operator for two operations.
bool cc
Boolean value indicating a complex conjugation operation.