Source code for amorphgen.utils.calculators

"""
amorphgen.utils.calculators
----------------------------
Calculator factory for universal machine learning interatomic
potentials (MLIPs).

All backends return a standard ASE calculator. The user picks a
model by short name
(``--model mace-mpa-0``, ``--model chgnet``, …)
and the factory handles the import, initialisation, and device
placement transparently.

Supported backends
~~~~~~~~~~~~~~~~~~
* **MACE**      — ``mace-mp-*``, ``mace-mpa-*``, ``mace-omat-*``,
                   ``mace-mh-*``, ``mace-matpes-*``, ``mace-omol``
* **CHGNet**    — ``chgnet``  (latest pretrained CHGNet)
* **SevenNet**  — ``sevennet``, ``7net-mf-ompa``, ``7net-l3i5``, ...
* **Custom**    — any local ``.model`` file via ``--model-path``
* **External**  — pass your own ASE calculator object directly
"""

from __future__ import annotations

import os
import warnings
from typing import Any


# ═════════════════════════════════════════════════════════════════════════════
# MACE model registry
# Source: https://github.com/ACEsuit/mace-foundations
# ═════════════════════════════════════════════════════════════════════════════

MACE_FOUNDATION_MODELS: dict[str, str] = {
    # ── MACE-MP-0a (initial release, MPTrj, PBE+U) ──────────────────────────
    "mace-mp-0a-small":   "small",
    "mace-mp-0a-medium":  "medium",
    "mace-mp-0a-large":   "large",
    # ── MACE-MP-0b (improved pair repulsion) ─────────────────────────────────
    "mace-mp-0b-small":   "small-0b",
    "mace-mp-0b-medium":  "medium-0b",
    "mace-mp-0b-large":   "large-0b",
    # ── MACE-MP-0b2 (high-pressure stability) ────────────────────────────────
    "mace-mp-0b2-small":  "small-0b2",
    "mace-mp-0b2-medium": "medium-0b2",
    "mace-mp-0b2-large":  "large-0b2",
    # ── MACE-MP-0b3 (fixed phonons vs 0b2) ───────────────────────────────────
    "mace-mp-0b3-small":  "small-0b3",
    "mace-mp-0b3-medium": "medium-0b3",
    "mace-mp-0b3-large":  "large-0b3",
    # ── MACE-MPA-0 (MPTrj + sAlex — recommended default) ────────────────────
    "mace-mpa-0":         "medium-mpa-0",
    "mace-mpa-0-medium":  "medium-mpa-0",
    # ── MACE-OMAT-0 (Open Materials — excellent phonons, ASL license) ────────
    "mace-omat-0-small":  "https://github.com/ACEsuit/mace-mp/releases/download/mace_omat_0/mace-omat-0-small.model",
    "mace-omat-0-medium": "https://github.com/ACEsuit/mace-mp/releases/download/mace_omat_0/mace-omat-0-medium.model",
    "mace-omat-0":        "https://github.com/ACEsuit/mace-mp/releases/download/mace_omat_0/mace-omat-0-medium.model",
    # ── MACE-MATPES (PBE / r2SCAN, ASL license) ─────────────────────────────
    "mace-matpes-pbe":    "https://github.com/ACEsuit/mace-foundations/releases/download/mace_matpes_0/MACE-matpes-pbe-omat-ft.model",
    "mace-matpes-r2scan": "https://github.com/ACEsuit/mace-foundations/releases/download/mace_matpes_0/MACE-matpes-r2scan-omat-ft.model",
    # ── MACE-MH (multi-domain: bulk + surface + molecule) ────────────────────
    "mace-mh-0":          "https://github.com/ACEsuit/mace-foundations/releases/download/mace_mh_1/mace-mh-0.model",
    "mace-mh-1":          "https://github.com/ACEsuit/mace-foundations/releases/download/mace_mh_1/mace-mh-1.model",
    # ── MACE-OMOL (molecules) ────────────────────────────────────────────────
    "mace-omol":          "https://github.com/ACEsuit/mace-foundations/releases/download/mace_omol_0/mace-omol-0-medium.model",
}

