Source code for pyprocar.core.fermisurface

__author__ = "Pedram Tavadze and Logan Lang"
__maintainer__ = "Pedram Tavadze and Logan Lang"
__email__ = "petavazohi@mail.wvu.edu, lllang@mix.wvu.edu"
__date__ = "December 01, 2020"

import logging
import os
import re
import sys
from enum import Enum
from typing import List

import matplotlib as mpl
import matplotlib.colors as colors
import matplotlib.pyplot as plt
import numpy as np
import yaml
from matplotlib import cm
from matplotlib import colors as mpcolors
from matplotlib.collections import LineCollection
from scipy.interpolate import griddata
from skimage import measure

from pyprocar.utils import ROOT, ConfigManager

logger = logging.getLogger(__name__)
user_logger = logging.getLogger("user")


def validate_band_indices(band_indices):
    """Validate the band indices"""
    if band_indices is None:
        return None
    elif all(isinstance(x, tuple) for x in band_indices) or isinstance(band_indices, list):
        return band_indices
    elif all(isinstance(x, int) for x in band_indices):
        return [band_indices]
    else:
        raise ValueError(
            f"Invalid band indices: {band_indices}. Band indices must be a list of lists of integers or a list of integers. This represents selecting the bands for each spin.\n"
            "Example: \n [[0,1], [2,3]] means that the first band is selected for the first spin and the second band is selected for the second spin."
        )

class SpinProjection(Enum):
    """An enumeration for defining the spin projection"""
    SZ = "S$_z$"
    SY = "S$_y$"
    SX = "S$_x$"
    SX2 = "S$_x^2$"
    SY2 = "S$_y^2$"
    SZ2 = "S$_z^2$"
    
    @classmethod
    def from_str(cls, spin_projection: str):
        spin_projection = spin_projection.lower()
        if "z" in spin_projection and "2" in spin_projection:
            return cls.SZ2
        elif "y" in spin_projection and "2" in spin_projection:
            return cls.SY2
        elif "x" in spin_projection and "2" in spin_projection:
            return cls.SX2
        elif "z" in spin_projection:
            return cls.SZ
        elif "y" in spin_projection:
            return cls.SY
        elif "x" in spin_projection:
            return cls.SX
        else:
            raise ValueError(f"Invalid spin projection: {spin_projection}")
    

