Source code for arviz_plots.backend.matplotlib

# pylint: disable=no-self-use
"""Matplotlib interface layer.

Notes
-----
Sets ``zorder`` of all non-text "geoms" to ``2`` so that elements plotted later
on are on top of previous ones.
"""

import warnings
from typing import Any, Dict

import matplotlib.colors as mcolors
import matplotlib.scale as mscale
import matplotlib.transforms as mtransforms
import numpy as np
from matplotlib import ticker
from matplotlib.cbook import normalize_kwargs
from matplotlib.collections import LineCollection, PathCollection
from matplotlib.lines import Line2D
from matplotlib.pyplot import rcParams
from matplotlib.pyplot import show as _show
from matplotlib.pyplot import subplots
from matplotlib.text import Text

from ..none import get_default_aes as get_agnostic_default_aes
from .legend import legend


class UnsetDefault:
    """Specific class to indicate an aesthetic hasn't been set."""


unset = UnsetDefault()


class SquareRootScale(mscale.ScaleBase):
    """ScaleBase class for generating square root scale."""

    name = "sqrt"

    def __init__(self, axis, **kwargs):  # pylint: disable=unused-argument
        mscale.ScaleBase.__init__(self, axis)

    def set_default_locators_and_formatters(self, axis):
        """Set the locators and formatters to default."""
        axis.set_major_locator(ticker.AutoLocator())
        axis.set_major_formatter(ticker.ScalarFormatter())
        axis.set_minor_locator(ticker.NullLocator())
        axis.set_minor_formatter(ticker.NullFormatter())

    def limit_range_for_scale(self, vmin, vmax, minpos):  # pylint: disable=unused-argument
        """Limit the range of the scale."""
        return max(0.0, vmin), vmax

    class SquareRootTransform(mtransforms.Transform):
        """Square root transformation."""

        input_dims = 1
        output_dims = 1
        is_separable = True

        def transform_non_affine(self, values):
            """Transform the data."""
            return np.array(values) ** 0.5

        def inverted(self):
            """Invert the transformation."""
            return SquareRootScale.InvertedSquareRootTransform()

    class InvertedSquareRootTransform(mtransforms.Transform):
        """Inverted square root transformation."""

        input_dims = 1
        output_dims = 1
        is_separable = True

        def transform(self, values):
            """Transform the data."""
            return np.array(values) ** 2

        def inverted(self):
            """Invert the transformation."""
            return SquareRootScale.SquareRootTransform()

    def get_transform(self):
        """Get the transformation."""
        return self.SquareRootTransform()


mscale.register_scale(SquareRootScale)


def get_background_color():
    """Get the background color."""
    bg_color = rcParams["figure.facecolor"]
    try:
        bg_color = mcolors.to_hex(bg_color)
    except ValueError:
        warnings.warn(
            "The background color is not a valid matplotlib color. "
            "Returning the default value '#ffffff'."
        )
        bg_color = "#ffffff"
    return bg_color


# generation of default values for aesthetics
def get_default_aes(aes_key, n, kwargs=None):
    """Generate `n` *matplotlib valid* default values for a given aesthetics keyword."""
    if kwargs is None:
        kwargs = {}
    if aes_key not in kwargs:
        default_prop_cycle = rcParams["axes.prop_cycle"].by_key()
        if ("color" in aes_key) or aes_key == "c":
            # fmt: off
            vals = [
                '#3f90da', '#ffa90e', '#bd1f01', '#94a4a2', '#832db6',
                '#a96b59', '#e76300', '#b9ac70', '#717581', '#92dadd'
            ]
            # fmt: on
            vals = default_prop_cycle.get("color", vals)
        elif aes_key in {"linestyle", "ls"}:
            vals = ["-", "--", ":", "-."]
            vals = default_prop_cycle.get("linestyle", vals)
        elif aes_key in {"marker", "m"}:
            vals = ["o", "+", "^", "x", "d", "s", "."]
            vals = default_prop_cycle.get("marker", vals)
        elif aes_key in default_prop_cycle:
            vals = default_prop_cycle[aes_key]
        else:
            return get_agnostic_default_aes(aes_key, n)
        return get_agnostic_default_aes(aes_key, n, {aes_key: vals})
    return get_agnostic_default_aes(aes_key, n, kwargs)


