Source code for arviz_plots.plots.prior_posterior_plot

"""Contain functions for Bayes Factor plotting."""

from collections.abc import Mapping, Sequence
from copy import copy
from importlib import import_module
from typing import Any, Literal

import numpy as np
import xarray as xr
from arviz_base import extract, rcParams
from xarray import concat

from arviz_plots.plot_collection import PlotCollection
from arviz_plots.plots.dist_plot import plot_dist
from arviz_plots.plots.utils import process_group_variables_coords, set_wrap_layout


[docs] def plot_prior_posterior( dt, var_names=None, filter_vars=None, group=None, # pylint: disable=unused-argument coords=None, sample_dims=None, kind=None, plot_collection=None, backend=None, labeller=None, aes_by_visuals: Mapping[ Literal[ "dist", "credible_interval", "point_estimate", "point_estimate_text", "title", "rug", ], Sequence[str], ] = None, visuals: Mapping[ Literal[ "dist", "credible_interval", "point_estimate", "point_estimate_text", "title", "rug", "remove_axis", ], Mapping[str, Any] | Literal[False], ] = None, stats: Mapping[ Literal["dist", "credible_interval", "point_estimate"], Mapping[str, Any] | xr.Dataset ] = None, **pc_kwargs, ): r"""Plot 1D marginal densities for prior and posterior. The Bayes factor is estimated by comparing a model (H1) against a model in which the parameter of interest has been restricted to be a point-null (H0) This computation assumes the models are nested and thus H0 is a special case of H1. Parameters ---------- dt : DataTree or dict of {str : DataTree} Input data. In case of dictionary input, the keys are taken to be model names. In such cases, a dimension "model" is generated and can be used to map to aesthetics. var_names : str or list of str, optional One or more variables to be plotted. Prefix the variables by ~ when you want to exclude them from the plot. filter_vars : {None, “like”, “regex”}, default=None If None, interpret var_names as the real variables names. If “like”, interpret var_names as substrings of the real variables names. If “regex”, interpret var_names as regular expressions on the real variables names. group : None This argument is ignored. Have it here for compatibility with other plotting functions. coords : dict, optional sample_dims : str or sequence of hashable, optional Dimensions to reduce unless mapped to an aesthetic. Defaults to ``rcParams["data.sample_dims"]`` kind : {"kde", "hist", "dot", "ecdf"}, optional How to represent the marginal density. Defaults to ``rcParams["plot.density_kind"]`` plot_collection : PlotCollection, optional backend : {"matplotlib", "bokeh"}, optional labeller : labeller, optional aes_by_visuals : mapping of {str : sequence of str}, optional Mapping of visuals to aesthetics that should use their mapping in `plot_collection` when plotted. The prior and posterior groups are combined creating a new dimension "group". By default, there is an aesthetic mapping from group to color. Valid keys are the same as for `visuals`. visuals : mapping of {str : mapping or False}, optional Valid keys are: * dist -> depending on the value of `kind` passed to: * "kde" -> passed to :func:`~arviz_plots.visuals.line_xy` * "ecdf" -> passed to :func:`~arviz_plots.visuals.ecdf_line` * "hist" -> passed to :func: `~arviz_plots.visuals.hist` * title -> passed to :func:`~arviz_plots.visuals.labelled_title` * legend -> passed to :class:`arviz_plots.PlotCollection.add_legend` stats : mapping, optional Valid keys are: * dist -> passed to kde, ecdf, ... **pc_kwargs Passed to :class:`arviz_plots.PlotCollection.wrap` Returns ------- PlotCollection Examples -------- Select two variables and plot them with a ecdf. .. plot:: :context: close-figs >>> from arviz_plots import plot_prior_posterior, style >>> style.use("arviz-variat") >>> from arviz_base import load_arviz_data >>> dt = load_arviz_data('centered_eight') >>> plot_prior_posterior(dt, var_names=["mu", "tau"], kind="ecdf") .. minigallery:: plot_prior_posterior """ if sample_dims is None: sample_dims = rcParams["data.sample_dims"] if isinstance(sample_dims, str): sample_dims = [sample_dims] sample_dims = list(sample_dims) if kind is None: kind = rcParams["plot.density_kind"] if stats is None: stats = {} else: stats = stats.copy() if visuals is None: visuals = {} else: visuals = visuals.copy() if sample_dims is None: sample_dims = rcParams["data.sample_dims"] if isinstance(sample_dims, str): sample_dims = [sample_dims] sample_dims = list(sample_dims) if not isinstance(visuals, dict): visuals = {} if backend is None: if plot_collection is None: backend = rcParams["plot.backend"] else: backend = plot_collection.backend plot_bknd = import_module(f".backend.{backend}", package="arviz_plots") prior_size = np.prod([dt.prior.sizes[dim] for dim in sample_dims]) posterior_size = np.prod([dt.posterior.sizes[dim] for dim in sample_dims]) num_samples = min(prior_size, posterior_size) ds_prior = ( extract(dt, group="prior", num_samples=num_samples, random_seed=0, keep_dataset=True) .drop_vars(sample_dims + ["sample"]) .assign_coords(sample=("sample", np.arange(num_samples))) ) ds_posterior = ( extract(dt, group="posterior", num_samples=num_samples, random_seed=0, keep_dataset=True) .drop_vars(sample_dims + ["sample"]) .assign_coords(sample=("sample", np.arange(num_samples))) ) distribution = concat([ds_prior, ds_posterior], dim="group").assign_coords( {"group": ["prior", "posterior"]} ) distribution = process_group_variables_coords( distribution, group=None, var_names=var_names, filter_vars=filter_vars, coords=coords, ) if len(sample_dims) > 1: # sample dims will have been stacked and renamed by `extract` sample_dims = ["sample"] if plot_collection is None: pc_kwargs["figure_kwargs"] = pc_kwargs.get("figure_kwargs", {}).copy() pc_kwargs["aes"] = pc_kwargs.get("aes", {}).copy() pc_kwargs["aes"].setdefault("color", ["group"]) pc_kwargs.setdefault("col_wrap", 4) pc_kwargs.setdefault( "cols", ["__variable__"] + [dim for dim in distribution.dims if dim not in sample_dims + ["group"]], ) pc_kwargs = set_wrap_layout(pc_kwargs, plot_bknd, distribution) plot_collection = PlotCollection.wrap( distribution, backend=backend, **pc_kwargs, ) visuals.setdefault("credible_interval", False) visuals.setdefault("point_estimate", False) visuals.setdefault("point_estimate_text", False) if aes_by_visuals is None: aes_by_visuals = {} else: aes_by_visuals = aes_by_visuals.copy() if kind == "hist": visuals.setdefault("dist", {}) visuals.setdefault("remove_axis", True) plot_collection = plot_dist( distribution, var_names=None, group=None, coords=None, sample_dims=sample_dims, kind=kind, point_estimate=None, ci_kind=None, ci_prob=None, plot_collection=plot_collection, backend=backend, labeller=labeller, aes_by_visuals=aes_by_visuals, visuals=visuals, stats=stats, **pc_kwargs, ) legend_kwargs = copy(visuals.get("legend", {})) if legend_kwargs is not False: legend_kwargs.setdefault("dim", ["group"]) plot_collection.add_legend(**legend_kwargs) return plot_collection