# ── CHGNet identifiers ───────────────────────────────────────────────────────
CHGNET_MODELS: set[str] = {"chgnet"}

# ── SevenNet identifiers ────────────────────────────────────────────────────
# Maps short name → SevenNet checkpoint name passed to SevenNetCalculator.
SEVENNET_MODELS: dict[str, str] = {
    "sevennet":       "7net-mf-ompa",        # alias → recommended default
    "sevennet-mf":    "7net-mf-ompa",
    "7net-mf-ompa":   "7net-mf-ompa",
    "7net-mf-0":      "7net-mf-0",
    "7net-omat":      "7net-omat",
    "7net-l3i5":      "7net-l3i5",
    "7net-0":         "7net-0",
    "7net-omni":      "7net-omni",
}

# ── Classical potential identifiers ──────────────────────────────────────────
CLASSICAL_MODELS: set[str] = {"lennard-jones", "lj", "buckingham", "buck"}


# ═════════════════════════════════════════════════════════════════════════════
# Human-readable model descriptions (for --list-models)
# ═════════════════════════════════════════════════════════════════════════════

MODEL_DESCRIPTIONS: dict[str, str] = {
    # MACE
    "mace-mp-0a-small":   "MACE-MP-0a  small   | MPTrj | DFT PBE+U | initial release",
    "mace-mp-0a-medium":  "MACE-MP-0a  medium  | MPTrj | DFT PBE+U | initial release",
    "mace-mp-0a-large":   "MACE-MP-0a  large   | MPTrj | DFT PBE+U | initial release",
    "mace-mp-0b-small":   "MACE-MP-0b  small   | MPTrj | improved pair repulsion",
    "mace-mp-0b-medium":  "MACE-MP-0b  medium  | MPTrj | improved pair repulsion",
    "mace-mp-0b-large":   "MACE-MP-0b  large   | MPTrj | improved pair repulsion",
    "mace-mp-0b2-small":  "MACE-MP-0b2 small   | MPTrj | improved high-pressure stability",
    "mace-mp-0b2-medium": "MACE-MP-0b2 medium  | MPTrj | improved high-pressure stability",
    "mace-mp-0b2-large":  "MACE-MP-0b2 large   | MPTrj | improved high-pressure stability",
    "mace-mp-0b3-small":  "MACE-MP-0b3 small   | MPTrj | fixed phonons vs 0b2",
    "mace-mp-0b3-medium": "MACE-MP-0b3 medium  | MPTrj | fixed phonons vs 0b2",
    "mace-mp-0b3-large":  "MACE-MP-0b3 large   | MPTrj | fixed phonons vs 0b2",
    "mace-mpa-0-medium":  "MACE-MPA-0  medium  | MPTrj+sAlex | ★ recommended default",
    "mace-omat-0-small":  "MACE-OMAT-0 small   | OMAT | excellent phonons | ASL license",
    "mace-omat-0-medium": "MACE-OMAT-0 medium  | OMAT | excellent phonons | ASL license",
    "mace-matpes-pbe":    "MACE-MATPES-PBE     | MATPES-PBE | DFT PBE, no +U | ASL",
    "mace-matpes-r2scan": "MACE-MATPES-r2SCAN  | MATPES-r2SCAN | better functional | ASL",
    "mace-mh-0":          "MACE-MH-0           | multi-domain bulk/surface/molecule",
    "mace-mh-1":          "MACE-MH-1           | multi-domain | ★ best cross-domain",
    "mace-omol":          "MACE-OMOL-0         | OMOL | optimised for molecules",
    # CHGNet
    "chgnet":             "CHGNet              | MPTrj | charge-informed | magnetic moments",
    # SevenNet
    "sevennet":           "SevenNet-MF-OMPA    | OMat+MPtrj+Alexandria | ★ recommended SevenNet",
    "7net-mf-ompa":       "SevenNet-MF-OMPA    | OMat+MPtrj+Alexandria | foundation",
    "7net-mf-0":          "SevenNet-MF-0       | multi-fidelity baseline",
    "7net-omat":          "SevenNet-OMat       | OMat-only training",
    "7net-l3i5":          "SevenNet-l3i5       | improved equivariant features",
    "7net-0":             "SevenNet-0          | original release (Jul 2024)",
    "7net-omni":          "SevenNet-omni       | omni model (multi-task)",
    # Classical
    "lennard-jones":      "Lennard-Jones       | pair potential | no GPU needed",
    "buckingham":         "Buckingham+Coulomb  | rigid-ion | Wolf summation | no GPU needed",
}


