Source code for arviz_plots.plots.pava_calibration_plot

"""Plot ppc using PAV-adjusted calibration plot."""
from collections.abc import Mapping, Sequence
from copy import copy
from importlib import import_module
from typing import Any, Literal

from arviz_base import rcParams
from arviz_base.labels import BaseLabeller
from arviz_stats.helper_stats import isotonic_fit

from arviz_plots.plot_collection import PlotCollection
from arviz_plots.plots.utils import filter_aes, get_contrast_colors, set_wrap_layout
from arviz_plots.visuals import (
    dline,
    fill_between_y,
    labelled_title,
    labelled_x,
    labelled_y,
    line_xy,
    scatter_xy,
)


[docs] def plot_ppc_pava( dt, data_type="binary", n_bootstaps=1000, ci_prob=None, var_names=None, filter_vars=None, # pylint: disable=unused-argument group="posterior_predictive", coords=None, # pylint: disable=unused-argument sample_dims=None, plot_collection=None, backend=None, labeller=None, aes_by_visuals: Mapping[ Literal[ "lines", "markers", "reference_line", "credible_interval", "xlabel", "ylabel", "title", ], Sequence[str], ] = None, visuals: Mapping[ Literal[ "lines", "markers", "reference_line", "credible_interval", "xlabel", "ylabel", "title", ], Mapping[str, Any] | Literal[False], ] = None, **pc_kwargs, ): """PAV-adjusted calibration plot. Uses the pool adjacent violators (PAV) algorithm for isotonic regression. An a 45-degree line corresponds to perfect calibration. Details are discussed in [1]_ and [2]_. Parameters ---------- dt : DataTree Input data data_type : str Defaults to "binary". Other options are "categorical" and "ordinal". If "categorical", the plot will show the "one-vs-others" calibration and generate one plot per category. If "ordinal", the plot will display cumulative conditional event probabilities and generate (number of categories - 1) plots. n_bootstaps : int, optional Number of bootstrap samples to use for estimating the confidence intervals. defaults to 1000. ci_prob : float, optional Probability for the credible interval. Defaults to ``rcParams["stats.ci_prob"]``. num_samples : int, optional Number of samples to use for the plot. Defaults to 100. var_names : str or list of str, optional One or more variables to be plotted. Currently only one variable is supported. Prefix the variables by ~ when you want to exclude them from the plot. filter_vars : {None, “like”, “regex”}, optional, default=None If None (default), interpret var_names as the real variables names. If “like”, interpret var_names as substrings of the real variables names. If “regex”, interpret var_names as regular expressions on the real variables names. group : str, optional The group from which to get the unique values. Defaults to "posterior_predictive". It could also be "prior_predictive". Notice that this plots always use the "observed_data" so use with extra care if you are using "prior_predictive". coords : dict, optional Coordinates to plot. CURRENTLY NOT IMPLEMENTED sample_dims : str or sequence of hashable, optional Dimensions to reduce unless mapped to an aesthetic. Defaults to ``rcParams["data.sample_dims"]`` plot_collection : PlotCollection, optional backend : {"matplotlib", "bokeh", "plotly"}, optional labeller : labeller, optional aes_by_visuals : mapping of {str : sequence of str}, optional Mapping of visuals to aesthetics that should use their mapping in `plot_collection` when plotted. Valid keys are the same as for `visuals`. visuals : mapping of {str : mapping or False}, optional Valid keys are: * lines -> passed to :func:`~arviz_plots.visuals.line_xy` * markers -> passed to :func:`~arviz_plots.visuals.scatter_xy` * reference_line -> passed to :func:`~arviz_plots.visuals.line_xy` * credible_interval -> passed to :func:`~arviz_plots.visuals.fill_between_y` * xlabel -> passed to :func:`~arviz_plots.visuals.labelled_x` * ylabel -> passed to :func:`~arviz_plots.visuals.labelled_y` * title -> passed to :func:`~arviz_plots.visuals.labelled_title` markers defaults to False, no markers are plotted. Pass an (empty) mapping to plot markers. **pc_kwargs Passed to :class:`arviz_plots.PlotCollection.grid` Returns ------- PlotCollection Examples -------- Plot the PAVA calibration plot for the rugby dataset. .. plot:: :context: close-figs >>> from arviz_plots import plot_ppc_pava, style >>> style.use("arviz-variat") >>> from arviz_base import load_arviz_data >>> dt = load_arviz_data('rugby') >>> plot_ppc_pava(dt, ci_prob=0.90) .. minigallery:: plot_ppc_pava References ---------- .. [1] Säilynoja et al. *Recommendations for visual predictive checks in Bayesian workflow*. (2025) arXiv preprint https://arxiv.org/abs/2503.01509 .. [2] Dimitriadis et al *Stable reliability diagrams for probabilistic classifiers*. PNAS, 118(8) (2021). https://doi.org/10.1073/pnas.2016191118 """ if ci_prob is None: ci_prob = rcParams["stats.ci_prob"] if sample_dims is None: sample_dims = rcParams["data.sample_dims"] if isinstance(sample_dims, str): sample_dims = [sample_dims] sample_dims = list(sample_dims) if visuals is None: visuals = {} else: visuals = visuals.copy() if backend is None: if plot_collection is None: backend = rcParams["plot.backend"] else: backend = plot_collection.backend labeller = BaseLabeller() visuals.setdefault("markers", False) ds_calibration = isotonic_fit(dt, var_names, group, n_bootstaps, ci_prob, data_type) plot_bknd = import_module(f".backend.{backend}", package="arviz_plots") bg_color = plot_bknd.get_background_color() contrast_color = get_contrast_colors(bg_color=bg_color) colors = plot_bknd.get_default_aes("color", 1, {}) markers = plot_bknd.get_default_aes("marker", 7, {}) lines = plot_bknd.get_default_aes("linestyle", 2, {}) if plot_collection is None: pc_kwargs["figure_kwargs"] = pc_kwargs.get("figure_kwargs", {}).copy() pc_kwargs["aes"] = pc_kwargs.get("aes", {}).copy() pc_kwargs.setdefault("cols", ["__variable__"]) pc_kwargs = set_wrap_layout(pc_kwargs, plot_bknd, ds_calibration) plot_collection = PlotCollection.wrap( ds_calibration, backend=backend, **pc_kwargs, ) if aes_by_visuals is None: aes_by_visuals = {} else: aes_by_visuals = aes_by_visuals.copy() ## reference line reference_ls_kwargs = copy(visuals.get("reference_line", {})) if reference_ls_kwargs is not False: _, _, reference_ls_ignore = filter_aes( plot_collection, aes_by_visuals, "reference_line", sample_dims ) reference_ls_kwargs.setdefault("color", "grey") reference_ls_kwargs.setdefault("linestyle", lines[1]) plot_collection.map( dline, "reference_line", data=ds_calibration, x=ds_calibration.sel(plot_axis="x"), ignore_aes=reference_ls_ignore, **reference_ls_kwargs, ) ## markers calibration_ms_kwargs = copy(visuals.get("markers", {})) if calibration_ms_kwargs is not False: _, _, calibration_ms_ignore = filter_aes( plot_collection, aes_by_visuals, "markers", sample_dims ) calibration_ms_kwargs.setdefault("color", colors[0]) calibration_ms_kwargs.setdefault("marker", markers[6]) plot_collection.map( scatter_xy, "markers", data=ds_calibration, ignore_aes=calibration_ms_ignore, **calibration_ms_kwargs, ) ## lines calibration_ls_kwargs = copy(visuals.get("lines", {})) if calibration_ls_kwargs is not False: _, _, calibration_ls_ignore = filter_aes( plot_collection, aes_by_visuals, "lines", sample_dims ) calibration_ls_kwargs.setdefault("color", colors[0]) plot_collection.map( line_xy, "lines", data=ds_calibration, ignore_aes=calibration_ls_ignore, **calibration_ls_kwargs, ) ci_kwargs = copy(visuals.get("credible_interval", {})) _, _, ci_ignore = filter_aes(plot_collection, aes_by_visuals, "credible_interval", sample_dims) if ci_kwargs is not False: ci_kwargs.setdefault("color", colors[0]) ci_kwargs.setdefault("alpha", 0.25) plot_collection.map( fill_between_y, "credible_interval", data=ds_calibration, x=ds_calibration.sel(plot_axis="x"), y_bottom=ds_calibration.sel(plot_axis="y_bottom"), y_top=ds_calibration.sel(plot_axis="y_top"), ignore_aes=ci_ignore, **ci_kwargs, ) # set xlabel _, xlabels_aes, xlabels_ignore = filter_aes( plot_collection, aes_by_visuals, "xlabel", sample_dims ) xlabel_kwargs = copy(visuals.get("xlabel", {})) if xlabel_kwargs is not False: if "color" not in xlabels_aes: xlabel_kwargs.setdefault("color", contrast_color) xlabel_kwargs.setdefault("text", "predicted value") plot_collection.map( labelled_x, "xlabel", ignore_aes=xlabels_ignore, subset_info=True, **xlabel_kwargs, ) # set ylabel _, ylabels_aes, ylabels_ignore = filter_aes( plot_collection, aes_by_visuals, "ylabel", sample_dims ) ylabel_kwargs = copy(visuals.get("ylabel", {})) if ylabel_kwargs is not False: if "color" not in ylabels_aes: ylabel_kwargs.setdefault("color", contrast_color) ylabel_kwargs.setdefault("text", "CEP") plot_collection.map( labelled_y, "ylabel", ignore_aes=ylabels_ignore, subset_info=True, **ylabel_kwargs, ) # title title_kwargs = copy(visuals.get("title", {})) _, _, title_ignore = filter_aes(plot_collection, aes_by_visuals, "title", sample_dims) if title_kwargs is not False: plot_collection.map( labelled_title, "title", ignore_aes=title_ignore, subset_info=True, labeller=labeller, **title_kwargs, ) return plot_collection