Source code for arviz_plots.plots.ppc_dist_plot

"""Posterior/prior predictive check using densities."""

import warnings
from collections.abc import Mapping, Sequence
from copy import copy
from importlib import import_module
from typing import Any, Literal

import numpy as np
import xarray as xr
from arviz_base import rcParams
from arviz_base.labels import BaseLabeller

from arviz_plots.plot_collection import PlotCollection
from arviz_plots.plots.dist_plot import plot_dist
from arviz_plots.plots.utils import (
    filter_aes,
    get_contrast_colors,
    process_group_variables_coords,
    set_wrap_layout,
)
from arviz_plots.visuals import ecdf_line, hist, line_xy


[docs] def plot_ppc_dist( dt, data_pairs=None, var_names=None, filter_vars=None, group="posterior_predictive", coords=None, sample_dims=None, kind=None, num_samples=50, plot_collection=None, backend=None, labeller=None, aes_by_visuals: Mapping[ Literal["predictive_dist", "observed_dist", "title"], Sequence[str] ] = None, visuals: Mapping[ Literal["predictive_dist", "observed_dist", "title", "remove_axis"], Mapping[str, Any] | Literal[False], ] = None, stats: Mapping[ Literal["predictive_dist", "observed_dist"], Mapping[str, Any] | xr.Dataset ] = None, **pc_kwargs, ): """ Plot 1D marginals for the posterior/prior predictive distribution and the observed data. Parameters ---------- dt : DataTree If group is "posterior_predictive", it should contain the ``posterior_predictive`` and ``observed_data`` groups. If group is "prior_predictive", it should contain the ``prior_predictive`` group. data_pairs : dict, optional Dictionary of keys prior/posterior predictive data and values observed data variable names. If None, it will assume that the observed data and the predictive data have the same variable name. var_names : str or list of str, optional One or more variables to be plotted. Prefix the variables by ~ when you want to exclude them from the plot. filter_vars : {None, “like”, “regex”}, default=None If None, interpret var_names as the real variables names. If “like”, interpret var_names as substrings of the real variables names. If “regex”, interpret var_names as regular expressions on the real variables names. group : str, Group to be plotted. Defaults to "posterior_predictive". It could also be "prior_predictive". coords : dict, optional sample_dims : str or sequence of hashable, optional Dimensions to reduce unless mapped to an aesthetic. Defaults to ``rcParams["data.sample_dims"]`` kind : {"kde", "hist", "ecdf"}, optional How to represent the marginal density. Defaults to ``rcParams["plot.density_kind"]`` num_samples : int, optional Number of samples to plot. Defaults to 100. plot_collection : PlotCollection, optional backend : {"matplotlib", "bokeh"}, optional labeller : labeller, optional aes_by_visuals : mapping of {str : sequence of str}, optional Mapping of visuals to aesthetics that should use their mapping in `plot_collection` when plotted. Valid keys are the same as for `visuals` except "remove_axis". With a single model, no aesthetic mappings are generated by default, each variable+coord combination gets a :term:`plot` but they all look the same, unless there are user provided aesthetic mappings. With multiple models, ``plot_dist`` maps "color" and "y" to the "model" dimension. By default, all aesthetics but "y" are mapped to the distribution representation, and if multiple models are present, "color" and "y" are mapped to the credible interval and the point estimate. When "point_estimate" key is provided but "point_estimate_text" isn't, the values assigned to the first are also used for the second. visuals : mapping of {str : mapping or False}, optional Valid keys are: * predictive_dist, observed_dist -> passed to a function that depends on the `kind` argument. - `kind="kde"` -> passed to :func:`~arviz_plots.visuals.line_xy` - `kind="ecdf"` -> passed to :func:`~arviz_plots.visuals.ecdf_line` - `kind="hist"` -> passed to :func: `~arviz_plots.visuals.hist` * title -> passed to :func:`~arviz_plots.visuals.labelled_title` * remove_axis -> not passed anywhere, can only be ``False`` to skip calling this function observed_dist defaults to False, no observed data is plotted, if group is "prior_predictive". Pass an (empty) mapping to plot the observed data. stats : mapping, optional Valid keys are: * predictive_dist, observed_dist -> passed to kde, ecdf, ... **pc_kwargs Passed to :meth:`~arviz_plots.PlotCollection.wrap` Returns ------- PlotCollection See Also -------- :ref:`plots_intro` : General introduction to batteries-included plotting functions, common use and logic overview Examples -------- Make a plot of the posterior predictive distribution vs the observed data. We used an ECDF representation customized the colors. .. plot:: :context: close-figs >>> from arviz_plots import plot_ppc_dist, style >>> style.use("arviz-variat") >>> from arviz_base import load_arviz_data >>> radon = load_arviz_data('radon') >>> pc = plot_ppc_dist( >>> radon, >>> kind="ecdf", >>> visuals={ >>> "predictive_dist": {"color":"C1"}, >>> "observed_dist": {"color":"C3"} >>> }, >>> ) .. minigallery:: plot_ppc_dist """ if sample_dims is None: sample_dims = rcParams["data.sample_dims"] if isinstance(sample_dims, str): sample_dims = [sample_dims] sample_dims = list(sample_dims) if kind is None: kind = rcParams["plot.density_kind"] if stats is None: stats = {} else: stats = stats.copy() if visuals is None: visuals = {} else: visuals = visuals.copy() if backend is None: if plot_collection is None: backend = rcParams["plot.backend"] else: backend = plot_collection.backend plot_bknd = import_module(f".backend.{backend}", package="arviz_plots") bg_color = plot_bknd.get_background_color() contrast_color = get_contrast_colors(bg_color=bg_color) rng = np.random.default_rng(4214) pp_dims = list(sample_dims) + [dims for dims in dt[group].dims if dims not in sample_dims] if data_pairs is None: data_pairs = (var_names, var_names) else: data_pairs = (list(data_pairs.keys()), list(data_pairs.values())) predictive_dist = process_group_variables_coords( dt, group=group, var_names=data_pairs[0], filter_vars=filter_vars, coords=coords ) predictive_types = [ predictive_dist[var].values.dtype.kind == "i" for var in predictive_dist.data_vars ] if "observed_data" in dt: observed_dist = process_group_variables_coords( dt, group="observed_data", var_names=data_pairs[1], filter_vars=filter_vars, coords=coords, ) observed_types = [ observed_dist[var].values.dtype.kind == "i" for var in observed_dist.data_vars ] else: observed_types = [] if any(predictive_types + observed_types): warnings.warn( "Detected at least one discrete variable.\n" "Consider using plot_ppc variants specific for discrete data, " "such as plot_ppc_pava or plot_ppc_rootogram.", UserWarning, stacklevel=2, ) # Select a random subset of samples n_pp_samples = np.prod( [predictive_dist.sizes[dim] for dim in sample_dims if dim in predictive_dist.dims] ) if num_samples > n_pp_samples: num_samples = n_pp_samples warnings.warn("num_samples is larger than the number of predictive samples.") pp_sample_ix = rng.choice(n_pp_samples, size=num_samples, replace=False) predictive_dist = predictive_dist.stack(sample=sample_dims).isel(sample=pp_sample_ix) if plot_collection is None: pc_kwargs["figure_kwargs"] = pc_kwargs.get("figure_kwargs", {}).copy() pc_kwargs["aes"] = pc_kwargs.get("aes", {}).copy() pc_kwargs["aes"].setdefault("overlay_ppc", ["sample"]) pc_kwargs.setdefault("cols", "__variable__") pc_kwargs = set_wrap_layout(pc_kwargs, plot_bknd, predictive_dist) plot_collection = PlotCollection.wrap( predictive_dist, backend=backend, **pc_kwargs, ) if aes_by_visuals is None: aes_by_visuals = {} else: aes_by_visuals = aes_by_visuals.copy() if labeller is None: labeller = BaseLabeller() # We don't want credible_interval or point_estimate to be mapped to the density representation visuals.setdefault("credible_interval", False) visuals.setdefault("point_estimate", False) visuals.setdefault("point_estimate_text", False) visuals.setdefault("rug_plot", False) # Plot the predictive density pred_density_kwargs = copy(visuals.get("predictive_dist", {})) if pred_density_kwargs is not False: visuals.setdefault("dist", pred_density_kwargs) visuals["dist"].setdefault("alpha", 0.3) if kind == "hist": if visuals["dist"] is not False: visuals["dist"].setdefault("color", None) plot_collection = plot_dist( predictive_dist, group=group, sample_dims=pp_dims, kind=kind, visuals=visuals, aes_by_visuals=aes_by_visuals, pc_kwargs=pc_kwargs, plot_collection=plot_collection, stats={"dist": stats.get("predictive_dist", {})}, ) plot_collection.rename_visuals(dist="predictive_dist") # Plot the observed density observed_density_kwargs = copy( visuals.get("observed_dist", False if group == "prior_predictive" else {}) ) if observed_density_kwargs is not False: observed_stats_kwargs = stats.get("observed_dist", {}).copy() observed_density_kwargs.setdefault("color", contrast_color) if kind == "hist": observed_density_kwargs.setdefault("alpha", 0.3) observed_density_kwargs.setdefault("edgecolor", None) _, _, observed_ignore = filter_aes( plot_collection, aes_by_visuals, "observed_dist", sample_dims ) if kind == "kde": dt_observed = observed_dist.azstats.kde(dim=pp_dims, **observed_stats_kwargs) plot_collection.map( line_xy, "observed_dist", data=dt_observed, ignore_aes=observed_ignore, **observed_density_kwargs, ) if kind == "hist": observed_stats_kwargs.setdefault("density", True) dt_observed = observed_dist.azstats.histogram(dim=pp_dims, **observed_stats_kwargs) plot_collection.map( hist, "observed_dist", data=dt_observed, ignore_aes=observed_ignore, **observed_density_kwargs, ) if kind == "ecdf": dt_observed = observed_dist.azstats.ecdf(**observed_stats_kwargs) plot_collection.map( ecdf_line, "observed_dist", data=dt_observed, ignore_aes=observed_ignore, **observed_density_kwargs, ) return plot_collection