# ═════════════════════════════════════════════════════════════════════════════
# Backend loaders (lazy imports — each backend is only imported when needed)
# ═════════════════════════════════════════════════════════════════════════════

def _load_mace(model: str, device: str, model_path: str | None = None,
               **kwargs) -> Any:
    """Load a MACE calculator."""
    try:
        from mace.calculators import mace_mp, MACECalculator
    except ImportError:
        raise ImportError(
            "MACE is not installed. Install it with:\n"
            "  pip install amorphgen[mace]\n"
            "Or: pip install mace-torch\n"
            "See: https://github.com/ACEsuit/mace"
        )

    # Custom / local model file
    if model_path is not None:
        if not os.path.isfile(model_path):
            raise FileNotFoundError(
                f"Custom MACE model file not found: {model_path}\n"
                "Please provide a valid path to a .model file."
            )
        print(f"[MACE] Loading custom model: {model_path}")
        return MACECalculator(model_paths=model_path, device=device, **kwargs)

    # Resolve short-name → internal string / URL
    resolved = MACE_FOUNDATION_MODELS.get(model, model)

    if resolved.startswith("https://") or os.path.isfile(resolved):
        print(f"[MACE] Loading from URL/path: {resolved[:80]}")
        return MACECalculator(model_paths=resolved, device=device, **kwargs)
    else:
        print(f"[MACE] Loading foundation model '{model}' → mace_mp(model='{resolved}')")
        return mace_mp(model=resolved, device=device, **kwargs)