[docs] class FermiSurface: """A class for plotting and analyzing 2D Fermi surfaces. This class provides comprehensive functionality for visualizing 2D Fermi surfaces from electronic band structure calculations. It supports plotting Fermi surface contours, spin texture analysis, and various customization options for scientific visualization. The class handles interpolation of k-point data, contour generation, and multiple plotting modes including line segments, scatter plots, and vector fields. Parameters ---------- kpoints : np.ndarray Array of k-points in Cartesian coordinates with shape (n_kpoints, 3). These should be the k-points from the electronic structure calculation. bands : np.ndarray Array of band energies with shape (n_kpoints, n_bands, n_spins). The Fermi energy should already be subtracted from these values. spd : np.ndarray Array of spin-projected density with shape (n_kpoints, n_bands, n_spins, n_orbitals, n_atoms). Contains the orbital and atomic projections for each k-point and band. figsize : tuple of float, optional Figure size as (width, height) in inches, by default (6, 6). ax : matplotlib.axes.Axes or None, optional Matplotlib axes object to plot on. If None, a new figure and axes will be created, by default None. **kwargs Additional keyword arguments passed to matplotlib functions. Attributes ---------- fig : matplotlib.figure.Figure The matplotlib figure object. ax : matplotlib.axes.Axes The matplotlib axes object. handles : list List of matplotlib handles for plotted elements. kpoints : np.ndarray The k-points array. bands : np.ndarray The band energies array. spd : np.ndarray The spin-projected density array. energy : float or None The energy level for Fermi surface analysis. useful_bands_by_spins : list or None List of band indices that cross the specified energy for each spin. x_limits : tuple The x-axis limits for plotting. y_limits : tuple The y-axis limits for plotting. clim : tuple The color limits for color mapping. Examples -------- >>> import numpy as np >>> import matplotlib.pyplot as plt >>> # Create sample data >>> kpoints = np.random.rand(100, 3) >>> bands = np.random.rand(100, 10, 1) - 0.5 # Centered around 0 >>> spd = np.random.rand(100, 10, 1, 4, 1) >>> >>> # Initialize FermiSurface >>> fs = FermiSurface(kpoints, bands, spd) >>> >>> # Find bands crossing the Fermi level >>> fs.find_energy(0.0) >>> >>> # Generate and plot contours >>> contour_data = fs.generate_contours() >>> fs.plot_band_spin_contour_line_segments(contour_data) >>> >>> # Customize the plot >>> fs.set_xlabel('$k_x$ (Å$^{-1}$)') >>> fs.set_ylabel('$k_y$ (Å$^{-1}$)') >>> fs.show() Notes ----- The class assumes that the input band energies have the Fermi energy already subtracted, so that the Fermi level corresponds to energy = 0. The k-points should be in Cartesian coordinates and ready for plotting. For spin texture analysis, additional spin arrays (sx, sy, sz) are required and should be passed to the appropriate spin texture methods. """
[docs] def __init__( self, kpoints, bands, spd, figsize: tuple = (6, 6), ax: plt.Axes | None = None, **kwargs, ): """Initialize the FermiSurface object. Parameters ---------- kpoints : np.ndarray Array of k-points in Cartesian coordinates with shape (n_kpoints, 3). bands : np.ndarray Array of band energies with shape (n_kpoints, n_bands, n_spins). spd : np.ndarray Array of spin-projected density with shape (n_kpoints, n_bands, n_spins, n_orbitals, n_atoms). figsize : tuple of float, optional Figure size as (width, height) in inches, by default (6, 6). ax : matplotlib.axes.Axes or None, optional Matplotlib axes object to plot on, by default None. **kwargs Additional keyword arguments. """ if ax is None: self.fig = plt.figure( figsize=figsize, ) self.ax = self.fig.add_subplot(111) else: self.fig = plt.gcf() self.ax = ax self.handles = [] # Since some time ago Kpoints are in cartesian coords (ready to use) self.kpoints = kpoints self.bands = bands self.spd = spd self.useful = None # List of useful bands (filled in findEnergy) self.energy = None logger.debug("FermiSurface.init: ...") logger.info("Kpoints.shape : " + str(self.kpoints.shape)) logger.info("bands.shape : " + str(self.bands.shape)) logger.info("spd.shape : " + str(self.spd.shape)) logger.debug("FermiSurface.init: ...Done") return None
[docs] def find_energy(self, energy): """A method to find bands which are near a given energy Parameters ---------- energy : float The energy to search for bands around. Returns ------- None None Raises ------ RuntimeError If no bands are found, raise an error. """ self.energy = energy logger.info("Energy : " + str(energy)) # searching for bands crossing the desired energy bands_to_plot = False self.useful_bands_by_spins = [] for i_spin in range(self.bands.shape[2]): bands = self.bands[:, :, i_spin] indices = np.where( np.logical_and(bands.min(axis=0) < energy, bands.max(axis=0) > energy) )[0] self.useful_bands_by_spins.append(indices) if len(indices) != 0: bands_to_plot = True print( f"Band indices near iso-surface: (bands.shape={bands.shape}) spin-{i_spin} | bands-{indices}" ) if not bands_to_plot: user_logger.error( f"Could not find any bands crossing the energy ({energy} eV) relative to the fermi energy.\n" "Please check the energy and the bands:\n\n" "1. Try shifting the energy to find crossings of the bands and the energy.\n" "2. Check the density of states to see where the bands are in terms of energy.\n" "3. Check the bands to see if they are crossing the energy." ) raise RuntimeError("No bands to plot") return None
[docs] def select_bands(self, band_indices:tuple[tuple[int, int], tuple[int, int]] = None): """Select specific bands for plotting based on band indices. Parameters ---------- band_indices : tuple of tuples of int, optional Band indices for each spin channel. Format: ((band_indices_spin0,), (band_indices_spin1,)). If None, uses all available bands, by default None. Returns ------- dict Dictionary containing selected bands data with keys: - 'bands': List of band arrays for each spin - 'spd': List of spin-projected density arrays for each spin - 'band_labels': List of unique band labels for each spin """ bands_data= { "bands": [], "spd": [], "band_labels": [], } band_indices = validate_band_indices(band_indices) n_spins = self.bands.shape[2] for i_spin in range(n_spins): if i_spin==1 and len(band_indices) == 1: continue if len(band_indices[i_spin]) == 0: continue bands = self.bands[:, band_indices[i_spin], i_spin].transpose() spd = self.spd[:, band_indices[i_spin], i_spin].transpose() band_labels = np.unique(band_indices[i_spin]) if spd.shape[0] == 0: continue bands_data["bands"].append(bands) bands_data["spd"].append(spd) bands_data["band_labels"].append(band_labels) return bands_data
[docs] def generate_band_colors(self, bands, i_spin:int = 0, cmap:str = "plasma"): """Generate colors for each band using a colormap. Parameters ---------- bands : np.ndarray Array of band energies with shape (n_bands, n_kpoints). i_spin : int, optional Spin index (0 or 1), by default 0. cmap : str, optional Colormap name, by default "plasma". Returns ------- np.ndarray Array of RGBA colors with shape (n_bands, 4). """ n_bands = bands.shape[0] cmap = cm.get_cmap(cmap) norm = mpcolors.Normalize(0, 1) factor = 0.25 if i_spin == 1 else 0 solid_color_surface = np.arange(n_bands) / n_bands + factor band_colors = np.array( [cmap(norm(x)) for x in solid_color_surface[:]] ).reshape(-1, 4) return band_colors
[docs] def generate_contours(self, band_indices: list[list[int]] = None, interpolation=500, ignore_scalars: bool = False): """ Generate 2D Fermi surface contours for selected bands. This method interpolates the band energies onto a regular grid in the kx-ky plane and extracts contour lines at the specified energy level (typically the Fermi energy). It supports selection of specific bands and spin channels, and can optionally ignore scalar values associated with the bands. Parameters ---------- band_indices : list of list of int, optional List of band indices to include for each spin channel. If None, uses the bands identified by `find_energy()`. interpolation : int, optional Number of grid points along each axis for interpolation, by default 500. ignore_scalars : bool, optional If True, scalar values (e.g., projections) are ignored in the output, by default False. Returns ------- dict Dictionary containing interpolated band energies, spin-projected densities, and band labels for each spin channel. Raises ------ RuntimeError If `find_energy()` has not been called prior to this method. """ logger.debug("Plot: ...") if self.useful_bands_by_spins is None: raise RuntimeError("self.find_energy() must be called before Plotting") # selecting components of K-points x, y = self.kpoints[:, 0], self.kpoints[:, 1] logger.debug("k_x[:10], k_y[:10] values" + str([x[:10], y[:10]])) # and new, interpolated component xmax, xmin = x.max(), x.min() ymax, ymin = y.max(), y.min() logger.debug("xlim = " + str([xmin, xmax]) + " ylim = " + str([ymin, ymax])) xnew, ynew = np.mgrid[ xmin : xmax : interpolation * 1j, ymin : ymax : interpolation * 1j ] unique_x = xnew[:, 0] unique_y = ynew[0, :] self.x_limits = (xmin, xmax) self.y_limits = (ymin, ymax) # Selecting bands to plot bands_indices = band_indices if band_indices is not None else self.useful_bands_by_spins bands_data = self.select_bands(bands_indices) bands_spin_contour_data = {} self.selected_bands = bands_data for i_spin, (bands, spd, band_labels) in enumerate(zip(bands_data["bands"], bands_data["spd"], bands_data["band_labels"])): # Interpolating band energies on to new grid bnew = [] logger.debug("Interpolating ...") for i_band, band in enumerate(bands): bnew.append(griddata((x, y), band, (xnew, ynew), method="cubic")) bnew = np.array(bnew) # Generates colors per band plots = [] for i_band, band_energies in enumerate(bnew): contour_data = {"lines": [],"scalars": [], "labels": []} contours = measure.find_contours(band_energies, self.energy) for i_contour, contour in enumerate(contours): # measure.find contours returns a list of coordinates indcies of the mesh. # However, due to the algorithm they take values that are in between mesh points. # We need to interpolate the values to the original kmesh x_vals = contour[:, 0] y_vals = contour[:, 1] x_interp = np.interp( x_vals, np.arange(0, unique_x.shape[0]), unique_x ) y_interp = np.interp( y_vals, np.arange(0, unique_y.shape[0]), unique_y ) points = np.array([[x_interp, y_interp]]) points = np.moveaxis(points, -1, 0) segments = np.concatenate([points[:-1], points[1:]], axis=1) contour_data["lines"].append(segments) if spd is not None and not ignore_scalars: scalars = griddata((x, y), spd[i_band, :], (x_interp, y_interp), method="cubic") contour_data["scalars"].append(scalars) bands_spin_contour_data[(i_band, i_spin)] = contour_data return bands_spin_contour_data
[docs] def plot_contour_line_segments(self, contour_data:dict, label:str = None, cmap:str = "plasma", norm:mpcolors.Normalize = None, clim:tuple = (None, None), linewidth:float = 0.2, linestyle:str = "solid", color:str = "k", alpha:float = 1.0, line_collection_kwargs:dict = None, ): """Plot contour line segments from contour data. Parameters ---------- contour_data : dict Dictionary containing contour lines and scalars data. label : str, optional Label for the plot legend, by default None. cmap : str, optional Colormap name, by default "plasma". norm : matplotlib.colors.Normalize, optional Normalization instance for color mapping, by default None. clim : tuple of float, optional Color limits as (vmin, vmax), by default (None, None). linewidth : float, optional Width of the contour lines, by default 0.2. linestyle : str, optional Style of the contour lines, by default "solid". color : str, optional Color of the lines when not using scalar coloring, by default "k". alpha : float, optional Transparency level (0-1), by default 1.0. line_collection_kwargs : dict, optional Additional kwargs for LineCollection, by default None. Returns ------- list List of matplotlib LineCollection handles. """ line_handles = [] lines = contour_data["lines"] scalars = contour_data["scalars"] if len(scalars) > 0: vmin = clim[0] vmax = clim[1] if vmin is None and vmax is None: for i_segment, segment_scalars in enumerate(scalars): vmin = min(vmin, segment_scalars.min()) vmax = max(vmax, segment_scalars.max()) clim = (vmin, vmax) if not hasattr(self, "norm") or not hasattr(self, "clim") or not hasattr(self, "cmap"): self.set_scalar_mappable(norm=norm, clim=clim, cmap=cmap) for i_segment, segments in enumerate(lines): lc = LineCollection(segments, linestyle=linestyle, linewidth=linewidth, alpha=alpha, **line_collection_kwargs) if len(scalars) > 0: lc.set_array(scalars[i_segment]) lc.set_cmap(self.cmap) lc.set_norm(self.norm) else: lc.set_color(color) if i_segment == 0: lc.set_label(label) line_handles.append(self.ax.add_collection(lc)) self.handles.extend(line_handles) return line_handles
[docs] def plot_band_spin_contour_line_segments(self, bands_spin_contour_data:dict, linestyles:tuple[str, str] = ("solid", "dashed"), colors:tuple[str, str] = None, linewidths:tuple[float, float] = (0.2, 0.2), alphas:tuple[float, float] = (1.0, 1.0), cmap:str = "plasma", norm:mpcolors.Normalize = None, clim:tuple = (None, None), line_collection_kwargs:dict = None ): """Plot contour line segments for multiple bands and spins. Parameters ---------- bands_spin_contour_data : dict Dictionary containing contour data for each (band, spin) combination. linestyles : tuple of str, optional Line styles for each spin channel, by default ("solid", "dashed"). colors : tuple of str, optional Colors for each spin channel. If None, uses automatic coloring, by default None. linewidths : tuple of float, optional Line widths for each spin channel, by default (0.2, 0.2). alphas : tuple of float, optional Transparency levels for each spin channel, by default (1.0, 1.0). cmap : str, optional Colormap name, by default "plasma". norm : matplotlib.colors.Normalize, optional Normalization instance for color mapping, by default None. clim : tuple of float, optional Color limits as (vmin, vmax), by default (None, None). line_collection_kwargs : dict, optional Additional kwargs for LineCollection, by default None. Returns ------- list List of matplotlib LineCollection handles. """ plot_contour_line_segments_kwargs = {} if not hasattr(self, "norm") or not hasattr(self, "clim") or not hasattr(self, "cmap"): self.set_scalar_mappable(norm=norm, clim=clim, cmap=cmap) plot_contour_line_segments_kwargs["norm"] = self.norm plot_contour_line_segments_kwargs["clim"] = self.clim plot_contour_line_segments_kwargs["cmap"] = self.cmap plot_contour_line_segments_kwargs["line_collection_kwargs"] = line_collection_kwargs if line_collection_kwargs is not None else {} band_spin_handles = [] for (i_band, i_spin), contour_data in bands_spin_contour_data.items(): label = f"Band {i_band}, Spin {i_spin}" bands = self.selected_bands["bands"][i_spin] band_colors = self.generate_band_colors(bands, i_spin) line_kwargs = plot_contour_line_segments_kwargs.copy() line_kwargs["label"] = label line_kwargs["linestyle"] = linestyles[i_spin] line_kwargs["linewidth"] = linewidths[i_spin] line_kwargs["alpha"] = alphas[i_spin] if colors is not None: line_kwargs["color"] = colors[i_spin] else: line_kwargs["color"] = band_colors[i_band] handles = self.plot_contour_line_segments(contour_data, **line_kwargs) band_spin_handles.extend(handles) self.handles.extend(band_spin_handles) return band_spin_handles
[docs] def generate_spin_texture_contours(self, sx, sy, sz, band_indices:tuple[int, int] | None = None, point_density:int = 10, spin_projection: SpinProjection | str = "z^2", interpolation=300): """This method generates the spin texture contours""" logger.debug("spin_texture: ...") point_density = 50 // point_density if isinstance(spin_projection, str): spin_projection = SpinProjection.from_str(spin_projection) if self.useful_bands_by_spins is None: raise RuntimeError("self.find_energy() must be called before plotting") # selecting components of K-points x, y = self.kpoints[:, 0], self.kpoints[:, 1] if band_indices is None: bands = self.bands[:, self.useful_bands_by_spins[0], 0].transpose() band_labels = np.unique(self.useful_bands_by_spins[0]) sx = sx[:, self.useful_bands_by_spins[0]].transpose() sy = sy[:, self.useful_bands_by_spins[0]].transpose() sz = sz[:, self.useful_bands_by_spins[0]].transpose() else: bands = self.bands[:, band_indices[0], 0].transpose() band_labels = np.unique(band_indices[0]) sx = sx[:, band_indices[0]].transpose() sy = sy[:, band_indices[0]].transpose() sz = sz[:, band_indices[0]].transpose() # and new, interpolated component xmax, xmin = x.max(), x.min() ymax, ymin = y.max(), y.min() logger.debug("xlim = " + str([xmin, xmax]) + " ylim = " + str([ymin, ymax])) xnew, ynew = np.mgrid[ xmin : xmax : interpolation * 1j, ymin : ymax : interpolation * 1j ] # interpolation bnew = [] for band in bands: logger.debug("Interpolating ...") interp_bands = griddata((x, y), band, (xnew, ynew), method="cubic") bnew.append(interp_bands) spin_texture_contours = [] # plt.ioff() # Turn off interactive mode fig, ax = plt.subplots() for z in bnew: spin_texture_contours.append(ax.contour(xnew, ynew, z, [self.energy])) plt.close(fig) if len(spin_texture_contours) == 0: raise RuntimeError("Could not find any contours at this energy") self.x_limits = [0, 0] self.y_limits = [0, 0] spin_texture_contour_data = { "contours": [], "points": [], "sx": [], "sy": [], "sz": [], "scalars": [], } for i_band, (contour, spinX, spinY, spinZ) in enumerate(zip(spin_texture_contours, sx, sy, sz)): # The previous interp. yields the level curves, nothing more is # useful from there paths = contour.get_paths() if paths: verts = [path.vertices for path in paths] points = np.concatenate(verts) logger.debug("Fermi surf. points.shape: " + str(points.shape)) newSx = griddata((x, y), spinX, (points[:, 0], points[:, 1])) newSy = griddata((x, y), spinY, (points[:, 0], points[:, 1])) newSz = griddata((x, y), spinZ, (points[:, 0], points[:, 1])) # This is so the density scales the way you think. increasing number means increasing density. # The number in the numerator is so it scales reasonable with 0-20 if spin_projection == SpinProjection.SZ: scalars = newSz[::point_density] elif spin_projection == SpinProjection.SY: scalars = newSy[::point_density] elif spin_projection == SpinProjection.SX: scalars = newSx[::point_density] elif spin_projection == SpinProjection.SX2: scalars = newSx[::point_density] ** 2 elif spin_projection == SpinProjection.SY2: scalars = newSy[::point_density] ** 2 elif spin_projection == SpinProjection.SZ2: scalars = newSz[::point_density] ** 2 spin_texture_contour_data["sx"].append(newSx[::point_density]) spin_texture_contour_data["sy"].append(newSy[::point_density]) spin_texture_contour_data["sz"].append(newSz[::point_density]) spin_texture_contour_data["points"].append(points[::point_density]) spin_texture_contour_data["contours"].append(contour) spin_texture_contour_data["scalars"].append(scalars) self.spin_texture_contour_data = spin_texture_contour_data return self.spin_texture_contour_data
[docs] def plot_spin_texture_contours(self, spin_texture_contour_data:dict, alpha:float = 1.0, linewidth:float = 0.2, colors:str = None, linestyles:str = "solid", norm:mpcolors.Normalize = None, clim:tuple = (None, None), cmap:str = "plasma", countour_kwargs: dict = None): """Plot spin texture contours from spin texture data. Parameters ---------- spin_texture_contour_data : dict Dictionary containing spin texture contour data with contours, scalars, etc. clim : tuple of float, optional Color limits as (vmin, vmax), by default (None, None). cmap : str, optional Colormap name, by default "plasma". alpha : float, optional Transparency level (0-1), by default 1.0. linewidth : float, optional Width of the contour lines, by default 0.2. colors : str, optional Color for the contours, by default None. linestyles : str, optional Style of the contour lines, by default "solid". norm : matplotlib.colors.Normalize, optional Normalization instance for color mapping, by default None. clim : tuple of float, optional Color limits as (vmin, vmax), by default (None, None). cmap : str, optional Colormap name, by default "plasma". countour_kwargs : dict, optional Additional kwargs for matplotlib contour function, by default None. Returns ------- list List of matplotlib contour handles. """ if not hasattr(self, "norm") or not hasattr(self, "clim") or not hasattr(self, "cmap"): self.set_scalar_mappable(norm=norm, clim=clim, cmap=cmap) countour_kwargs = {} if countour_kwargs is None else countour_kwargs countour_kwargs.setdefault("linewidths", linewidth) countour_kwargs.setdefault("colors", colors) countour_kwargs.setdefault("linestyles", linestyles) countour_kwargs.setdefault("alpha", alpha) if colors is not None: countour_kwargs.setdefault("colors", colors) else: countour_kwargs.setdefault("cmap", self.cmap) countour_kwargs.setdefault("norm", self.norm) countour_kwargs.setdefault("vmin", self.clim[0]) countour_kwargs.setdefault("vmax", self.clim[1]) self.contour_handles = [] for z in spin_texture_contour_data["contours"]: self.contour_handles.append(self.ax.contour(z, [self.energy], **countour_kwargs)) self.handles.extend(self.contour_handles) return self.contour_handles
[docs] def plot_spin_texture_scatter(self, spin_texture_contour_data:dict, s:int =50, edgecolor:str = "none", alpha:float = 1.0, marker:str = ".", color:str = "k", scatter_kwargs:dict = None, ): """Plot spin texture as scatter points on Fermi surface contours. Parameters ---------- spin_texture_contour_data : dict Dictionary containing spin texture contour data with points and scalars. s : int, optional Size of scatter points, by default 50. edgecolor : str, optional Color of point edges, by default "none". alpha : float, optional Transparency level (0-1), by default 1.0. marker : str, optional Marker style for scatter points, by default ".". color : str, optional Color for points when not using scalar coloring, by default "k". scatter_kwargs : dict, optional Additional kwargs for matplotlib scatter function, by default None. Returns ------- list List of matplotlib scatter handles. """ x_limits = [0, 0] y_limits = [0, 0] self.scatter_handles = [] for i_band, (sx, sy, sz, points, scalars) in enumerate(zip(spin_texture_contour_data["sx"], spin_texture_contour_data["sy"], spin_texture_contour_data["sz"], spin_texture_contour_data["points"], spin_texture_contour_data["scalars"])): x_limits = [ min(x_limits[0], points[:, 0].min()), max(x_limits[1], points[:, 0].max()), ] y_limits = [ min(y_limits[0], points[:, 1].min()), max(y_limits[1], points[:, 1].max()), ] scatter_kwargs = {} if scatter_kwargs is None else scatter_kwargs scatter_kwargs.setdefault("s", s) scatter_kwargs.setdefault("edgecolor", edgecolor) scatter_kwargs.setdefault("alpha", alpha) scatter_kwargs.setdefault("marker", marker) scatter_kwargs.setdefault("c", color) scatter_kwargs.setdefault("cmap", self.cmap) scatter_kwargs.setdefault("norm", self.norm) scatter_kwargs.setdefault("vmin", self.clim[0]) scatter_kwargs.setdefault("vmax", self.clim[1]) self.scatter_handles.append(self.ax.scatter( points[:, 0], points[:, 1], **scatter_kwargs )) self.x_limits = ( x_limits[0] - abs(x_limits[0]) * 0.1, x_limits[1] + abs(x_limits[1]) * 0.1, ) self.y_limits = ( y_limits[0] - abs(y_limits[0]) * 0.1, y_limits[1] + abs(y_limits[1]) * 0.1, ) self.handles.extend(self.scatter_handles) return self.scatter_handles
[docs] def plot_spin_texture_arrows(self, spin_texture_contour_data:dict, arrow_color: str | list[str]| None = None, scale:float | None = None, scale_units:str = "inches", units:str = "inches", angles:str = "uv", quiver_kwargs: dict = None, ): """Plot spin texture as arrows (vectors) on Fermi surface contours. This method visualizes the spin texture by drawing arrows that represent the in-plane spin components (sx, sy) at points along the Fermi surface contours. The arrow direction indicates the spin orientation and can be colored by various spin projections. Parameters ---------- spin_texture_contour_data : dict Dictionary containing spin texture contour data with points, sx, sy, sz, and scalars. arrow_color : str, list of str, or None, optional Color(s) for the arrows. If str, all arrows use same color. If list, different colors for each band. If None, colors based on scalars, by default None. scale : float or None, optional Scale factor for arrow length. If None, matplotlib auto-scales, by default None. scale_units : str, optional Units for the scale parameter, by default "inches". units : str, optional Units for arrow dimensions, by default "inches". angles : str, optional How to interpret arrow angles ('uv' for x,y components), by default "uv". cmap : str, optional Colormap name for scalar coloring, by default "plasma". norm : matplotlib.colors.Normalize, optional Normalization instance for color mapping, by default None. quiver_kwargs : dict, optional Additional kwargs for matplotlib quiver function, by default None. Returns ------- None This method adds quiver plots to the axes but returns None. Notes ----- The arrows represent the in-plane spin components (sx, sy) while the color can represent any spin projection specified during contour generation. """ quiver_kwargs = {} if quiver_kwargs is None else quiver_kwargs quiver_kwargs.setdefault("scale", 1 / scale) quiver_kwargs.setdefault("scale_units", scale_units) quiver_kwargs.setdefault("angles", angles) x_limits = [0, 0] y_limits = [0, 0] self.quiver_handles = [] for i_band, (sx, sy, sz, points, scalars) in enumerate(zip(spin_texture_contour_data["sx"], spin_texture_contour_data["sy"], spin_texture_contour_data["sz"], spin_texture_contour_data["points"], spin_texture_contour_data["scalars"])): x_limits = [ min(x_limits[0], points[:, 0].min()), max(x_limits[1], points[:, 0].max()), ] y_limits = [ min(y_limits[0], points[:, 1].min()), max(y_limits[1], points[:, 1].max()), ] quiver_args = [ points[:, 0], # Arrow position x-component points[:, 1], # Arrow position y-component sx[:], # Arrow direction x-component sy[:], # Arrow direction y-component ] band_quiver_kwargs = quiver_kwargs.copy() if isinstance(arrow_color, list): band_quiver_kwargs["color"] = arrow_color[i_band] elif isinstance(arrow_color, str): band_quiver_kwargs["color"] = arrow_color else: quiver_args.append(scalars) band_quiver_kwargs.setdefault("cmap", self.cmap) band_quiver_kwargs.setdefault("norm", self.norm) band_quiver_kwargs.setdefault("clim", self.clim) band_quiver_kwargs["color"] = None self.quiver_handles.append(self.ax.quiver(*quiver_args,**band_quiver_kwargs)) self.x_limits = ( x_limits[0] - abs(x_limits[0]) * 0.1, x_limits[1] + abs(x_limits[1]) * 0.1, ) self.y_limits = ( y_limits[0] - abs(y_limits[0]) * 0.1, y_limits[1] + abs(y_limits[1]) * 0.1, ) self.handles.extend(self.quiver_handles) return None
[docs] def show_colorbar(self, label:str = "", n_ticks:int = 5, cmap:str = "plasma", clim:tuple = (None, None), colorbar_kwargs:dict = None): """Add a colorbar to the plot. Parameters ---------- label : str, optional Label for the colorbar, by default "". n_ticks : int, optional Number of ticks on the colorbar, by default 5. cmap : str, optional Colormap name, by default "plasma". norm : matplotlib.colors.Normalize, optional Normalization instance for color mapping, by default None. clim : tuple of float, optional Color limits as (vmin, vmax), by default (None, None). colorbar_kwargs : dict, optional Additional kwargs for matplotlib colorbar function, by default None. """ self.colorbar = self.fig.colorbar( self.cm, ax=self.ax, label=label, **colorbar_kwargs)
[docs] def set_scalar_mappable(self, norm:mpcolors.Normalize = None, clim:tuple = (None, None), cmap:str = "plasma"): vmin = clim[0] vmax = clim[1] if vmin is None: vmin=-0.5 if vmax is None: vmax=0.5 if norm is None: norm = mpcolors.Normalize norm = norm(vmin, vmax) self.norm = norm self.clim = clim self.cmap = cmap self.cm = cm.ScalarMappable(norm=norm, cmap=cmap)
[docs] def set_colorbar_ticks(self, n_ticks:int = 5, tick_labels = None, tick_positions = None, **kwargs): """Set the tick positions and labels for the colorbar. Parameters ---------- n_ticks : int, optional Number of ticks to use if tick_positions is None, by default 5. tick_labels : array-like, optional Custom tick labels, by default None. tick_positions : array-like, optional Custom tick positions, by default None. **kwargs Additional keyword arguments passed to matplotlib set_yticks. """ if tick_positions is None: tick_positions = np.linspace(self.colorbar.vmin, self.colorbar.vmax, n_ticks) if tick_labels is None: tick_labels = np.linspace(self.colorbar.vmin, self.colorbar.vmax, n_ticks) set_y_tick_kwargs = {} set_y_tick_kwargs.setdefault("ticks", tick_positions) set_y_tick_kwargs.setdefault("labels", tick_labels) set_y_tick_kwargs.update(kwargs) self.colorbar.ax.set_yticks(**set_y_tick_kwargs)
[docs] def set_xticks(self, n_ticks:int = 5, tick_labels = None, tick_positions = None, **kwargs): """Set the tick positions and labels for the colorbar. Parameters ---------- n_ticks : int, optional Number of ticks to use if tick_positions is None, by default 5. tick_labels : array-like, optional Custom tick labels, by default None. tick_positions : array-like, optional Custom tick positions, by default None. **kwargs Additional keyword arguments passed to matplotlib set_yticks. """ if tick_positions is None: tick_positions = np.linspace(self.x_limits[0], self.x_limits[1], n_ticks) if tick_labels is None: tick_labels = np.linspace(self.x_limits[0], self.x_limits[1], n_ticks) set_x_tick_kwargs = {} set_x_tick_kwargs.setdefault("ticks", tick_positions) set_x_tick_kwargs.setdefault("labels", tick_labels) set_x_tick_kwargs.update(kwargs) self.ax.set_xticks(**set_x_tick_kwargs)
[docs] def set_yticks(self, n_ticks:int = 5, tick_labels = None, tick_positions = None, **kwargs): """Set the tick positions and labels for the colorbar. Parameters ---------- n_ticks : int, optional Number of ticks to use if tick_positions is None, by default 5. tick_labels : array-like, optional Custom tick labels, by default None. tick_positions : array-like, optional Custom tick positions, by default None. **kwargs Additional keyword arguments passed to matplotlib set_yticks. """ if tick_positions is None: tick_positions = np.linspace(self.y_limits[0], self.y_limits[1], n_ticks) if tick_labels is None: tick_labels = np.linspace(self.y_limits[0], self.y_limits[1], n_ticks) set_y_tick_kwargs = {} set_y_tick_kwargs.setdefault("ticks", tick_positions) set_y_tick_kwargs.setdefault("labels", tick_labels) set_y_tick_kwargs.update(kwargs) self.ax.set_yticks(**set_y_tick_kwargs)
[docs] def set_colorbar_tick_params(self, **kwargs): """Set the tick parameters for the colorbar. Parameters ---------- **kwargs Keyword arguments passed to matplotlib tick_params. """ self.colorbar.ax.tick_params(**kwargs)
[docs] def set_colorbar_label(self, label:str|None = None, **kwargs): """Set the label for the colorbar. Parameters ---------- label : str or None, optional Label text. If None, uses current label, by default None. **kwargs Additional keyword arguments passed to matplotlib set_ylabel. """ if label is None: label = self.colorbar.ax.get_yaxis().label.get_text() self.colorbar.ax.set_ylabel(label, **kwargs)
[docs] def set_xlim(self, xlimits = None, **kwargs): """Set the x-axis limits. Parameters ---------- xlimits : tuple of float, optional X-axis limits as (xmin, xmax). If None, uses automatically determined limits, by default None. **kwargs Additional keyword arguments passed to matplotlib set_xlim. """ if xlimits is None: xlimits = self.x_limits self.ax.set_xlim(xlimits, **kwargs)
[docs] def set_ylim(self, ylimits =None, **kwargs): """Set the y-axis limits. Parameters ---------- ylimits : tuple of float, optional Y-axis limits as (ymin, ymax). If None, uses automatically determined limits, by default None. **kwargs Additional keyword arguments passed to matplotlib set_ylim. """ if ylimits is None: ylimits = self.y_limits self.ax.set_ylim(ylimits, **kwargs)
[docs] def set_xlabel(self, xlabel = '$k_{x}$ ($\AA^{-1}$)', **kwargs): """Set the x-axis label. Parameters ---------- xlabel : str, optional X-axis label text, by default '$k_{x}$ ($\AA^{-1}$)'. **kwargs Additional keyword arguments passed to matplotlib set_xlabel. """ self.ax.set_xlabel(xlabel, **kwargs)
[docs] def set_ylabel(self, ylabel = '$k_{y}$ ($\AA^{-1}$)', **kwargs): """Set the y-axis label. Parameters ---------- ylabel : str, optional Y-axis label text, by default '$k_{y}$ ($\AA^{-1}$)'. **kwargs Additional keyword arguments passed to matplotlib set_ylabel. """ self.ax.set_ylabel(ylabel, **kwargs)
[docs] def set_tick_params(self, axis: str = "both", which: str = "major", reset: bool = False, **kwargs): """Set the tick parameters for the axes. Parameters ---------- axis : str, optional Which axis to apply the parameters to ('x', 'y', or 'both'), by default "both". which : str, optional Which ticks to apply the parameters to ('major', 'minor', or 'both'), by default "major". reset : bool, optional Whether to reset all parameters before setting new ones, by default False. **kwargs Additional keyword arguments passed to matplotlib tick_params. """ self.ax.tick_params(axis=axis, which=which, reset=reset, **kwargs)
[docs] def get_colorbar(self): """Get the colorbar object. Returns ------- matplotlib.colorbar.Colorbar The colorbar instance associated with this plot. """ return self.colorbar
[docs] def add_legend(self, **kwargs): """Add a legend to the plot. Parameters ---------- **kwargs Additional keyword arguments passed to matplotlib legend function. """ self.ax.legend()
[docs] def set_aspect(self, aspect: float | str = "equal", **kwargs): """Set the aspect ratio of the plot. Parameters """ self.ax.set_aspect(aspect, **kwargs)
[docs] def savefig(self, savefig, dpi: int | str = "figure", **kwargs): """Save the figure to a file. Parameters ---------- savefig : str or path-like The filename or path where the figure should be saved. dpi : int or str, optional The resolution in dots per inch. Can be 'figure' to use figure's dpi, by default "figure". **kwargs Additional keyword arguments passed to matplotlib savefig function. """ self.fig.savefig(savefig, dpi=dpi, bbox_inches="tight", **kwargs)
[docs] def show(self, **kwargs): """Display the plot. Parameters ---------- **kwargs Additional keyword arguments passed to matplotlib show function. """ plt.show(**kwargs)