from aimstools.misc import *
from pathlib import Path
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.ticker as ticker
from ase.dft.kpoints import parse_path_string, BandPath
from ase.formula import Formula
from aimstools.density_of_states.utilities import gradient_fill
[docs]class PhononSpectrum:
"""Container class for eigenvalue spectrum and associated data.
Attributes:
atoms (ase.atoms.Atoms): ASE atoms object.
qpoints (ndarray): (nqpoints, 3) array with k-points.
qpoint_axis (ndarray): (nqpoints, 1) linear plotting axis.
frequencies (ndarray): (nqpoints, nbands) array with frequencies.
label_coords (list): List of k-point label coordinates on the plotting axis.
qpoint_labels (list): List of k-point labels.
jumps (list): List of jumps from unconnected Brillouin zone sections on the plotting axis.
unit (str): Energy unit.
"""
def __init__(
self,
atoms: "ase.atoms.Atoms" = None,
qpoints: "numpy.ndarray" = None,
qpoint_axis: "numpy.ndarray" = None,
frequencies: "numpy.ndarray" = None,
label_coords: list = None,
qpoint_labels: list = None,
jumps: list = None,
unit: str = None,
bandpath: str = None,
) -> None:
self._atoms = atoms
self._qpoints = qpoints
self._qpoint_axis = qpoint_axis
self._frequencies = frequencies
self._label_coords = label_coords
self._qpoint_labels = qpoint_labels
self._jumps = jumps
self._unit = unit
self._bandpath = bandpath
def __repr__(self):
return "{}(bandpath={}, unit={})".format(
self.__class__.__name__, self.bandpath, self.unit
)
@property
def atoms(self):
return self._atoms
@property
def qpoints(self):
return self._qpoints
@property
def qpoint_axis(self):
return self._qpoint_axis
@property
def frequencies(self):
return self._frequencies
@property
def label_coords(self):
return self._label_coords
@property
def qpoint_labels(self):
return self._qpoint_labels
@property
def jumps(self):
return self._jumps
@property
def unit(self):
return self._unit
@property
def bandpath(self):
return self._bandpath
[docs]class PhononDOS:
"""Container class for phonon DOS spectrum and associated data.
Attributes:
atoms (ase.atoms.Atoms): ASE atoms object.
frequencies (ndarray): (nqpoints, nbands) array with frequencies.
label_coords (list): List of k-point label coordinates on the plotting axis.
qpoint_labels (list): List of k-point labels.
jumps (list): List of jumps from unconnected Brillouin zone sections on the plotting axis.
unit (str): Energy unit.
"""
def __init__(
self,
atoms: "ase.atoms.Atoms" = None,
frequencies: "numpy.ndarray" = None,
contributions: "numpy.ndarray" = None,
unit: str = None,
) -> None:
self._atoms = atoms
self._frequencies = frequencies
self._contributions = contributions
self._unit = unit
def __repr__(self):
return "{}(unit={})".format(self.__class__.__name__, self.unit)
@property
def atoms(self):
return self._atoms
@property
def frequencies(self):
return self._frequencies
@property
def contributions(self):
return self._contributions
@property
def unit(self):
return self._unit
[docs]class PhononPlot:
"""Context to draw phonon plot."""
def __init__(self, main=True, **kwargs) -> None:
self.ax = kwargs.get("ax", None)
assert (
type(self.ax) != list
), "Axes object must be a single matplotlib.axes.Axes, not list."
self.spectrum = kwargs.get("spectrum", None)
self.set_data_from_spectrum()
self.show_grid_lines = kwargs.get("show_grid_lines", True)
self.grid_lines_axes = kwargs.get("show_grid_lines_axes", "x")
self.grid_linestyle = kwargs.get("grid_linestyle", (0, (1, 1)))
self.grid_linewidth = kwargs.get("grid_linewidth", 1.0)
self.grid_linecolor = kwargs.get("grid_linecolor", mutedblack)
self.show_jumps = kwargs.get("show_jumps", True)
self.jumps_linewidth = kwargs.get(
"jumps_linewidth", plt.rcParams["lines.linewidth"]
)
self.jumps_linestyle = kwargs.get("jumps_linestyle", "-")
self.jumps_linecolor = kwargs.get("jumps_linecolor", mutedblack)
self.show_bandstructure = kwargs.get("show_bandstructure", True)
self.bands_color = kwargs.get("bands_color", mutedblack)
self.bands_color = kwargs.get("color", mutedblack)
self.bands_linewidth = kwargs.get(
"bands_linewidth", plt.rcParams["lines.linewidth"]
)
self.bands_linewidth = kwargs.get("linewidth", plt.rcParams["lines.linewidth"])
self.bands_linestyle = kwargs.get("bands_linestyle", "-")
self.bands_linestyle = kwargs.get("linestyle", "-")
self.bands_alpha = kwargs.get("bands_alpha", 1.0)
self.bands_alpha = kwargs.get("alpha", 1.0)
self.show_acoustic_bands = kwargs.get("show_acoustic_bands", True)
self.acoustic_bands_color = kwargs.get("acoustic_bands_color", "royalblue")
self.y_tick_locator = kwargs.get("y_tick_locator", 100)
self.set_xy_axes_labels()
self.set_qpoint_labels()
self.set_x_limits()
self.main = main
[docs] def set_data_from_spectrum(self):
spectrum = self.spectrum
self.labels = spectrum.qpoint_labels.copy()
self.labelcoords = spectrum.label_coords.copy()
self.jumps = spectrum.jumps.copy()
self.x = spectrum.qpoint_axis.copy()
self.y = spectrum.frequencies.copy()
self.unit = spectrum.unit
[docs] def draw(self):
ylocs = ticker.MultipleLocator(base=self.y_tick_locator)
self.ax.yaxis.set_major_locator(ylocs)
self.ax.set_xlabel(self.xlabel, fontsize=plt.rcParams["axes.labelsize"])
self.ax.set_ylabel(self.ylabel, fontsize=plt.rcParams["axes.labelsize"])
self.ax.set_xlim(self.xlimits)
self.ax.set_xticks(self.xlabelcoords)
self.ax.set_xticklabels(self.xlabels, fontsize=plt.rcParams["axes.labelsize"])
self.ax.tick_params(axis="x", which="both", length=0)
if self.show_grid_lines and self.main:
self.ax.grid(
b=self.show_grid_lines,
which="major",
axis=self.grid_lines_axes,
linestyle=self.grid_linestyle,
linewidth=self.grid_linewidth,
color=self.grid_linecolor,
)
if self.show_jumps and self.main:
for j in self.jumps:
self.ax.axvline(
x=j,
linestyle=self.jumps_linestyle,
color=self.jumps_linecolor,
linewidth=self.jumps_linewidth,
)
if self.show_bandstructure and self.main:
self.ax.plot(
self.x,
self.y,
color=self.bands_color,
alpha=self.bands_alpha,
linewidth=self.bands_linewidth,
linestyle=self.bands_linestyle,
)
if self.show_acoustic_bands and self.main:
self._show_accoustic()
[docs] def set_xy_axes_labels(self):
self.xlabel = ""
self.ylabel = "frequency [{}]".format(self.unit)
[docs] def set_qpoint_labels(self):
def pretty(kpt):
if kpt == "G":
kpt = r"$\Gamma$"
elif len(kpt) == 2:
kpt = kpt[0] + "$_" + kpt[1] + "$"
return kpt
labels = self.labels
labels = [pretty(j) for j in labels]
coords = self.labelcoords
i = 1
while i < len(labels):
if coords[i - 1] == coords[i]:
labels[i - 1] = labels[i - 1] + "|" + labels[i]
labels.pop(i)
coords.pop(i)
else:
i += 1
self.xlabels = labels
self.xlabelcoords = coords
[docs] def set_x_limits(self):
x = self.x
lower_xlimit = 0.0
upper_xlimit = np.max(x)
self.xlimits = (lower_xlimit, upper_xlimit)
def _show_accoustic(self):
y = self.y.copy()
acc = y[:, :3]
self.ax.plot(
self.x,
acc,
color=self.acoustic_bands_color,
linewidth=self.bands_linewidth,
linestyle=self.bands_linestyle,
alpha=self.bands_alpha,
)
[docs]class PhononDOSPlot:
"""Context to draw Phonon DOS plot. Handles labelling, shifting and broadening."""
def __init__(
self,
main: bool = True,
dos: "aimstools.phonons.utilities.PhononDOS" = None,
**kwargs
) -> None:
self.ax = kwargs.get("ax", None)
assert (
type(self.ax) != list
), "Axes object must be a single matplotlib.axes.Axes, not list."
self.dos = dos
self.energies = self.dos.frequencies
self.contributions = self.dos.contributions
self.unit = self.dos.unit
self.flip_axes = kwargs.get("flip_axes", True)
self.main = main
self.dos_linewidth = kwargs.get(
"dos_linewidth", plt.rcParams["lines.linewidth"]
)
self.dos_linestyle = kwargs.get("dos_linestyle", "-")
self.show_grid_lines = kwargs.get("show_grid_lines", False)
self.grid_lines_axes = kwargs.get("show_grid_lines_axes", "x")
self.grid_linestyle = kwargs.get("grid_linestyle", (0, (1, 1)))
self.grid_linewidth = kwargs.get("grid_linewidth", 1.0)
self.grid_linecolor = kwargs.get("grid_linecolor", mutedblack)
self.color = kwargs.get("color", mutedblack)
self.energy_tick_locator = kwargs.get("energy_tick_locator", 100)
self.dos_tick_locator = kwargs.get("dos_tick_locator", "auto")
self.fill = kwargs.get("fill", "gradient")
self.set_dos_window()
self.set_xy_axes_labels()
self.set_dos_tick_locator()
[docs] def set_dos_tick_locator(self):
if self.dos_tick_locator == "auto":
a, b = self.lower_dos_limit, self.upper_dos_limit
d = round(abs(b - a) / 3, 1)
self.dos_tick_locator = d
else:
assert isinstance(
self.dos_tick_locator, (int, float)
), "DOS tick locator must be int or float."
[docs] def set_xy_axes_labels(self):
self.dos_label = r"DOS"
self.energy_label = "frequency [{}]".format(self.unit)
[docs] def set_dos_window(self):
tdos = self.contributions.copy()
self.lower_dos_limit = 0
self.upper_dos_limit = np.max(tdos) * 1.05
[docs] def draw(self):
energies = self.energies.copy()
values = self.contributions.copy()
if self.flip_axes:
xlabel = self.dos_label
ylabel = self.energy_label
xlimits = (self.lower_dos_limit, self.upper_dos_limit)
ylimits = (np.min(energies), np.max(energies))
xlocs = ticker.MultipleLocator(base=self.dos_tick_locator)
ylocs = ticker.MultipleLocator(base=self.energy_tick_locator)
else:
xlabel = self.energy_label
ylabel = self.dos_label
ylimits = (np.min(energies), np.max(energies))
ylimits = (self.lower_dos_limit, self.upper_dos_limit)
xlocs = ticker.MultipleLocator(base=self.energy_tick_locator)
ylocs = ticker.MultipleLocator(base=self.dos_tick_locator)
self.ax.xaxis.set_major_locator(xlocs)
self.ax.yaxis.set_major_locator(ylocs)
self.ax.set_xlabel(xlabel, fontsize=plt.rcParams["axes.labelsize"])
self.ax.set_ylabel(ylabel, fontsize=plt.rcParams["axes.labelsize"])
self.ax.set_xlim(xlimits)
self.ax.set_ylim(ylimits)
if self.flip_axes:
x = values
y = energies
else:
x = energies
y = values
self.ax.plot(
x,
y,
color=self.color,
linewidth=self.dos_linewidth,
linestyle=self.dos_linestyle,
)
if self.fill == "gradient":
self.ax = gradient_fill(x, y, self.ax, self.color, flip=self.flip_axes)
if self.show_grid_lines and self.main:
self.ax.grid(
b=self.show_grid_lines,
which="major",
axis=self.grid_lines_axes,
linestyle=self.grid_linestyle,
linewidth=self.grid_linewidth,
color=self.grid_linecolor,
)