# pylint: disable=unused-argument
"""Intermediate level visuals elements.
The visuals module provides backend-agnostic functionality.
That is, the functions in this module take a set of arguments,
take care of backend-agnostic processing of those arguments
and eventually they call the requested plotting backend.
"""
import numpy as np
import xarray as xr
from arviz_base.labels import BaseLabeller
from arviz_plots.plot_collection import backend_from_object
[docs]
def hist(da, target, **kwargs):
"""Plot a histogram bins(as two arrays of left and right bin edges) vs bin_height('y').
The input argument `da` is split into l_e, r_e and y using the dimension ``plot_axis``.
"""
plot_backend = backend_from_object(target)
return plot_backend.hist(
da.sel(plot_axis="histogram"),
da.sel(plot_axis="left_edges"),
da.sel(plot_axis="right_edges"),
target,
**kwargs,
)
def step_hist(da, target, **kwargs):
"""Plot step histogram."""
l_e = da.sel(plot_axis="left_edges").values
r_e = da.sel(plot_axis="right_edges").values
y = da.sel(plot_axis="histogram").values
bottom = kwargs.pop("bottom", 0)
if np.any(bottom != 0):
height = y - bottom
else:
height = y
x_coords = np.concatenate((l_e, [r_e[-1]]))
y_coords = np.concatenate((height, [height[-1]]))
plot_backend = backend_from_object(target)
return plot_backend.step(x_coords, y_coords, target, step_mode="after", **kwargs)
[docs]
def line_xy(da, target, x=None, y=None, **kwargs):
"""Plot a line x vs y.
The input argument `da` is split into x and y using the dimension ``plot_axis``.
If additional x and y arguments are provided, x and y are added to the values
in the `da` dataset sliced along plot_axis='x' and plot_axis='y'.
"""
plot_backend = backend_from_object(target)
x, y = _process_da_x_y(da, x, y)
return plot_backend.line(x, y, target, **kwargs)
[docs]
def ci_line_y(values, target, **kwargs):
"""Plot a line from y_bottom to y_top at given value of x."""
plot_backend = backend_from_object(target)
return plot_backend.ciliney(
values.sel(plot_axis="x"),
values.sel(plot_axis="y_bottom"),
values.sel(plot_axis="y_top"),
target,
**kwargs,
)
[docs]
def line_x(da, target, y=None, **kwargs):
"""Plot a line along the x axis (y constant)."""
if y is None:
y = np.zeros_like(da)
if np.asarray(y).size == 1:
y = np.zeros_like(da) + (y.item() if hasattr(y, "item") else y)
plot_backend = backend_from_object(target)
return plot_backend.line(da, y, target, **kwargs)
[docs]
def line(da, target, xname=None, **kwargs):
"""Plot a line along the y axis with x being the range of len(y)."""
if len(da.shape) != 1:
raise ValueError(f"Expected unidimensional data but got {da.sizes}")
yvalues = da.values
xvalues = np.arange(len(yvalues)) if xname is None else da[xname]
plot_backend = backend_from_object(target)
return plot_backend.line(xvalues, yvalues, target, **kwargs)
[docs]
def multiple_lines(da, target, x_dim, xvalues=None, **kwargs):
"""Plot multiple lines together.
Parameters
----------
da : DataArray
2d DataArray with `x_dim` as one of its dimensions.
target : Any
Object representing the target :term:`plot`
x_dim : hashable
Dimension of `da` to be encoded along the x axis of the plot.
xvalues : array-like, optional
Specific values for the positions of the data along the x axis.
Defaults to ``da.coords[x_dim].values``
**kwargs
Passed to the backend function :func:`~arviz_plots.backend.none.multiple_lines`
Returns
-------
Any
Object representing the generated :term:`visual`
"""
if da.ndim != 2:
raise ValueError(f"DataArray must be 2D, but has dims: {da.dims}")
if x_dim not in da.dims:
raise ValueError(f"overlay_dim '{x_dim}' not found in DataArray dims {da.dims}")
da = da.transpose(x_dim, ...)
yvalues = da.values
if xvalues is None:
xvalues = da.coords[x_dim].values
if len(xvalues) != yvalues.shape[0]:
raise ValueError(
f"xvalues length ({len(xvalues)}) does not match x-dim size ({yvalues.shape[0]})."
)
plot_backend = backend_from_object(target)
return plot_backend.multiple_lines(xvalues, yvalues, target, **kwargs)
[docs]
def trace_rug(da, target, mask, xname=None, y=None, **kwargs):
"""Create a rug plot with the subset of `da` indicated by `mask`."""
xname = xname.item() if hasattr(xname, "item") else xname
if xname is False:
xvalues = da
else:
if xname is None:
if len(da.shape) != 1:
raise ValueError(f"Expected unidimensional data but got {da.sizes}")
xvalues = np.arange(len(da))
else:
xvalues = da[xname]
if y is None:
y = da.min().item()
if len(xvalues.shape) != 1:
raise ValueError(f"Expected unidimensional data but got {xvalues.sizes}")
return scatter_x(xvalues[mask], target=target, y=y, **kwargs)
[docs]
def scatter_x(da, target, y=None, **kwargs):
"""Plot a dot/rug/scatter along the x axis (y constant)."""
if y is None:
y = np.zeros_like(da)
if np.asarray(y).size == 1:
y = np.zeros_like(da) + (y.item() if hasattr(y, "item") else y)
plot_backend = backend_from_object(target)
return plot_backend.scatter(da, y, target, **kwargs)
[docs]
def scatter_xy(da, target, x=None, y=None, mask=None, **kwargs):
"""Plot a scatter plot x vs y.
The input argument `da` is split into x and y using the dimension ``plot_axis``.
If additional x and y arguments are provided, x and y are added to the values
in the `da` dataset sliced along plot_axis='x' and plot_axis='y'.
If a mask is provided, it is applied to both x and y values.
"""
plot_backend = backend_from_object(target)
x, y = _process_da_x_y(da, x, y, mask)
return plot_backend.scatter(x, y, target, **kwargs)
[docs]
def scatter_couple(da_x, da_y, target, mask=None, **kwargs):
"""Plot a scatter plot for a pairplot couple."""
plot_backend = backend_from_object(target)
if mask is not None:
da_x = da_x[mask]
da_y = da_y[mask]
return plot_backend.scatter(da_x.values, da_y.values, target, **kwargs)
[docs]
def ecdf_line(values, target, **kwargs):
"""Plot a step line."""
plot_backend = backend_from_object(target)
return plot_backend.step(
values.sel(plot_axis="x"), values.sel(plot_axis="y"), target, step_mode="before", **kwargs
)
[docs]
def vline(values, target, **kwargs):
"""Plot a vertical line that spans the whole figure independently of zoom."""
plot_backend = backend_from_object(target)
return plot_backend.vline(values.item(), target, **kwargs)
[docs]
def hline(values, target, **kwargs):
"""Plot a horizontal line that spans the whole figure independently of zoom."""
plot_backend = backend_from_object(target)
return plot_backend.hline(values.item(), target, **kwargs)
[docs]
def vspan(da, target, **kwargs):
"""Plot a vertical shaded region that spans the whole figure."""
plot_backend = backend_from_object(target)
return plot_backend.vspan(da.values[0], da.values[1], target, **kwargs)
[docs]
def hspan(da, target, **kwargs):
"""Plot a vertical shaded region that spans the whole figure."""
plot_backend = backend_from_object(target)
return plot_backend.hspan(da.values[0], da.values[1], target, **kwargs)
def dline(da, target, x=None, y=None, **kwargs):
"""Plot a diagonal line across the x-y range."""
plot_backend = backend_from_object(target)
if x is None:
x = y
if y is None:
y = x
xy_min = min(np.min(x), np.min(y))
xy_max = max(np.max(x), np.max(y))
return plot_backend.line([xy_min, xy_max], [xy_min, xy_max], target, **kwargs)
[docs]
def fill_between_y(da, target, *, x=None, y_bottom=None, y=None, y_top=None, **kwargs):
"""Fill the region between to given y values."""
if "kwarg" in da.dims:
if "x" in da.kwarg:
x = da.sel(kwarg="x") if x is None else da.sel(kwarg="x") + x
if "y_bottom" in da.kwarg:
y_bottom = (
da.sel(kwarg="y_bottom")
if y_bottom is None
else da.sel(kwarg="y_bottom") + y_bottom
)
if "y_top" in da.kwarg:
y_top = da.sel(kwarg="y_top") if y_top is None else da.sel(kwarg="y_top") + y_top
if y is not None:
y_top += y
y_bottom += y
if np.ndim(np.squeeze(y_top)) == 0:
y_top = np.full_like(x, y_top)
if np.ndim(np.squeeze(y_bottom)) == 0:
y_bottom = np.full_like(x, y_bottom)
plot_backend = backend_from_object(target)
return plot_backend.fill_between_y(x, y_bottom, y_top, target, **kwargs)
def _process_da_x_y(da, x, y, mask=None):
"""Process da, x and y arguments into x and y values and apply mask if it is not None."""
da_has_x = "plot_axis" in da.dims and "x" in da.plot_axis
da_has_y = "plot_axis" in da.dims and "y" in da.plot_axis
if da_has_x:
x = da.sel(plot_axis="x") if x is None else da.sel(plot_axis="x") + x
if da_has_y:
y = da.sel(plot_axis="y") if y is None else da.sel(plot_axis="y") + y
if x is None and y is None:
raise ValueError("Unable to find values for x and y.")
if x is None:
x = da
elif y is None:
y = da
if mask is not None:
x = x[mask]
y = y[mask]
return np.broadcast_arrays(x, y)
def _ensure_scalar(*args):
return tuple(arg.item() if hasattr(arg, "item") else arg for arg in args)
[docs]
def annotate_xy(
da,
target,
*,
text,
x=None,
y=None,
vertical_align=None,
horizontal_align=None,
**kwargs,
):
"""Annotate a point (x, y) in a plot."""
if vertical_align is not None:
kwargs["vertical_align"] = (
vertical_align.item() if hasattr(vertical_align, "item") else vertical_align
)
if horizontal_align is not None:
kwargs["horizontal_align"] = (
horizontal_align.item() if hasattr(horizontal_align, "item") else horizontal_align
)
x, y = _process_da_x_y(da, x, y)
plot_backend = backend_from_object(target)
return plot_backend.text(x, y, text, target, **kwargs)
[docs]
def point_estimate_text(da, target, *, point_estimate, x=None, y=None, point_label="x", **kwargs):
"""Annotate a point estimate."""
x, y = _ensure_scalar(*_process_da_x_y(da, x, y))
point = x if point_label == "x" else y
if np.size(point) != 1:
raise ValueError(
"Found non-scalar point estimate. Check aes mapping and sample_dims. "
f"The dimensions still left to reduce/facet are {point.dims}."
)
text = f"{point:.3g} {point_estimate}"
plot_backend = backend_from_object(target)
return plot_backend.text(
x,
y,
text,
target,
**kwargs,
)
[docs]
def annotate_label(
da, target, *, var_name, sel, isel, x=None, y=None, dim=None, labeller=None, **kwargs
):
"""Annotate a dimension or aesthetic property."""
x, y = _ensure_scalar(*_process_da_x_y(da, x, y))
if labeller is None:
labeller = BaseLabeller()
if dim is None:
text = labeller.make_label_flat(var_name, sel, isel)
else:
sel = {key: value for key, value in sel.items() if key == dim}
isel = {key: value for key, value in isel.items() if key == dim}
text = labeller.sel_to_str(sel, isel)
plot_backend = backend_from_object(target)
return plot_backend.text(
x,
y,
text,
target,
**kwargs,
)
[docs]
def label_plot(
da,
target,
text=None,
x=0.5,
y=0.5,
lim_low=0,
lim_high=1,
labeller=None,
var_name=None,
axis_to_remove=False,
sel=None,
isel=None,
**kwargs,
):
"""Add a label to a plot."""
if text is None:
if labeller is None:
labeller = BaseLabeller()
text = labeller.make_label_vert(var_name, sel, isel)
x, y = _ensure_scalar(x, y)
lim_low, lim_high = _ensure_scalar(lim_low, lim_high)
plot_backend = backend_from_object(target)
plot_backend.xlim((lim_low, lim_high), target)
plot_backend.ylim((lim_low, lim_high), target)
if axis_to_remove:
plot_backend.remove_axis(target, axis=axis_to_remove)
return plot_backend.text(
x,
y,
text,
target,
**kwargs,
)
def set_ticklabel_visibility(da, target, *, axis="both", visible=True, **kwargs):
"""Set the visibility of tick labels on a plot."""
plot_backend = backend_from_object(target)
return plot_backend.set_ticklabel_visibility(target, axis=axis, visible=visible, **kwargs)
[docs]
def labelled_title(
da, target, *, text=None, labeller=None, var_name=None, sel=None, isel=None, **kwargs
):
"""Add a title label to a plot using an ArviZ labeller."""
if text is not None and labeller is not None:
text = f"{labeller.make_label_vert(var_name, sel, isel)} ({text})"
elif labeller is not None:
text = labeller.make_label_vert(var_name, sel, isel)
plot_backend = backend_from_object(target)
return plot_backend.title(text, target, **kwargs)
[docs]
def labelled_y(
da, target, *, text=None, labeller=None, var_name=None, sel=None, isel=None, **kwargs
):
"""Add a y label to a plot using an ArviZ labeller."""
if text is None and labeller is None:
raise ValueError("Either text or labeller must be provided")
if text is not None and labeller is not None:
raise ValueError("Only text or labeller can be provided")
if labeller is not None:
text = labeller.make_label_vert(var_name, sel, isel)
plot_backend = backend_from_object(target)
return plot_backend.ylabel(text, target, **kwargs)
[docs]
def labelled_x(
da, target, *, text=None, labeller=None, var_name=None, sel=None, isel=None, **kwargs
):
"""Add a x label to a plot using an ArviZ labeller."""
if text is None and labeller is None:
raise ValueError("Either text or labeller must be provided")
if text is not None and labeller is not None:
raise ValueError("Only text or labeller can be provided")
if labeller is not None:
text = labeller.make_label_vert(var_name, sel, isel)
plot_backend = backend_from_object(target)
return plot_backend.xlabel(text, target, **kwargs)
[docs]
def ticklabel_props(da, target, **kwargs):
"""Set the size of ticks."""
plot_backend = backend_from_object(target)
return plot_backend.ticklabel_props(target, **kwargs)
[docs]
def remove_axis(da, target, **kwargs):
"""Dispatch to ``remove_axis`` function in backend."""
plot_backend = backend_from_object(target)
return plot_backend.remove_axis(target, **kwargs)
def remove_matrix_axis(da_x, da_y, target, **kwargs):
"""Dispatch to ``remove_axis`` function in backend."""
plot_backend = backend_from_object(target)
return plot_backend.remove_axis(target, **kwargs)
[docs]
def remove_ticks(da, target, **kwargs):
"""Dispatch to ``remove_axis`` function in backend."""
plot_backend = backend_from_object(target)
return plot_backend.remove_ticks(target, **kwargs)
[docs]
def set_xticks(da, target, values, labels, **kwargs):
"""Dispatch to ``set_xticks`` function in backend."""
plot_backend = backend_from_object(target)
return plot_backend.xticks(values, labels, target, **kwargs)
def set_y_scale(da, target, scale, **kwargs):
"""Set scale for y-axis."""
plot_backend = backend_from_object(target)
return plot_backend.set_y_scale(target, scale, **kwargs)
[docs]
def grid(da, target, **kwargs):
"""Dispatch to ``remove_axis`` function in backend."""
plot_backend = backend_from_object(target)
return plot_backend.grid(target, **kwargs)