Source code for triqs_modest.utils.checkpoint

"""
Checkpointer — an MPI-aware checkpoint manager for DMFT calculations.

Wraps CheckpointBase (the c2py-exposed C++ class) and adds:
- Smart create-or-open constructor, safe on all MPI ranks
- Optional initial data stored via HDFArchive (master only)
- Keyword-extras written alongside required iteration data (master only)
- restart() broadcasts last iteration to all ranks (collective)
- Full Python iteration, indexing, and len() support
"""

import os
import triqs.utility.mpi as mpi
from h5 import HDFArchive
from .file_io import CheckpointBase, IterationData

__all__ = ["Checkpointer", "IterationData"]


[docs] class Checkpointer: """MPI-aware checkpoint manager for DMFT calculations. Enforces saving the minimal restart set (mu, Sigma_imp_list, Sigma_hartree_list) per iteration and optionally stores additional quantities (Green's functions, density matrices, etc.) in the same HDF5 group via keyword arguments. All methods are safe to call on all MPI ranks — no ``is_master_node()`` guards required in user code. Only the master process performs HDF5 file I/O. Parameters ---------- dirname : str Path to the checkpoint directory. initial_data : dict, optional Data to store in initial_data.h5 (e.g. ``{"obe": obe, "embedding": E}``). Ignored if dirname already exists (i.e., when reopening). Examples -------- Create or resume, then run the DMFT loop (all ranks): .. code-block:: python chkpt = Checkpointer("my_calc", initial_data={"obe": obe, "Embedding" : E}) last = chkpt.restart() # collective — data on all ranks if last is not None: Sigma_dyn = last.Sigma_imp_list[0] Sigma_static = last.Sigma_hartree_list[0] else: Sigma_dyn, Sigma_static = make_zero_self_energies(mesh) it_shift = len(chkpt) for n in range(it_shift, n_loops): ... chkpt.append( IterationData(mu=mu, Sigma_imp_list=[Sigma_dyn], Sigma_hartree_list=[Sigma_static]), Delta_iw=Delta, Gimp_iw=Gimp, dm=dm, ) """ def __init__(self, dirname, initial_data=None) -> None: self._dirname = dirname self._cpp = None if mpi.is_master_node(): self._cpp = CheckpointBase(dirname) if initial_data is not None and len(self._cpp) == 0: with HDFArchive(os.path.join(dirname, "initial_data.h5"), "w") as ar: for key, value in initial_data.items(): ar[key] = value return None
[docs] def append(self, iteration_data, **extras) -> None: """Append an iteration to the checkpoint. Safe on all MPI ranks — non-master ranks are silently skipped. Parameters ---------- iteration_data : IterationData Required restart data: mu, Sigma_imp_list, Sigma_hartree_list. **extras Additional named quantities written to the same HDF5 group via HDFArchive. Supports any TRIQS-serializable type (block_gf, matrix, scalar, etc.). """ if mpi.is_master_node(): current_iter = len(self._cpp) self._cpp.append(iteration_data) if extras: with HDFArchive(os.path.join(self._dirname, "iterations.h5"), "a") as ar: for key, value in extras.items(): ar[str(current_iter)][key] = value
[docs] def restart(self) -> IterationData|None: """Return last iteration data on all MPI ranks, or None if empty. Collective operation — all ranks must call this together. Master reads from disk; components (mu, Sigma_imp_list, Sigma_hartree_list) are broadcast individually to all ranks. Returns ------- IterationData or None Last stored iteration on all ranks, None if no iterations exist. """ n_iter = None if mpi.is_master_node(): n_iter = len(self._cpp) n_iter = mpi.bcast(n_iter) if n_iter == 0: return None iter_data = None if mpi.is_master_node(): mu = self._cpp[-1].mu Sigma_imp_list = self._cpp[-1].Sigma_imp_list Sigma_H_list = self._cpp[-1].Sigma_hartree_list iter_data = IterationData(mu=mu, Sigma_imp_list=Sigma_imp_list, Sigma_hartree_list=Sigma_H_list) iter_data = mpi.bcast(iter_data) return iter_data
@property def dirname(self) -> str: """Path to the checkpoint directory.""" return self._dirname
[docs] def __getitem__(self, i) -> IterationData: """Index access (master only). Returns None on non-master ranks.""" if mpi.is_master_node(): return self._cpp[i] return None
[docs] def __len__(self) -> int: """Number of stored iterations on master, 0 on non-master. Use ``mpi.bcast(len(chkpt))`` to obtain the count on all ranks. """ n_iter = None if mpi.is_master_node(): n_iter = len(self._cpp) return mpi.bcast(n_iter)
def __iter__(self): for i in range(len(self)): yield self[i] def __repr__(self): if mpi.is_master_node(): n = len(self._cpp) return f"Checkpointer('{self._dirname}', n_iter={n})"