Source code for aimstools.bandstructures.regular_bandstructure

from logging import raiseExceptions
from aimstools.misc import *
from aimstools.bandstructures.base import BandStructureBaseClass
from aimstools.bandstructures.utilities import BandStructurePlot, BandSpectrum

from ase.dft.kpoints import parse_path_string

from collections import namedtuple

import numpy as np


[docs]class RegularBandStructure(BandStructureBaseClass): def __init__(self, outputfile, soc=False) -> None: super().__init__(outputfile) self.soc = soc self.spin = "none" if self.control["spin"] != "collinear" else "collinear" self.task = "band structure" self.band_sections = self.band_sections.regular self._set_bandpath_from_sections() if self.spin == "none": bandfiles = self.get_bandfiles(spin="none", soc=soc) bandfiles = bandfiles.regular self.bands = self.read_bandfiles(zip(bandfiles, bandfiles)) else: if soc: logger.warning( "FHI-aims puts out two band files for spin=collinear with SOC, but only the first one seems to be meaningful. Please check for yourself." ) self.spin == "none" bandfiles = self.get_bandfiles(spin="none", soc=soc).regular self.bands = self.read_bandfiles(zip(bandfiles, bandfiles)) else: bandfiles_up = self.get_bandfiles(spin="up", soc=soc).regular bandfiles_dn = self.get_bandfiles(spin="dn", soc=soc).regular self.bands = self.read_bandfiles(zip(bandfiles_up, bandfiles_dn)) self._spectrum = self.set_spectrum(None, None) def __repr__(self): return "{}(outputfile={}, spin_orbit_coupling={})".format( self.__class__.__name__, repr(self.outputfile), self.soc )
[docs] def read_bandfiles(self, bandfiles): bands = {} b = namedtuple("band", ["kpoints", "occupations", "eigenvalues"]) nspins = 2 if self.spin == "collinear" else 1 for section, bandfile in zip(self.band_sections, bandfiles): ev, occ, rocc, revs = [], [], [], [] for s in range(nspins): bf = bandfile[s] # index, k1, k2, k3, occ, ev, occ, ev ... data = np.loadtxt(bf)[:, 1:] points = data[:, :3] occupations = data[:, 3:-2:2] eigenvalues = data[:, 4:-1:2] rpoints = points[::-1].copy() roccs = occupations[::-1].copy() reigvs = eigenvalues[::-1].copy() ev.append(eigenvalues) revs.append(reigvs) occ.append(occupations) rocc.append(roccs) ev = np.stack(ev, axis=1) occ = np.stack(occ, axis=1) rocc = np.stack(rocc, axis=1) revs = np.stack(revs, axis=1) band_forward = b(points, occ, ev) band_backward = b(rpoints, rocc, revs) pathsegment = (section.symbol1, section.symbol2) pathsegment_r = (section.symbol2, section.symbol1) bands[pathsegment] = band_forward bands[pathsegment_r] = band_backward return bands
[docs] def set_spectrum(self, bandpath=None, reference=None): bands = self.bands if bandpath != None: self.set_bandpath(bandpath) bp = parse_path_string(self.bandpath.path) jumps = [] kps = [] occs = [] kpoint_axis = [] kpoint_labels = [] label_coords = [] spectrum = [] icell_cv = 2 * np.pi * np.linalg.pinv(self.structure.cell).T for segment in bp: kpoint_labels.append(segment[0]) label_coords.append(label_coords[-1] if len(label_coords) > 0 else 0.0) for s1, s2 in zip(segment[:-1], segment[1:]): energies = bands[(s1, s2)].eigenvalues kpoints = np.dot(bands[(s1, s2)].kpoints, icell_cv) occ = bands[(s1, s2)].occupations kstep = np.linalg.norm(kpoints[-1, :] - kpoints[0, :]) kaxis = np.linspace(0, kstep, kpoints.shape[0]) + label_coords[-1] kstep += label_coords[-1] spectrum.append(energies) kps.append(kpoints) occs.append(occ) kpoint_axis.append(kaxis) label_coords.append(kstep) kpoint_labels.append(s2) jumps.append(label_coords[-1]) jumps = jumps[:-1] spectrum = np.concatenate(spectrum, axis=0) kps = np.concatenate(kps, axis=0) kpoint_axis = np.concatenate(kpoint_axis, axis=0) occs = np.concatenate(occs, axis=0) fermi_level = self.fermi_level.soc if self.soc else self.fermi_level.scalar self.set_energy_reference(reference, self.soc) reference, shift = self.energy_reference sp = BandSpectrum( atoms=self.structure.atoms, kpoints=kps, kpoint_axis=kpoint_axis, eigenvalues=spectrum, occupations=occs, label_coords=label_coords, kpoint_labels=kpoint_labels, jumps=jumps, fermi_level=fermi_level, reference=reference, shift=shift, bandpath=bandpath, ) self._spectrum = sp
@property def spectrum(self): if self._spectrum == None: self.set_spectrum(bandpath=None, reference=None) return self._spectrum
[docs] def get_spectrum(self, bandpath=None, reference=None): self.set_spectrum(reference=reference, bandpath=bandpath) return self.spectrum
def _process_kwargs(self, kwargs): kwargs = kwargs.copy() spin = kwargs.pop("spin", None) deprecated = ["title", "mark_fermi_level", "mark_band_gap"] for dep in deprecated: if dep in kwargs.keys(): kwargs.pop(dep) logger.warning( f"Keyword {dep} is deprecated. Please do not use this anymore." ) kwargs["spin"] = self.spin2index(spin) return kwargs
[docs] def plot(self, axes=None, **kwargs): """Main function to handle plotting of band structures. Supports all keywords of :func:`~aimstools.bandstructures.regular_bandstructure.RegularBandStructure.plot`. Example: >>> from aimstools.bandstructures import RegularBandStructure as RBS >>> bs = RBS("path/to/dir") >>> bs.plot() Args: axes (matplotlib.axes.Axes): Axes to draw on, defaults to None. figsize (tuple): Figure size in inches. Defaults to (5,5). filename (str): Saves figure to file. Defaults to None. spin (int): Spin channel, can be "up", "dn", 0 or 1. Defaults to 0. bandpath (str): Band path for plotting of form "GMK,GA". reference (str): Energy reference for plotting, e.g., "VBM", "middle", "fermi level". Defaults to None. show_fermi_level (bool): Show Fermi level. Defaults to True. fermi_level_color (str): Color of Fermi level line. Defaults to fermi_color. fermi_level_alpha (float): Alpha channel of Fermi level line. Defaults to 1.0. fermi_level_linestyle (str): Line style of Fermi level line. Defaults to "--". fermi_level_linewidth (float): Line width of Fermi level line. Defaults to mpllinewidth. show_grid_lines (bool): Show grid lines for axes ticks. Defaults to True. grid_lines_axes (str): Show grid lines for given axes. Defaults to "x". grid_linestyle (tuple): Grid lines linestyle. Defaults to (0, (1, 1)). grid_linewidth (float): Width of grid lines. Defaults to 1.0. grid_linecolor (str): Grid lines color. Defaults to mutedblack. show_jumps (bool): Show jumps between Brillouin zone sections by darker vertical lines. Defaults to True. jumps_linewidth (float): Width of jump lines. Defaults to mpllinewidth. jumps_linestyle (str): Line style of the jump lines. Defaults to "-". jumps_linecolor (str): Color of the jump lines. Defaults to mutedblack. show_bandstructure (bool): Show band structure lines. Defaults to True. bands_color (bool): Color of the band structure lines. Synonymous with color. Defaults to mutedblack. bands_linewidth (float): Line width of band structure lines. Synonymous with linewidth. Defaults to mpllinewidth. bands_linestyle (str): Band structure lines linestyle. Synonymous with linestyle. Defaults to "-". bands_alpha (float): Band structure lines alpha channel. Synonymous with alpha. Defaults to 1.0. show_bandgap_vertices (bool): Show direct and indirect band gap transitions. Defaults to True. window (tuple): Window on energy axis, can be float or tuple of two floats in eV. Defaults to 3 eV. y_tick_locator (float): Places ticks on energy axis on regular intervals. Defaults to 0.5 eV. Returns: axes: Axes object. """ kwargs = self._process_kwargs(kwargs) bandpath = kwargs.pop("bandpath", None) reference = kwargs.pop("reference", None) spectrum = self.get_spectrum(bandpath=bandpath, reference=reference) with AxesContext(ax=axes, **kwargs) as axes: bs = BandStructurePlot(ax=axes, spectrum=spectrum, **kwargs) bs.draw() return axes