def _load_chgnet(device: str, default_dtype: str | None = None,
                 **kwargs) -> Any:
    """Load the pretrained CHGNet calculator.

    CHGNet's published MD benchmarks all use float32 and its checkpoint is
    stored at float32 natively.  Pre-2026-05-11 this loader accepted a
    ``default_dtype`` kwarg, silently forwarded it to ``CHGNetCalculator``,
    and the latter ignored it — so YAML configs with ``default_dtype: float64``
    were a no-op.  This loader makes the YAML field meaningful by setting
    torch's default dtype before importing chgnet (so its module-level
    ``TORCH_DTYPE`` constant captures the right precision) and overwriting
    ``chgnet.model.model.TORCH_DTYPE`` explicitly in case chgnet was already
    imported by an earlier call.

    **float64 is currently not supported** — CHGNet's ``composition_model``
    submodule builds its input feature vectors via a code path that bypasses
    ``TORCH_DTYPE``, so even after casting the main model to float64 the
    forward pass crashes with a dtype mismatch.  Use MACE (e.g. ``mace-mpa-0``)
    if you need a float64 backend.

    Parameters
    ----------
    device : str
        ``"cpu"``, ``"cuda"``, or ``"mps"``.
    default_dtype : {"float32", None}, optional
        ``None`` (default) and ``"float32"`` both give native float32.
        ``"float64"`` raises ``NotImplementedError``.

    Raises
    ------
    NotImplementedError
        If ``default_dtype="float64"`` — CHGNet's composition_model
        sub-module cannot be cleanly cast to float64.  Use MACE if you
        need float64 precision.
    ValueError
        If ``default_dtype`` is some other unrecognised string.
    """
    import torch

    if default_dtype is None:
        default_dtype = "float32"
    if default_dtype == "float64":
        raise NotImplementedError(
            "CHGNet does not support default_dtype='float64': its "
            "composition_model submodule constructs input tensors at "
            "float32 regardless of torch.get_default_dtype(), so the "
            "forward pass crashes with a dtype mismatch.  CHGNet is "
            "trained and benchmarked at float32 — keep default_dtype "
            "as 'float32' (or omit it) for CHGNet, or switch to MACE "
            "(model='mace-mpa-0', default_dtype='float64') if you need "
            "float64 precision."
        )
    if default_dtype != "float32":
        raise ValueError(
            f"default_dtype must be 'float32' or None for CHGNet; "
            f"got {default_dtype!r}"
        )

    # Set torch's default dtype *before* importing chgnet so the module-level
    # TORCH_DTYPE constant is captured at float32.  This is the standard
    # state for PyTorch anyway; we set it explicitly to override any earlier
    # caller that set float64 (e.g. a MACE loader earlier in the process).
    torch.set_default_dtype(torch.float32)

    try:
        from chgnet.model.model import CHGNet
        from chgnet.model.dynamics import CHGNetCalculator
        import chgnet.model.model as _chgnet_model_mod
    except ImportError:
        raise ImportError(
            "CHGNet is not installed. Install it with:\n"
            "  pip install chgnet\n"
            "See: https://chgnet.lbl.gov/"
        )

    # If chgnet was already imported, its TORCH_DTYPE is frozen — overwrite.
    if getattr(_chgnet_model_mod, "TORCH_DTYPE", None) is not torch.float32:
        _chgnet_model_mod.TORCH_DTYPE = torch.float32

    print(f"[CHGNet] Loading pretrained model on {device} (dtype=float32)")
    # Load on CPU first to avoid MPS float64 crash, then cast and move.
    model = CHGNet.load(use_device="cpu")

    if device == "mps":
        model = model.float()   # MPS requires float32
        model = model.to("mps")
        print("[CHGNet] Moved to MPS (float32)")

    # CHGNetCalculator silently ignores unknown kwargs; strip default_dtype
    # so it doesn't sit in **kwargs and confuse future signature checks.
    kwargs.pop("default_dtype", None)
    return CHGNetCalculator(model=model, use_device=device, **kwargs)




def _load_sevennet(model: str, device: str, **kwargs) -> Any:
    """Load a SevenNet calculator (KAIST equivariant MLIP).

    Multi-fidelity models (``7net-mf-*``) require a ``modal`` argument
    selecting the DFT functional/dataset; we default to ``'mpa'``
    (MPtrj+Alexandria, PBE) which is the closest match to most PBE-trained
    benchmarks. Override with ``modal='omat24'`` for OMat-style PBE+U.
    """
    try:
        from sevenn.calculator import SevenNetCalculator
    except ImportError:
        raise ImportError(
            "SevenNet is not installed. Install it with:\n"
            "  pip install sevenn\n"
            "See: https://github.com/MDIL-SNU/SevenNet"
        )

    checkpoint = SEVENNET_MODELS.get(model, model)
    # Multi-fidelity (mf) models need a modal selection
    if "mf" in checkpoint and "modal" not in kwargs:
        kwargs["modal"] = "mpa"
        print(f"[SevenNet] '{checkpoint}' is multi-fidelity; defaulting "
              f"modal='mpa' (MPtrj+Alexandria, PBE)")
    print(f"[SevenNet] Loading pretrained '{checkpoint}' on {device}")
    return SevenNetCalculator(model=checkpoint, device=device, **kwargs)


