"""Utilities for batteries included plots."""
from copy import copy
from importlib import import_module
import numpy as np
import xarray as xr
from arviz_base import references_to_dataset
from arviz_base.utils import _var_names
from arviz_plots.plot_collection import concat_model_dict, process_facet_dims
from arviz_plots.visuals import hline, hspan, vline, vspan
def get_group(data, group, allow_missing=False):
"""Get a group from a Datatree or Dataset if possible and return a Dataset.
Also supports InferenceData or dictionaries of Datasets.
Parameters
----------
data : DataTree, Dataset, InferenceData or mapping of {str : Dataset}
Object from which to extract `group`
group : hashable
Id to be extracted. It is checked against the ``name`` attribute
and attempted to use as key to get the `group` item from `data`
allow_missing : bool, default False
Return ``None`` if `group` can't be extracted instead of raising an error.
Returns
-------
Dataset
Raises
------
KeyError
If unable to access `group` from `data` and ``allow_missing=False``.
"""
if isinstance(data, xr.Dataset):
return data
if hasattr(data, "name") and data.name == group:
return data.ds
try:
data = data[group]
except KeyError:
if not allow_missing:
raise
return None
if isinstance(data, xr.Dataset):
return data
return data.ds
def process_group_variables_coords(dt, group, var_names, filter_vars, coords, allow_dict=True):
"""Process main input arguments of batteries included plotting functions."""
if coords is None:
coords = {}
if isinstance(dt, dict) and not allow_dict:
raise ValueError("Input data as dictionary not supported")
if isinstance(dt, dict):
distribution = {}
for key, value in dt.items():
var_names = _var_names(var_names, get_group(value, group), filter_vars)
distribution[key] = (
get_group(value, group).sel(coords)
if var_names is None
else get_group(value, group)[var_names].sel(coords)
)
distribution = concat_model_dict(distribution)
else:
distribution = get_group(dt, group)
var_names = _var_names(var_names, distribution, filter_vars)
if var_names is not None:
distribution = distribution[var_names]
distribution = distribution.sel(coords)
return distribution
def filter_aes(pc, aes_by_visuals, visual, sample_dims):
"""Split aesthetics and get relevant dimensions.
Returns
-------
artist_dims : list
Dimensions that should be reduced for this visual.
That is, all dimensions in `sample_dims` that are not
mapped to any aesthetic.
artist_aes : iterable
ignore_aes : set
"""
artist_aes = aes_by_visuals.get(visual, {})
pc_aes = pc.aes_set
ignore_aes = set(pc_aes).difference(artist_aes)
_, all_loop_dims = pc.update_aes(ignore_aes=ignore_aes)
artist_dims = [dim for dim in sample_dims if dim not in all_loop_dims]
return artist_dims, artist_aes, ignore_aes
def set_wrap_layout(pc_kwargs, plot_bknd, ds):
"""Set the figure size and handle column wrapping.
Parameters
----------
pc_kwargs : dict
Plot collection kwargs
plot_bknd : str
Backend for plotting
ds : Dataset
Dataset to be plotted
"""
figsize = pc_kwargs["figure_kwargs"].get("figsize", None)
figsize_units = pc_kwargs["figure_kwargs"].get("figsize_units", "inches")
pc_kwargs.setdefault("col_wrap", 4)
col_wrap = pc_kwargs["col_wrap"]
if figsize is None:
num_plots = process_facet_dims(ds, pc_kwargs["cols"])[0]
if num_plots < col_wrap:
cols = num_plots
rows = 1
else:
div_mod = divmod(num_plots, col_wrap)
rows = div_mod[0] + (div_mod[1] != 0)
cols = col_wrap
figsize = plot_bknd.scale_fig_size(
figsize,
rows=rows,
cols=cols,
figsize_units=figsize_units,
)
figsize_units = "dots"
pc_kwargs["figure_kwargs"]["figsize"] = figsize
pc_kwargs["figure_kwargs"]["figsize_units"] = figsize_units
return pc_kwargs
def set_grid_layout(pc_kwargs, plot_bknd, ds, num_rows=None, num_cols=None):
"""Set the figure size for the given number of rows and columns.
Parameters
----------
pc_kwargs : dict
Plot collection kwargs
plot_bknd : str
Backend for plotting
ds : Dataset
Dataset to be plotted
num_rows, num_cols : int, optional
Take the number of rows or columns as the provided one irrespective
of pc_kwargs
"""
figsize = pc_kwargs["figure_kwargs"].get("figsize", None)
figsize_units = pc_kwargs["figure_kwargs"].get("figsize_units", "inches")
if figsize is None:
if num_cols is None:
num_cols = process_facet_dims(ds, pc_kwargs["cols"])[0]
if num_rows is None:
num_rows = process_facet_dims(ds, pc_kwargs["rows"])[0]
figsize = plot_bknd.scale_fig_size(
figsize,
rows=num_rows,
cols=num_cols,
figsize_units=figsize_units,
)
figsize_units = "dots"
pc_kwargs["figure_kwargs"]["figsize"] = figsize
pc_kwargs["figure_kwargs"]["figsize_units"] = figsize_units
return pc_kwargs
[docs]
def add_lines(
plot_collection,
values,
orientation="vertical",
aes_by_visuals=None,
visuals=None,
sample_dims=None,
ref_dim="ref_dim",
**kwargs,
):
"""Add lines.
This function adds lines to a plot collection based on the provided values.
It supports both vertical and horizontal lines, depending on the specified orientation.
Parameters
----------
plot_collection : PlotCollection
Plot collection to which the lines will be added.
values : int, float, tuple, list or dict
Positions for the lines.
orientation : str, default "vertical"
The orientation of the lines, either "vertical" or "horizontal".
aes_by_visuals : mapping of {str : sequence of str}, optional
Mapping of visuals to aesthetics that should use their mapping in `plot_collection`
when plotted. Valid keys are the same as for `visuals`.
The default is to use an "overlay_ref" aesthetic for all elements.
It is possible to request aesthetics without mappings defined in the
provided `plot_collection`. In those cases, a mapping of "ref_dim" to the requested
aesthetic will be automatically added.
visuals : mapping of {str : mapping or False}, optional
Valid keys are:
* "ref_line" -> passed to :func:`~arviz_plots.visuals.vline` for vertical `orientation`
and to :func:`~arviz_plots.visuals.hline` for horizontal `orientation`
* "ref_text" -> TODO
sample_dims : list, optional
Dimensions that should not be added to the Dataset generated from
`values` via :func:`arviz_base.references_to_dataset`.
Defaults to all dimensions in ``plot_collection.data`` that are not ``facet_dims``
ref_dim : str, optional
Specifies the name of the dimension for the line values.
Defaults to "ref_dim".
**kwargs : mapping of {str : sequence}, optional
Mapping of aesthetic keys to the values to be used in their mapping.
See :func:`~arviz_plots.PlotCollection.generate_aes_dt` for more details.
Returns
-------
plot_collection : PlotCollection
Plot collection with the lines added.
Examples
--------
Add lines at values 0 and 5 for all variables.
.. plot::
:context: close-figs
>>> from arviz_plots import plot_dist, add_lines, style
>>> style.use("arviz-variat")
>>> from arviz_base import load_arviz_data
>>> dt = load_arviz_data('centered_eight')
>>> pc = plot_dist(
>>> dt,
>>> kind="ecdf",
>>> var_names=["mu"],
>>> )
>>> add_lines(pc, values=[0, 5])
"""
if visuals is None:
visuals = {}
if aes_by_visuals is None:
aes_by_visuals = {}
else:
aes_by_visuals = aes_by_visuals.copy()
aes_by_visuals["ref_line"] = aes_by_visuals.get("ref_line", ["overlay_ref"])
aes_by_visuals["ref_text"] = aes_by_visuals.get("ref_text", ["overlay_ref"])
if sample_dims is None:
sample_dims = list(set(plot_collection.data.dims).difference(plot_collection.facet_dims))
if isinstance(sample_dims, str):
sample_dims = [sample_dims]
plot_bknd = import_module(f".backend.{plot_collection.backend}", package="arviz_plots")
bg_color = plot_bknd.get_background_color()
_, contrast_gray_color = get_contrast_colors(bg_color=bg_color, gray_flag=True)
plot_func = vline if orientation == "vertical" else hline
ref_ds = references_to_dataset(
values, plot_collection.data, sample_dims=sample_dims, ref_dim=ref_dim
)
requested_aes = (
set(aes_by_visuals["ref_line"])
.union(aes_by_visuals["ref_text"])
.difference(plot_collection.aes_set)
)
if ref_dim in ref_ds.dims:
for aes_key in requested_aes:
aes_values = np.array(plot_bknd.get_default_aes(aes_key, ref_ds.sizes[ref_dim], kwargs))
plot_collection.update_aes_from_dataset(
aes_key,
xr.Dataset(
{
var_name: (ref_dim, aes_values)
for var_name in plot_collection.data.data_vars
},
coords={ref_dim: ref_ds[ref_dim]},
),
)
_, ref_aes, ref_ignore = filter_aes(plot_collection, aes_by_visuals, "ref_line", sample_dims)
ref_kwargs = copy(visuals.get("ref_line", {}))
if ref_kwargs is not False:
if "color" not in ref_aes:
ref_kwargs.setdefault("color", contrast_gray_color)
if "linestyle" not in ref_aes:
ref_kwargs.setdefault("linestyle", plot_bknd.get_default_aes("linestyle", 2)[1])
plot_collection.map(
plot_func,
"ref_line",
data=ref_ds,
ignore_aes=ref_ignore,
**ref_kwargs,
)
return plot_collection
[docs]
def add_bands(
plot_collection,
values,
orientation="vertical",
aes_by_visuals=None,
visuals=None,
sample_dims=None,
ref_dim=None,
**kwargs,
):
"""Add bands.
This function adds bands (shared areas) to a plot collection based on the provided
values. It supports both vertical and horizontal bands, depending on the
specified orientation.
Parameters
----------
plot_collection : PlotCollection
Plot collection to which the bands will be added.
values : tuple, list or dict
Start and end values for the bands to be plotted.
orientation : str, default "vertical"
The orientation of the bands, either "vertical" or "horizontal".
aes_by_visuals : mapping of {str : sequence of str}, optional
Mapping of visuals to aesthetics that should use their mapping in `plot_collection`
when plotted. Valid keys are the same as for `visuals`.
The default is to use an "overlay_band" aesthetic for all elements.
It is possible to request aesthetics without mappings defined in the
provided `plot_collection`. In those cases, a mapping of the dimensions in
`ref_dim` minus its last element to the requested aesthetic will be
automatically added.
visuals : mapping of {str : mapping or False}, optional
Valid keys are:
* "ref_band" -> passed to :func:`~arviz_plots.visuals.vspan` for vertical `orientation`
and to :func:`~arviz_plots.visuals.hspan` for horizontal `orientation`
sample_dims : list, optional
Dimensions that should not be added to the Dataset generated from
`values` via :func:`arviz_base.references_to_dataset`.
Defaults to all dimensions in ``plot_collection.data`` that are not ``facet_dims``
ref_dim : list, optional
List of dimension names that define the axes along which the band values are stored.
These dimensions are used to align or compare input data with band data.
Defaults to ["ref_dim", "band_dim"].
**kwargs : sequence, optional
Mapping of aesthetic keys to the values to be used in their mapping.
See :func:`~arviz_plots.PlotCollection.generate_aes_dt` for more details.
Returns
-------
plot_collection : PlotCollection
Plot collection with the bands added.
Examples
--------
Add two bands for the theta variable, one from -2 to 2 and the other from -5 to 5.
.. plot::
:context: close-figs
>>> from arviz_plots import plot_dist, add_bands, style
>>> style.use("arviz-variat")
>>> from arviz_base import load_arviz_data
>>> dt = load_arviz_data('centered_eight')
>>> pc = plot_dist(dt)
>>> add_bands(pc, values=[[-2, 2], [-5, 5]])
"""
if visuals is None:
visuals = {}
if aes_by_visuals is None:
aes_by_visuals = {}
else:
aes_by_visuals = aes_by_visuals.copy()
aes_by_visuals["ref_band"] = aes_by_visuals.get("ref_band", ["overlay_band"])
if sample_dims is None:
sample_dims = list(set(plot_collection.data.dims).difference(plot_collection.facet_dims))
if isinstance(sample_dims, str):
sample_dims = [sample_dims]
if ref_dim is None:
ref_dim = ["ref_dim", "band_dim"]
plot_func = vspan if orientation == "vertical" else hspan
ref_ds = references_to_dataset(
values, plot_collection.data, sample_dims=sample_dims, ref_dim=ref_dim
)
plot_bknd = import_module(f".backend.{plot_collection.backend}", package="arviz_plots")
bg_color = plot_bknd.get_background_color()
_, contrast_gray_color = get_contrast_colors(bg_color=bg_color, gray_flag=True)
requested_aes = set(aes_by_visuals["ref_band"]).difference(plot_collection.aes_set)
*ref_dim, band_dim = ref_dim
if ref_ds.sizes[band_dim] != 2:
raise ValueError(
f"Expected dimension '{band_dim}' in reference dataset to have size 2, "
f"but found size {ref_ds.sizes[band_dim]}"
)
aes_dt = plot_collection.generate_aes_dt(
{aes: ref_dim for aes in requested_aes}, ref_ds, **kwargs
)
for aes, child in aes_dt.children.items():
plot_collection.update_aes_from_dataset(aes, child.dataset)
_, ref_aes, ref_ignore = filter_aes(plot_collection, aes_by_visuals, "ref_band", sample_dims)
ref_kwargs = copy(visuals.get("ref_band", {}))
if ref_kwargs is not False:
if "color" not in ref_aes:
ref_kwargs.setdefault("color", contrast_gray_color)
if "alpha" not in ref_aes:
ref_kwargs.setdefault("alpha", 0.25)
plot_collection.map(plot_func, "ref_band", data=ref_ds, ignore_aes=ref_ignore, **ref_kwargs)
return plot_collection
def get_contrast_colors(bg_color="#ffffff", gray_flag=False):
"""Get contrast colors based on the background color."""
color = bg_color.lstrip("#")
r = int(color[0:2], 16)
g = int(color[2:4], 16)
b = int(color[4:6], 16)
# calculating the YIQ brightness value
yiq = (r * 299 + g * 587 + b * 114) / 1000
if gray_flag:
return ("#ffffff", "#b3b3b3") if yiq < 128 else ("#000000", "#4c4c4c")
return "#000000" if yiq >= 128 else "#ffffff"