27#include <itertools/omp_chunk.hpp>
89 for (
auto i = 0; i < get_rank<A>; ++i) {
90 if (not(0 <= idx[i] and idx[i] < a.shape()[i])) {
return false; }
110 template <typename F, typename A, typename Idx = std::array<long, static_cast<std::size_t>(
get_rank<A>)>>
112 { f(idx) } -> std::same_as<std::tuple<Idx, operation>>;
125 template <typename F, typename A, typename Idx = std::array<long, static_cast<std::size_t>(
get_rank<A>)>>
127 { f(idx) } -> std::same_as<get_value_t<A>>;
143 template <
typename F,
typename A>
157 using arr_idx_t = std::array<long, static_cast<std::size_t>(ndims)>;
161 std::vector<sym_class_t> sym_classes;
164 std::vector<sym_idx_t> data;
171 [[nodiscard]] std::vector<sym_class_t>
const &
get_sym_classes()
const {
return sym_classes; }
177 [[nodiscard]]
long num_classes()
const {
return sym_classes.size(); }
190 template <
typename H>
192 void init(A &a, H
const &init_func,
bool parallel =
false)
const {
198 for (
auto const &sym_class : itertools::omp_chunk(mpi::chunk(sym_classes))) {
199 auto idx = a.indexmap().to_idx(sym_class[0].first);
200 auto ref_val = init_func(idx);
201 std::apply(a, idx) = ref_val;
202 for (
auto const &[lin_idx, op] : sym_class) { std::apply(a, a.indexmap().to_idx(lin_idx)) = op(ref_val); }
206 a = mpi::all_reduce(a);
208 for (
auto const &sym_class : sym_classes) {
209 auto idx = a.indexmap().to_idx(sym_class[0].first);
210 auto ref_val = init_func(idx);
211 std::apply(a, idx) = ref_val;
212 for (
auto const &[lin_idx, op] : sym_class) { std::apply(a, a.indexmap().to_idx(lin_idx)) = op(ref_val); }
228 double max_diff = 0.0;
230 for (
auto const &sym_class : sym_classes) {
233 for (
auto const &[lin_idx, op] : sym_class) { ref_val += op(std::apply(a, a.indexmap().to_idx(lin_idx))); }
234 ref_val /= sym_class.size();
237 for (
auto const &[lin_idx, op] : sym_class) {
238 auto mapped_val = op(ref_val);
239 auto mapped_idx = a.indexmap().to_idx(lin_idx);
240 auto current_val = std::apply(a, mapped_idx);
241 auto diff = std::abs(mapped_val - current_val);
243 if (diff > max_diff) {
245 max_idx = mapped_idx;
248 std::apply(a, mapped_idx) = mapped_val;
252 return std::pair{max_diff, max_idx};
262 long len = sym_classes.size();
263 auto vec = std::vector<get_value_t<A>>(len);
264 for (
long i : range(len)) vec[i] = std::apply(a, a.indexmap().to_idx(sym_classes[i][0].first));
274 template <
typename V>
276 static_assert(std::is_same_v<const get_value_t<A> &,
decltype(vec[0])>);
277 for (
long i : range(vec.size())) {
278 for (
auto const &[lin_idx, op] : sym_classes[i]) { std::apply(a, a.indexmap().to_idx(lin_idx)) = op(vec[i]); }
295 sym_grp(A
const &a, std::vector<F>
const &sym_list,
long const max_length = 0) {
301 data.reserve(a.size());
304 for_each(checked.
shape(), [&checked, &sym_list, max_length,
this](
auto... is) {
305 if (not checked(is...)) {
307 auto class_start = data.end();
310 checked(is...) = true;
314 data.emplace_back(checked.indexmap()(is...), op);
317 auto idx = std::array{is...};
318 auto class_size = iterate(idx, op, checked, sym_list, max_length) + 1;
321 sym_classes.emplace_back(class_start, class_size);
329 std::vector<F>
const &sym_list,
long const max_length,
long excursion_length = 0) {
332 long long segment_length = 0;
335 for (
auto const &sym : sym_list) {
337 auto [idxp, opp] = sym(idx);
343 if (not std::apply(checked, idxp)) {
345 std::apply(checked, idxp) =
true;
348 data.emplace_back(std::apply(checked.indexmap(), idxp), opp);
352 segment_length += iterate(idxp, opp, checked, sym_list, max_length) + 1;
354 }
else if (excursion_length < max_length) {
357 segment_length += iterate(idxp, opp, checked, sym_list, max_length, ++excursion_length);
363 return segment_length;
A generic multi-dimensional array.
auto const & shape() const noexcept
Get the shape of the view/array.
Class representing a symmetry group.
std::array< long, static_cast< std::size_t >(ndims)> arr_idx_t
Multi-dimensional index type.
long num_classes() const
Get the number of symmetry classes.
void init(A &a, H const &init_func, bool parallel=false) const
Initialize an nda::Array using an nda::NdaInitFunc.
std::vector< sym_class_t > const & get_sym_classes() const
Get the symmetry classes.
void init_from_representative_data(A &a, V const &vec) const
Initialize a multi-dimensional array from its representative data using symmetries.
std::vector< get_value_t< A > > get_representative_data(A const &a) const
Reduce an nda::Array to its representative data using symmetries.
std::pair< long, operation > sym_idx_t
Element type of a symmetry class.
std::span< sym_idx_t > sym_class_t
Symmetry class type.
std::pair< double, arr_idx_t > symmetrize(A &a) const
Symmetrize an array and return the maximum symmetry violation and its corresponding array index.
sym_grp(A const &a, std::vector< F > const &sym_list, long const max_length=0)
Construct a symmetry group for a given array and a list of its symmetries.
sym_grp()=default
Default constructor for a symmetry group.
Check if a given type satisfies the array concept.
Concept defining an initializer function.
Concept defining a symmetry in nda.
auto conj(T t)
Get the complex conjugate of a scalar.
bool is_valid(A const &a, std::array< long, static_cast< std::size_t >(get_rank< A >)> const &idx)
Check if a multi-dimensional index is valid, i.e. not out of bounds, w.r.t. to a given nda::Array obj...
basic_array< ValueType, Rank, Layout, 'A', ContainerPolicy > array
Alias template of an nda::basic_array with an 'A' algebra.
constexpr int get_rank
Constexpr variable that specifies the rank of an nda::Array or of a contiguous 1-dimensional range.
std::decay_t< decltype(get_first_element(std::declval< A const >()))> get_value_t
Get the value type of an array/view or a scalar type.
__inline__ void for_each(std::array< Int, R > const &shape, F &&f)
Loop over all possible index values of a given shape and apply a function to them.
Includes all MPI relevant headers.
Includes all relevant headers for the core nda library.
A structure to capture combinations of complex conjugation and sign flip operations.
bool sgn
Boolean value indicating a sign flip operation.
T operator()(T const &x) const
Function call operator to apply the operation to a value.
operation operator*(operation const &rhs)
Multiplication operator for two operations.
bool cc
Boolean value indicating a complex conjugation operation.