def _load_classical(model: str, device: str = "cpu", **kwargs) -> Any:
    """
    Load a classical pair potential calculator.

    Parameters are passed via ``classical_params`` in kwargs (from YAML
    config or Python API).

    Parameters
    ----------
    model : str
        "lennard-jones" / "lj" or "buckingham" / "buck".
    **kwargs
        Must contain ``classical_params`` dict with:

        For LJ::

            {"params": {("Ar","Ar"): {"epsilon": 0.0104, "sigma": 3.40}},
             "cutoff": 10.0}

        For Buckingham::

            {"params": {("Si","O"): {"A": 18003.76, "rho": 0.2052, "C": 133.54}},
             "charges": {"Si": 2.4, "O": -1.2},
             "cutoff": 10.0}
    """
    from .classical import LennardJonesCalculator, BuckinghamCalculator

    cp = kwargs.pop("classical_params", None)
    if cp is None:
        raise ValueError(
            f"Model '{model}' requires 'classical_params' with potential "
            f"parameters. Provide via YAML config or Python API.\n"
            f"See: amorphgen.utils.classical for parameter format."
        )

    # Convert string-keyed pair dicts to tuple keys
    # YAML gives {"Si-O": {...}} but calculators expect {("Si","O"): {...}}
    params = cp.get("params", {})
    converted = {}
    for key, val in params.items():
        if isinstance(key, str) and "-" in key:
            s1, s2 = key.split("-", 1)
            converted[(s1, s2)] = val
        else:
            converted[key] = val
    cp["params"] = converted

    lower = model.lower()
    dev_str = f", device={device}" if device != "cpu" else ""
    if lower in ("lennard-jones", "lj"):
        print(f"[Classical] Lennard-Jones, {len(converted)} pair(s), "
              f"cutoff={cp.get('cutoff', 10.0)} A{dev_str}")
        return LennardJonesCalculator(
            params=converted,
            cutoff=cp.get("cutoff", 10.0),
            device=device,
        )
    else:
        charges = cp.get("charges", {})
        coulomb = cp.get("coulomb", True)
        alpha = cp.get("alpha", 0.2)
        print(f"[Classical] Buckingham+Coulomb, {len(converted)} pair(s), "
              f"cutoff={cp.get('cutoff', 10.0)} A, "
              f"coulomb={'on' if coulomb else 'off'}{dev_str}")
        return BuckinghamCalculator(
            params=converted,
            charges=charges,
            cutoff=cp.get("cutoff", 10.0),
            alpha=alpha,
            coulomb=coulomb,
            device=device,
        )


# ═════════════════════════════════════════════════════════════════════════════
# Backend detection
# ═════════════════════════════════════════════════════════════════════════════

