"""psense quantities plot code."""
from collections.abc import Mapping, Sequence
from copy import copy
from importlib import import_module
from typing import Any, Literal
from arviz_base import extract, rcParams
from arviz_base.labels import BaseLabeller
from arviz_stats.psense import power_scale_dataset
from xarray import concat
from arviz_plots.plot_collection import PlotCollection
from arviz_plots.plots.utils import (
filter_aes,
get_contrast_colors,
process_group_variables_coords,
set_grid_layout,
)
from arviz_plots.visuals import hline, labelled_title, labelled_x, line_xy, scatter_xy, set_xticks
[docs]
def plot_psense_quantities(
dt,
alphas=None,
quantities=None,
mcse=True,
var_names=None,
filter_vars=None,
prior_var_names=None,
likelihood_var_names=None,
prior_coords=None,
likelihood_coords=None,
coords=None,
sample_dims=None,
plot_collection=None,
backend=None,
labeller=None,
aes_by_visuals: Mapping[
Literal[
"prior_markers",
"prior_lines",
"likelihood_markers",
"likelihood_lines",
"mcse",
"ticks",
"title",
],
Sequence[str],
] = None,
visuals: Mapping[
Literal[
"prior_markers",
"prior_lines",
"likelihood_markers",
"likelihood_lines",
"mcse",
"ticks",
"title",
"legend",
],
Mapping[str, Any] | Literal[False],
] = None,
**pc_kwargs,
):
"""Plot power scaled posterior quantities.
The posterior quantities are computed by power-scaling the prior or likelihood and
visualizing the resulting changes, using Pareto-smoothed importance sampling to
avoid refitting as explained in [1]_.
Parameters
----------
dt : DataTree
Input data
alphas : tuple of float
Lower and upper alpha values for power scaling. Defaults to (0.8, 1.25).
quantities : list of str
Quantities to plot. Options are 'mean', 'sd', 'median'. For quantiles, use
'0.25', '0.5', etc. Defaults to ['mean', 'sd'].
mcse : bool
Whether to plot the Monte Carlo standard error for each quantity. Defaults to True.
var_names : str or list of str, optional
One or more variables to be plotted.
Prefix the variables by ~ when you want to exclude them from the plot.
filter_vars : {None, “like”, “regex”}, optional, default=None
If None (default), interpret var_names as the real variables names.
If “like”, interpret var_names as substrings of the real variables names.
If “regex”, interpret var_names as regular expressions on the real variables names.
prior_var_names : str, optional.
Name of the log-prior variables to include in the power scaling sensitivity diagnostic
likelihood_var_names : str, optional.
Name of the log-likelihood variables to include in the power scaling sensitivity diagnostic
prior_coords : dict, optional.
Coordinates defining a subset over the group element for which to
compute the log-prior sensitivity diagnostic
likelihood_coords : dict, optional
Coordinates defining a subset over the group element for which to
compute the log-likelihood sensitivity diagnostic
coords : dict, optional
sample_dims : str or sequence of hashable, optional
Dimensions to reduce unless mapped to an aesthetic.
Defaults to ``rcParams["data.sample_dims"]``
plot_collection : PlotCollection, optional
backend : {"matplotlib", "bokeh", "plotly"}, optional
labeller : labeller, optional
aes_by_visuals : mapping of {str : sequence of str}, optional
Mapping of visuals to aesthetics that should use their mapping in `plot_collection`
when plotted. Valid keys are the same as for `visuals`.
visuals : mapping of {str : mapping or False}, optional
Valid keys are:
* prior_markers -> passed to :func:`~arviz_plots.visuals.scatter_xy`
* prior_lines -> passed to :func:`~arviz_plots.visuals.line_xy`
* likelihood_markers -> passed to :func:`~arviz_plots.visuals.scatter_xy`
* likelihood_lines -> passed to :func:`~arviz_plots.visuals.line_xy`
* mcse -> passed to :func:`~arviz_plots.visuals.hline`
* ticks -> passed to :func:`~arviz_plots.visuals.set_xticks`
* title -> passed to :func:`~arviz_plots.visuals.labelled_title`
* legend -> passed to :class:`arviz_plots.PlotCollection.add_legend`
**pc_kwargs
Passed to :class:`arviz_plots.PlotCollection.grid`
Returns
-------
PlotCollection
Examples
--------
Select a single parameter, one of the two likelihoods, and plot the mean, standard deviation,
and 25th percentile.
.. plot::
:context: close-figs
>>> from arviz_plots import plot_psense_quantities, style
>>> style.use("arviz-variat")
>>> from arviz_base import load_arviz_data
>>> rugby = load_arviz_data('rugby')
>>> plot_psense_quantities(rugby,
>>> var_names=["sd_att"],
>>> likelihood_var_names=["home_points"],
>>> quantities=["mean", "sd", "0.25"])
.. minigallery:: plot_psense_quantities
References
----------
.. [1] Kallioinen et al, *Detecting and diagnosing prior and likelihood sensitivity with
power-scaling*, Stat Comput 34, 57 (2024), https://doi.org/10.1007/s11222-023-10366-5
"""
if sample_dims is None:
sample_dims = rcParams["data.sample_dims"]
if isinstance(sample_dims, str):
sample_dims = [sample_dims]
sample_dims = list(sample_dims)
if visuals is None:
visuals = {}
else:
visuals = visuals.copy()
if backend is None:
if plot_collection is None:
backend = rcParams["plot.backend"]
else:
backend = plot_collection.backend
if alphas is None:
alphas = (0.8, 1.25)
alphas_p1 = (alphas[0], 1, alphas[1])
alphas_p1_labels = [str(val) for val in alphas_p1]
if quantities is None:
quantities = ["mean", "sd"]
if isinstance(quantities, str):
quantities = [quantities]
labeller = BaseLabeller()
ds_posterior = extract(
dt,
var_names=var_names,
filter_vars=filter_vars,
group="posterior",
combined=False,
keep_dataset=True,
)
ds_prior = power_scale_dataset(
dt,
group="prior",
alphas=alphas,
sample_dims=sample_dims,
group_var_names=prior_var_names,
group_coords=prior_coords,
)
ds_likelihood = power_scale_dataset(
dt,
group="likelihood",
alphas=alphas,
sample_dims=sample_dims,
group_var_names=likelihood_var_names,
group_coords=likelihood_coords,
)
distribution = concat([ds_prior, ds_likelihood], dim="component_group").assign_coords(
{"component_group": ["prior", "likelihood"]}
)
distribution = process_group_variables_coords(
distribution, group=None, var_names=var_names, filter_vars=filter_vars, coords=coords
)
if len(sample_dims) > 1:
# sample dims will have been stacked and renamed by `power_scale_dataset`
sample_dims = ["sample"]
to_concat_quantities = []
to_concat_mcse = []
name_quantities = []
if "mean" in quantities:
to_concat_quantities.append(distribution.mean(sample_dims))
if mcse:
to_concat_mcse.append(ds_posterior.azstats.mcse(method="mean"))
name_quantities.append("mean")
if "sd" in quantities:
to_concat_quantities.append(distribution.std(sample_dims))
if mcse:
to_concat_mcse.append(ds_posterior.azstats.mcse(method="sd"))
name_quantities.append("sd")
if "median" in quantities:
to_concat_quantities.append(distribution.median(sample_dims))
if mcse:
to_concat_mcse.append(ds_posterior.azstats.mcse(method="median"))
name_quantities.append("median")
for val in quantities:
if val.replace(".", "").isnumeric():
q = float(val)
to_concat_quantities.append(
distribution.quantile(q, sample_dims).rename_vars({"quantile": f"q={val}"})
)
if mcse:
to_concat_mcse.append(ds_posterior.azstats.mcse(method="quantile", prob=q))
name_quantities.append(f"q={val}")
ds_quantities = concat(to_concat_quantities, "quantities").assign_coords(
quantities=name_quantities
)
if mcse:
mcse_quantities = concat(to_concat_mcse, "quantities").assign_coords(
quantities=name_quantities
)
baseline_quantities = ds_quantities.sel(component_group="prior", alpha=1).drop_vars(
["alpha", "component_group"]
)
min_ = baseline_quantities - mcse_quantities * 2
max_ = baseline_quantities + mcse_quantities * 2
plot_bknd = import_module(f".backend.{backend}", package="arviz_plots")
bg_color = plot_bknd.get_background_color()
contrast_color = get_contrast_colors(bg_color=bg_color)
colors = plot_bknd.get_default_aes("color", 2, {})
markers = plot_bknd.get_default_aes("marker", 6, {})
lines = plot_bknd.get_default_aes("linestyle", 2, {})
if plot_collection is None:
pc_kwargs["figure_kwargs"] = pc_kwargs.get("figure_kwargs", {}).copy()
pc_kwargs["figure_kwargs"].setdefault("sharex", True)
pc_kwargs["aes"] = pc_kwargs.get("aes", {}).copy()
pc_kwargs["aes"].setdefault("color", ["component_group"])
pc_kwargs.setdefault("cols", ["quantities"])
pc_kwargs.setdefault(
"rows",
["__variable__"]
+ [
dim
for dim in distribution.dims
if dim not in sample_dims + ["component_group", "alpha"]
],
)
pc_kwargs = set_grid_layout(pc_kwargs, plot_bknd, ds_quantities)
plot_collection = PlotCollection.grid(
ds_quantities,
backend=backend,
**pc_kwargs,
)
if aes_by_visuals is None:
aes_by_visuals = {}
else:
aes_by_visuals = aes_by_visuals.copy()
aes_by_visuals.setdefault("quantities_marker", ["color", "marker"])
aes_by_visuals.setdefault("quantities", ["color"])
# plot quantities for prior-perturbations
## markers
prior_ms_kwargs = copy(visuals.get("prior_markers", {}))
if prior_ms_kwargs is not False:
_, _, prior_ms_ignore = filter_aes(
plot_collection, aes_by_visuals, "prior_markers", sample_dims
)
prior_ms_kwargs.setdefault("marker", markers[0])
prior_ms_kwargs.setdefault("color", colors[0])
plot_collection.map(
scatter_xy,
"prior_markers",
data=ds_quantities.sel(component_group="prior"),
x=ds_quantities.alpha,
ignore_aes=prior_ms_ignore,
**prior_ms_kwargs,
)
## lines
prior_ls_kwargs = copy(visuals.get("prior_lines", {}))
if prior_ls_kwargs is not False:
_, _, prior_ls_ignore = filter_aes(
plot_collection, aes_by_visuals, "prior_lines", sample_dims
)
prior_ls_kwargs.setdefault("color", colors[0])
plot_collection.map(
line_xy,
"prior_lines",
data=ds_quantities.sel(component_group="prior"),
x=ds_quantities.alpha,
ignore_aes=prior_ls_ignore,
**prior_ls_kwargs,
)
# plot quantities for likelihood-perturbations
## markers
likelihood_ms_kwargs = copy(visuals.get("likelihood_markers", {}))
if likelihood_ms_kwargs is not False:
_, _, likelihood_ms_ignore = filter_aes(
plot_collection, aes_by_visuals, "likelihood_markers", sample_dims
)
likelihood_ms_kwargs.setdefault("marker", markers[5])
likelihood_ms_kwargs.setdefault("color", colors[1])
plot_collection.map(
scatter_xy,
"likelihood_markers",
data=ds_quantities.sel(component_group="likelihood"),
x=ds_quantities.alpha,
ignore_aes=likelihood_ms_ignore,
**likelihood_ms_kwargs,
)
## lines
likelihood_ls_kwargs = copy(visuals.get("likelihood_lines", {}))
if likelihood_ls_kwargs is not False:
_, _, likelihood_ls_ignore = filter_aes(
plot_collection, aes_by_visuals, "likelihood_lines", sample_dims
)
likelihood_ls_kwargs.setdefault("color", colors[1])
plot_collection.map(
line_xy,
"likelihood_lines",
data=ds_quantities.sel(component_group="likelihood"),
x=ds_quantities.alpha,
ignore_aes=likelihood_ls_ignore,
**likelihood_ls_kwargs,
)
# plot mcse
if mcse:
mcse_kwargs = copy(visuals.get("mcse", {}))
_, _, mcse_ignore = filter_aes(plot_collection, aes_by_visuals, "mcse", sample_dims)
if mcse_kwargs is not False:
mcse_kwargs.setdefault("color", "grey")
mcse_kwargs.setdefault("linestyle", lines[1])
plot_collection.map(hline, "mcse", data=min_, ignore_aes=mcse_ignore, **mcse_kwargs)
plot_collection.map(hline, "mcse", data=max_, ignore_aes=mcse_ignore, **mcse_kwargs)
# set ticks
ticks_kwargs = copy(visuals.get("ticks", {}))
_, _, ticks_ignore = filter_aes(plot_collection, aes_by_visuals, "ticks", sample_dims)
if ticks_kwargs is not False:
plot_collection.map(
set_xticks,
"ticks",
values=alphas_p1,
labels=alphas_p1_labels,
ignore_aes=ticks_ignore,
store_artist=backend == "none",
**ticks_kwargs,
)
# set xlabel
_, xlabels_aes, xlabels_ignore = filter_aes(
plot_collection, aes_by_visuals, "xlabel", sample_dims
)
xlabel_kwargs = visuals.get("xlabel", {}).copy()
if xlabel_kwargs is not False:
if "color" not in xlabels_aes:
xlabel_kwargs.setdefault("color", contrast_color)
xlabel_kwargs.setdefault("text", "Power-scaling α")
plot_collection.map(
labelled_x,
"xlabel",
ignore_aes=xlabels_ignore,
subset_info=True,
**xlabel_kwargs,
)
# title
title_kwargs = copy(visuals.get("title", {}))
_, _, title_ignore = filter_aes(plot_collection, aes_by_visuals, "title", sample_dims)
if title_kwargs is not False:
plot_collection.map(
labelled_title,
"title",
ignore_aes=title_ignore,
subset_info=True,
labeller=labeller,
**title_kwargs,
)
# legend
legend_kwargs = copy(visuals.get("legend", {}))
if legend_kwargs is not False:
legend_kwargs.setdefault("dim", ["component_group"])
plot_collection.add_legend(**legend_kwargs)
return plot_collection