Source code for arviz_plots.plots.ppc_rootogram_plot

"""Plot ppc rootogram for discrete (count) data."""
from collections.abc import Mapping, Sequence
from copy import copy
from importlib import import_module
from typing import Any, Literal

from arviz_base import rcParams
from arviz_base.labels import BaseLabeller
from arviz_stats.helper_stats import point_interval_unique, point_unique

from arviz_plots.plot_collection import PlotCollection
from arviz_plots.plots.utils import (
    filter_aes,
    get_contrast_colors,
    process_group_variables_coords,
    set_wrap_layout,
)
from arviz_plots.visuals import (
    ci_line_y,
    grid,
    labelled_title,
    labelled_x,
    labelled_y,
    scatter_xy,
    set_y_scale,
)


[docs] def plot_ppc_rootogram( dt, ci_prob=None, yscale="sqrt", data_pairs=None, var_names=None, filter_vars=None, group="posterior_predictive", coords=None, sample_dims=None, plot_collection=None, backend=None, labeller=None, aes_by_visuals: Mapping[ Literal[ "predictive_markers", "observed_markers", "credible_interval", "xlabel", "ylabel", "grid", "title", ], Sequence[str], ] = None, visuals: Mapping[ Literal[ "predictive_markers", "observed_markers", "credible_interval", "xlabel", "ylabel", "grid", "title", ], Mapping[str, Any] | Literal[False], ] = None, **pc_kwargs, ): """Rootogram with confidence intervals per predicted count. Rootograms are useful to check the calibration of count models. A rootogram shows the difference between observed and predicted counts. The y-axis, showing frequencies, is on the square root scale. This makes easier to compare observed and expected frequencies even for low frequencies [1]_ and [2]_. For more details on how to interpret this plot, see https://arviz-devs.github.io/EABM/Chapters/Prior_posterior_predictive_checks.html Parameters ---------- dt : DataTree If group is "posterior_predictive", it should contain the ``posterior_predictive`` and ``observed_data`` groups. If group is "prior_predictive", it should contain the ``prior_predictive`` group. ci_prob : float, optional Probability for the credible interval. Defaults to ``rcParams["stats.ci_prob"]``. yscale : str, optional Scale for the y-axis. Defaults to "sqrt", pass "linear" for linear scale. Currently only "matplotlib" backend is supported. For "bokeh" and "plotly" the y-axis is linear. data_pairs : dict, optional Dictionary of keys prior/posterior predictive data and values observed data variable names. If None, it will assume that the observed data and the predictive data have the same variable name. var_names : str or list of str, optional One or more variables to be plotted. Currently only one variable is supported. Prefix the variables by ~ when you want to exclude them from the plot. filter_vars : {None, “like”, “regex”}, optional, default=None If None (default), interpret var_names as the real variables names. If “like”, interpret var_names as substrings of the real variables names. If “regex”, interpret var_names as regular expressions on the real variables names. group : str, Group to be plotted. Defaults to "posterior_predictive". It could also be "prior_predictive". coords : dict, optional Coordinates to plot. sample_dims : str or sequence of hashable, optional Dimensions to reduce unless mapped to an aesthetic. Defaults to ``rcParams["data.sample_dims"]`` plot_collection : PlotCollection, optional backend : {"matplotlib", "bokeh", "plotly"}, optional labeller : labeller, optional aes_by_visuals : mapping of {str : sequence of str}, optional Mapping of visuals to aesthetics that should use their mapping in `plot_collection` when plotted. Valid keys are the same as for `visuals`. visuals : mapping of {str : mapping or False}, optional Valid keys are: * predictive_markers -> passed to :func:`~arviz_plots.visuals.scatter_xy` * observed_markers -> passed to :func:`~arviz_plots.visuals.scatter_xy`. * credible_interval -> passed to :func:`~arviz_plots.visuals.ci_line_y` * xlabel -> passed to :func:`~arviz_plots.visuals.labelled_x` * ylabel -> passed to :func:`~arviz_plots.visuals.labelled_y` * grid -> passed to :func:`~arviz_plots.visuals.grid` * title -> passed to :func:`~arviz_plots.visuals.labelled_title` observed_markers defaults to False, no observed data is plotted, if group is "prior_predictive". Pass an (empty) mapping to plot the observed data. **pc_kwargs Passed to :class:`arviz_plots.PlotCollection.wrap` Returns ------- PlotCollection Examples -------- Plot the rootogram for the crabs dataset. .. plot:: :context: close-figs >>> from arviz_plots import plot_ppc_rootogram, style >>> style.use("arviz-variat") >>> from arviz_base import load_arviz_data >>> dt = load_arviz_data('crabs_poisson') >>> plot_ppc_rootogram(dt) .. minigallery:: plot_ppc_rootogram References ---------- .. [1] Kleiber C, Zeileis A. *Visualizing Count Data Regressions Using Rootograms*. The American Statistician, 70(3). (2016) https://doi.org/10.1080/00031305.2016.1173590 .. [2] Säilynoja et al. *Recommendations for visual predictive checks in Bayesian workflow*. (2025) arXiv preprint https://arxiv.org/abs/2503.01509 """ if ci_prob is None: ci_prob = rcParams["stats.ci_prob"] if sample_dims is None: sample_dims = rcParams["data.sample_dims"] if isinstance(sample_dims, str): sample_dims = [sample_dims] sample_dims = list(sample_dims) if visuals is None: visuals = {} else: visuals = visuals.copy() if backend is None: if plot_collection is None: backend = rcParams["plot.backend"] else: backend = plot_collection.backend labeller = BaseLabeller() if data_pairs is None: data_pairs = (var_names, var_names) else: data_pairs = (list(data_pairs.keys()), list(data_pairs.values())) predictive_dist = process_group_variables_coords( dt, group=group, var_names=data_pairs[0], filter_vars=filter_vars, coords=coords ) predictive_types = [ predictive_dist[var].values.dtype.kind == "f" for var in predictive_dist.data_vars ] if "observed_data" in dt: observed_dist = process_group_variables_coords( dt, group="observed_data", var_names=data_pairs[1], filter_vars=filter_vars, coords=coords, ) observed_types = [ observed_dist[var].values.dtype.kind == "f" for var in observed_dist.data_vars ] observed_ds = point_unique(dt, observed_dist.data_vars) else: observed_types = [] if any(predictive_types + observed_types): raise ValueError( "Detected at least one continuous variable.\n" "Use plot_ppc variants specific for continuous data, " "such as plot_ppc_dist.", ) ds_predictive = point_interval_unique(dt, predictive_dist.data_vars, group, ci_prob) plot_bknd = import_module(f".backend.{backend}", package="arviz_plots") bg_color = plot_bknd.get_background_color() contrast_color = get_contrast_colors(bg_color=bg_color) colors = plot_bknd.get_default_aes("color", 1, {}) markers = plot_bknd.get_default_aes("marker", 7, {}) if plot_collection is None: pc_kwargs["figure_kwargs"] = pc_kwargs.get("figure_kwargs", {}).copy() pc_kwargs["aes"] = pc_kwargs.get("aes", {}).copy() pc_kwargs.setdefault("cols", "__variable__") pc_kwargs = set_wrap_layout(pc_kwargs, plot_bknd, ds_predictive) plot_collection = PlotCollection.wrap( ds_predictive, backend=backend, **pc_kwargs, ) if aes_by_visuals is None: aes_by_visuals = {} else: aes_by_visuals = aes_by_visuals.copy() aes_by_visuals.setdefault("predictive_markers", plot_collection.aes_set) aes_by_visuals.setdefault("credible_interval", plot_collection.aes_set) ## predictive_markers predictive_ms_kwargs = copy(visuals.get("predictive_markers", {})) if predictive_ms_kwargs is not False: _, predictive_ms_aes, predictive_ms_ignore = filter_aes( plot_collection, aes_by_visuals, "predictive_markers", sample_dims ) if "color" not in predictive_ms_aes: predictive_ms_kwargs.setdefault("color", colors[0]) predictive_ms_kwargs.setdefault("marker", markers[4]) plot_collection.map( scatter_xy, "predictive_markers", data=ds_predictive, ignore_aes=predictive_ms_ignore, **predictive_ms_kwargs, ) ## confidence intervals ci_kwargs = copy(visuals.get("credible_interval", {})) _, ci_aes, ci_ignore = filter_aes( plot_collection, aes_by_visuals, "credible_interval", sample_dims ) if ci_kwargs is not False: if "color" not in ci_aes: ci_kwargs.setdefault("color", colors[0]) ci_kwargs.setdefault("alpha", 0.3) ci_kwargs.setdefault("width", 3) plot_collection.map( ci_line_y, "credible_interval", data=ds_predictive, ignore_aes=ci_ignore, **ci_kwargs, ) ## observed_markers observed_ms_kwargs = copy( visuals.get("observed_markers", False if group == "prior_predictive" else {}) ) if observed_ms_kwargs is not False: _, _, observed_ms_ignore = filter_aes( plot_collection, aes_by_visuals, "observed_markers", sample_dims ) observed_ms_kwargs.setdefault("color", contrast_color) observed_ms_kwargs.setdefault("marker", markers[6]) plot_collection.map( scatter_xy, "observed_markers", data=observed_ds, ignore_aes=observed_ms_ignore, **observed_ms_kwargs, ) ## grid grid_kwargs = copy(visuals.get("grid", {})) if grid_kwargs is not False: _, _, grid_ignore = filter_aes(plot_collection, aes_by_visuals, "grid", sample_dims) grid_kwargs.setdefault("color", "#cccccc") grid_kwargs.setdefault("axis", "y") plot_collection.map( grid, "grid", ignore_aes=grid_ignore, **grid_kwargs, ) # set xlabel _, xlabels_aes, xlabels_ignore = filter_aes( plot_collection, aes_by_visuals, "xlabel", sample_dims ) xlabel_kwargs = copy(visuals.get("xlabel", {})) if xlabel_kwargs is not False: if "color" not in xlabels_aes: xlabel_kwargs.setdefault("color", contrast_color) xlabel_kwargs.setdefault("text", "counts") plot_collection.map( labelled_x, "xlabel", ignore_aes=xlabels_ignore, subset_info=True, **xlabel_kwargs, ) # set ylabel _, ylabels_aes, ylabels_ignore = filter_aes( plot_collection, aes_by_visuals, "ylabel", sample_dims ) ylabel_kwargs = copy(visuals.get("ylabel", {})) if ylabel_kwargs is not False: if "color" not in ylabels_aes: ylabel_kwargs.setdefault("color", contrast_color) ylabel_kwargs.setdefault("text", "frequency") plot_collection.map( labelled_y, "ylabel", ignore_aes=ylabels_ignore, subset_info=True, **ylabel_kwargs, ) # title title_kwargs = copy(visuals.get("title", {})) _, _, title_ignore = filter_aes(plot_collection, aes_by_visuals, "title", sample_dims) if title_kwargs is not False: plot_collection.map( labelled_title, "title", ignore_aes=title_ignore, subset_info=True, labeller=labeller, **title_kwargs, ) plot_collection.map( set_y_scale, store_artist=backend == "none", ignore_aes=plot_collection.aes_set, scale=yscale, ) return plot_collection