Source code for amorphgen.pipeline.batch_quench

"""
amorphgen.pipeline.batch_quench
--------------------------------
Run a subset of pipeline stages independently on each of N input
structures, producing a library of amorphous candidates.

Typical use cases:

* **MQ snapshot quench** (default ``stages=[5, 6, 7]``): take N snapshots
  extracted from a Stage 4 high-T equilibration trajectory and quench each
  through stages 5 (cooling) -> 6 (low-T eq) -> 7 (final opt).
* **Hybrid workflow** (``stages=[4, 5, 6, 7]``): take N already-disordered
  inputs (e.g. ``--random-gen`` outputs), anneal at high T, then quench.

Stage numbers follow the canonical 7-stage pipeline:
4 = eq_high, 5 = quench, 6 = eq_low, 7 = final_opt.
"""

from __future__ import annotations

import os
import re
from copy import deepcopy

from ase.io import read, write


def _run_dir_name(snap_file: str, fallback_idx: int) -> str:
    """Pick a self-documenting run-dir name from a snapshot filename.

    ``snapshot_0007_frame00184.extxyz`` -> ``run_0007``.
    If no leading ``snapshot_NNNN`` index is parseable, fall back to
    ``run_{fallback_idx:04d}`` so the loop's enumerate index is preserved.
    """
    base = os.path.splitext(os.path.basename(snap_file))[0]
    m = re.match(r"snapshot[_-]?(\d+)", base)
    if m:
        return f"run_{int(m.group(1)):04d}"
    return f"run_{fallback_idx:04d}"

from ..utils import get_calculator, merge_config
from ..configs import DEFAULT_CONFIG
from . import quench, equilibrate, final_opt


[docs] def run(snapshot_files: list[str], n_runs: int | None = None, select: str = "uniform", cfg_override: dict | None = None, work_dir: str = "batch_quench", stages: list[int] | None = None, calc=None, resume: bool = False): """ Batch quench multiple snapshots. Parameters ---------- snapshot_files : list of str Paths to snapshot structure files. n_runs : int, optional Number of runs (defaults to len(snapshot_files)). select : str How to select snapshots: "uniform" or "last". cfg_override : dict, optional work_dir : str Base output directory. stages : list of int Which stages to run per snapshot. Stage numbers follow the canonical 7-stage pipeline: 4=eq_high, 5=quench, 6=eq_low, 7=final_opt. Default ``[5, 6, 7]`` (quench + eq_low + final_opt — the standard post-Stage-4 batch workflow). Include 4 when starting from random / already-disordered structures that need to be annealed first (the 'hybrid' workflow). calc : ASE calculator, optional resume : bool If True, skip runs whose final output already exists. """ if stages is None: stages = [5, 6, 7] global_cfg = merge_config(DEFAULT_CONFIG, cfg_override) os.makedirs(work_dir, exist_ok=True) if n_runs is None: n_runs = len(snapshot_files) # Select subset import numpy as np n_available = len(snapshot_files) if select == "uniform": indices = np.linspace(0, n_available - 1, min(n_runs, n_available), dtype=int) else: indices = list(range(max(0, n_available - n_runs), n_available)) selected = [snapshot_files[i] for i in indices] # Build calculator once if calc is None: device = global_cfg.get("device", "cuda") if device == "auto": import torch device = "cuda" if torch.cuda.is_available() else "cpu" calc = get_calculator( model=global_cfg.get("model", "mace-mpa-0"), device=device, model_path=global_cfg.get("model_path"), ) bar = "=" * 65 print(f"\n{bar}") print(f" Batch quench: {len(selected)} runs, stages {stages}") print(f" Output: {work_dir}/") print(f"{bar}\n") results = [] for i, snap_file in enumerate(selected): run_name = _run_dir_name(snap_file, fallback_idx=i) run_dir = os.path.join(work_dir, run_name) final_output = os.path.join(run_dir, "final_amorphous.xyz") legacy_final = os.path.join(run_dir, "final_amorphous.extxyz") if resume and (os.path.isfile(final_output) or os.path.isfile(legacy_final)): existing = final_output if os.path.isfile(final_output) else legacy_final print(f" [{run_name}] Already complete -- skipping.") results.append(read(existing)) continue os.makedirs(run_dir, exist_ok=True) print(f"\n {'-' * 60}") print(f" Run {i+1:04d} / {len(selected)} <- {os.path.basename(snap_file)} -> {run_name}/") print(f" {'-' * 60}") atoms = read(snap_file) atoms.calc = calc orig_dir = os.getcwd() os.chdir(run_dir) try: for s in stages: if s == 4: atoms = equilibrate.run(atoms, cfg_override=cfg_override, calc=calc, stage="high") elif s == 5: atoms = quench.run(atoms, cfg_override=cfg_override, calc=calc) elif s == 6: atoms = equilibrate.run(atoms, cfg_override=cfg_override, calc=calc, stage="low") elif s == 7: atoms = final_opt.run(atoms, cfg_override=cfg_override, calc=calc) else: raise ValueError( f"batch_quench: unknown stage {s}. " f"Allowed: 4 (eq_high), 5 (quench), 6 (eq_low), 7 (final_opt)." ) finally: os.chdir(orig_dir) write(final_output, atoms, format="extxyz") results.append(atoms) from ..utils.common import compute_density_gcm3 d = compute_density_gcm3(atoms) print(f" [{run_name}] Done -> {final_output} density={d:.2f} g/cm3") print(f"\n{bar}") print(f" Batch complete: {len(results)} structures generated") print(f"{bar}\n") return results