TRIQS/TRIQS 4.0.0
Researching Interacting Quantum Systems
Loading...
Searching...
No Matches
tail_fitter.hpp
Go to the documentation of this file.
1// Copyright (c) 2018 Commissariat à l'énergie atomique et aux énergies alternatives (CEA)
2// Copyright (c) 2018 Centre national de la recherche scientifique (CNRS)
3// Copyright (c) 2018-2023 Simons Foundation
4//
5// This program is free software: you can redistribute it and/or modify
6// it under the terms of the GNU General Public License as published by
7// the Free Software Foundation, either version 3 of the License, or
8// (at your option) any later version.
9//
10// This program is distributed in the hope that it will be useful,
11// but WITHOUT ANY WARRANTY; without even the implied warranty of
12// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
13// GNU General Public License for more details.
14//
15// You may obtain a copy of the License at
16// https://www.gnu.org/licenses/gpl-3.0.txt
17//
18// Authors: Olivier Parcollet, Nils Wentzell
19
24
25#pragma once
26
27#include "../arrays.hpp"
28#include "../utility/macros.hpp"
29
30#include <itertools/itertools.hpp>
31#include <nda/nda.hpp>
32#include <nda/lapack/gelss_worker.hpp>
33
34#include <algorithm>
35#include <array>
36#include <cmath>
37#include <complex>
38#include <memory>
39#include <optional>
40#include <type_traits>
41#include <utility>
42#include <vector>
43
44namespace triqs::mesh {
45 // Forward declaration.
46 class imfreq;
47} // namespace triqs::mesh
48
49namespace triqs::mesh::detail {
50
55
79 inline auto vander(std::vector<std::complex<double>> const &z_pts, int q) {
80 nda::matrix<std::complex<double>> V(z_pts.size(), q + 1);
81 for (auto [i, z_i] : itertools::enumerate(z_pts)) {
82 auto z = std::complex<double>{1};
83 for (int n = 0; n <= q; ++n) {
84 V(i, n) = z;
85 z *= z_i;
86 }
87 }
88 return V;
89 }
90
111 template <int R> auto tail_eval(nda::array_const_view<std::complex<double>, R> A, std::complex<double> z_0) {
112 auto compute = [&A, z_0](auto res) {
113 auto z = std::complex<double>{1};
114 auto const q = A.extent(0);
115 for (int n = 0; n < q; ++n, z /= z_0) res += A(n, nda::ellipsis{}) * z;
116 return res;
117 };
118 if constexpr (R > 1) {
119 // return an array of rank R - 1
120 return compute(nda::zeros<std::complex<double>>(nda::stdutil::front_pop(A.shape())));
121 } else {
122 // return a complex scalar
123 return compute(std::complex<double>{0});
124 }
125 }
126
127} // namespace triqs::mesh::detail
128
129namespace triqs::mesh {
130
198 class C2PY_IGNORE tail_fitter {
199 public:
201 static constexpr double default_tail_fraction = 0.2;
202
204 static constexpr int default_n_tail_max = 30;
205
215 tail_fitter(double r, int p_max, std::optional<int> q = {}) : r_(r), p_max_(p_max), adjust_q_(not q.has_value()), q_(adjust_q_ ? q_max_ : *q) {
216 if (r_ <= 0 or r_ > 1) TRIQS_RUNTIME_ERROR << "Error in tail-fitter: Fraction of mesh points must be in (0, 1]";
217 if (p_max_ <= 0) TRIQS_RUNTIME_ERROR << "Error in tail-fitter: Maximum number of mesh points must be > 0";
218 }
219
233 template <typename M> int n_pts_in_tail(M const &m) const { return std::min(static_cast<int>(std::round(r_ * m.size() / 2)), p_max_); }
234
236 double get_tail_fraction() const { return r_; }
237
246 template <typename M> auto get_tail_fit_indices(M const &m) {
247 // total number of points in the fitting window
248 auto const p_r = static_cast<int>(std::round(r_ * m.size() / 2));
249
250 // number of points actually used for the fit
251 auto const p = n_pts_in_tail(m);
252
253 // reserve space for the indices
254 std::vector<long> idx_vec;
255 idx_vec.reserve(2ul * p);
256
257 // initialize the left most and right most indices for both fitting windows
258 double const step = static_cast<double>(p_r) / p;
259 double left_idx = m.first_index();
260 double right_idx = m.last_index();
261
262 for ([[maybe_unused]] auto i : nda::range(p)) {
263 idx_vec.push_back(long(left_idx));
264 idx_vec.push_back(long(right_idx));
265 left_idx += step;
266 right_idx -= step;
267 }
268
269 return idx_vec;
270 }
271
279 template <bool enforce_hermiticity = false> auto &get_lss() {
280 if constexpr (enforce_hermiticity)
281 return lss_hermitian_;
282 else
283 return lss_;
284 }
285
311 template <bool enforce_hermiticity = false, typename M> void setup_lss(M const &m, int n_A) {
312 // least square worker type
313 using worker_t = std::conditional_t<enforce_hermiticity, nda::lapack::gelss_worker_hermitian, nda::lapack::gelss_worker<std::complex<double>>>;
314
315 // indices of the mesh points to use in the tail fit
316 if (fit_idxs_.empty()) fit_idxs_ = get_tail_fit_indices(m);
317
318 // set up Vandermonde matrix (the points are given by z_i = |m.w_max()| / m.to_value(n))
319 double const z_max = std::abs(m.w_max());
320 if (V_.is_empty()) {
321 std::vector<std::complex<double>> z_pts;
322 z_pts.reserve(fit_idxs_.size());
323 for (long n : fit_idxs_) z_pts.push_back(z_max / m.to_value(n));
324 V_ = detail::vander(z_pts, q_);
325 }
326
327 // check if we have enough data points for the least square procedure (p > n_A + 1)
328 if (n_A + 1 > V_.extent(0) / 2) TRIQS_RUNTIME_ERROR << "Error in tail-fitter::setup_lss: Insufficient data points for least square procedure";
329
330 // factory function for least square workers
331 auto worker_factory = [&](int n) { return std::make_unique<const worker_t>(V_(nda::range::all, nda::range(n_A, n + 1))); };
332
333 // get the correct (hermitian vs. non-hermitian) array of least square workers
334 auto &lss = get_lss<enforce_hermiticity>();
335
336 // set up the least square workers
337 if (!adjust_q_) {
338 // use the expansion order given in the constructor
339 lss[n_A] = worker_factory(q_);
340 } else {
341 // find the maximum expansion order such that the smallest singular value of the Vandermonde matrix is > rcond_
342 lss[n_A].reset();
343 // ensure that |z_max|^{1-q} > 10^{-16}
344 long q_max = std::min(static_cast<long>(q_max_), static_cast<long>(1. + 16. / std::log10(1 + std::abs(m.w_max()))));
345 // we try to use at least two times as many data points as we have unknown coefficients
346 q_max = std::min(q_max, V_.extent(0) / 2);
347 for (long q = q_max; q >= n_A; --q) {
348 auto ptr = worker_factory(q);
349 if (ptr->S_vec()[ptr->S_vec().size() - 1] > rcond_) {
350 lss[n_A] = std::move(ptr);
351 break;
352 }
353 }
354 }
355
356 // throw an exception if the Vandermonde matrix is ill-conditioned
357 if (!lss[n_A]) TRIQS_RUNTIME_ERROR << "Error in tail-fitter::setup_lss: Ill-conditioned Vandermonde matrix";
358 }
359
397 template <int P, bool enforce_hermiticity = false, typename M, int R>
398 auto fit(M const &m, nda::array_const_view<std::complex<double>, R> D, bool rescale, nda::array_const_view<std::complex<double>, R> C,
399 std::optional<long> d = {}) {
400 // compile-time and run-time checks
401 static_assert(!enforce_hermiticity || std::is_same_v<M, imfreq>);
402 if (enforce_hermiticity and not d.has_value())
403 TRIQS_RUNTIME_ERROR << "Error in tail-fitter::fit: Enforcing hermiticity requires an inner matrix dimension";
404 if (m.positive_only()) TRIQS_RUNTIME_ERROR << "Error in tail-fitter::fit: Cannot fit on a positive_only mesh";
405
406 // early return if the number of known coefficients is larger than the expansion order
407 int const n_A = C.extent(0);
408 if (n_A > q_) return std::pair<nda::array<std::complex<double>, R>, double>{C, 0.0};
409
410 // set up the least squares worker for the given number of known coefficients if it has not been done already
411 auto &lss = get_lss<enforce_hermiticity>();
412 if (!lss[n_A]) setup_lss<enforce_hermiticity>(m, n_A);
413
414 // permute the indices of D such that the relevant frequency mesh corresponds to the first dimension
415 auto D_rot = nda::rotate_index_view<P>(D);
416
417 // flatten D in the target space and the remaining meshes into the second dimension of a new matrix
418 long const ncols = D_rot.size() / D_rot.shape()[0];
419 auto D_mat = nda::matrix<std::complex<double>>(V_.extent(0), ncols);
420 for (auto [i, n] : itertools::enumerate(fit_idxs_)) {
421 if constexpr (R == 1) {
422 D_mat(i, 0) = D_rot(m.to_data_index(n));
423 } else {
424 for (auto [j, x] : itertools::enumerate(D_rot(m.to_data_index(n), nda::ellipsis{}))) { D_mat(i, j) = x; }
425 }
426 }
427
428 // flatten and prepare the array of known coefficients C
429 double const z_max = std::abs(m.w_max());
430 if (n_A > 0) {
431 // check the shape of C
432 if (ncols != C.size() / C.shape()[0])
433 TRIQS_RUNTIME_ERROR << "Error in tail-fitter::fit: Shape of C array incompatible with the shape of the D array";
434
435 // flatten C and scale its values by |z_max|^{-q}
436 double z = 1.0;
437 auto C_mat = nda::matrix<std::complex<double>>(n_A, ncols);
438 for (long i : nda::range(n_A)) {
439 if constexpr (R == 1) {
440 C_mat(i, 0) = z * C(i, nda::ellipsis{});
441 } else {
442 for (auto [j, x] : itertools::enumerate(C(i, nda::ellipsis{}))) C_mat(i, j) = z * x;
443 }
444 z /= z_max;
445 }
446
447 // subtract the expansion terms corresponding to the known moments from the function values
448 D_mat -= V_(nda::range::all, nda::range(n_A)) * C_mat;
449 }
450
451 // perform the least squares procedure
452 auto [A_mat, err] = (*lss[n_A])(D_mat, d);
453
454 // rescale the coefficients if requested
455 if (rescale) {
456 double z = 1.0;
457 for ([[maybe_unused]] long i : nda::range(n_A)) z *= z_max;
458 for (long i : nda::range(A_mat.extent(0))) {
459 A_mat(i, nda::range::all) *= z;
460 z *= z_max;
461 }
462 }
463
464 // reinterpret the result as an R-dimensional array according to the initial shape
465 auto shape = D_rot.shape();
466 shape[0] = lss[n_A]->n_var() + n_A;
467 auto A = nda::array<std::complex<double>, R>{shape};
468
469 // add the known moments to the result
470 if (n_A) A(nda::range(n_A), nda::ellipsis{}) = C;
471
472 // add the calculated moments to the result
473 shape[0] = lss[n_A]->n_var();
474 auto idxmap = typename nda::array<std::complex<double>, R>::layout_t{shape};
475 auto A_view = A(nda::range(n_A, A.shape()[0]), nda::ellipsis{});
476 A_view = nda::array_view<std::complex<double>, R>{idxmap, A_mat.storage()};
477
478 return std::pair<nda::array<std::complex<double>, R>, double>{std::move(A), err};
479 }
480
497 template <int P, typename M, int R>
498 auto fit_hermitian(M const &m, nda::array_const_view<std::complex<double>, R> D, bool rescale, nda::array_const_view<std::complex<double>, R> C,
499 std::optional<long> d = {}) {
500 return fit<P, true, M, R>(m, D, rescale, C, d);
501 }
502
503 private:
504 static constexpr int q_max_ = 9;
505 static constexpr double rcond_ = 1e-8;
506 double r_;
507 int p_max_;
508 bool adjust_q_;
509 int q_;
510 std::array<std::unique_ptr<const nda::lapack::gelss_worker<dcomplex>>, q_max_ + 1> lss_;
511 std::array<std::unique_ptr<const nda::lapack::gelss_worker_hermitian>, q_max_ + 1> lss_hermitian_;
512 nda::matrix<dcomplex> V_;
513 std::vector<long> fit_idxs_;
514 };
515
520 class C2PY_IGNORE tail_fitter_handle {
521 public:
530 void set_tail_fit_parameters(double tail_fraction, int n_tail_max = tail_fitter::default_n_tail_max,
531 std::optional<int> expansion_order = {}) const {
532 tf_ptr_ = std::make_shared<tail_fitter>(tail_fitter{tail_fraction, n_tail_max, expansion_order});
533 }
534
543 C2PY_IGNORE tail_fitter &get_tail_fitter() const {
544 if (!tf_ptr_) tf_ptr_ = std::make_shared<tail_fitter>(tail_fitter::default_tail_fraction, tail_fitter::default_n_tail_max);
545 return *tf_ptr_;
546 }
547
557 C2PY_IGNORE tail_fitter &get_tail_fitter(double tail_fraction, int n_tail_max = tail_fitter::default_n_tail_max,
558 std::optional<int> expansion_order = {}) const {
559 set_tail_fit_parameters(tail_fraction, n_tail_max, expansion_order);
560 return *tf_ptr_;
561 }
562
563 private:
564 mutable std::shared_ptr<tail_fitter> tf_ptr_;
565 };
566
568
569} // namespace triqs::mesh
Backward-compatibility umbrella header pulling in the nda array library.
Imaginary frequency mesh type.
Definition imfreq.hpp:102
Shared handle for tail fitting.
void set_tail_fit_parameters(double tail_fraction, int n_tail_max=tail_fitter::default_n_tail_max, std::optional< int > expansion_order={}) const
Set the pointer to a new tail-fitter object constructed with the given parameters.
tail_fitter & get_tail_fitter() const
Get the triqs::mesh::tail_fitter object.
tail_fitter & get_tail_fitter(double tail_fraction, int n_tail_max=tail_fitter::default_n_tail_max, std::optional< int > expansion_order={}) const
Construct a new triqs::mesh::tail_fitter object with the given parameters and return it.
Fit the high- and low-frequency tail of a function defined on a triqs::mesh::refreq or a triqs::mesh...
auto get_tail_fit_indices(M const &m)
Get a vector containing the indices of all mesh points to use in the tail fit, i.e....
double get_tail_fraction() const
Get the fraction of the mesh to be considered for the tail fit.
static constexpr double default_tail_fraction
Default fraction of the mesh to consider in the tail fit.
auto fit_hermitian(M const &m, nda::array_const_view< std::complex< double >, R > D, bool rescale, nda::array_const_view< std::complex< double >, R > C, std::optional< long > d={})
Perform a linear least squares fit and return the coefficients of the tail expansion together with t...
static constexpr int default_n_tail_max
Default maximum number of points to use in the tail fit.
void setup_lss(M const &m, int n_A)
Set up the linear least squares workers for a given mesh and a given number of known coefficient arr...
tail_fitter(double r, int p_max, std::optional< int > q={})
Construct a tail fitter for a given fraction of the mesh, the maximum number of mesh points to use ...
auto & get_lss()
Get linear least squares workers.
int n_pts_in_tail(M const &m) const
Get the number of mesh points used in the tail fit for the given mesh.
auto fit(M const &m, nda::array_const_view< std::complex< double >, R > D, bool rescale, nda::array_const_view< std::complex< double >, R > C, std::optional< long > d={})
Perform a linear least squares fit and return the coefficients of the tail expansion together with t...
auto vander(std::vector< std::complex< double > > const &z_pts, int q)
Construct a Vandermonde matrix.
auto tail_eval(nda::array_const_view< std::complex< double >, R > A, std::complex< double > z_0)
Evaluate the tail expansion of a function at a given point .
#define TRIQS_RUNTIME_ERROR
Throw a triqs::runtime_error with the current source location.
Common macros used in TRIQS.