TRIQS/TRIQS 4.0.0
Researching Interacting Quantum Systems
Loading...
Searching...
No Matches
mean_error.hpp
Go to the documentation of this file.
1// Copyright (c) 2019-2021 Simons Foundation
2//
3// This program is free software: you can redistribute it and/or modify
4// it under the terms of the GNU General Public License as published by
5// the Free Software Foundation, either version 3 of the License, or
6// (at your option) any later version.
7//
8// This program is distributed in the hope that it will be useful,
9// but WITHOUT ANY WARRANTY; without even the implied warranty of
10// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11// GNU General Public License for more details.
12//
13// You may obtain a copy of the License at
14// https://www.gnu.org/licenses/gpl-3.0.txt
15//
16// Authors: Philipp Dumitrescu, Olivier Parcollet, Nils Wentzell
17
22
23#pragma once
24
25#include "./concepts.hpp"
26#include "./utils.hpp"
27
28#include <itertools/itertools.hpp>
29#include <mpi/mpi.hpp>
30#include <nda/mpi.hpp>
31#include <nda/nda.hpp>
32
33#include <numeric>
34#include <optional>
35#include <ranges>
36#include <utility>
37
38namespace triqs::stat {
39
44
53 enum class mean_tag { sum, mean };
54
67 enum class error_tag { sum, var_data, var_mean, err_data, err_mean, jk_err };
68
80 template <mean_tag mtag, AccCompatible T> void apply_mean_tag(T &m, [[maybe_unused]] long nsamples) {
81 if constexpr (mtag == mean_tag::sum) m *= nsamples;
82 }
83
96 template <error_tag etag, AccCompatible T> void apply_error_tag(T &sum_sq_devs, [[maybe_unused]] long nsamples) {
97 if constexpr (etag == error_tag::sum) return;
98 auto const nd = static_cast<double>(nsamples);
99 if constexpr (etag == error_tag::err_data || etag == error_tag::var_data)
100 sum_sq_devs /= (nd - 1);
101 else if constexpr (etag == error_tag::err_mean || etag == error_tag::var_mean)
102 sum_sq_devs /= (nd * (nd - 1));
103 else if constexpr (etag == error_tag::jk_err)
104 sum_sq_devs *= (nd - 1) / nd;
105 if constexpr (etag == error_tag::err_data || etag == error_tag::err_mean || etag == error_tag::jk_err) sum_sq_devs = nda::sqrt(sum_sq_devs);
106 }
107
120 template <mean_tag mtag = mean_tag::mean, StatCompatibleRange R> auto mean(R &&rg) { // NOLINT (ranges should not be forwarded)
121 if constexpr (mtag == mean_tag::mean) {
122 // calculate the arithmetic mean
123 auto res = zeroed_sample(*std::ranges::begin(rg));
124 for (auto const &[n, x] : itertools::enumerate(rg)) res += (x - res) / (n + 1);
125 return res;
126 } else {
127 // calculate the simple sum
128 return std::accumulate(std::ranges::begin(rg), std::ranges::end(rg), zeroed_sample(*std::ranges::begin(rg)));
129 }
130 }
131
160 template <mean_tag mtag = mean_tag::mean, StatCompatibleRange R>
161 auto mean_mpi(std::optional<mpi::communicator> c, R &&rg) { // NOLINT (ranges should not be forwarded)
163
164 // early return if no communicator is provided
165 if (!c) return mean<mtag>(rg);
166
167 value_t res = mean<mtag>(rg);
168 if constexpr (mtag == mean_tag::mean) {
169 // for mtag == mean_tag::mean, we need to take care of different sample sizes
170 auto const n_i = std::ranges::size(rg);
171 auto const n = mpi::all_reduce(n_i, *c);
172 res *= static_cast<double>(n_i) / static_cast<double>(n);
173 }
174 mpi::all_reduce_in_place(res, *c);
175 return res;
176 }
177
195 template <error_tag etag = error_tag::err_mean, mean_tag mtag = mean_tag::mean, StatCompatibleRange R>
196 auto mean_and_err(R &&rg) { // NOLINT (ranges should not be forwarded)
197 // calculate the arithmetic mean and the sum of squared deviations from the mean
198 auto res_m = zeroed_sample(*std::ranges::begin(rg));
199 auto res_s = make_real(res_m);
200 for (auto const &[n, x] : itertools::enumerate(rg)) {
201 auto const nd = static_cast<double>(n);
202 res_s += abs_square(x - res_m) * nd / (nd + 1);
203 res_m += (x - res_m) / (nd + 1);
204 }
205 // apply the mean and error tags
206 apply_mean_tag<mtag>(res_m, std::ranges::size(rg));
207 apply_error_tag<etag>(res_s, std::ranges::size(rg));
208 return std::make_pair(res_m, res_s);
209 }
210
235 template <error_tag etag = error_tag::err_mean, mean_tag mtag = mean_tag::mean, StatCompatibleRange R>
236 auto mean_and_err_mpi(std::optional<mpi::communicator> c, R &&rg) { // NOLINT (ranges should not be forwarded)
238
239 // early return if no communicator is provided
240 if (!c) return mean_and_err<etag, mtag>(rg);
241
242 // local mean and sum of squared deviations from the mean
243 auto [m, ssqdev] = mean_and_err<error_tag::sum>(rg);
244
245 // reduce the sample size
246 auto const n = std::ranges::size(rg);
247 auto const n_red = mpi::all_reduce(n, *c);
248
249 // reduce the mean
250 value_t res_m = m * (static_cast<double>(n) / static_cast<double>(n_red));
251 mpi::all_reduce_in_place(res_m, *c);
252
253 // reduce the sum of squared deviations from the mean
254 ssqdev += n * abs_square(m - res_m);
255 mpi::all_reduce_in_place(ssqdev, *c);
256
257 // apply the mean and error tags
258 apply_mean_tag<mtag>(res_m, n_red);
259 apply_error_tag<etag>(ssqdev, n_red);
260 return std::make_pair(res_m, ssqdev);
261 }
262
277 [[nodiscard]] auto tau_estimate_from_vars(auto const &var, auto const &var0) {
278 return nda::make_regular(nda::map([](auto v, auto v0) { return (v0 == 0.0) ? nan_sample(v0) : 0.5 * (v / v0 - 1.0); })(var, var0));
279 }
280
296 template <StatCompatible T> auto tau_estimate_from_errors(T const &s_n, T const &s_0) {
298 }
299
301
302} // namespace triqs::stat
auto mean_and_err_mpi(std::optional< mpi::communicator > c, R &&rg)
Calculate the arithmetic mean or the simple sum as well as a corresponding error estimate of some ran...
error_tag
Tag to indicate what to calculate when computing the error of a range of values.
void apply_error_tag(T &sum_sq_devs, long nsamples)
Given the sum of squared deviations from the mean, , and the number of samples , apply a transformat...
auto tau_estimate_from_errors(T const &s_n, T const &s_0)
Compute an estimate for the integrated auto-correlation time from standard errors.
auto mean_and_err(R &&rg)
Calculate the arithmetic mean or the simple sum as well as a corresponding error estimate of some ran...
auto mean_mpi(std::optional< mpi::communicator > c, R &&rg)
Calculate the arithmetic mean or the simple sum of some range of values spread across multiple MPI pr...
mean_tag
Tag to indicate what to calculate when computing the mean of a range of values.
auto tau_estimate_from_vars(auto const &var, auto const &var0)
Compute an estimate for the integrated auto-correlation time from variances.
auto mean(R &&rg)
Calculate the arithmetic mean or the simple sum of some range of values.
void apply_mean_tag(T &m, long nsamples)
Given the mean and the number of samples , apply a transformation to get the result specified by the...
std::remove_cvref_t< decltype(nda::make_regular(std::declval< T >()))> get_regular_t
Type trait to get the type that would be returned by nda::make_regular.
Definition utils.hpp:57
auto nan_sample(T const &sample)
Get a sample with all elements set to NaN.
Definition utils.hpp:86
auto zeroed_sample(T const &sample)
Get a sample with all elements set to zero.
Definition utils.hpp:66
auto abs_square(auto const &x)
Calculate the (elementwise) absolute square of an array/view/scalar.
Definition utils.hpp:107
auto make_real(T &&t)
Make a given object real and regular.
Definition utils.hpp:51
Provides various concepts for the Utilities.
Provides various utilities for the Utilities.