__author__ = "Logan Lang"
__maintainer__ = "Logan Lang"
__email__ = "lllang@mix.wvu.edu"
__date__ = "March 31, 2020"
import ast
import copy
import logging
import math
import os
import re
import xml.etree.ElementTree as ET
from dataclasses import dataclass, field
from enum import Enum
from functools import cached_property
from pathlib import Path
from typing import Dict, List, Optional, Tuple, Union
import numpy as np
from pyprocar.core import DensityOfStates, ElectronicBandStructure, KPath, Structure
from pyprocar.io.qe.projwfc import AtomicProjXML, ProjwfcDOS, ProjwfcIn, ProjwfcOut
from pyprocar.io.qe.pw import PwIn, PwOut, PwXML
from pyprocar.utils.units import AU_TO_ANG, HARTREE_TO_EV
logger = logging.getLogger(__name__)
user_logger = logging.getLogger("user")
[docs]
class QEParser:
"""Auto-detects Quantum ESPRESSO files in a directory and exposes
lazy parser properties and computed objects (EBS, DOS, Structure).
Example
-------
parser = QEParserAuto("/path/to/qe/calculation")
print(parser.summary())
ebs = parser.ebs
dos = parser.dos
structure = parser.structure
"""
[docs]
def __init__(self, dirpath: Union[str, Path]) -> None:
self._dirpath: Path = Path(dirpath)
self._detected: Dict[str, Union[Path, List[Path], None]] = {
"scf_in": None,
"scf_out": None,
"bands_in": None,
"bands_out": None,
"nscf_in": None,
"nscf_out": None,
"projwfc_in": None,
"projwfc_out": None,
"pdos_files": [],
"atomic_proj_xml": None,
"data_file_schema_xml": None,
"data_xml": None,
"pw_xml": None,
}
self.detect_files()
# -------- file detection --------
[docs]
def detect_files(self) -> None:
if not self._dirpath.exists():
user_logger.warning(f"Directory not found: {self._dirpath}")
return
files: List[Path] = []
for root, dirs, filenames in os.walk(self._dirpath, followlinks=True):
for name in filenames:
try:
files.append(Path(root) / name)
except Exception:
pass
# Only works for pathlib==3.13 or python==3.13
# files = [p for p in self._dirpath.rglob("*", recurse_symlinks=True) if p.is_file()]
# XMLs
atomic_proj_xml = [p for p in files if re.search(r"(?i)^atomic_proj\.xml$", p.name)]
data_file_schema = [p for p in files if re.search(r"(?i)data-file-schema\.xml$", p.name)]
data_xmls = [p for p in files if re.search(r"(?i)data-file\.xml$", p.name)]
pw_xmls = [p for p in files if p.suffix.lower() == ".xml" and p.name.lower() not in {"atomic_proj.xml", "data-file-schema.xml", "data-file.xml"}]
# Inputs
# Detect inputs by peeking content
in_files = [p for p in files if p.suffix.lower() == ".in"]
scf_ins: List[Path] = []
bands_ins: List[Path] = []
nscf_ins: List[Path] = []
projwfc_ins: List[Path] = []
for inf in in_files:
try:
if ProjwfcIn.is_file_of_type(inf):
projwfc_ins.append(inf)
elif PwIn.is_file_of_type(inf):
# Best-effort classification by filename hint
if re.search(r"(?i)\bbands", inf.name):
bands_ins.append(inf)
elif re.search(r"(?i)\bnscf", inf.name):
nscf_ins.append(inf)
else:
scf_ins.append(inf)
except Exception:
pass
# Out files: detect program type by peeking first 5 lines
out_files = [p for p in files if (p.suffix.lower() == ".out" or p.suffix.lower() == ".log")]
pwscf_outs: List[Path] = []
projwfc_outs: List[Path] = []
for of in out_files:
try:
if PwOut.is_file_of_type(of):
pwscf_outs.append(of)
elif ProjwfcOut.is_file_of_type(of):
projwfc_outs.append(of)
except Exception:
pass
# Classify PWSCF out files by filename hints (best-effort)
scf_outs = [p for p in pwscf_outs if re.search(r"(?i)\bscf", p.name)]
bands_outs = [p for p in pwscf_outs if re.search(r"(?i)\bbands", p.name)]
nscf_outs = [p for p in pwscf_outs if re.search(r"(?i)\bnscf", p.name)]
# PDOS files
pdos_files = [p for p in files if re.search(r"(?i)pdos_atm#|pdos_tot", p.name)]
def sort_pref(fp_list: List[Path]) -> List[Path]:
try:
return sorted(
fp_list,
key=lambda p: (
len(p.relative_to(self._dirpath).parts),
-p.stat().st_mtime,
),
)
except Exception:
return fp_list
self._detected["scf_in"] = sort_pref(scf_ins)[0] if scf_ins else None
self._detected["scf_out"] = sort_pref(scf_outs)[0] if scf_outs else (sort_pref(pwscf_outs)[0] if pwscf_outs else None)
self._detected["bands_in"] = sort_pref(bands_ins)[0] if bands_ins else None
self._detected["bands_out"] = sort_pref(bands_outs)[0] if bands_outs else None
self._detected["nscf_in"] = sort_pref(nscf_ins)[0] if nscf_ins else None
self._detected["nscf_out"] = sort_pref(nscf_outs)[0] if nscf_outs else None
self._detected["projwfc_in"] = sort_pref(projwfc_ins)[0] if projwfc_ins else None
self._detected["projwfc_out"] = sort_pref(projwfc_outs)[0] if projwfc_outs else None
self._detected["pdos_files"] = sort_pref(pdos_files)
self._detected["atomic_proj_xml"] = sort_pref(atomic_proj_xml)[0] if atomic_proj_xml else None
self._detected["data_file_schema_xml"] = sort_pref(data_file_schema)[0] if data_file_schema else None
self._detected["data_xml"] = sort_pref(data_xmls)[0] if data_xmls else None
self._detected["pw_xml"] = sort_pref(pw_xmls)[0] if pw_xmls else None
detected_files = self.summary()
log_msg = f"Detected files:\n"
for k, v in detected_files.items():
log_msg += f"{k}: "
if isinstance(v, dict):
log_msg += f"{k}:\n"
for k2, v2 in v.items():
log_msg += f" {k2}: {v2}\n"
else:
log_msg += f"{v}\n"
logger.info(log_msg)
[docs]
def summary(self) -> Dict[str, Union[str, List[str], None]]:
def _p(v: Optional[Path] | List[Path]):
if v is None:
return None
if isinstance(v, list):
return [str(x) for x in v]
return str(v)
return {
"dirpath": str(self._dirpath),
"files": {k: _p(v) for k, v in self._detected.items()},
"parsers": {
"scf_in": self._detected["scf_in"] is not None,
"scf_out": self._detected["scf_out"] is not None,
"bands_out": self._detected["bands_out"] is not None,
"nscf_out": self._detected["nscf_out"] is not None,
"projwfc_out": self._detected["projwfc_out"] is not None,
"atomic_proj_xml": self._detected["atomic_proj_xml"] is not None,
"pw_xml": self._detected["pw_xml"] is not None,
"data_file_schema_xml": self._detected["data_file_schema_xml"] is not None,
"pdos": isinstance(self._detected.get("pdos_files"), list) and len(self._detected.get("pdos_files") or []) > 0,
},
}
# -------- lazy parser properties --------
@cached_property
def scf_in(self) -> Optional[PwIn]:
fp = self._detected.get("scf_in")
if not fp:
user_logger.warning("SCF input not found")
return None
try:
return PwIn(fp)
except Exception as exc:
user_logger.warning(f"Error parsing SCF input: {exc}")
return None
@cached_property
def scf_out(self) -> Optional[PwOut]:
fp = self._detected.get("scf_out")
if not fp:
user_logger.warning("SCF output not found")
return None
try:
return PwOut(fp)
except Exception as exc:
user_logger.warning(f"Error parsing SCF output: {exc}")
return None
@cached_property
def bands_in(self) -> Optional[PwIn]:
fp = self._detected.get("bands_in")
if not fp:
return None
try:
return PwIn(fp)
except Exception:
return None
@cached_property
def bands_out(self) -> Optional[PwOut]:
fp = self._detected.get("bands_out")
if not fp:
return None
try:
return PwOut(fp)
except Exception:
return None
@cached_property
def nscf_in(self) -> Optional[PwIn]:
fp = self._detected.get("nscf_in")
if not fp:
return None
try:
return PwIn(fp)
except Exception:
return None
@cached_property
def nscf_out(self) -> Optional[PwOut]:
fp = self._detected.get("nscf_out")
if not fp:
return None
try:
return PwOut(fp)
except Exception:
return None
@cached_property
def projwfc_in(self) -> Optional[ProjwfcIn]:
fp = self._detected.get("projwfc_in")
if not fp:
return None
try:
return ProjwfcIn(fp)
except Exception:
return None
@cached_property
def projwfc_out(self) -> Optional[ProjwfcOut]:
fp = self._detected.get("projwfc_out")
if not fp:
return None
try:
return ProjwfcOut(fp)
except Exception:
return None
@cached_property
def projwfc_dos(self) -> Optional[ProjwfcDOS]:
fps = self._detected.get("pdos_files")
if not fps or not isinstance(fps, list) or len(fps) == 0:
return None
try:
return ProjwfcDOS(self._dirpath)
except Exception:
return None
@cached_property
def atomic_proj_xml(self) -> Optional[AtomicProjXML]:
fp = self._detected.get("atomic_proj_xml")
if not fp:
return None
try:
return AtomicProjXML(fp)
except Exception:
return None
@cached_property
def pw_xml(self) -> Optional[PwXML]:
fp = self._detected.get("pw_xml")
if not fp:
return None
try:
return PwXML(fp)
except Exception:
return None
@cached_property
def data_file_schema_xml(self) -> Optional[PwXML]:
fp = self._detected.get("data_file_schema_xml")
if not fp:
return None
try:
return PwXML(fp)
except Exception:
return None
@cached_property
def alat(self) -> Optional[float]:
if self.scf_out is not None and self.scf_out.alat is not None:
logger.info("Parsing alat from scf.out")
alat = self.scf_out.alat * AU_TO_ANG
elif self.pw_xml is not None and self.pw_xml.alat is not None:
logger.info("Parsing alat from pw.xml")
alat = self.pw_xml.alat
elif self.data_file_schema_xml is not None and self.data_file_schema_xml.alat is not None:
logger.info("Parsing alat from data_file_schema.xml")
alat = self.data_file_schema_xml.alat
else:
user_logger.warning("No alat found in scf.out or pw.xml")
return None
logger.debug(f"alat: {alat}")
return alat
@cached_property
def _raw_kpoints(self) -> Optional[np.ndarray]:
kpoints_cart = None
if self.atomic_proj_xml is not None:
logger.info("Parsing kpoints from atomic_proj.xml")
kpoints_cart = self.atomic_proj_xml.kpoints
elif self.projwfc_out is not None and self.projwfc_out.kpoints is not None:
logger.info("Parsing kpoints from projwfc.out")
kpoints_cart = self.projwfc_out.kpoints
elif self.bands_in is not None and self.bands_in.kpoints_card is not None and self.bands_in.kpoints_card.kpoints is not None:
logger.info("Parsing kpoints from bands.in")
kpoints_cart = self.bands_in.kpoints_card.kpoints
elif self.pw_xml is not None and self.pw_xml.kpoints is not None:
logger.info("Parsing kpoints from pw.xml")
kpoints_cart = self.pw_xml.kpoints
elif self.data_file_schema_xml is not None and self.data_file_schema_xml.kpoints is not None:
logger.info("Parsing kpoints from data_file_schema.xml")
kpoints_cart = self.data_file_schema_xml.kpoints
else:
user_logger.warning("No kpoints found in atomic_proj.xml or projwfc.out or bands.in")
return None
scaled_kpoints_cart = kpoints_cart * (2*np.pi / (self.alat))
kpoints = np.around(scaled_kpoints_cart.dot(np.linalg.inv(self.reciprocal_lattice)), decimals=8)
return kpoints
@cached_property
def kpath(self) -> Optional[KPath]:
if self.is_dos_calculation:
logger.info("No kpath found for DOS calculation")
return None
if self.bands_in is None:
logger.info("No bands.in file found, therefore not parsing kpath")
return None
if self.bands_in.kpoints_card is None:
logger.info("No kpoints_card found in bands.in, therefore not parsing kpath")
return None
if self.bands_in.kpoints_card.modified_knames is None:
logger.info("No modified_knames found in bands.in, therefore not parsing kpath")
return None
high_sym_points = self.bands_in.kpoints_card.high_symmetry_points
kticks = find_high_symmetry_ticks(self._raw_kpoints, high_sym_points)
self._kticks = kticks
new_kpoints = insert_continuous_points(self._raw_kpoints, kticks)
new_kpoints = np.array(new_kpoints)
segment_names = self.bands_in.kpoints_card.modified_knames
ngrids = [grid + 1 if i != len(self.bands_in.kpoints_card.ngrids) - 1 else grid for i, grid in enumerate(self.bands_in.kpoints_card.ngrids)]
return KPath(
knames= segment_names,
kticks=kticks,
special_kpoints=self.bands_in.kpoints_card.special_kpoints,
ngrids=ngrids,
has_time_reversal=True,
)
@cached_property
def kticks(self) -> List[int]:
if hasattr(self, "_kticks"):
return self._kticks
return []
@property
def kpoints(self) -> Optional[np.ndarray]:
kpoints = self._raw_kpoints
if self.kpath is not None:
logger.info("Parsing kpoints from kpath")
kpoints = insert_continuous_points(self._raw_kpoints, self.kticks)
return kpoints
@cached_property
def nk1(self) -> Optional[int]:
if self.nscf_in is not None and self.nscf_in.kpoints_card is not None and self.nscf_in.kpoints_card.nk1 is not None:
return self.nscf_in.kpoints_card.nk1
if self.pw_xml is not None and self.pw_xml.nk1 is not None:
return self.pw_xml.nk1
elif self.data_file_schema_xml is not None and self.data_file_schema_xml.nk1 is not None:
logger.info("Parsing nk1 from data_file_schema.xml")
return self.data_file_schema_xml.nk1
return None
@cached_property
def nk2(self) -> Optional[int]:
if self.nscf_in is not None and self.nscf_in.kpoints_card is not None and self.nscf_in.kpoints_card.nk2 is not None:
return self.nscf_in.kpoints_card.nk2
if self.pw_xml is not None and self.pw_xml.nk2 is not None:
return self.pw_xml.nk2
elif self.data_file_schema_xml is not None and self.data_file_schema_xml.nk2 is not None:
logger.info("Parsing nk2 from data_file_schema.xml")
return self.data_file_schema_xml.nk2
return None
@cached_property
def nk3(self) -> Optional[int]:
if self.nscf_in is not None and self.nscf_in.kpoints_card is not None and self.nscf_in.kpoints_card.nk3 is not None:
return self.nscf_in.kpoints_card.nk3
if self.pw_xml is not None and self.pw_xml.nk3 is not None:
return self.pw_xml.nk3
elif self.data_file_schema_xml is not None and self.data_file_schema_xml.nk3 is not None:
logger.info("Parsing nk3 from data_file_schema.xml")
return self.data_file_schema_xml.nk3
return None
@cached_property
def sk1(self) -> Optional[int]:
if self.nscf_in is not None and self.nscf_in.kpoints_card is not None and self.nscf_in.kpoints_card.sk1 is not None:
return self.nscf_in.kpoints_card.sk1
if self.pw_xml is not None and self.pw_xml.sk1 is not None:
return self.pw_xml.sk1
elif self.data_file_schema_xml is not None and self.data_file_schema_xml.sk1 is not None:
logger.info("Parsing sk1 from data_file_schema.xml")
return self.data_file_schema_xml.sk1
return None
@cached_property
def sk2(self) -> Optional[int]:
if self.nscf_in is not None and self.nscf_in.kpoints_card is not None and self.nscf_in.kpoints_card.sk2 is not None:
return self.nscf_in.kpoints_card.sk2
if self.pw_xml is not None and self.pw_xml.sk2 is not None:
return self.pw_xml.sk2
elif self.data_file_schema_xml is not None and self.data_file_schema_xml.sk2 is not None:
logger.info("Parsing sk2 from data_file_schema.xml")
return self.data_file_schema_xml.sk2
return None
@cached_property
def sk3(self) -> Optional[int]:
if self.nscf_in is not None and self.nscf_in.kpoints_card is not None and self.nscf_in.kpoints_card.sk3 is not None:
return self.nscf_in.kpoints_card.sk3
if self.pw_xml is not None and self.pw_xml.sk3 is not None:
return self.pw_xml.sk3
elif self.data_file_schema_xml is not None and self.data_file_schema_xml.sk3 is not None:
logger.info("Parsing sk3 from data_file_schema.xml")
return self.data_file_schema_xml.sk3
return None
@cached_property
def reciprocal_lattice(self) -> Optional[np.ndarray]:
if self.pw_xml is not None and self.pw_xml.reciprocal_lattice is not None:
logger.info("Parsing reciprocal lattice from pw.xml")
reciprocal_lattice = self.pw_xml.reciprocal_lattice
elif self.data_file_schema_xml is not None and self.data_file_schema_xml.reciprocal_lattice is not None:
logger.info("Parsing reciprocal lattice from data_file_schema.xml")
reciprocal_lattice = self.data_file_schema_xml.reciprocal_lattice
elif self.scf_out is not None and self.scf_out.reciprocal_axes is not None:
logger.info("Parsing reciprocal lattice from scf.out")
reciprocal_lattice = self.scf_out.reciprocal_axes
elif self.bands_out is not None and self.bands_out.reciprocal_axes is not None:
logger.info("Parsing reciprocal lattice from bands.out")
reciprocal_lattice = self.bands_out.reciprocal_axes
elif self.nscf_out is not None and self.nscf_out.reciprocal_axes is not None:
logger.info("Parsing reciprocal lattice from nscf.out")
reciprocal_lattice = self.nscf_out.reciprocal_axes
else:
logger.warning("No reciprocal lattice found in pw.xml or scf.out or bands.out or nscf.out")
return None
return (2 * np.pi / self.alat) * reciprocal_lattice
@cached_property
def fermi(self) -> Optional[float]:
if self.scf_out is not None and self.scf_out.fermi_energy_ev is not None:
logger.debug(f"Fermi energy found in {self.scf_out.fermi_energy_ev}")
return self.scf_out.fermi_energy_ev
if self.pw_xml is not None and self.pw_xml.fermi is not None:
return self.pw_xml.fermi
elif self.data_file_schema_xml is not None and self.data_file_schema_xml.fermi is not None:
return self.data_file_schema_xml.fermi
return None
@cached_property
def bands(self) -> Optional[np.ndarray]:
if self.atomic_proj_xml is not None and self.atomic_proj_xml.bands is not None:
logger.info("Parsing bands from atomic_proj.xml")
bands = self.atomic_proj_xml.bands
elif self.projwfc_out is not None and self.projwfc_out.bands is not None:
logger.info("Parsing bands from projwfc.out")
bands = HARTREE_TO_EV * self.projwfc_out.bands
elif self.pw_xml is not None and self.pw_xml.bands is not None:
logger.info("Parsing bands from pw.xml")
bands = HARTREE_TO_EV * self.pw_xml.bands
elif self.data_file_schema_xml is not None and self.data_file_schema_xml.bands is not None:
logger.info("Parsing bands from data_file_schema.xml")
bands = HARTREE_TO_EV * self.data_file_schema_xml.bands
else:
user_logger.warning("No bands found in atomic_proj.xml or projwfc.out or pw.xml")
return None
if self.kpath is not None:
bands = insert_continuous_points(bands, self.kticks)
logger.debug(f"Bands: {bands.shape}")
return bands
@cached_property
def spd_phase(self) -> Optional[np.ndarray]:
if self.atomic_proj_xml is None and self.projwfc_out is None:
return None
wfc_mapping = self.projwfc_out.wfc_mapping
projections = self.atomic_proj_xml.projections
orbitals = self.projwfc_out.orbitals
n_kpoints = self.atomic_proj_xml.n_kpoints
n_bands = self.atomic_proj_xml.n_bands
n_spin_channels = self.atomic_proj_xml.n_spin_channels
n_atoms = self.projwfc_out.n_atoms
n_orbitals = self.projwfc_out.n_orbitals
pyprocar_projections_phase = np.zeros(
shape=(n_kpoints, n_bands, n_spin_channels, n_atoms, n_orbitals),
dtype=projections.dtype,
)
for state_num, wfc_info in wfc_mapping.items():
element = wfc_info["element"]
wfc_num = wfc_info["wfc_num"]
atm_num = wfc_info["atm_num"]
l = wfc_info["l"]
j = wfc_info["j"]
m_j = wfc_info["m_j"]
m = wfc_info["m"]
if m_j is not None:
orbital_dict = {
"l": l,
"j": j,
"m_j": m_j,
}
else:
orbital_dict = {
"l": l,
"m": m,
}
i_orbital = orbitals.index(orbital_dict)
i_atom = atm_num - 1
i_state = state_num - 1
pyprocar_projections_phase[..., i_atom, i_orbital] += projections[..., i_state]
n_kpoints = self.atomic_proj_xml.n_kpoints
n_bands = self.atomic_proj_xml.n_bands
n_spin_channels = self.atomic_proj_xml.n_spin_channels
n_atoms = self.projwfc_out.n_atoms
n_orbitals = self.projwfc_out.n_orbitals
n_principals = 1
# Move spin channels to the last axis. This is need to have the dimensionality to have the same shape as the format in pyproxcar
pyprocar_projections_phase = np.moveaxis(pyprocar_projections_phase, 2, -1)
pyprocar_projections_phase = pyprocar_projections_phase.reshape((n_kpoints, n_bands, n_atoms, n_principals, n_orbitals, n_spin_channels))
if self.kpath is not None:
pyprocar_projections_phase = insert_continuous_points(pyprocar_projections_phase, self.kticks)
logger.debug(f"Spd Phase: {pyprocar_projections_phase.shape}")
return pyprocar_projections_phase
@cached_property
def spd(self) -> Optional[np.ndarray]:
if self.atomic_proj_xml is None and self.projwfc_out is None and self.spd_phase is None:
return None
logger.info(f"Parsing spd from spd phase")
spd = np.absolute(self.spd_phase)**2
n_kpoints = self.spd_phase.shape[0]
if self.kpath is not None and n_kpoints != self.kpoints.shape[0]:
spd = insert_continuous_points(spd, self.kticks)
logger.debug(f"Spd: {spd.shape}")
return spd
@cached_property
def orbitals(self) -> Optional[List[str]]:
if self.projwfc_out is not None:
logger.info("Parsing orbitals from projwfc.out")
return self.projwfc_out.orbitals
elif self.atomic_proj_xml is not None:
logger.info("Parsing orbitals from atomic_proj.xml")
return self.atomic_proj_xml.orbitals
else:
logger.info("No orbitals found in projwfc.out or atomic_proj.xml")
return None
# -------- computed properties --------
@cached_property
def ebs(self) -> Optional[ElectronicBandStructure]:
return ElectronicBandStructure(
kpoints=self.kpoints,
n_kx=self.nk1,
n_ky=self.nk2,
n_kz=self.nk3,
bands=self.bands,
projected=self.spd,
efermi=self.fermi,
kpath=self.kpath,
projected_phase=self.spd_phase,
labels=self.orbitals,
reciprocal_lattice=self.reciprocal_lattice,
)
@cached_property
def projected_dos(self) -> Optional[np.ndarray]:
if self.projwfc_dos is None:
return None
n_energies = self.projwfc_dos.n_energies
n_spin_channels = self.projwfc_dos.n_spin_channels
n_orbitals = self.projwfc_out.n_orbitals
n_atoms = self.projwfc_dos.n_atoms
# Reshaping to match what pyprocar expects
n_principals = 1
projected_dos = self.projwfc_dos.projected_dos # with shape (n_energies, n_spin_channels, n_atoms, n_orbitals)
projected_dos = np.moveaxis(projected_dos, 1, -1) # shape (n_energies, n_orbitals, n_atoms, n_spin_channels)
projected_dos = np.moveaxis(projected_dos, 0, -1) # shape (n_atoms, n_orbitals, n_spin_channels, n_energies)
projected_dos = projected_dos.reshape(n_atoms, n_principals, n_orbitals, n_spin_channels, n_energies)
logger.debug(f"projected_dos: {projected_dos.shape}")
return projected_dos
@cached_property
def total_dos(self) -> Optional[np.ndarray]:
if self.projwfc_dos is None:
return None
n_spin_channels = self.projwfc_dos.n_spin_channels
n_energies = self.projwfc_dos.n_energies
logger.debug(f"total_dos: {self.projwfc_dos.total_dos.shape}")
return self.projwfc_dos.total_dos.reshape((n_spin_channels, n_energies), order="C")
@cached_property
def energies(self) -> Optional[np.ndarray]:
if self.projwfc_dos is None:
return None
return self.projwfc_dos.bands[0] - self.fermi
@cached_property
def is_dos_calculation(self) -> bool:
logger.info("Checking if DOS calculation")
is_dos_calculation = not self.projwfc_in.is_kresolved
logger.info(f"Is DOS calculation: {is_dos_calculation}")
return is_dos_calculation
@cached_property
def dos(self) -> Optional[DensityOfStates]:
if self.projwfc_dos is None:
user_logger.warning("No PDOS files found for DOS construction")
return None
if not self.is_dos_calculation:
return None
logger.debug(f"energies: {self.energies.shape}")
logger.debug(f"total_dos: {self.total_dos.shape}")
logger.debug(f"projected_dos: {self.fermi}")
return DensityOfStates(
energies=self.energies,
total=self.total_dos,
efermi=self.fermi,
projected=self.projected_dos,
)
@cached_property
def species(self) -> Optional[List[str]]:
if self.pw_xml is not None and self.pw_xml.atomic_species is not None:
return self.pw_xml.atomic_species
elif self.data_file_schema_xml is not None and self.data_file_schema_xml.atomic_species is not None:
return self.data_file_schema_xml.atomic_species
else:
user_logger.warning("No atomic species found in any input or output file")
return None
@cached_property
def direct_lattice(self) -> Optional[np.ndarray]:
if self.pw_xml is not None and self.pw_xml.direct_lattice is not None:
return self.pw_xml.direct_lattice
elif self.data_file_schema_xml is not None and self.data_file_schema_xml.direct_lattice is not None:
return self.data_file_schema_xml.direct_lattice
else:
user_logger.warning("No direct lattice found in any input or output file")
return None
@cached_property
def atomic_positions(self) -> Optional[np.ndarray]:
if self.pw_xml is not None and self.pw_xml.atomic_positions is not None:
return self.pw_xml.atomic_positions
elif self.data_file_schema_xml is not None and self.data_file_schema_xml.atomic_positions is not None:
return self.data_file_schema_xml.atomic_positions
else:
user_logger.warning("No atomic positions found in any input or output file")
return None
@cached_property
def rotations(self) -> Optional[np.ndarray]:
if self.pw_xml is not None and self.pw_xml.rotations is not None:
return self.pw_xml.rotations
elif self.data_file_schema_xml is not None and self.data_file_schema_xml.rotations is not None:
return self.data_file_schema_xml.rotations
else:
user_logger.warning("No rotations found in any input or output file")
return None
@cached_property
def structure(self) -> Optional[Structure]:
return Structure(
atoms=self.species,
lattice=self.direct_lattice,
fractional_coordinates=self.atomic_positions,
rotations=self.rotations,
)
def find_high_symmetry_ticks(raw_kpoints, high_sym_points, atol=1e-4):
"""
Find indices of raw_kpoints that match high_sym_points within tolerance.
Each high_sym_point is matched once, in order, to the first raw_kpoint
within tolerance. Duplicates in high_sym_points are allowed.
Parameters
----------
raw_kpoints : (N, 3) ndarray
List of kpoints along the path.
high_sym_points : (M, 3) ndarray
List of special kpoints to match, in order (duplicates allowed).
atol : float
Absolute tolerance for matching.
Returns
-------
kticks : list[int]
Indices in raw_kpoints corresponding to high_sym_points.
"""
raw_kpoints = np.asarray(raw_kpoints)
high_sym_points = np.asarray(high_sym_points)
# Compute pairwise distances (N, M)
dists = np.linalg.norm(
raw_kpoints[:, None, :] - high_sym_points[None, :, :], axis=-1
)
kticks = []
last_idx = -1 # ensure we move forward along raw_kpoints
for j in range(dists.shape[1]):
# Find matches *after* the last matched index
matches = np.where((dists[:, j] < atol) & (np.arange(len(raw_kpoints)) > last_idx))[0]
if len(matches) > 0:
idx = matches[0] # first valid match
kticks.append(idx)
last_idx = idx
else:
raise ValueError(
f"No match found for high_sym_point {j}: {high_sym_points[j]}"
)
return kticks
def insert_continuous_points(arr: np.ndarray, tick_indices: np.ndarray) -> np.ndarray:
"""
Insert duplicates at tick indices to enforce VASP-style repeated kpoints.
Parameters
----------
arr : np.ndarray
Array with shape (nk, ...), where axis=0 corresponds to kpoints.
tick_indices : array-like
Indices of tick points (end of each segment).
Continuous ticks will be duplicated.
Returns
-------
np.ndarray
New array with duplicated rows at continuous tick points.
"""
tick_indices = np.asarray(tick_indices)
# Continuous ticks are all except the very first one
continuous_ticks = tick_indices[1:-1]
# Values to duplicate
rows_to_insert = arr[continuous_ticks]
# Insert them back at the right positions
# np.insert shifts indices automatically, so we need to offset
out = np.insert(arr, continuous_ticks + 1, rows_to_insert, axis=0)
return out