"""Elements to combine multiple batteries-included plots into a single figure."""
import re
from importlib import import_module
from arviz_base import rcParams
from arviz_plots import PlotCollection
from arviz_plots.plot_collection import backend_from_object
from arviz_plots.plots.utils import process_group_variables_coords, set_grid_layout
def get_valid_arg(key, value, backend):
"""Convert none backend aesthetic argument indicator to a valid value for the given backend.
Parameters
----------
key : str
The keyword part of the :ref:`backend-interface-arguments` for which `value` should
be valid.
value : any
The current value for `key`. It might be an indicator from the none backend such as
"color_0" or "linestyle_3" which gets processed or something else in which case
it is assumed to be a valid argument already and returned as is.
backend : str
The backend for which `value` should be valid.
Returns
-------
valid_value : any
"""
plot_backend = import_module(f"arviz_plots.backend.{backend}")
key_matcher = "color" if key in {"facecolor", "edgecolor"} else key
if isinstance(value, str):
match = re.match(key_matcher + "_([0-9]+)", value)
if match:
index = int(match.groups()[0])
return plot_backend.get_default_aes(key, index + 1)[index]
return value
def backendize_kwargs(kwargs, backend):
"""Process the visual description dictionary from the none backend to valid kwargs."""
return {
key: get_valid_arg(key, value, backend)
for key, value in kwargs.items()
if key != "function"
}
def render(da, target, **kwargs):
"""Render visual descriptions from the none backend with a plotting backend."""
backend = backend_from_object(target, return_module=False)
plot_backend = import_module(f"arviz_plots.backend.{backend}")
visuals = da.item()
plot_fun_name = visuals["function"]
visuals = backendize_kwargs(visuals, backend)
kwargs = backendize_kwargs(kwargs, backend)
return getattr(plot_backend, plot_fun_name)(target=target, **{**visuals, **kwargs})
[docs]
def combine_plots(
dt,
plots,
var_names=None,
filter_vars=None,
group="posterior",
coords=None,
sample_dims=None,
expand="column",
plot_names=None,
backend=None,
**pc_kwargs,
):
"""Arrange multiple batteries-included plots in a customizable column or row layout.
Parameters
----------
dt : DataTree of dict of {str : DataTree}
Input data. In case of dictionary input, the keys are taken to be model names.
In such cases, a dimension "model" is generated and can be used to map to aesthetics.
Note that not all batteries included functions accept dictionary input, so it will
only work when all plotting functions requested in `plots` are compatible with it.
plots : list of tuple of (callable, mapping)
List of all the plotting functions to be combined. Each element in this list
is a tuple with two elements. The first is the function to be called, the second
is a dictionary with any keyword arguments that should be used when calling that function.
var_names : str or sequence of str, optional
One or more variables to be plotted.
Prefix the variables by ~ when you want to exclude them from the plot.
filter_vars : {None, “like”, “regex”}, default None
If None, interpret `var_names` as the real variables names.
If “like”, interpret `var_names` as substrings of the real variables names.
If “regex”, interpret `var_names` as regular expressions on the real variables names.
group : str, default "posterior"
Group to be plotted.
coords : dict, optional
sample_dims : str or sequence of hashable, optional
Dimensions to reduce unless mapped to an aesthetic.
Defaults to ``rcParams["data.sample_dims"]``
expand : {"column", "row"}, default "column"
How to combine the different plotting functions. If "column", each plotting function
will be added as a new column, if "row" it will be a new row instead.
plot_names : list of str, optional
List of the same length as `plots` with the plot names to use as coordinate values
in the returned :class:`~arviz_plots.PlotCollection`.
backend : {"matplotlib", "bokeh", "plotly"}, optional
Plotting backend to use. Defaults to ``rcParams["plot.backend"]``.
**pc_kwargs
Passed to :class:`arviz_plots.PlotCollection.grid`
Returns
-------
PlotCollection
Examples
--------
Customize the names of the plots in the returned :class:`PlotCollection`
.. plot::
:context: close-figs
>>> import arviz_plots as azp
>>> azp.style.use("arviz-variat")
>>> from arviz_base import load_arviz_data
>>> rugby = load_arviz_data('rugby')
>>> pc = azp.combine_plots(
>>> rugby,
>>> plots=[
>>> (azp.plot_ppc_pit, {}),
>>> (azp.plot_ppc_rootogram, {}),
>>> ],
>>> group="posterior_predictive",
>>> plot_names=["pit", "rootogram"],
>>> )
Now if we inspect the ``pc.viz`` attribute, we can see it has a ``column`` dimension
with the requested coordinate values:
.. plot::
:context: close-figs
>>> pc.viz
.. minigallery:: combine_plots
"""
if plot_names is None:
plot_names = [
getattr(elem[0], "__name__") + f"_{idx:02d}" for idx, elem in enumerate(plots)
]
if sample_dims is None:
sample_dims = rcParams["data.sample_dims"]
if isinstance(sample_dims, str):
sample_dims = [sample_dims]
if backend is None:
backend = rcParams["plot.backend"]
distribution = process_group_variables_coords(
dt, group=group, var_names=var_names, filter_vars=filter_vars, coords=coords
)
facet_dims = ["__variable__"] + (
[]
if "predictive" in group
else [dim for dim in distribution.dims if dim not in sample_dims]
)
pc_kwargs["figure_kwargs"] = pc_kwargs.get("figure_kwargs", {}).copy()
if expand == "column":
pc_kwargs.setdefault("cols", ["column"])
pc_kwargs.setdefault("rows", facet_dims)
expand_kwargs = {"column": len(plots)}
elif expand == "row":
pc_kwargs.setdefault("cols", facet_dims)
pc_kwargs.setdefault("rows", ["row"])
expand_kwargs = {"row": len(plots)}
else:
raise ValueError(f"`expand` must be 'row' or 'column' but got '{expand}'")
distribution = distribution.expand_dims(**expand_kwargs).assign_coords({expand: plot_names})
plot_bknd = import_module(f".backend.{backend}", package="arviz_plots")
pc_kwargs = set_grid_layout(pc_kwargs, plot_bknd, distribution)
pc = PlotCollection.grid(
distribution,
backend=backend,
**pc_kwargs,
)
for name, (plot, kwargs) in zip(plot_names, plots):
pc_i = plot(
dt,
backend="none",
group=group,
var_names=var_names,
filter_vars=filter_vars,
coords=coords,
sample_dims=sample_dims,
**kwargs,
)
pc.coords = None
pc.aes = pc_i.aes
pc.coords = {expand: name}
for viz_group, ds in pc_i.viz.children.items():
if viz_group in {"plot", "row_index", "col_index"}:
continue
attrs = ds.attrs
pc.map(
render,
fun_label=f"{viz_group}_{name}",
data=ds.dataset,
ignore_aes=attrs.get("ignore_aes", frozenset()),
)
pc.coords = None
# TODO: at some point all `pc_i.aes` objects should be merged
# and stored into the `pc.aes` attribute
return pc