Source code for arviz_plots.plots.bf_plot

"""Contain functions for Bayes Factor plotting."""

from collections.abc import Mapping, Sequence
from copy import copy
from importlib import import_module
from typing import Any, Literal

import xarray as xr
from arviz_base import rcParams
from arviz_stats.bayes_factor import bayes_factor

from arviz_plots.plots.prior_posterior_plot import plot_prior_posterior
from arviz_plots.plots.utils import add_lines, filter_aes, get_contrast_colors


[docs] def plot_bf( dt, var_names, ref_val=0, kind=None, sample_dims=None, plot_collection=None, backend=None, labeller=None, aes_by_visuals: Mapping[ Literal[ "dist", "ref_line", "title", ], Sequence[str], ] = None, visuals: Mapping[ Literal[ "dist", "ref_line", "title", "legend", ], Mapping[str, Any] | Literal[False], ] = None, stats: Mapping[Literal["dist"], Mapping[str, Any] | xr.Dataset] = None, **pc_kwargs, ): r"""Bayes Factor for comparing hypothesis of two nested models. The Bayes factor is estimated by comparing a model (H1) against a model in which the parameter of interest has been restricted to be a point-null (H0) This computation assumes H0 is a special case of H1. For more details see here https://arviz-devs.github.io/EABM/Chapters/Model_comparison.html#savagedickey-ratio Parameters ---------- dt : DataTree or dict of {str : DataTree} Input data. In case of dictionary input, the keys are taken to be model names. In such cases, a dimension "model" is generated and can be used to map to aesthetics. var_names : str, optional Variables for which the bayes factor will be computed and the prior and posterior will be plotted. ref_val : int or float, default 0 Reference (point-null) value for Bayes factor estimation. kind : {"kde", "hist", "dot", "ecdf"}, optional How to represent the marginal density. Defaults to ``rcParams["plot.density_kind"]`` sample_dims : str or sequence of hashable, optional Dimensions to reduce unless mapped to an aesthetic. Defaults to ``rcParams["data.sample_dims"]`` plot_collection : PlotCollection, optional backend : {"matplotlib", "bokeh", "plotly"}, optional labeller : labeller, optional aes_by_visuals : mapping of {str : sequence of str}, optional Mapping of visuals to aesthetics that should use their mapping in `plot_collection` when plotted. Valid keys are the same as for `visuals`. visuals : mapping of {str : mapping or False}, optional Valid keys are: * dist -> depending on the value of `kind` passed to: * "kde" -> passed to :func:`~arviz_plots.visuals.line_xy` * "ecdf" -> passed to :func:`~arviz_plots.visuals.ecdf_line` * "hist" -> passed to :func: `~arviz_plots.visuals.hist` * ref_line -> passed to :func: `~arviz_plots.visuals.vline` * title -> passed to :func:`~arviz_plots.visuals.labelled_title` * legend -> passed to :class:`arviz_plots.PlotCollection.add_legend` stats : mapping, optional Valid keys are: * dist -> passed to kde, ecdf, ... **pc_kwargs Passed to :class:`arviz_plots.PlotCollection.wrap` Returns ------- PlotCollection Examples -------- Select one variable. .. plot:: :context: close-figs >>> from arviz_plots import plot_bf, style >>> style.use("arviz-variat") >>> from arviz_base import load_arviz_data >>> dt = load_arviz_data('centered_eight') >>> plot_bf(dt, var_names="mu", kind="hist") .. minigallery:: plot_bf """ if visuals is None: visuals = {} else: visuals = visuals.copy() if aes_by_visuals is None: aes_by_visuals = {} else: aes_by_visuals = aes_by_visuals.copy() if backend is None: if plot_collection is None: backend = rcParams["plot.backend"] else: backend = plot_collection.backend plot_bknd = import_module(f".backend.{backend}", package="arviz_plots") bg_color = plot_bknd.get_background_color() contrast_color = get_contrast_colors(bg_color=bg_color) bf, _ = bayes_factor(dt, var_names, ref_val, return_ref_vals=True) if isinstance(var_names, str): var_names = [var_names] bf_aes_ds = xr.Dataset( { var: xr.DataArray( None, coords={"BF_type": [f"BF01:{bf[var]['BF01']:.2f}"]}, dims=["BF_type"], ) for var in var_names } ) plot_collection = plot_prior_posterior( dt, var_names=var_names, coords=None, sample_dims=sample_dims, kind=kind, plot_collection=plot_collection, backend=backend, labeller=labeller, visuals=visuals, stats=stats, **pc_kwargs, ) plot_collection.update_aes_from_dataset("bf_aes", bf_aes_ds) ref_line_kwargs = copy(visuals.get("ref_line", {})) if ref_line_kwargs is False: raise ValueError( "visuals['ref_line'] can't be False, use ref_val=False to remove this element" ) if ref_val is not False: _, ref_aes, _ = filter_aes(plot_collection, aes_by_visuals, "ref_line", "sample") if "color" not in ref_aes: ref_line_kwargs.setdefault("color", contrast_color) if "alpha" not in ref_aes: ref_line_kwargs.setdefault("alpha", 0.5) add_lines( plot_collection, ref_val, aes_by_visuals=aes_by_visuals, visuals={"ref_line": ref_line_kwargs}, ) # legend if backend == "matplotlib": ## remove this when we have a better way to handle legends legend_kwargs = copy(visuals.get("legend", {})) if legend_kwargs is not False: legend_kwargs.setdefault("dim", ["__variable__", "BF_type"]) legend_kwargs.setdefault("loc", "upper left") legend_kwargs.setdefault("fontsize", 10) legend_kwargs.setdefault("text_only", True) plot_collection.add_legend(**legend_kwargs) return plot_collection