Source code for pyprocar.scripts.scriptBandsplot

_author__ = "Pedram Tavadze and Logan Lang"
__maintainer__ = "Pedram Tavadze and Logan Lang"
__email__ = "petavazohi@mail.wvu.edu, lllang@mix.wvu.edu"
__date__ = "March 31, 2020"

import logging
import os
from typing import List

import matplotlib.pyplot as plt
import numpy as np

from pyprocar import io
from pyprocar.cfg import ConfigFactory, ConfigManager, PlotType
from pyprocar.plotter import EBSPlot
from pyprocar.utils import data_utils, welcome
from pyprocar.utils.info import orbital_names
from pyprocar.utils.log_utils import set_verbose_level

user_logger = logging.getLogger("user")
logger = logging.getLogger(__name__)


[docs] def bandsplot( code: str, dirname: str, mode: str = "plain", spins: List[int] = None, atoms: List[int] = None, orbitals: List[int] = None, items: dict = {}, fermi: float = None, fermi_shift: float = 0, interpolation_factor: int = 1, interpolation_type: str = "cubic", projection_mask: np.ndarray = None, kticks=None, knames=None, kdirect: bool = True, elimit: List[float] = None, ax: plt.Axes = None, show: bool = True, savefig: str = None, print_plot_opts: bool = False, export_data_file: str = None, export_append_mode: bool = True, ktick_limit: List[float] = None, x_limit: List[float] = None, use_cache: bool = True, verbose: int = 1, quiet_welcome: bool = False, **kwargs, ): """A function to plot the band structutre Parameters ---------- code : str, optional String to of the code used, by default "vasp" dirname : str, optional The directory name of the calculation, by default None mode : str, optional Sting for the mode of the calculation, by default "plain" spins : List[int], optional A list of spins, by default None atoms : List[int], optional A list of atoms, by default None orbitals : List[int], optional A list of orbitals, by default None items : dict, optional A dictionary where the keys are the atoms and the values a list of orbitals, by default {} fermi : float, optional Float for the fermi energy, by default None. By default the fermi energy will be shifted by the fermi value that is found in the directory. For band structure calculations, due to convergence issues, this fermi energy might not be accurate. If so add the fermi energy from the self-consistent calculation. fermi_shift : float, optional Float to shift the fermi energy, by default 0. interpolation_factor : int, optional The interpolation_factor, by default 1 interpolation_type : str, optional The interpolation type, by default "cubic" projection_mask : np.ndarray, optional A custom projection mask, by default None kticks : _type_, optional A list of kticks, by default None knames : _type_, optional A list of kanems, by default None elimit : List[float], optional A list of floats to decide the energy window, by default None ax : plt.Axes, optional A matplotlib axes, by default None show : bool, optional Boolean if to show the plot, by default True savefig : str, optional String to save the plot, by default None export_data_file : str, optional The file name to export the data to. If not provided the data will not be exported. export_append_mode : bool, optional Boolean to append the mode to the file name. If not provided the data will be overwritten. print_plot_opts: bool, optional Boolean to print the plotting options quiet_welcome: bool, optional Boolean to not print the welcome message use_cache: bool, optional Boolean to use cache for EBS verbose: int, optional Verbosity level """ set_verbose_level(verbose) user_logger.info(f"If you want more detailed logs, set verbose to 2 or more") user_logger.info("_" * 100) if not quiet_welcome: welcome() default_config = ConfigFactory.create_config(PlotType.BAND_STRUCTURE) config = ConfigManager.merge_configs(default_config, kwargs) user_logger.info("_" * 100) modes_txt = " , ".join(config.modes) message = f""" There are additional plot options that are defined in the configuration file. You can change these configurations by passing the keyword argument to the function. To print a list of all plot options set `print_plot_opts=True` Here is a list modes : {modes_txt} """ if not quiet_welcome: user_logger.info(message) if print_plot_opts: for key, value in default_config.as_dict().items(): user_logger.info(f"{key} : {value}") user_logger.info("_" * 100) ebs_pkl_filepath = os.path.join(dirname, "ebs.pkl") structure_pkl_filepath = os.path.join(dirname, "structure.pkl") kpath_pkl_filepath = os.path.join(dirname, "kpath.pkl") if not use_cache: user_logger.warning(f"Not using cache, removing existing cache files") if os.path.exists(structure_pkl_filepath): logger.info(f"Removing existing structure file: {structure_pkl_filepath}") os.remove(structure_pkl_filepath) if os.path.exists(kpath_pkl_filepath): logger.info(f"Removing existing kpath file: {kpath_pkl_filepath}") os.remove(kpath_pkl_filepath) if os.path.exists(ebs_pkl_filepath): logger.info(f"Removing existing EBS file: {ebs_pkl_filepath}") os.remove(ebs_pkl_filepath) if not os.path.exists(ebs_pkl_filepath): logger.info(f"Parsing EBS from {dirname}") parser = io.Parser(code=code, dirpath=dirname) ebs = parser.ebs structure = parser.structure kpath = ebs.kpath data_utils.save_pickle(ebs, ebs_pkl_filepath) data_utils.save_pickle(structure, structure_pkl_filepath) else: logger.info( f"Loading EBS, Structure, and Kpath from cached Pickle files in {dirname}" ) ebs = data_utils.load_pickle(ebs_pkl_filepath) structure = data_utils.load_pickle(structure_pkl_filepath) kpath = ebs.kpath codes_with_scf_fermi = ["qe", "elk"] if code in codes_with_scf_fermi and fermi is None: logger.info(f"No fermi given, using the found fermi energy: {ebs.efermi}") fermi = ebs.efermi if fermi is not None: logger.info(f"Shifting Fermi energy to zero: {fermi}") ebs.bands -= fermi ebs.bands += fermi_shift fermi_level = fermi_shift y_label = r"E - E$_F$ (eV)" else: y_label = r"E (eV)" user_logger.warning( "`fermi` is not set! Set `fermi={value}`. The plot did not shift the bands by the Fermi energy." ) # fixing the spin, to plot two channels into one (down is negative) if np.array_equal(spins, [-1, 1]) or np.array_equal(spins, [1, -1]): if ebs.fix_collinear_spin(): spins = [0] ebs_plot = EBSPlot(ebs, kpath, ax, spins, kdirect=kdirect, config=config) projection_labels = [] labels = [] if mode == "plain": user_logger.info("Plotting bands in plain mode") ebs_plot.plot_bands() elif mode == "ipr": user_logger.info("Plotting bands in IPR mode") weights = ebs_plot.ebs.ebs_ipr() if config.weighted_color: color_weights = weights else: color_weights = None if config.weighted_width: width_weights = weights else: width_weights = None color_mask = projection_mask width_mask = projection_mask ebs_plot.plot_parameteric( color_weights=color_weights, width_weights=width_weights, color_mask=color_mask, width_mask=width_mask, spins=spins, elimit=elimit, ) ebs_plot.set_colorbar_title(title="Inverse Participation Ratio") elif mode in ["overlay", "overlay_species", "overlay_orbitals"]: weights = [] if mode == "overlay_species": if orbitals is None: orbitals = list(np.arange(len(ebs_plot.ebs.projected[0][0]), dtype=int)) user_logger.info("Plotting bands in overlay species mode") for ispc in structure.species: labels.append(ispc) atoms = np.where(structure.atoms == ispc)[0] projection_label = f"atom-{ispc}_orbitals-" + ",".join( str(x) for x in orbitals ) projection_labels.append(projection_label) w = ebs_plot.ebs.ebs_sum( atoms=atoms, principal_q_numbers=[-1], orbitals=orbitals, spins=spins, ) weights.append(w) if mode == "overlay_orbitals": user_logger.info("Plotting bands in overlay orbitals mode") for iorb, orb in enumerate(["s", "p", "d", "f"]): if orb == "f" and not ebs_plot.ebs.norbitals > 9: continue orbitals = orbital_names[orb] labels.append(orb) atom_label = "" if atoms: atom_labels = ",".join(str(x) for x in atoms) atom_label = f"atom-{atom_labels}_" projection_label = f"{atom_label}orbitals-{orb}" projection_labels.append(projection_label) w = ebs_plot.ebs.ebs_sum( atoms=atoms, principal_q_numbers=[-1], orbitals=orbitals, spins=spins, ) weights.append(w) elif mode == "overlay": user_logger.info("Plotting bands in overlay mode") if isinstance(items, dict): items = [items] if isinstance(items, list): for it in items: for ispc in it: atoms = np.where(structure.atoms == ispc)[0] if isinstance(it[ispc][0], str): orbitals = [] for iorb in it[ispc]: orbitals = np.append( orbitals, orbital_names[iorb] ).astype(int) labels.append(ispc + "-" + "".join(it[ispc])) else: orbitals = it[ispc] labels.append( ispc + "-" + "_".join(str(x) for x in it[ispc]) ) atom_labels = ",".join(str(x) for x in atoms) orbital_labels = ",".join(str(x) for x in orbitals) projection_label = ( f"atoms-{atom_labels}_orbitals-{orbital_labels}" ) projection_labels.append(projection_label) w = ebs_plot.ebs.ebs_sum( atoms=atoms, principal_q_numbers=[-1], orbitals=orbitals, spins=spins, ) weights.append(w) ebs_plot.plot_parameteric_overlay( spins=spins, weights=weights, labels=projection_labels ) else: if atoms is not None and isinstance(atoms[0], str): atoms_str = atoms atoms = [] for iatom in np.unique(atoms_str): atoms = np.append(atoms, np.where(structure.atoms == iatom)[0]).astype( np.int ) if orbitals is not None and isinstance(orbitals[0], str): orbital_str = orbitals orbitals = [] for iorb in orbital_str: orbitals = np.append(orbitals, orbital_names[iorb]).astype(np.int) projection_labels = [] projection_label = "" atoms_labels = "" if atoms: atoms_labels = ",".join(str(x) for x in atoms) projection_label += f"atoms-{atoms_labels}" orbital_labels = "" if orbitals: orbital_labels = ",".join(str(x) for x in orbitals) if len(projection_label) != 0: projection_label += "_" projection_label += f"orbitals-{orbital_labels}" projection_labels.append(projection_label) weights = ebs_plot.ebs.ebs_sum( atoms=atoms, principal_q_numbers=[-1], orbitals=orbitals, spins=spins ) if config.weighted_color: color_weights = weights else: color_weights = None if config.weighted_width: width_weights = weights else: width_weights = None color_mask = projection_mask width_mask = projection_mask if mode == "parametric": user_logger.info("Plotting bands in parametric mode") ebs_plot.plot_parameteric( color_weights=color_weights, width_weights=width_weights, color_mask=color_mask, width_mask=width_mask, spins=spins, labels=projection_labels, ) ebs_plot.set_colorbar_title() elif mode == "scatter": user_logger.info("Plotting bands in scatter mode") ebs_plot.plot_scatter( color_weights=color_weights, width_weights=width_weights, color_mask=color_mask, width_mask=width_mask, spins=spins, labels=projection_labels, ) ebs_plot.set_colorbar_title() elif mode == "atomic": user_logger.info("Plotting bands in atomic mode") if ebs.kpoints.shape[0] != 1: raise Exception("Must use a single kpoint") if color_weights is not None: color_weights = np.vstack((color_weights, color_weights)) ebs_plot.plot_atomic_levels( color_weights=color_weights, width_weights=width_weights, color_mask=color_mask, width_mask=width_mask, spins=spins, elimit=elimit, labels=projection_labels, ) ebs_plot.set_xlabel(label=config.x_label) ebs_plot.set_colorbar_title() else: user_logger.warning( f"Selected mode {mode} not valid. Please check the spelling" ) ebs_plot.set_xticks(kticks, knames) ebs_plot.set_yticks(interval=elimit) ebs_plot.set_xlim(interval=x_limit, ktick_interval=ktick_limit) ebs_plot.set_ylim(elimit) ebs_plot.set_ylabel(label=y_label) ebs_plot.set_xlabel(label=config.x_label) if fermi is not None: ebs_plot.draw_fermi(fermi_level=fermi_level) ebs_plot.set_title() ebs_plot.grid() ebs_plot.legend(labels) if savefig is not None: ebs_plot.save(savefig) if show: ebs_plot.show() if export_data_file is not None: if export_append_mode: file_basename, file_type = export_data_file.split(".") filename = f"{file_basename}_{mode}.{file_type}" else: filename = export_data_file ebs_plot.export_data(filename) return ebs_plot.fig, ebs_plot.ax