TRIQS/nda 1.3.0
Multi-dimensional array library for C++
Loading...
Searching...
No Matches
algorithms.hpp
Go to the documentation of this file.
1// Copyright (c) 2019-2023 Simons Foundation
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7// http://www.apache.org/licenses/LICENSE-2.0.txt
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14//
15// Authors: Olivier Parcollet, Nils Wentzell
16
17/**
18 * @file
19 * @brief Provides various algorithms to be used with nda::Array objects.
20 */
21
22#pragma once
23
24#include "./concepts.hpp"
25#include "./layout/for_each.hpp"
26#include "./traits.hpp"
27
28#include <algorithm>
29#include <cmath>
30#include <cstdlib>
31#include <functional>
32#include <type_traits>
33#include <utility>
34
35namespace nda {
36
37 /**
38 * @addtogroup av_algs
39 * @{
40 */
41
42 // FIXME : CHECK ORDER of the LOOP !
43 /**
44 * @brief Perform a fold operation on the given nda::Array object.
45 *
46 * @details It calculates the following (where r is an initial value);
47 *
48 * @code{.cpp}
49 * auto res = f(...f(f(f(r, a(0,...,0)), a(0,...,1)), a(0,...,2)), ...);
50 * @endcode
51 *
52 * @note The array is always traversed in C-order.
53 *
54 * @tparam A nda::Array type.
55 * @tparam F Callable type.
56 * @tparam R Type of the initial value.
57 * @param f Callable object taking two arguments compatible with the initial value and the array value type.
58 * @param a nda::Array object.
59 * @param r Initial value.
60 * @return Result of the fold operation.
61 */
62 template <Array A, typename F, typename R>
63 auto fold(F f, A const &a, R r) {
64 // cast the initial value to the return type of f to avoid narrowing
65 decltype(f(r, get_value_t<A>{})) r2 = r;
66 nda::for_each(a.shape(), [&a, &r2, &f](auto &&...args) { r2 = f(r2, a(args...)); });
67 return r2;
68 }
69
70 /// The same as nda::fold, except that the initial value is a default constructed value type of the array.
71 template <Array A, typename F>
72 auto fold(F f, A const &a) {
73 return fold(std::move(f), a, get_value_t<A>{});
74 }
75
76 /**
77 * @brief Does any of the elements of the array evaluate to true?
78 *
79 * @details The given nda::Array object can also be some lazy expression that evaluates to a boolean. For example:
80 *
81 * @code{.cpp}
82 * auto A = nda::array<double, 2>::rand(2, 3);
83 * auto greater05 = nda::map([](auto x) { return x > 0.5; })(A);
84 * auto res = nda::any(greater05);
85 * @endcode
86 *
87 * @tparam A nda::Array type.
88 * @param a nda::Array object.
89 * @return True if at least one element of the array evaluates to true, false otherwise.
90 */
91 template <Array A>
92 bool any(A const &a) {
93 static_assert(std::is_same_v<get_value_t<A>, bool>, "Error in nda::any: Value type of the array must be bool");
94 return fold([](bool r, auto const &x) -> bool { return r or bool(x); }, a, false);
95 }
96
97 /**
98 * @brief Do all elements of the array evaluate to true?
99 *
100 * @details The given nda::Array object can also be some lazy expression that evaluates to a boolean. For example:
101 *
102 * @code{.cpp}
103 * auto A = nda::array<double, 2>::rand(2, 3);
104 * auto greater0 = nda::map([](auto x) { return x > 0.0; })(A);
105 * auto res = nda::all(greater0);
106 * @endcode
107 *
108 * @tparam A nda::Array type.
109 * @param a nda::Array object.
110 * @return True if all elements of the array evaluate to true, false otherwise.
111 */
112 template <Array A>
113 bool all(A const &a) {
114 static_assert(std::is_same_v<get_value_t<A>, bool>, "Error in nda::all: Value type of the array must be bool");
115 return fold([](bool r, auto const &x) -> bool { return r and bool(x); }, a, true);
116 }
117
118 /**
119 * @brief Find the maximum element of an array.
120 *
121 * @details It uses nda::fold and `std::max`.
122 *
123 * @tparam A nda::Array type.
124 * @param a nda::Array object.
125 * @return Maximum element of the array.
126 */
127 template <Array A>
128 auto max_element(A const &a) {
129 return fold(
130 [](auto const &x, auto const &y) {
131 using std::max;
132 return max(x, y);
133 },
134 a, get_first_element(a));
135 }
136
137 /**
138 * @brief Find the minimum element of an array.
139 *
140 * @details It uses nda::fold and `std::min`.
141 *
142 * @tparam A nda::Array type.
143 * @param a nda::Array object.
144 * @return Minimum element of the array.
145 */
146 template <Array A>
147 auto min_element(A const &a) {
148 return fold(
149 [](auto const &x, auto const &y) {
150 using std::min;
151 return min(x, y);
152 },
153 a, get_first_element(a));
154 }
155
156 /**
157 * @ingroup av_math
158 * @brief Calculate the Frobenius norm of a 2-dimensional array.
159 *
160 * @tparam A nda::ArrayOfRank<2> type.
161 * @param a Array object.
162 * @return Frobenius norm of the array/matrix.
163 */
164 template <ArrayOfRank<2> A>
165 double frobenius_norm(A const &a) {
166 return std::sqrt(fold(
167 [](double r, auto const &x) -> double {
168 auto ab = std::abs(x);
169 return r + ab * ab;
170 },
171 a, double(0)));
172 }
173
174 /**
175 * @brief Sum all the elements of an nda::Array object.
176 *
177 * @tparam A nda::Array type.
178 * @param a nda::Array object.
179 * @return Sum of all elements.
180 */
181 template <Array A>
182 auto sum(A const &a)
183 requires(nda::is_scalar_v<get_value_t<A>>)
184 {
185 return fold(std::plus<>{}, a);
186 }
187
188 /**
189 * @brief Multiply all the elements of an nda::Array object.
190 *
191 * @tparam A nda::Array type.
192 * @param a nda::Array object.
193 * @return Product of all elements.
194 */
195 template <Array A>
196 auto product(A const &a)
197 requires(nda::is_scalar_v<get_value_t<A>>)
198 {
199 return fold(std::multiplies<>{}, a, get_value_t<A>{1});
200 }
201
202 /** @} */
203
204} // namespace nda
auto max_element(A const &a)
Find the maximum element of an array.
auto fold(F f, A const &a, R r)
Perform a fold operation on the given nda::Array object.
auto product(A const &a)
Multiply all the elements of an nda::Array object.
auto fold(F f, A const &a)
The same as nda::fold, except that the initial value is a default constructed value type of the array...
bool any(A const &a)
Does any of the elements of the array evaluate to true?
auto min_element(A const &a)
Find the minimum element of an array.
auto sum(A const &a)
Sum all the elements of an nda::Array object.
bool all(A const &a)
Do all elements of the array evaluate to true?
double frobenius_norm(A const &a)
Calculate the Frobenius norm of a 2-dimensional array.