"""
Quantum ESPRESSO driver for TRIQS+DFT workflow automation.
This module provides a driver class for automating Quantum ESPRESSO (QE) and
Wannier90 workflows in the context of DMFT calculations with TRIQS/modest
"""
import os, shlex, subprocess
from dataclasses import dataclass
from datetime import datetime
import numpy as np
from h5 import HDFArchive
import triqs.utility.mpi as mpi
from ..wannier90.converter import Converter
[docs]
class DFTWorkflowError(Exception):
"""Exception raised for errors in the DFT workflow."""
pass
# Default environment variables to preserve for subprocess execution
_DEFAULT_ENV_VARS = [
'PATH',
'LD_LIBRARY_PATH',
'SHELL',
'PWD',
'HOME',
'OMP_NUM_THREADS',
'OMPI_MCA_btl_vader_single_copy_mechanism'
]
[docs]
@dataclass
class MPIHandler:
"""
Handles MPI configuration and execution environment for parallel DFT calculations.
This class encapsulates MPI-related configuration and provides utilities for
running parallel executables with proper environment variable handling.
Attributes
----------
mpi_exec : str
The MPI executor command (e.g., "mpirun", "mpiexec", "srun").
n_cores : int
Number of MPI processes/cores to use for parallel execution.
Examples
--------
>>> mpi_handler = MPIHandler(mpi_exec="mpirun -np 4", n_cores=4)
>>> env = mpi_handler.get_env_vars()
"""
mpi_exec : str = "mpirun"
n_cores : int = 1
def __post_init__(self):
"""Validate MPI configuration."""
if self.n_cores <= 0:
raise ValueError(f"n_cores must be positive, got {self.n_cores}")
if not self.mpi_exec.strip():
raise ValueError("mpi_exec cannot be empty")
def __repr__(self):
return f"MPIHandler(mpi_exec={self.mpi_exec}, n_cores={self.n_cores})"
__str__ = __repr__
[docs]
def get_env_vars(self):
"""
Retrieve essential environment variables for subprocess execution.
Collects a subset of environment variables that are typically needed
for proper execution of MPI-enabled DFT codes.
Returns
-------
dict[str, str]
Dictionary of environment variable names and their values.
Only includes variables that are set in the current environment.
Note
----
The following variables are collected if present:
- PATH: System executable paths
- LD_LIBRARY_PATH: Shared library paths
- SHELL: User shell
- PWD: Present working directory
- HOME: User home directory
- OMP_NUM_THREADS: OpenMP thread count
- OMPI_MCA_btl_vader_single_copy_mechanism: OpenMPI optimization setting
"""
env_vars = {}
for var_name in _DEFAULT_ENV_VARS:
var = os.getenv(var_name)
if var: env_vars[var_name] = var
return env_vars
[docs]
def report(self, *args):
"""
Report a message (only from master MPI node).
Parameters
----------
*args : Any
Arguments to pass to mpi.report for printing.
"""
return mpi.report(*args)
[docs]
def is_master_node(self) -> bool:
"""
Check if this is the master MPI node.
Returns
-------
bool
True if this is the master node, False otherwise.
"""
return mpi.is_master_node()
[docs]
class Driver:
"""
Driver for orchestrating Quantum ESPRESSO and Wannier90 calculations.
This class automates the execution of a typical DFT workflow consisting of:
1. Self-consistent field (SCF) calculations
2. Non-self-consistent field (NSCF) calculations
3. Wannier90 preprocessing and execution
4. pw2wannier90 interface for projections
The driver is designed to work within DMFT self-consistency loops where
DFT calculations may need to be updated with charge density modifications.
Attributes
----------
seedname : str
Base name for all input/output files (e.g., "system" for "system.scf.in").
mpi_handler : MPIHandler
MPI configuration handler for parallel execution.
"""
[docs]
def __init__(self, seedname : str, mpi_handler : MPIHandler) -> None:
"""
Initialize the Quantum ESPRESSO driver.
Parameters
----------
seedname : str
Base name for all input/output files.
mpi_handler : MPIHandler
MPI configuration for parallel execution.
"""
self.seedname = seedname
self.mpi_handler = mpi_handler
def __repr__(self): return f"QuantumEspressoDriver(seedname= {self.seedname}, mpi_handler= {repr(self.mpi_handler)})"
__str__ = __repr__
def _run_qe_step(self, executable : str, step_name : str, k_parallel : bool = False) -> int:
"""
Execute a Quantum ESPRESSO calculation step.
This is an internal method that handles the execution of QE executables
(pw.x, pw2wannier90.x, etc.) with proper I/O redirection and error handling.
Parameters
----------
executable : str
Name of the QE executable to run (e.g., "pw.x", "pw2wannier90.x").
step_name : str
Calculation step identifier used for file naming (e.g., "scf", "nscf").
k_parallel : bool, optional
If True, enables k-point parallelization with -nk flag. Default is False.
Returns
-------
int
Return code from the subprocess execution (0 for success).
Note
----
Expected file naming convention:
- Input: {seedname}.{step_name}.in
- Output: {seedname}.{step_name}.out
- Error: {seedname}.{step_name}.err
The calculation is only executed on the master MPI node to avoid
duplicate execution in MPI environments.
Raises
------
FileNotFoundError
If the required input file does not exist.
DFTWorkflowError
If the executable returns a non-zero exit code.
"""
if not self.mpi_handler.is_master_node(): return 0
# validate input file existence
infile = f"{self.seedname}.{step_name}.in"
if not os.path.exists(infile): raise FileNotFoundError(f"Required input file not found: {infile}")
outfile = f"{self.seedname}.{step_name}.out"
errfile = f"{self.seedname}.{step_name}.err"
with open(infile, "r") as inp, open(outfile, "w") as out, open(errfile, "w") as err:
command = [*shlex.split(self.mpi_handler.mpi_exec), executable]
#if k_parallel: command.extend(["-nk", str(self.mpi_handler.n_cores)])
self.mpi_handler.report(f"[{datetime.now()}] running {' '.join(command)} < {infile} > {outfile}")
result = subprocess.run(
command,
stdin=inp,
stdout=out,
stderr=err,
env=self.mpi_handler.get_env_vars(),
text=True,
)
if result.returncode != 0:
error_msg = f"{executable} ({step_name}) failed with code {result.returncode}"
self.mpi_handler.report(error_msg)
raise DFTWorkflowError(error_msg)
return result.returncode
def _run_w90_step(self, executable : str) -> int:
"""
Execute a Wannier90 calculation step.
This is an internal method that handles the execution of Wannier90
executables with proper command-line argument formatting.
Parameters
----------
executable : str
Wannier90 executable command with flags (e.g., "wannier90.x -pp").
Returns
-------
int
Return code from the subprocess execution (0 for success).
Note
----
The seedname is automatically appended as the final command-line argument.
The calculation is only executed on the master MPI node.
Raises
------
DFTWorkflowError
If the executable returns a non-zero exit code.
"""
if not self.mpi_handler.is_master_node(): return 0
command = [*shlex.split(executable), self.seedname]
self.mpi_handler.report(f"[{datetime.now()}] running {' '.join(command)}")
try: subprocess.check_call(command, shell=False, env=self.mpi_handler.get_env_vars())
except subprocess.CalledProcessError as e:
error_msg = f"Wannier90 step failed: {executable}"
self.mpi_handler.report(error_msg)
raise DFTWorkflowError(error_msg) from e
return 0
[docs]
def run_scf(self) -> int:
"""
Run a self-consistent field (SCF) calculation.
Returns
-------
int
Return code (0 for success).
"""
return self._run_qe_step("pw.x", "scf", k_parallel=True)
[docs]
def run_mod_scf(self) -> int:
"""
Run a modified SCF calculation with updated charge density.
This is typically used in DMFT self-consistency loops where the
charge density has been modified based on the lattice self-energy.
Returns
-------
int
Return code (0 for success).
"""
return self._run_qe_step("pw.x", "mod_scf", k_parallel=True)
[docs]
def run_nscf(self) -> int:
"""
Run a non-self-consistent field (NSCF) calculation.
This step computes band structures on a denser k-point grid using
the converged charge density from SCF.
Returns
-------
int
Return code (0 for success).
"""
return self._run_qe_step("pw.x", "nscf", k_parallel=True)
[docs]
def run_pw2wannier90(self) -> int:
"""
Run the pw2wannier90 interface.
This step computes overlaps and projections needed for Wannier90
from the QE wavefunction data.
Returns
-------
int
Return code (0 for success).
"""
return self._run_qe_step("pw2wannier90.x", "pw2wannier90")
[docs]
def run_wannier90_pp(self) -> int:
"""
Run Wannier90 preprocessing step.
This generates the required k-point list and other auxiliary files
for the subsequent QE and Wannier90 calculations.
Returns
-------
int
Return code (0 for success).
"""
return self._run_w90_step("wannier90.x -pp")
[docs]
def run_wannier90(self) -> int:
"""
Run the Wannier90 wannierization.
This step performs the actual Wannier function construction and
generates the Wannier Hamiltonian.
Returns
-------
int
Return code (0 for success).
"""
return self._run_w90_step("wannier90.x")
[docs]
def band_energy_and_write_charge_update(self, N_k):
n_spin, n_k, n_max_bands, _ = N_k.shape
fermi_weights, n_bands_per_k, Hk, bz_weights = None, None, None, None
spin_to_data_index = [0,0]
if mpi.is_master_node():
with HDFArchive(f"{self.seedname}.h5", "r") as ar:
fermi_weights = ar['dft_misc_input']['dft_fermi_weights']
n_bands_per_k = ar['dft_input']['n_orbitals']
Hk = ar['dft_input']['hopping']
bz_weights = ar['dft_input']['bz_weights']
spin_to_data_index = [0,0] if ar['dft_input']['SO'] == 0 else [0,1]
fermi_weights = mpi.bcast(fermi_weights)
n_bands_per_k = mpi.bcast(n_bands_per_k)
Hk = mpi.bcast(Hk)
bz_weights = mpi.bcast(bz_weights)
spin_to_index = mpi.bcast(spin_to_data_index)
mpi.barrier()
density_matrix_dft = [ [ fermi_weights[ik,idx,:].astype(complex) for ik in range(n_k) ] for idx in spin_to_data_index]
band_energy_correction = 0.0
# subtract off DFT density matrix and compute band energy correction
for ik in range(n_k):
for spin, idx in enumerate(spin_to_data_index):
nb = n_bands_per_k[ik, idx]
diag_indices = np.diag_indices(nb)
N_k[spin, ik][diag_indices] -= density_matrix_dft[spin][ik][:nb]
band_energy_correction += np.dot(N_k[spin][ik], Hk[ik,idx,:nb,:nb]).trace().real * bz_weights[ik]
mpi.report(f"\nBand Energy Correction= {band_energy_correction} eV\n")
mpi.report(f"{n_k} -1 ! number of k-points, default number of bands\n")
delta_Nk = np.zeros( (n_k, n_max_bands, n_max_bands), dtype=complex)
for ik in range(n_k):
for inu in range(n_bands_per_k[ik,0]):
for imu in range(n_bands_per_k[ik,0]):
valre = (N_k[0][ik][inu, imu].real + N_k[1][ik][inu, imu].real) /2.0
valim = (N_k[0][ik][inu, imu].imag + N_k[1][ik][inu, imu].imag) /2.0
delta_Nk[ik, inu, imu] = valre + 1j*valim
if mpi.is_master_node():
subgrp = 'dft_update'
with HDFArchive(f"{self.seedname}.h5", "a") as ar:
if subgrp not in ar: ar.create_group(subgrp)
ar[subgrp]["delta_N"] = delta_Nk
return band_energy_correction
[docs]
def read_dft_energy(self):
energy_unit = 13.605693123 # eV
dft_energy = 0.0
if mpi.is_master_node():
read_scf = False if os.path.isfile(f"{self.seedname}.mod_scf.out") else True
if read_scf:
with open(f"{self.seedname}.scf.out", "r") as f: lines = f.readlines()
for line in lines:
if '!' in line:
dft_energy = float(line.split()[-2]) * energy_unit
break
else:
with open(f"{self.seedname}.mod_scf.out", "r") as f: lines = f.readlines()
for line in lines:
if "(sum(wg*et))" in line:
print("\nReading band energy from the mod_scf calculation\n")
band_energy_modscf = float(line.split()[-2]) * energy_unit
print(f"The mod_scf band energy is: {band_energy_modscf} eV")
if "total energy" in line:
print("\nReading total energy from the mod_scf calculation\n")
dft_energy = float(line.split()[-2]) * energy_unit
print(f"The uncorrected DFT energy is: {dft_energy} eV")
dft_energy -= band_energy_modscf
print(f"The DFT energy without kinetic part is: {dft_energy} eV")
with open(f"{self.seedname}.nscf.out", "r") as f: lines = f.readlines()
for line in lines:
if "The nscf band energy" in line:
print("\nReading band energy from the nscf calculation\n")
band_energy_nscf = float(line.split()[-2]) * energy_unit
dft_energy += band_energy_nscf
print(f"The nscf band energy is: {band_energy_nscf} eV")
print(f"The corrected DFT energy is: {dft_energy} eV")
break
dft_energy = mpi.bcast(dft_energy);
return dft_energy
[docs]
def run_initial_stage(self, **kwargs) -> int:
steps = [ ("SCF calculation", self.run_scf),
("NSCF calculation", self.run_nscf),
("Wannier90 preprocessing", self.run_wannier90_pp),
("pw2wannier90 interface", self.run_pw2wannier90),
("Wannier90 wannierization", self.run_wannier90),
("TRIQS QE Converter", Converter(seedname=self.seedname, bloch_basis=True, **kwargs).convert_dft_input),
]
status = 0
for i, (step_name, step_func) in enumerate(steps, 1): status = step_func()
return status
[docs]
def run_update_stage(self, N_k, Eint_m_dc, **kwargs):
dft_energy = self.read_dft_energy()
band_energy_correction = self.band_energy_and_write_charge_update(N_k)
mpi.report(f"! DFT + DMFT Total Energy: {dft_energy + band_energy_correction + Eint_m_dc}")
steps = [ ("Mod SCF calculation", self.run_mod_scf),
("NSCF calculation", self.run_nscf),
("Wannier90 preprocessing", self.run_wannier90_pp),
("pw2wannier90 interface", self.run_pw2wannier90),
("Wannier90 wannierization", self.run_wannier90),
("TRIQS QE Converter", Converter(seedname=self.seedname, bloch_basis=True, **kwargs).convert_dft_input),
]
status = 0
for i, (step_name, step_func) in enumerate(steps, 1): status = step_func()
return status