Source code for tp.plot.utilities

"""Utilities to aid the plotting scripts."""

#Functions
#---------
#
#    colour_scale:
#        sorts colour limits and colourbar format.
#    parse_colours
#        returns a colour map from a range of input types.
#    scale_to_axis
#        scales data to axes limits
#    set_locators:
#        set locators one-liner.
#"""

import matplotlib as mpl
import numpy as np
import tp

[docs]def colour_scale(c, name, cmap, cmin=None, cmax=None, cscale=None, unoccupied='grey'): """Formats the colour scale for phono3py quantities. Attempts to set the limits to maximise visibility of the most important data and determines if the colourbar should be extended. Arguments --------- c : array-like colour data. name : str name of colour variable. cmap : colormap colourmap. cmin : float, optional colour scale minimum. Default: display 99 % data. cmax : float, optional colour scale maximum. Default: display 99.9 % data. cscale : str, optional override colour scale (linear/ log). Default: None. unoccupied : str, optional if the colour variable is occupation, values below 1 are coloured in this colour. If set to None, or cmin is set, this feature is turned off. Default: grey. Returns ------- cnorm colour normalisation object. extend which direction to extend the colourbar in. """ extend = [False, False] if name in ['frequency', 'heat_capacity']: if cmin is None: cmin = np.nanmin(c) elif cmin > np.nanmin(c): extend[0] = True if cmax is None: cmax = np.nanmax(c) elif cmax < np.nanmax(c): extend[1] = True cnorm = mpl.colors.Normalize(vmin=cmin, vmax=cmax) elif name == 'occupation': csort = np.ravel(np.ma.masked_invalid(np.ma.masked_equal(c, 0)).compressed()) csort = csort[csort.argsort()] clim = csort[int(round(len(csort)*99.9/100 - 1, 0))] if cmin is None: cmin = np.nanmin(c) if unoccupied is not None and cmin < 1: cmin = 1 extend[0] = True cmap.set_under(unoccupied) elif cmin > np.nanmin(c): extend[0] = True if cmax is None: cmax = clim extend[1] = True elif cmax < np.nanmax(c): extend[1] = True if cmax/cmin < 10: cnorm = mpl.colors.Normalize(vmin=cmin, vmax=cmax) else: cnorm = mpl.colors.LogNorm(vmin=cmin, vmax=cmax) else: csort = np.ravel(np.ma.masked_invalid(np.ma.masked_equal(c, 0)).compressed()) csort = csort[csort.argsort()] clim = [csort[int(round(len(csort)/100 - 1, 0))], csort[int(round(len(csort)*99.9/100 - 1, 0))]] if cmin is None: cmin = clim[0] elif cmin > clim[0]: extend[0] = True if cmax is None: cmax = clim[1] elif cmax < clim[1]: extend[1] = True cnorm = mpl.colors.LogNorm(vmin=cmin, vmax=cmax) if cscale is not None: if cscale == 'linear': cnorm = mpl.colors.Normalize(vmin=cmin, vmax=cmax) elif cscale == 'log': cnorm = mpl.colors.LogNorm(vmin=cmin, vmax=cmax) if extend[0] and extend[1]: extend = 'both' elif extend[0] and not extend[1]: extend = 'min' elif not extend[0] and extend[1]: extend = 'max' elif not extend[0] and not extend[1]: extend = 'neither' return cnorm, extend
[docs]def parse_colours(colour): """Parses a range of inputs into a colourmap Argument -------- colour : colourmap or str or array-like or dict colourmap or colourmap name or #rrggbb highlight colour or highlight, min, max colours in that order, or dictionary with cmid and cmin and/or cmax keys. Returns ------- colourmap colourmap """ from copy import copy try: cmap = copy(mpl.cm.get_cmap(colour)) except ValueError: if isinstance(colour, mpl.colors.ListedColormap): cmap = copy(colour) elif isinstance(colour, str): cmap = tp.plot.colour.uniform(colour) elif isinstance(colour, list): try: # rgb cmap = tp.plot.colour.uniform(colour) except ValueError: # hex or named cmap = tp.plot.colour.uniform(*colour) elif isinstance(colour, dict): cmap = tp.plot.colour.uniform(**colour) else: raise Exception('colour must be a colourmap, colourmap ' 'name, single highlight colour or ' 'highlight, min, max colours in that ' 'order, or a dictionary with cmid and ' 'cmin and/or cmax keys. Colour format must ' 'be hex rgb (array) or a named colour ' 'recognised by matplotlib. ') return cmap
[docs]def scale_to_axis(ax, data, exclude=[], scale=None, axis='y'): """Scale data to fit an axis. Assumes data is linear, but still works for linear and log axes. Arguments --------- ax : axes axes to fit to. data : array-like or dict data to scale. exclude : array-like, optional keys to exclude from scaling. Default: none. scale : array-like, optional force scaling limits to [min, max]. Default: axis limits. axis : bool, optional axis to scale to. Default: y. Returns ------- data scaled data. """ if scale is not None: assert isinstance(scale, list) and len(list(scale)) == 2, \ 'if specified, scale must be a two element array-like [min,max]' assert axis in ['x', 'y'], 'axis must be x or y' if not isinstance(data, dict): data = {'qwerfv': data} dmax = 0 for key in data: if key not in exclude: ymax = float(np.amax(data[key])) if ymax > dmax: dmax = ymax axes = {'x': ax.get_xlim(), 'y': ax.get_ylim()} axscale = {'x': ax.get_xaxis().get_scale(), 'y': ax.get_yaxis().get_scale()} if scale is None: scale = axes[axis] log = axscale[axis] == 'log' if log: scale = np.log10(scale) for key in data: if key not in exclude: data[key] = np.multiply(data[key], (scale[1]-scale[0]) / dmax) data[key] = np.add(scale[0], data[key]) if log: data[key] = np.power(10, data[key]) if 'qwerfv' in data: data = data['qwerfv'] return data
[docs]def set_locators(ax, x=None, y=None, dos=False): """Set locators quickly. If log, sets scale as well. Arguments --------- ax : axes axes to format x : str, optional locator on x axis. Accepts linear, log or null. Default: do nothing. y : str, optional locator on y axis. Accepts linear, log or null. Default: do nothing. dos : bool, optional removes axes ticks and ticklabels and y axis label. Runs first, so ticks can be reinstated on x, e.g. for waterfall plots using x='log'. Default: False. Returns ------- None formats ax directly. """ assert isinstance(dos, bool), 'dos must be True or False.' from tp.settings import locator if dos: ax.xaxis.set_major_locator(locator()['null']) ax.xaxis.set_minor_locator(locator()['null']) ax.yaxis.set_major_locator(locator()['null']) ax.yaxis.set_minor_locator(locator()['null']) ax.set_ylabel('') if x is not None and x != 'null': ax.set_xscale(x) if y is not None and y != 'null': ax.set_yscale(y) for a, scale in zip([ax.xaxis, ax.yaxis], [x, y]): if scale is not None: if scale == 'linear': a.set_major_locator(locator()['major']) a.set_minor_locator(locator()['minor']) elif scale == 'log': a.set_major_locator(locator()['log']) elif scale == 'null': a.set_major_locator(locator()['null']) a.set_minor_locator(locator()['null']) else: raise Exception('x and y must be "linear", "log" or ' '"null" if specified.') return