[docs] def _detect_backend(model: str) -> str: """ Determine which backend a model name belongs to. Returns one of: "mace", "chgnet", "sevennet", "classical". Raises ValueError if the model is not recognised. """ lower = model.lower() # Classical pair potentials if lower in CLASSICAL_MODELS: return "classical" # MACE — explicit registry match only if lower in MACE_FOUNDATION_MODELS: return "mace" # MACE prefix (allow custom mace-* models, but warn if not in registry) if lower.startswith("mace-"): import warnings warnings.warn( f"Model '{model}' not in MACE registry. " f"Will attempt to load it. Use --list-models to see known models.", stacklevel=3, ) return "mace" # CHGNet if lower in CHGNET_MODELS: return "chgnet" # SevenNet if lower in SEVENNET_MODELS or lower.startswith("7net") or lower.startswith("sevennet"): return "sevennet" raise ValueError( f"Unrecognised model '{model}'. Use --list-models to see available " f"options, or pass --model-path for a custom model file." )
# ═════════════════════════════════════════════════════════════════════════════ # Public API # ═════════════════════════════════════════════════════════════════════════════
[docs] def get_calculator( model: str = "mace-mpa-0", device: str = "auto", model_path: str | None = None, **kwargs, ) -> Any: """ Build and return an ASE calculator for the given foundation model. This is the **unified entry point** for all supported MLFF backends. The returned object is always a standard ASE calculator that can be attached to any ``ase.Atoms`` object. Parameters ---------- model : str Short name identifying the model. Examples: * MACE: ``"mace-mpa-0"``, ``"mace-mh-1"``, ``"mace-omat-0"`` * CHGNet: ``"chgnet"`` * SevenNet: ``"sevennet"``, ``"7net-mf-ompa"``, ``"7net-l3i5"`` * Classical: ``"lennard-jones"``, ``"buckingham"`` Use :func:`list_models` or ``--list-models`` to see all options. Ignored if *model_path* is provided (defaults to MACE backend). device : str ``"cuda"`` or ``"cpu"``. model_path : str, optional Path to a local ``.model`` file (e.g. a fine-tuned MACE model). Takes priority over *model*. Currently only MACE ``.model`` files are supported for custom paths. **kwargs Extra keyword arguments forwarded to the backend-specific calculator constructor. Returns ------- ase.calculators.calculator.Calculator A ready-to-use ASE calculator. Raises ------ ValueError If the model name is not recognised by any backend. ImportError If the required backend package is not installed. FileNotFoundError If *model_path* points to a non-existent file. Examples -------- >>> calc = get_calculator("mace-mpa-0", device="cuda") >>> calc = get_calculator("chgnet", device="cpu") >>> calc = get_calculator("7net-mf-ompa", device="cuda") >>> calc = get_calculator(model_path="/data/my_finetuned.model") >>> calc = get_calculator("buckingham", classical_params={...}) """ # ── Resolve "auto" device once, here, so backends always see a real # device string. Order: explicit > CUDA > MPS > CPU. if device == "auto": try: import torch if torch.cuda.is_available(): device = "cuda" elif (hasattr(torch.backends, "mps") and torch.backends.mps.is_available()): device = "mps" else: device = "cpu" except ImportError: device = "cpu" # ── Custom model path → MACE backend ────────────────────────────────── if model_path is not None: return _load_mace(model, device=device, model_path=model_path, **kwargs) # ── Route by backend ────────────────────────────────────────────────── backend = _detect_backend(model) if backend == "classical": return _load_classical(model, device=device, **kwargs) elif backend == "mace": return _load_mace(model, device=device, **kwargs) elif backend == "chgnet": return _load_chgnet(device=device, **kwargs) elif backend == "sevennet": return _load_sevennet(model, device=device, **kwargs) else: raise ValueError(f"Unknown backend '{backend}' for model '{model}'")
# ── Deprecated alias for backward compatibility ───────────────────────────
[docs] def get_mace_calculator( model: str = "mace-mpa-0", device: str = "cuda", model_path: str | None = None, **kwargs, ) -> Any: """ Build and return a MACE calculator. .. deprecated:: 2.0.0 Use :func:`get_calculator` instead, which supports MACE and all other backends (CHGNet, SevenNet). Parameters ---------- model : str MACE model short name (e.g. ``"mace-mpa-0"``). device : str ``"cuda"`` or ``"cpu"``. model_path : str, optional Path to a local ``.model`` file. **kwargs Forwarded to ``MACECalculator`` or ``mace_mp()``. Returns ------- ASE calculator """ warnings.warn( "get_mace_calculator() is deprecated. Use get_calculator() instead, " "which supports MACE and all other MLFF backends.", DeprecationWarning, stacklevel=2, ) return _load_mace(model, device=device, model_path=model_path, **kwargs)
[docs] def list_models() -> None: """Print all available foundation models to stdout.""" bar = "-" * 72 print(f"\n{bar}") print(" Available foundation models (pass as --model NAME)") print(bar) # Group by backend sections = [ ("MACE", {k: v for k, v in MODEL_DESCRIPTIONS.items() if k.startswith("mace-")}), ("CHGNet", {k: v for k, v in MODEL_DESCRIPTIONS.items() if k == "chgnet"}), ("SevenNet", {k: v for k, v in MODEL_DESCRIPTIONS.items() if k == "sevennet" or k.startswith("7net")}), ("Classical (requires classical_params in YAML)", { k: v for k, v in MODEL_DESCRIPTIONS.items() if k in ("lennard-jones", "buckingham")}), ] for backend_name, models in sections: print(f"\n -- {backend_name} {'-' * (60 - len(backend_name))}") for name, desc in models.items(): print(f" {name:<25s} {desc}") print(f"\n{bar}") print(" Custom model: --model-path /path/to/my_finetuned.model") print(f"{bar}\n")