28#include <triqs/gfs.hpp>
29#include <triqs/mesh.hpp>
30#include <triqs/utility/tuple_tools.hpp>
35 namespace experimental {
37 using nda::stdutil::sum;
39 using namespace triqs::gfs;
40 using namespace triqs::mesh;
42 using dcomplex = std::complex<double>;
45 template <
int Rank>
class nfft_buf_t {
47 template <
typename = std::make_index_sequence<Rank>>
struct imfreq_product;
48 template <std::size_t... Is>
struct imfreq_product<std::index_sequence<Is...>> {
using type = prod<
decltype(Is, imfreq{})...>; };
51 using freq_mesh_t =
typename imfreq_product<>::type;
53 nfft_buf_t(freq_mesh_t
const &fiw_mesh, array_view<dcomplex, Rank> fiw_arr,
int buf_size,
bool do_checks =
false)
54 : fiw_mesh(fiw_mesh), fiw_arr(fiw_arr), buf_size(buf_size), do_checks(do_checks), plan_ptr(std::make_unique<nfft_plan>()), buf_counter(0) {
59 triqs::tuple::for_each_enumerate(fiw_mesh.components(), [
this](
int r, mesh::imfreq
const &m) {
60 if (m.statistic() == Fermion) {
62 common_factor *= (m.size() / 2) % 2 ? -1 : 1;
66 common_factor *= ((m.size() - 1) / 2) % 2 ? -1 : 1;
68 std::cerr <<
" ERROR: nfft_buf_t needs more bosonic frequencies.\n";
73 std::array<long, Rank> buf_extents = fiw_mesh.size_of_components() + index_shifts;
75 if (!all_fermion) nfft_indexmap = idx_map_t(buf_extents);
81 auto next_power_of_two = [](
unsigned int v) {
93 std::array<long, Rank> extents_fftw;
94 for (
int i = 0; i < Rank; i++) extents_fftw[i] = 2 * next_power_of_two(buf_extents[i]);
96 unsigned nfft_flags = PRE_PHI_HUT | PRE_PSI | MALLOC_X | MALLOC_F_HAT | MALLOC_F | FFTW_INIT | FFT_OUT_OF_PLACE | NFFT_SORT_NODES;
97 unsigned fftw_flags = FFTW_ESTIMATE | FFTW_DESTROY_INPUT;
100 std::vector<int> buf_extents_int(buf_extents.begin(), buf_extents.end());
101 std::vector<int> extents_fftw_int(extents_fftw.begin(), extents_fftw.end());
102 nfft_init_guru(plan_ptr.get(), Rank, buf_extents_int.data(), buf_size, extents_fftw_int.data(), m, nfft_flags, fftw_flags);
106 if (buf_counter != 0) std::cerr <<
" WARNING: Points in NFFT Buffer lost \n";
107 if (plan_ptr) nfft_finalize(plan_ptr.get());
111 nfft_buf_t(nfft_buf_t
const &) =
delete;
112 nfft_buf_t(nfft_buf_t &&) =
default;
113 nfft_buf_t &operator=(nfft_buf_t
const &) =
delete;
114 nfft_buf_t &operator=(nfft_buf_t &&) =
default;
117 void rebind(array_view<dcomplex, Rank> new_fiw_arr) {
119 assert(get_shape(new_fiw_arr) == get_shape(fiw_arr) || get_shape(fiw_arr) == get_shape(array_view<dcomplex, Rank>{}));
120 fiw_arr.rebind(new_fiw_arr);
125 void push_back(std::array<double, Rank>
const &tau_arr, dcomplex ftau) {
131 double sum_tau_beta = 0;
132 triqs::tuple::for_each_enumerate(fiw_mesh.components(), [&tau_arr, &sum_tau_beta,
this](
int r, mesh::imfreq
const &m) {
133 double tau = tau_arr[r];
134 double beta = m.beta();
137 x_arr()[buf_counter * Rank + r] = tau_arr[r] / beta - 0.5;
139 if (m.statistic() == Fermion) sum_tau_beta += tau / beta;
144 fx_arr()[buf_counter] = std::exp(1i * M_PI * sum_tau_beta) * ftau;
160 if (is_empty())
return;
163 for (
int i = buf_counter; i < buf_size; ++i) {
165 for (
int r = 0; r < Rank; ++r) x_arr()[i * Rank + r] = -0.5 + double(i) / buf_size;
172 bool is_empty()
const {
return buf_counter == 0; }
175 bool is_full()
const {
return buf_counter >= buf_size; }
179 freq_mesh_t fiw_mesh;
182 array_view<dcomplex, Rank> fiw_arr;
188 using idx_map_t = nda::idx_map<Rank, 0, C_stride_order<Rank>, layout_prop_e::none>;
189 idx_map_t nfft_indexmap;
193 std::array<long, Rank> index_shifts;
205 std::unique_ptr<nfft_plan> plan_ptr;
211 double *x_arr() {
return plan_ptr->x; }
214 dcomplex *fx_arr() {
return reinterpret_cast<dcomplex *
>(plan_ptr->f); }
217 const dcomplex *fk_arr()
const {
return reinterpret_cast<dcomplex *
>(plan_ptr->f_hat); }
228 auto N_min = *std::min_element(plan_ptr->N, plan_ptr->N + plan_ptr->d);
229 if (N_min <= plan_ptr->m) {
232 nfft_adjoint_direct(plan_ptr.get());
238 if (plan_ptr->nfft_flags & PRE_ONE_PSI) nfft_precompute_one_psi(plan_ptr.get());
240 if (plan_ptr->flags & PRE_ONE_PSI) nfft_precompute_one_psi(plan_ptr.get());
243 const char *error_str = nfft_check(plan_ptr.get());
244 if (error_str != 0) TRIQS_RUNTIME_ERROR <<
"Error in NFFT module: " << error_str <<
"\n";
248 nfft_adjoint(plan_ptr.get());
254 for (
auto it = fiw_arr.begin(); it != fiw_arr.end(); ++it, ++count) {
255 int factor = (sum(it.indices()) % 2 ? -1 : 1);
256 *it += fk_arr()[count] * factor * common_factor;
259 for (
auto it = fiw_arr.begin(); it != fiw_arr.end(); ++it) {
260 int count = std::apply(nfft_indexmap, it.indices() + index_shifts);
261 int factor = (sum(it.indices()) % 2 ? -1 : 1);
262 *it += fk_arr()[count] * factor * common_factor;