################################################################################
#
# TPRF: Two-Particle Response Function (TPRF) Toolbox for TRIQS
#
# Copyright (C) 2019 by The Simons Foundation
# Author: H. U.R. Strand
#
# TPRF is free software: you can redistribute it and/or modify it under the
# terms of the GNU General Public License as published by the Free Software
# Foundation, either version 3 of the License, or (at your option) any later
# version.
#
# TPRF is distributed in the hope that it will be useful, but WITHOUT ANY
# WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS
# FOR A PARTICULAR PURPOSE. See the GNU General Public License for more
# details.
#
# You should have received a copy of the GNU General Public License along with
# TPRF. If not, see <http://www.gnu.org/licenses/>.
#
################################################################################

import itertools


import itertools

import numpy as np

import triqs.utility.mpi as mpi
from h5 import HDFArchive

from triqs.gf import Gf, MeshImFreq, Fourier, LegendreToMatsubara, BlockGf, inverse, Idx, MeshProduct
from triqs.gf.tools import fit_legendre

from triqs.operators import n, c, c_dag, Operator
from triqs.operators.util.hamiltonians import h_int_kanamori, h_int_density
from triqs.operators.util.op_struct import set_operator_structure

from triqs_tprf.ParameterCollection import ParameterCollection
from triqs_tprf.ParameterCollection import ParameterCollections

import sys

if mpi.is_master_node():
    with HDFArchive(f'data_sc.h5', 'r') as a:
        p = a['ps'][-1]
else: p = None
p = mpi.bcast(p)

print("Measure triangle")

# -- Rewrite problem with completely spin-orbital diagonal Green's function
# -- and relabel local interaction operators.

num_orbitals = 1
spin_names = ('up','do')
orb_names = list(range(num_orbitals))
p.init.gf_struct = set_operator_structure(spin_names, orb_names, off_diag=False)

from pyed.OperatorUtils import relabel_operators

from_list = [ op(spin, orb) for spin, orb, op in itertools.product(spin_names, orb_names, [c, c_dag]) ]
to_list = [ op(spin+f'_{orb}', 0)  for spin, orb, op in itertools.product(spin_names, orb_names, [c, c_dag]) ]

p.solve.h_int = relabel_operators(p.solve.h_int, from_list, to_list)

WormComponents = [ list(x) for x in itertools.product(range(2*num_orbitals), repeat=4) ] 

print(f'WormComponents =\n{WormComponents}')

cfg_qmc=dict(
    PercentageWormInsert=0.3,
    PercentageWormReplace=0.1,
    WormEta=1,
    WormSearchEta=1,
    Nwarmups2Plus=10000,
    WormMeasP3iwPH=1,
    WormPHConvention=0, # Important to get correct frequency structure
    N3iwb=30,
    N3iwf=30,
    WormComponents=WormComponents, # Add G2 components to sample here.
    )

p.solve.n_warmup_cycles = int(1e7)
p.solve.n_cycles = int(1e8)
p.solve.measure_G_l = False

p.solve.worm = True
p.solve.cfg_qmc = cfg_qmc

from w2dyn_cthyb import Solver

cthyb = Solver(**p.init.dict())

cthyb.G0_iw['up_0'][0, 0] << p.G0_w['up'][0, 0]
cthyb.G0_iw['do_0'][0, 0] << p.G0_w['do'][0, 0]

cthyb.solve(**p.solve.dict())
p.GF_worm_components = cthyb.GF_worm_components

if mpi.is_master_node():
    with HDFArchive(f'data_triangle.h5', 'w') as a: a['p'] = p