def scale_fig_size(figsize, rows=1, cols=1, figsize_units=None):
    """Scale figure properties according to figsize, rows and cols.

    Parameters
    ----------
    figsize : (float, float) or None
        Size of figure in `figsize_units`
    rows : int
        Number of rows
    cols : int
        Number of columns
    figsize_units : {"inches", "dots"}
        Ignored if `figsize` is ``None``

    Returns
    -------
    figsize : (float, float) or None
        Size of figure in dots
    labelsize : float
        fontsize for labels
    linewidth : float
        linewidth
    """
    if figsize_units is None:
        figsize_units = "inches"
    if figsize is None:
        default_width, default_height = rcParams["figure.figsize"]
        width = cols * (default_width if cols < 4 else 0.6 * default_width)
        height = default_height / 4 * (rows + 1) ** 1.1
        figsize_units = "inches"
    else:
        width, height = figsize
    if figsize_units == "inches":
        dpi = rcParams["figure.dpi"]
        width *= dpi
        height *= dpi
    elif figsize_units != "dots":
        raise ValueError(f"figsize_units must be 'dots' or 'inches', but got {figsize_units}")

    return (width, height)


# object creation and i/o
[docs] def show(figure): # pylint: disable=unused-argument """Show all existing matplotlib figures.""" _show()
def savefig(figure, path, **kwargs): """Save the figure to a file. Parameters ---------- figure : `~matplotlib.figure.Figure` The figure to save. path : pathlib.Path Path to the file where the figure will be saved. **kwargs : dict, optional Additional keyword arguments passed to `matplotlib.pyplot.savefig`. """ figure.savefig(path, **kwargs) def get_figsize(plot_collection): """Get the size of the :term:`figure` element and its units.""" return plot_collection.viz["figure"].item().get_size_inches(), "inches"
[docs] def create_plotting_grid( number, rows=1, cols=1, *, figsize=None, figsize_units="inches", squeeze=True, sharex=False, sharey=False, polar=False, width_ratios=None, height_ratios=None, plot_hspace=None, subplot_kws=None, **kwargs, ): """Create a figure with a grid of plotting targets in it. Parameters ---------- number : int Number of axes required rows, cols : int, default 1 Number of rows and columns. figsize : (float, float), optional Size of the figure in `figsize_units`. figsize_units : {"inches", "dots"}, default "inches" Units in which `figsize` is given. squeeze : bool, default True sharex, sharey : bool, default False polar : bool subplot_kws : bool Passed to :func:`~matplotlib.pyplot.subplots` as ``subplot_kw`` **kwargs: dict, optional Passed to :func:`~matplotlib.pyplot.subplots` Returns ------- `~matplotlib.figure.Figure` `~matplotlib.axes.Axes` or ndarray of `~matplotlib.axes.Axes` """ if subplot_kws is None: subplot_kws = {} subplot_kws = subplot_kws.copy() if polar: subplot_kws["projection"] = "polar" if plot_hspace is not None: kwargs["gridspec_kw"] = kwargs.get("gridspec_kw", {}).copy() kwargs["gridspec_kw"].setdefault("wspace", plot_hspace) if figsize is not None: if figsize_units == "dots": dpi = rcParams["figure.dpi"] figsize = (figsize[0] / dpi, figsize[1] / dpi) elif figsize_units != "inches": raise ValueError(f"figsize_units must be 'dots' or 'inches', but got {figsize_units}") fig, axes = subplots( rows, cols, sharex=sharex, sharey=sharey, squeeze=squeeze, width_ratios=width_ratios, height_ratios=height_ratios, figsize=figsize, subplot_kw=subplot_kws, **kwargs, ) extra = (rows * cols) - number if extra > 0: for i, ax in enumerate(axes.ravel("C")): if i >= number: ax.set_axis_off() return fig, axes
# helper functions def _filter_kwargs(kwargs, visual, artist_kws): """Filter a dictionary to remove all keys whose values are ``unset``. It also normalizes the matplotlib arguments and aliases to avoid clashing of aliases with their extended version. """ kwargs = {key: value for key, value in kwargs.items() if value is not unset} if visual is not None: artist_kws = normalize_kwargs(artist_kws.copy(), visual) return {**artist_kws, **kwargs} # "geoms" def hist( y, l_e, r_e, target, *, bottom=0, color=unset, facecolor=unset, edgecolor=unset, alpha=unset, **artist_kws, ): """Interface to matplotlib for a histogram bar plot.""" artist_kws.setdefault("zorder", 2) if np.any(bottom != 0): height = y - bottom else: height = y if color is not unset: if facecolor is unset: facecolor = color if edgecolor is unset: edgecolor = color kwargs = {"color": facecolor, "edgecolor": edgecolor, "alpha": alpha} return target.fill_between( np.r_[l_e, r_e[-1]], np.r_[height, height[-1]], step="post", **_filter_kwargs(kwargs, None, artist_kws), )
[docs] def line(x, y, target, *, color=unset, alpha=unset, width=unset, linestyle=unset, **artist_kws): """Interface to matplotlib for a line plot.""" artist_kws.setdefault("zorder", 2) kwargs = {"color": color, "alpha": alpha, "linewidth": width, "linestyle": linestyle} return target.plot(x, y, **_filter_kwargs(kwargs, Line2D, artist_kws))[0]
[docs] def multiple_lines( x, y, target, *, color=unset, alpha=unset, width=unset, linestyle=unset, **artist_kws ): """Interface to matplotlib for a multiple line plot using a single LineCollection.""" artist_kws.setdefault("zorder", 2) y_2d = np.atleast_2d(y) segments = [np.column_stack([x, y_col]) for y_col in y_2d.T] plot_kwargs = {"colors": color, "alpha": alpha, "linewidths": width, "linestyles": linestyle} filtered_kwargs = _filter_kwargs(plot_kwargs, LineCollection, artist_kws) line_collection = LineCollection(segments, **filtered_kwargs) target.add_collection(line_collection) target.autoscale_view() return line_collection
[docs] def scatter( x, y, target, *, size=unset, marker=unset, alpha=unset, color=unset, facecolor=unset, edgecolor=unset, width=unset, **artist_kws, ): """Interface to matplotlib for a scatter plot.""" artist_kws.setdefault("zorder", 2) fillable_marker = (marker is unset) or (marker in Line2D.filled_markers) if color is not unset: if facecolor is unset and edgecolor is unset: facecolor = color if fillable_marker: edgecolor = color elif facecolor is unset: facecolor = color elif edgecolor is unset and fillable_marker: edgecolor = color kwargs = { "s": size, "marker": marker, "alpha": alpha, "c": facecolor, "edgecolors": edgecolor, "linewidths": width, } return target.scatter(x, y, **_filter_kwargs(kwargs, None, artist_kws))
def step( x, y, target, *, color=unset, alpha=unset, width=unset, linestyle=unset, step_mode=unset, **artist_kws, ): """Interface to matplotlib for a step line.""" artist_kws.setdefault("zorder", 2) kwargs = {"color": color, "alpha": alpha, "linewidth": width, "linestyle": linestyle} if step_mode is not unset: if step_mode == "before": kwargs["where"] = "pre" elif step_mode == "after": kwargs["where"] = "post" else: kwargs["where"] = "mid" return target.step(x, y, **_filter_kwargs(kwargs, Line2D, artist_kws))[0]
[docs] def text( x, y, string, target, *, size=unset, alpha=unset, color=unset, vertical_align="center", horizontal_align="center", **artist_kws, ): """Interface to matplotlib for adding text to a plot.""" kwargs = { "fontsize": size, "alpha": alpha, "color": color, "horizontalalignment": horizontal_align, "verticalalignment": vertical_align, } return target.text(x, y, string, **_filter_kwargs(kwargs, Text, artist_kws))
def fill_between_y(x, y_bottom, y_top, target, **artist_kws): """Fill the area between y_bottom and y_top.""" artist_kws.setdefault("linewidth", 0) return target.fill_between(x, y_bottom, y_top, **artist_kws) def vline(x, target, *, color=unset, alpha=unset, width=unset, linestyle=unset, **artist_kws): """Interface to matplotlib for a vertical line spanning the whole axes.""" artist_kws.setdefault("zorder", 0) kwargs = {"color": color, "alpha": alpha, "linewidth": width, "linestyle": linestyle} return target.axvline(x, **_filter_kwargs(kwargs, Line2D, artist_kws)) def hline(y, target, *, color=unset, alpha=unset, width=unset, linestyle=unset, **artist_kws): """Interface to matplotlib for a horizontal line spanning the whole axes.""" artist_kws.setdefault("zorder", 0) kwargs = {"color": color, "alpha": alpha, "linewidth": width, "linestyle": linestyle} return target.axhline(y, **_filter_kwargs(kwargs, Line2D, artist_kws)) def vspan(xmin, xmax, target, *, color=unset, alpha=unset, **artist_kws): """Interface to matplotlib for a vertical shaded region spanning the whole axes.""" artist_kws.setdefault("zorder", 0) kwargs = {"color": color, "alpha": alpha} return target.axvspan(xmin, xmax, **_filter_kwargs(kwargs, None, artist_kws)) def hspan(ymin, y_max, target, *, color=unset, alpha=unset, **artist_kws): """Interface to matplotlib for a horizontal shaded region spanning the whole axes.""" artist_kws.setdefault("zorder", 0) kwargs = {"color": color, "alpha": alpha} return target.axhspan(ymin, y_max, **_filter_kwargs(kwargs, None, artist_kws)) def ciliney( x, y_bottom, y_top, target, *, color=unset, alpha=unset, width=unset, linestyle=unset, **artist_kws, ): """Interface to matplotlib for a line from y_bottom to y_top at given value of x.""" artist_kws.setdefault("zorder", 2) kwargs = {"color": color, "alpha": alpha, "linewidth": width, "linestyle": linestyle} return target.plot([x, x], [y_bottom, y_top], **_filter_kwargs(kwargs, Line2D, artist_kws))[0] # general plot appeareance
[docs] def title(string, target, *, size=unset, color=unset, **artist_kws): """Interface to matplotlib for adding a title to a plot.""" kwargs = {"fontsize": size, "color": color} return target.set_title(string, **_filter_kwargs(kwargs, Text, artist_kws))
[docs] def ylabel(string, target, *, size=unset, color=unset, **artist_kws): """Interface to matplotlib for adding a label to the y axis.""" kwargs = {"fontsize": size, "color": color} return target.set_ylabel(string, **_filter_kwargs(kwargs, Text, artist_kws))
[docs] def xlabel(string, target, *, size=unset, color=unset, **artist_kws): """Interface to matplotlib for adding a label to the x axis.""" kwargs = {"fontsize": size, "color": color} return target.set_xlabel(string, **_filter_kwargs(kwargs, Text, artist_kws))
[docs] def xticks(ticks, labels, target, *, rotation=unset, **artist_kws): """Interface to matplotlib for adding x ticks and labels to a plot.""" if rotation is not unset: artist_kws["rotation"] = rotation return target.set_xticks(ticks, labels, **artist_kws)
[docs] def yticks(ticks, labels, target, *, rotation=unset, **artist_kws): """Interface to matplotlib for adding y ticks and labels to a plot.""" if rotation is not unset: artist_kws["rotation"] = rotation return target.set_yticks(ticks, labels, **artist_kws)
def set_ticklabel_visibility(target, *, axis="both", visible=True): """Interface to matplotlib for setting visibility of tick labels.""" if axis == "both": target.tick_params(axis="both", labelbottom=visible, labelleft=visible) elif axis == "x": target.tick_params(axis="x", labelbottom=visible) elif axis == "y": target.tick_params(axis="y", labelleft=visible) else: raise ValueError(f"axis must be one of 'x', 'y' or 'both', got '{axis}'") def xlim(lims, target, **artist_kws): """Interface to matplotlib for setting limits for the x axis.""" target.set_xlim(lims, **artist_kws) def ylim(lims, target, **artist_kws): """Interface to matplotlib for setting limits for the y axis.""" target.set_ylim(lims, **artist_kws)
[docs] def ticklabel_props(target, *, axis="both", size=unset, color=unset, **artist_kws): """Interface to matplotlib for setting ticks size.""" kwargs = {"labelsize": size, "labelcolor": color} target.tick_params(axis=axis, **_filter_kwargs(kwargs, None, artist_kws))
[docs] def remove_ticks(target, *, axis="y"): """Interface to matplotlib for removing ticks from a plot.""" if axis == "y": target.yaxis.set_ticks([]) elif axis == "x": target.xaxis.set_ticks([]) elif axis == "both": target.xaxis.set_ticks([]) target.yaxis.set_ticks([])
[docs] def remove_axis(target, axis="y"): """Interface to matplotlib for removing axis from a plot.""" target.spines["top"].set_visible(False) target.spines["right"].set_visible(False) target.tick_params( axis="both", which="both", left=axis == "x", top=False, right=False, bottom=axis == "y" ) if axis == "y": target.yaxis.set_ticks([]) target.spines["left"].set_visible(False) target.spines["bottom"].set_visible(True) target.xaxis.set_ticks_position("bottom") target.tick_params(axis="x", direction="out", width=1, length=3) elif axis == "x": target.xaxis.set_ticks([]) target.spines["left"].set_visible(True) target.spines["bottom"].set_visible(False) target.xaxis.set_ticks_position("bottom") target.tick_params(axis="y", direction="out", width=1, length=3) elif axis == "both": target.set_axis_off() else: raise ValueError(f"axis must be one of 'x', 'y' or 'both', got '{axis}'")
def set_y_scale(target, scale): """Interface to matplotlib for setting the y scale of a plot.""" target.set_yscale(scale) def grid(target, axis, color): """Interface to matplotlib for setting a grid in any axis.""" target.grid(axis=axis, color=color)