# Copyright (c) Scanlon Materials Theory Group
# Distributed under the terms of the MIT License.
"""
Subpackage providing helper functions for generating publication ready plots.
"""
import os
from functools import wraps
import matplotlib.pyplot
import numpy as np
from matplotlib import rcParams
from matplotlib.collections import LineCollection
try:
from importlib.resources import files as ilr_files
except ImportError: # Python < 3.9
from importlib_resources import files as ilr_files
colour_cache = {}
sumo_base_style = ilr_files("sumo.plotting") / "sumo_base.mplstyle"
sumo_dos_style = ilr_files("sumo.plotting") / "sumo_dos.mplstyle"
sumo_bs_style = ilr_files("sumo.plotting") / "sumo_bs.mplstyle"
sumo_phonon_style = ilr_files("sumo.plotting") / "sumo_phonon.mplstyle"
sumo_optics_style = ilr_files("sumo.plotting") / "sumo_optics.mplstyle"
[docs]
def styled_plot(*style_sheets):
"""Return a decorator that will apply matplotlib style sheets to a plot.
``style_sheets`` is a base set of styles, which will be ignored if
``no_base_style`` is set in the decorated function arguments.
The style will further be overwritten by any styles in the ``style``
optional argument of the decorated function.
Args:
style_sheets (:obj:`list`, :obj:`str`, or :obj:`dict`): Any matplotlib
supported definition of a style sheet. Can be a list of style of
style sheets.
"""
def decorator(get_plot):
@wraps(get_plot)
def wrapper(*args, fonts=None, style=None, no_base_style=False, **kwargs):
if no_base_style:
list_style = []
else:
list_style = list(style_sheets)
if style is not None:
if isinstance(style, list):
list_style += style
else:
list_style += [style]
if fonts is not None:
list_style += [{"font.family": "sans-serif", "font.sans-serif": fonts}]
matplotlib.pyplot.style.use(list_style)
return get_plot(*args, **kwargs)
return wrapper
return decorator
[docs]
def pretty_plot(width=None, height=None, plt=None, dpi=None):
"""Get a :obj:`matplotlib.pyplot` object with publication ready defaults.
Args:
width (:obj:`float`, optional): The width of the plot.
height (:obj:`float`, optional): The height of the plot.
plt (:obj:`matplotlib.pyplot`, optional): A :obj:`matplotlib.pyplot`
object to use for plotting.
dpi (:obj:`int`, optional): The dots-per-inch (pixel density) for
the plot.
Returns:
:obj:`matplotlib.pyplot`: A :obj:`matplotlib.pyplot` object with
publication ready defaults set.
"""
if plt is None:
plt = matplotlib.pyplot
if width is None:
width = matplotlib.rcParams["figure.figsize"][0]
if height is None:
height = matplotlib.rcParams["figure.figsize"][1]
if dpi is not None:
matplotlib.rcParams["figure.dpi"] = dpi
fig = plt.figure(figsize=(width, height))
fig.add_subplot(1, 1, 1)
return plt
[docs]
def pretty_subplot(
nrows,
ncols,
width=None,
height=None,
sharex=True,
sharey=True,
dpi=None,
plt=None,
gridspec_kw=None,
):
"""Get a :obj:`matplotlib.pyplot` subplot object with pretty defaults.
Args:
nrows (int): The number of rows in the subplot.
ncols (int): The number of columns in the subplot.
width (:obj:`float`, optional): The width of the plot.
height (:obj:`float`, optional): The height of the plot.
sharex (:obj:`bool`, optional): All subplots share the same x-axis.
Defaults to ``True``.
sharey (:obj:`bool`, optional): All subplots share the same y-axis.
Defaults to ``True``.
dpi (:obj:`int`, optional): The dots-per-inch (pixel density) for
the plot.
plt (:obj:`matplotlib.pyplot`, optional): A :obj:`matplotlib.pyplot`
object to use for plotting.
gridspec_kw (:obj:`dict`, optional): Gridspec parameters. Please see:
:obj:`matplotlib.pyplot.subplot` for more information. Defaults
to ``None``.
Returns:
:obj:`matplotlib.pyplot`: A :obj:`matplotlib.pyplot` subplot object
with publication ready defaults set.
"""
if width is None:
width = rcParams["figure.figsize"][0]
if height is None:
height = rcParams["figure.figsize"][1]
# TODO: Make this work if plt is already set...
if plt is None:
plt = matplotlib.pyplot
plt.subplots(
nrows,
ncols,
sharex=sharex,
sharey=sharey,
dpi=dpi,
figsize=(width, height),
facecolor="w",
gridspec_kw=gridspec_kw,
)
return plt
[docs]
def curry_power_tick(times_sign=r"\times"):
def f(val, pos):
return power_tick(val, pos, times_sign=times_sign)
return f
[docs]
def power_tick(val, pos, times_sign=r"\times"):
"""Custom power ticker function."""
if val == 0:
return r"$\mathregular{0}$"
elif val < 0:
exponent = int(np.log10(-val))
else:
exponent = int(np.log10(val))
coeff = val / 10**exponent
prec = 0 if coeff % 1 == 0 else 1
return rf"${coeff:.{prec}f}\mathrm{{{times_sign}}}10^{{{exponent:2d}}}$"
[docs]
def colorline(
x,
y,
weights,
color1="#FF0000",
color2="#00FF00",
color3="#0000FF",
colorspace="lab",
linestyles="solid",
linewidth=2.5,
):
"""Get a RGB coloured line for plotting.
Args:
x (list): x-axis data.
y (list): y-axis data (can be multidimensional array).
weights (list): The weights of the color1, color2, and color3 channels.
Given as an array with the shape (n, 3), where n is the same length
as the x and y data.
color1 (str): A color specified in any way supported by matplotlib.
color2 (str): A color specified in any way supported by matplotlib.
color3 (str): A color specified in any way supported by matplotlib.
colorspace (str): The colorspace in which to perform the interpolation.
The allowed values are rgb, hsv, lab, luvlc, lablch, and xyz.
linestyles (:obj:`str`, optional): Linestyle for plot. Options are
``"solid"`` or ``"dotted"``.
"""
y = np.array(y)
if len(y.shape) == 1:
y = np.array([y])
weights = np.array([weights])
seg = []
colours = []
for yy, ww in zip(y, weights):
pts = np.array([x, yy]).T.reshape(-1, 1, 2)
if len(pts) > 1: # need at least one point to interpolate colours
seg.extend(np.concatenate([pts[:-1], pts[1:]], axis=1))
nseg = len(x) - 1
w = [0.5 * (ww[i] + ww[i + 1]) for i in range(nseg)]
c = get_interpolated_colors(
color1, color2, color3, w, colorspace=colorspace
)
colours.extend(c.tolist())
lc = LineCollection(
seg,
colors=colours,
rasterized=True,
linewidth=linewidth,
linestyles=linestyles,
)
return lc
[docs]
def get_interpolated_colors(color1, color2, color3, weights, colorspace="lab"):
"""
Interpolate colors at a number of points within a colorspace.
Args:
color1 (str): A color specified in any way supported by matplotlib.
color2 (str): A color specified in any way supported by matplotlib.
color3 (str): A color specified in any way supported by matplotlib.
weights (list): A list of weights with the shape (n, 3).
Where the 3 values of the last axis give the amount of
color1, color2, and color3.
colorspace (str): The colorspace in which to perform the interpolation.
The allowed values are rgb, hsv, lab, luvlc, lablch, and xyz.
Returns:
A list of colors, specified in the rgb format as a (n, 3) array.
"""
from colormath.color_conversions import convert_color
from colormath.color_objects import (
HSVColor,
LabColor,
LCHabColor,
LCHuvColor,
XYZColor,
sRGBColor,
)
from matplotlib.colors import to_rgb
colorspace_mapping = {
"rgb": sRGBColor,
"hsv": HSVColor,
"lab": LabColor,
"luvlch": LCHuvColor,
"lablch": LCHabColor,
"xyz": XYZColor,
}
if colorspace not in list(colorspace_mapping.keys()):
raise ValueError(f"colorspace must be one of {colorspace_mapping.keys()}")
colorspace = colorspace_mapping[colorspace]
# first convert matplotlib color specification to colormath sRGB
color1_rgb = sRGBColor(*to_rgb(color1))
color2_rgb = sRGBColor(*to_rgb(color2))
color3_rgb = sRGBColor(*to_rgb(color3))
# now convert to the colorspace basis for interpolation
basis1 = np.array(
convert_color(color1_rgb, colorspace, target_illuminant="d50").get_value_tuple()
)
basis2 = np.array(
convert_color(color2_rgb, colorspace, target_illuminant="d50").get_value_tuple()
)
basis3 = np.array(
convert_color(color3_rgb, colorspace, target_illuminant="d50").get_value_tuple()
)
# ensure weights is a numpy array
weights = np.asarray(weights)
# perform the interpolation in the colorspace basis
colors = (
basis1 * weights[:, 0][:, None]
+ basis2 * weights[:, 1][:, None]
+ basis3 * weights[:, 2][:, None]
)
# convert colors to RGB
rgb_colors = [
convert_color(colorspace(*c), sRGBColor).get_value_tuple() for c in colors
]
# ensure all rgb values are less than 1 (sometimes issues in interpolation
# gives values slightly over 1)
return np.minimum(rgb_colors, 1)
[docs]
def draw_themed_line(y, ax, orientation="horizontal", **kwargs):
"""Draw a horizontal line using the theme settings
Args:
y (float): Position of line in data coordinates
ax (Axes): Matplotlib Axes on which line is drawn
orientation (str, optional): Orientation of line. Options are
``"horizontal"`` or ``"vertical"``.
**kwargs: Additional keyword arguments passed to ``ax.axhline`` or
``ax.axvline``, which can be used to override the theme settings.
"""
# Note to future developers: feel free to add plenty more optional
# arguments to this to mess with linestyle, zorder etc.
# Just .update() the options dict
themed_line_options = dict(
color=rcParams["grid.color"],
linestyle="--",
dashes=(5, 2),
zorder=0,
linewidth=rcParams["ytick.major.width"],
)
themed_line_options.update(kwargs)
if orientation == "horizontal":
ax.axhline(y, **themed_line_options)
elif orientation == "vertical":
ax.axvline(y, **themed_line_options)
else:
raise ValueError(f'Line orientation "{orientation}" not supported')