"""
VASP driver for TRIQS+DFT workflow automation.
This module provides a driver class for automating VASP calculations using
PLO projectors in the context of DMFT calculations with TRIQS/modest.
VASP runs as a persistent forked process communicating via a lock file mechanism,
unlike QE which is called as separate subprocesses for each step.
"""
import os
import shlex
import signal
import time
from dataclasses import dataclass
from datetime import datetime
import numpy as np
from h5 import HDFArchive
import triqs.utility.mpi as mpi
from .converter import Converter
from .plovasp.converter import generate_and_output_as_text
[docs]
class DFTWorkflowError(Exception):
"""Exception raised for errors in the DFT workflow."""
pass
# Default environment variables to preserve for subprocess execution
_DEFAULT_ENV_VARS = ['PATH', 'LD_LIBRARY_PATH', 'SHELL', 'PWD', 'HOME', 'OMP_NUM_THREADS', 'OMPI_MCA_btl_vader_single_copy_mechanism']
[docs]
@dataclass
class MPIHandler:
"""
Handles MPI configuration and execution environment for parallel DFT calculations.
Attributes
----------
mpi_exec : str
The MPI executor command (e.g., "mpirun -np 16").
"""
mpi_exec : str = "mpirun"
def __post_init__(self):
if not self.mpi_exec.strip():
raise ValueError("mpi_exec cannot be empty")
def __repr__(self):
return f"MPIHandler(mpi_exec={self.mpi_exec})"
__str__ = __repr__
[docs]
def get_env_vars(self):
"""Retrieve essential environment variables for subprocess execution."""
env_vars = {}
for var_name in _DEFAULT_ENV_VARS:
var = os.getenv(var_name)
if var: env_vars[var_name] = var
return env_vars
[docs]
def report(self, *args):
return mpi.report(*args)
[docs]
def is_master_node(self) -> bool:
return mpi.is_master_node()
[docs]
class Driver:
"""
Driver for orchestrating VASP DFT calculations with PLO projectors.
VASP is started once as a persistent forked process and communicates
with the Python driver via a lock file mechanism:
- VASP creates vasp.lock when starting, deletes it when SCF is done
- Python creates vasp.lock to signal VASP to resume with updated charge
- VASP reads the charge correction, does another SCF step, deletes lock
Attributes
----------
seedname : str
Base name for HDF5 files.
plo_cfg : str
Path to the PLO configuration file.
vasp_command : str
VASP executable name (e.g., "vasp_std").
mpi_handler : MPIHandler
MPI configuration handler.
"""
def __init__(self, seedname: str, plo_cfg: str, mpi_handler: MPIHandler,
vasp_command: str = "vasp_std") -> None:
self.seedname = seedname
self.plo_cfg = plo_cfg
self.vasp_command = vasp_command
self.mpi_handler = mpi_handler
self._vasp_process_id = None
def __repr__(self):
return f"VaspDriver(seedname={self.seedname}, plo_cfg={self.plo_cfg}, vasp_command={self.vasp_command})"
__str__ = __repr__
def _fork_and_start_vasp(self):
"""
Fork a child process that runs VASP via MPI.
The child process running VASP never returns from this function.
The parent process returns the child's PID and continues.
Only called on the master MPI node.
Returns
-------
int
Process ID of the VASP child process.
"""
env_vars = self.mpi_handler.get_env_vars()
mpi_parts = shlex.split(self.mpi_handler.mpi_exec)
mpi_exe = mpi_parts[0]
# Resolve full path to MPI executable
for path_dir in env_vars.get('PATH', '').split(':'):
candidate = os.path.join(path_dir, mpi_exe)
if os.path.isfile(candidate) and os.access(candidate, os.X_OK):
mpi_exe = candidate
break
arguments = [mpi_exe] + mpi_parts[1:] + [self.vasp_command]
self.mpi_handler.report(f"[{datetime.now()}] Starting VASP: {' '.join(arguments)}")
vasp_process_id = os.fork()
if vasp_process_id == 0:
# Child process: close inherited file descriptors to avoid h5 handle leaks
for fd in range(3, 256):
try:
os.close(fd)
except OSError:
pass
print('\n Starting VASP now\n', flush=True)
os.execve(mpi_exe, arguments, env_vars)
print('\n VASP exec failed\n', flush=True)
os._exit(127)
return vasp_process_id
def _is_lock_file_present(self):
"""Check if vasp.lock exists (VASP is still running its SCF step)."""
res_bool = False
if mpi.is_master_node():
res_bool = os.path.isfile('./vasp.lock')
res_bool = mpi.bcast(res_bool)
return res_bool
def _is_vasp_running(self):
"""Check if the VASP child process is still alive."""
if self._vasp_process_id is None:
return False
try:
os.kill(self._vasp_process_id, 0)
return True
except OSError:
return False
def _wait_for_lock_to_appear(self, timeout=600):
"""Wait for VASP to create the lock file (indicating it has started)."""
start = time.time()
while not self._is_lock_file_present():
if time.time() - start > timeout:
raise DFTWorkflowError(f"VASP did not create vasp.lock within {timeout}s")
time.sleep(1)
mpi.barrier(poll_msec=100)
def _wait_for_vasp(self, timeout=3600):
"""Wait for VASP to finish (lock file disappears)."""
start = time.time()
while self._is_lock_file_present():
if time.time() - start > timeout:
raise DFTWorkflowError(f"VASP did not finish within {timeout}s")
if mpi.is_master_node() and not self._is_vasp_running():
raise DFTWorkflowError("VASP process has died unexpectedly")
time.sleep(1)
mpi.barrier(poll_msec=100)
def _run_plo_converter(self):
"""Run the PLO converter to generate H(k) and projectors in the HDF5 file."""
if mpi.is_master_node():
if not os.path.exists(self.plo_cfg):
raise FileNotFoundError(f"PLO config file not found: {self.plo_cfg}")
self.mpi_handler.report(f"[{datetime.now()}] Running PLO converter")
generate_and_output_as_text(self.plo_cfg, vasp_dir='./')
self.mpi_handler.report(f"[{datetime.now()}] Running VASP HDF5 converter")
Converter(filename=self.seedname).convert_dft_input()
mpi.barrier(poll_msec=100)
[docs]
def run_initial_stage(self, **kwargs) -> int:
"""
Start VASP, run initial SCF, and convert output to HDF5.
This forks VASP as a persistent background process, waits for the
initial SCF to complete, then runs the PLO converter.
"""
# Remove stale control files left over from a previous (crashed) run.
# Only done here, on the initial call: a leftover vasp.lock would make
# the driver think VASP is already running, and a leftover STOPCAR would
# abort VASP immediately.
if mpi.is_master_node():
for stale in ('STOPCAR', 'vasp.lock'):
if os.path.isfile(stale):
os.remove(stale)
mpi.barrier(poll_msec=100)
# Fork VASP on master node
vasp_process_id = 0
if mpi.is_master_node():
vasp_process_id = self._fork_and_start_vasp()
vasp_process_id = mpi.bcast(vasp_process_id)
self._vasp_process_id = vasp_process_id
# Wait for VASP to start (lock appears) then finish SCF (lock disappears)
self._wait_for_lock_to_appear()
self._wait_for_vasp()
# Convert VASP output to HDF5
self._run_plo_converter()
return 0
[docs]
def run_update_stage(self, N_k, Eint_m_dc, **kwargs):
"""
Run a single VASP charge-update SCF step and reconvert.
Steps:
1. Read DFT energy and compute band energy correction
2. Write the charge density correction to vaspgamma.h5
3. Create vasp.lock to trigger the VASP charge update
4. Wait for VASP to finish
5. Re-run the PLO converter with the updated projectors
Running several VASP steps per DMFT iteration (with the self-energy held
fixed) is left to the caller: invoke this once per step and recompute the
charge density correction from the updated one-body elements in between.
"""
dft_energy = self.read_dft_energy()
band_energy_correction = self.band_energy_and_write_charge_update(N_k)
mpi.report(f"DFT + DMFT Total Energy: {dft_energy + band_energy_correction + Eint_m_dc}")
# Trigger VASP to resume with updated charge
if mpi.is_master_node():
open('./vasp.lock', 'a').close()
mpi.barrier(poll_msec=100)
# Wait for VASP to finish charge update
self._wait_for_vasp()
# Re-convert with updated projectors
self._run_plo_converter()
return 0
[docs]
def band_energy_and_write_charge_update(self, N_k):
"""
Compute band energy correction and write charge density update to HDF5.
Parameters
----------
N_k : np.ndarray
DMFT density matrix in band basis, shape (n_k, n_spin, n_max_bands, n_max_bands).
Returns
-------
float
Band energy correction in eV.
"""
n_k, n_spin, n_max_bands, _ = N_k.shape
fermi_weights, n_bands_per_k, Hk, bz_weights = None, None, None, None
spin_to_data_index, SO = [0, 0], False
if mpi.is_master_node():
with HDFArchive(f"{self.seedname}.h5", "r") as ar:
fermi_weights = ar['dft_misc_input']['dft_fermi_weights']
n_bands_per_k = ar['dft_input']['n_orbitals']
Hk = ar['dft_input']['hopping']
bz_weights = ar['dft_input']['bz_weights']
SO = bool(ar['dft_input']['SO'])
# spin-orbit: one spinor channel; collinear spin-pol: up/down; else: index 0 twice
spin_to_data_index = [0] if SO else ([0, 1] if ar['dft_input']['SP'] else [0, 0])
fermi_weights = mpi.bcast(fermi_weights)
n_bands_per_k = mpi.bcast(n_bands_per_k)
Hk = mpi.bcast(Hk)
bz_weights = mpi.bcast(bz_weights)
spin_to_data_index = mpi.bcast(spin_to_data_index)
SO = mpi.bcast(SO)
mpi.barrier(poll_msec=100)
density_matrix_dft = [[fermi_weights[ik, idx, :].astype(complex) for ik in range(n_k)] for idx in spin_to_data_index]
band_energy_correction = 0.0
for ik in range(n_k):
for spin, idx in enumerate(spin_to_data_index):
nb = n_bands_per_k[ik, idx]
diag_indices = np.diag_indices(nb)
N_k[ik, spin][diag_indices] -= density_matrix_dft[spin][ik][:nb]
band_energy_correction += np.dot(N_k[ik][spin][:nb, :nb], Hk[ik, idx, :nb, :nb]).trace().real * bz_weights[ik]
mpi.report(f"\nBand Energy Correction= {band_energy_correction} eV\n")
# Build the per-k correction block: a single spinor channel for spin-orbit,
# otherwise the spin-average of the two DMFT spin channels (non-magnetic / SP=0).
delta_Nk = np.zeros((n_k, n_max_bands, n_max_bands), dtype=complex)
for ik in range(n_k):
nb0 = n_bands_per_k[ik, 0]
if SO:
delta_Nk[ik, :nb0, :nb0] = N_k[ik, 0, :nb0, :nb0]
else:
delta_Nk[ik, :nb0, :nb0] = 0.5 * (N_k[ik, 0, :nb0, :nb0] + N_k[ik, 1, :nb0, :nb0])
if mpi.is_master_node():
# Write the charge density correction in the format VASP's ADD_GAMMA_FROM_FILE
# reads (vaspgamma.h5): a top-level band_window plus a deltaN group with one
# (nb x nb) block per IBZ k-point, nb = number of bands in that k's window.
# Mirrors triqs_dft_tools SumkDFT.calc_density_correction(dm_type='vasp'):
# - spin-orbit (SO): a single spinor-band channel -> deltaN/ud
# - otherwise (non-magnetic / spin-averaged): up == down -> deltaN/up, deltaN/down
with HDFArchive(f"{self.seedname}.h5", "r") as ar:
band_window = ar['dft_misc_input']['band_window']
n_k_ibz = ar['dft_misc_input']['n_k_ibz']
deltaN_ibz = [delta_Nk[ik, :n_bands_per_k[ik, 0], :n_bands_per_k[ik, 0]]
for ik in range(n_k_ibz)]
with HDFArchive('vaspgamma.h5', 'w') as vasp_h5:
vasp_h5['band_window'] = [band_window[0][:n_k_ibz, :]]
vasp_h5.create_group('deltaN')
if SO:
vasp_h5['deltaN']['ud'] = deltaN_ibz
else:
vasp_h5['deltaN']['up'] = deltaN_ibz
vasp_h5['deltaN']['down'] = deltaN_ibz
# also keep the raw (full-grid) correction in the seedname h5 for the record
subgrp = 'dft_update'
with HDFArchive(f"{self.seedname}.h5", "a") as ar:
if subgrp not in ar: ar.create_group(subgrp)
ar[subgrp]["delta_N"] = delta_Nk
return band_energy_correction
[docs]
def read_dft_energy(self):
"""
Read DFT total energy from VASP output.
Tries vaspout.h5 first, falls back to OSZICAR.
VASP energies are already in eV.
Returns
-------
float
DFT total energy in eV.
"""
dft_energy = 0.0
if mpi.is_master_node():
if os.path.isfile('vaspout.h5'):
with HDFArchive('vaspout.h5', 'r') as h5:
if 'oszicar' in h5['intermediate/ion_dynamics']:
dft_energy = float(np.asarray(h5['intermediate/ion_dynamics/oszicar'][-1, 1]).item())
dft_energy = mpi.bcast(dft_energy)
return dft_energy
# Fallback: read from OSZICAR
last_nonempty_line = None
with open('OSZICAR', 'r') as f:
for line in f:
if line.strip():
last_nonempty_line = line
if last_nonempty_line is None:
raise DFTWorkflowError('OSZICAR is empty (cannot read DFT energy)')
parts = last_nonempty_line.split()
try:
dft_energy = float(parts[2])
except (IndexError, ValueError) as err:
raise DFTWorkflowError(f'Failed to parse DFT energy from OSZICAR: {last_nonempty_line!r}') from err
dft_energy = mpi.bcast(dft_energy)
return dft_energy
[docs]
def kill(self):
"""Terminate the VASP process cleanly."""
if self._vasp_process_id is not None and mpi.is_master_node():
self.mpi_handler.report(f"[{datetime.now()}] Stopping VASP (PID {self._vasp_process_id})")
with open('STOPCAR', 'wt') as f:
f.write('LABORT = .TRUE.\n')
try:
os.kill(self._vasp_process_id, signal.SIGTERM)
except OSError:
pass
self._vasp_process_id = None