"""Pair focus plot code."""
from collections.abc import Mapping, Sequence
from copy import copy
from importlib import import_module
from typing import Any, Literal
import numpy as np
import xarray as xr
from arviz_base import rcParams
from arviz_base.labels import BaseLabeller
from arviz_plots.plot_collection import PlotCollection
from arviz_plots.plots.utils import (
filter_aes,
get_group,
process_group_variables_coords,
set_wrap_layout,
)
from arviz_plots.visuals import labelled_x, labelled_y, scatter_x, scatter_xy
[docs]
def plot_pair_focus(
dt,
focus_var,
focus_var_coords=None,
var_names=None,
filter_vars=None,
group="posterior",
coords=None,
sample_dims=None,
plot_collection=None,
backend=None,
labeller=None,
aes_by_visuals: Mapping[
Literal[
"scatter",
"divergence",
"xlabel",
"ylabel",
],
Sequence[str],
] = None,
visuals: Mapping[
Literal[
"scatter",
"divergence",
"xlabel",
"ylabel",
],
Mapping[str, Any] | Literal[False],
] = None,
**pc_kwargs,
):
"""Plot a fixed variable against other variables in the dataset.
Parameters
----------
dt : DataTree
Input data
focus_var: str or DataArray
Name of the variable or DataArray to be plotted against all other variables.
focus_var_coords : mapping, optional
Coordinates to use for the target variable.
var_names: str or list of str, optional
One or more variables to be plotted.
Prefix the variables by ~ when you want to exclude them from the plot.
filter_vars: {None, “like”, “regex”}, default None
If None (default), interpret var_names as the real variables names.
If “like”, interpret var_names as substrings of the real variables names.
If “regex”, interpret var_names as regular expressions on the real variables names.
group : str, default "posterior"
Group to use for plotting. Defaults to "posterior".
coords : mapping, optional
Coordinates to use for plotting.
sample_dims : iterable, optional
Dimensions to reduce unless mapped to an aesthetic.
Defaults to ``rcParams["data.sample_dims"]``
plot_collection : PlotCollection, optional
backend : {"matplotlib", "bokeh", "plotly", "none"}, optional
Plotting backend to use. Defaults to ``rcParams["plot.backend"]``
labeller : labeller, optional
aes_by_visuals : mapping, optional
Mapping of visuals to aesthetics that should use their mapping in `plot_collection`
when plotted. Valid keys are the same as for `visuals`.
By default, there are no aesthetic mappings at all
visuals : mapping of {str : mapping or False}, optional
Valid keys are:
* scatter -> passed to :func:`~.visuals.scatter_x`
* divergence -> passed to :func:`~.visuals.scatter_xy`. Defaults to False.
* xlabel -> :func:`~.visuals.labelled_x`
* ylabel -> :func:`~.visuals.labelled_y`
**pc_kwargs
Passed to :meth:`arviz_plots.PlotCollection.wrap`
Returns
-------
PlotCollection
Examples
--------
Default plot_pair_focus
.. plot::
:context: close-figs
>>> from arviz_plots import plot_pair_focus, style
>>> style.use("arviz-variat")
>>> from arviz_base import load_arviz_data
>>> dt = load_arviz_data('centered_eight')
>>> plot_pair_focus(
>>> dt,
>>> var_names=["mu", "tau"],
>>> focus_var="theta",
>>> focus_var_coords={"school": "Choate"},
>>> )
.. minigallery:: plot_pair_focus
"""
if sample_dims is None:
sample_dims = rcParams["data.sample_dims"]
if isinstance(sample_dims, str):
sample_dims = [sample_dims]
if visuals is None:
visuals = {}
if pc_kwargs is None:
pc_kwargs = {}
else:
pc_kwargs = pc_kwargs.copy()
if aes_by_visuals is None:
aes_by_visuals = {}
else:
aes_by_visuals = aes_by_visuals.copy()
if backend is None:
if plot_collection is None:
backend = rcParams["plot.backend"]
else:
backend = plot_collection.backend
distribution = process_group_variables_coords(
dt, group=group, var_names=var_names, filter_vars=filter_vars, coords=coords
)
if isinstance(focus_var, str):
y = (
get_group(dt, group)[focus_var].sel(focus_var_coords)
if focus_var_coords
else get_group(dt, group)[focus_var]
)
elif isinstance(focus_var, xr.DataArray):
y = focus_var
else:
raise TypeError(
f"focus_var should be a string or DataArray, got {type(focus_var)} instead."
)
dims_y = [dim for dim in y.dims if dim not in sample_dims]
if len(dims_y) > 0:
raise ValueError(f"focus variable has unexpected dimensions: {dims_y}.")
plot_bknd = import_module(f".backend.{backend}", package="arviz_plots")
colors = plot_bknd.get_default_aes("color", 2, {})
if plot_collection is None:
pc_kwargs.setdefault(
"cols", ["__variable__"] + [dim for dim in distribution.dims if dim not in sample_dims]
)
pc_kwargs["figure_kwargs"] = pc_kwargs.get("figure_kwargs", {}).copy()
pc_kwargs = set_wrap_layout(pc_kwargs, plot_bknd, distribution)
pc_kwargs["aes"] = pc_kwargs.get("aes", {}).copy()
if "chain" in distribution:
pc_kwargs["aes"].setdefault("overlay", ["chain"])
pc_kwargs["figure_kwargs"].setdefault("sharey", True)
plot_collection = PlotCollection.wrap(
distribution,
backend=backend,
**pc_kwargs,
)
# scatter
aes_by_visuals["scatter"] = {"overlay"}.union(aes_by_visuals.get("scatter", {}))
scatter_kwargs = copy(visuals.get("scatter", {}))
if scatter_kwargs is not False:
_, scatter_aes, scatter_ignore = filter_aes(
plot_collection, aes_by_visuals, "scatter", sample_dims
)
if "color" not in scatter_aes:
scatter_kwargs.setdefault("color", colors[0])
if "width" not in scatter_aes:
scatter_kwargs.setdefault("width", 0)
if "alpha" not in scatter_aes:
scatter_kwargs.setdefault("alpha", 0.5)
plot_collection.map(
scatter_x,
"scatter",
ignore_aes=scatter_ignore,
y=y,
**scatter_kwargs,
)
# divergence
aes_by_visuals["divergence"] = {"overlay"}.union(aes_by_visuals.get("divergence", {}))
div_kwargs = copy(visuals.get("divergence", False))
if div_kwargs is True:
div_kwargs = {}
sample_stats = get_group(dt, "sample_stats", allow_missing=True)
if (
div_kwargs is not False
and sample_stats is not None
and "diverging" in sample_stats.data_vars
and np.any(sample_stats.diverging)
):
divergence_mask = dt.sample_stats.diverging
_, div_aes, div_ignore = filter_aes(
plot_collection, aes_by_visuals, "divergence", sample_dims
)
if "color" not in div_aes:
div_kwargs.setdefault("color", colors[1])
if "alpha" not in div_aes:
div_kwargs.setdefault("alpha", 0.4)
plot_collection.map(
scatter_xy,
"divergence",
ignore_aes=div_ignore,
y=y,
mask=divergence_mask,
**div_kwargs,
)
if labeller is None:
labeller = BaseLabeller()
# xlabel of plots
xlabel_kwargs = copy(visuals.get("xlabel", {}))
if xlabel_kwargs is not False:
_, _, xlabel_ignore = filter_aes(plot_collection, aes_by_visuals, "xlabel", sample_dims)
plot_collection.map(
labelled_x,
"xlabel",
subset_info=True,
ignore_aes=xlabel_ignore,
labeller=labeller,
**xlabel_kwargs,
)
# ylabel of plots
ylabel_kwargs = copy(visuals.get("ylabel", {}))
if ylabel_kwargs is not False:
_, _, ylabel_ignore = filter_aes(plot_collection, aes_by_visuals, "ylabel", sample_dims)
# generate y label text using labeller
focus_var_coords = {key: value.item() for key, value in y.coords.items() if value.size <= 1}
y_label_text = labeller.make_label_vert(
y.name, focus_var_coords, {name: 0 for name in focus_var_coords}
)
plot_collection.map(
labelled_y,
"ylabel",
ignore_aes=ylabel_ignore,
text=y_label_text,
**ylabel_kwargs,
)
return plot_collection