Source code for arviz_plots.plots.ppc_tstat

"""ppc t-stat plot code."""
from collections.abc import Mapping, Sequence
from copy import copy
from importlib import import_module
from typing import Any, Literal

import numpy as np
import xarray as xr
from arviz_base import rcParams
from arviz_base.labels import BaseLabeller

from arviz_plots.plot_collection import PlotCollection
from arviz_plots.plots.dist_plot import plot_dist
from arviz_plots.plots.utils import (
    get_contrast_colors,
    process_group_variables_coords,
    set_wrap_layout,
)
from arviz_plots.visuals import scatter_x


[docs] def plot_ppc_tstat( dt, var_names=None, group="posterior_predictive", filter_vars=None, sample_dims=None, t_stat="median", kind=None, point_estimate=None, ci_kind=None, ci_prob=None, plot_collection=None, coords=None, backend=None, data_pairs=None, labeller=None, aes_by_visuals: Mapping[ Literal[ "dist", "observed_tstat", "credible_interval", "point_estimate", "point_estimate_text", "title", "rug", ], Sequence[str], ] = None, visuals: Mapping[ Literal[ "dist", "observed_tstat", "credible_interval", "point_estimate", "point_estimate_text", "title", "rug", "remove_axis", ], Mapping[str, Any] | Literal[False], ] = None, stats: Mapping[ Literal["dist", "credible_interval", "point_estimate"], Mapping[str, Any] | xr.Dataset ] = None, **pc_kwargs, ): """ Plot Bayesian t-stat for observed data and posterior/prior predictive. Parameters ---------- dt : DataTree If group is "posterior_predictive", it should contain the ``posterior_predictive`` and ``observed_data`` groups. If group is "prior_predictive", it should contain the ``prior_predictive`` group. var_names : str or list of str, optional One or more variables to be plotted. Prefix the variables by ~ when you want to exclude them from the plot. group : str, Group to be plotted. Defaults to "posterior_predictive". It could also be "prior_predictive". filter_vars : {None, “like”, “regex”}, default=None If None, interpret var_names as the real variables names. If “like”, interpret var_names as substrings of the real variables names. If “regex”, interpret var_names as regular expressions on the real variables names. sample_dims : str or sequence of hashable, optional Dimensions to reduce unless mapped to an aesthetic. Defaults to ``rcParams["data.sample_dims"]`` t_stat : str, float, or callable() default "median" Test statistics to compute from the observations and predictive distributions. Allowed strings are “mean”, “median”, “std”, “var”, “min”, “max”, “iqr” (interquartile range) and “mad” (median absolute deviation). Alternative a quantile can be passed as a float (or str) in the interval (0, 1). Finally, a user defined function is also accepted. kind : {"kde", "hist", "dot", "ecdf"}, optional How to represent the marginal density. Defaults to ``rcParams["plot.density_kind"]`` point_estimate : {"mean", "median", "mode"}, optional Which point estimate to plot. Defaults to rcParam :data:`stats.point_estimate` ci_kind : {"eti", "hdi"}, optional Which credible interval to use. Defaults to ``rcParams["stats.ci_kind"]`` ci_prob : float, optional Indicates the probability that should be contained within the plotted credible interval. Defaults to ``rcParams["stats.ci_prob"]`` plot_collection : PlotCollection, optional coords : dict, optional backend : {"matplotlib", "bokeh", "plotly"}, optional labeller : labeller, optional data_pairs : dict, optional Dictionary of keys prior/posterior predictive data and values observed data variable names. If None, it will assume that the observed data and the predictive data have the same variable name. aes_by_visuals : mapping of {str : sequence of str}, optional Mapping of visuals to aesthetics that should use their mapping in `plot_collection` when plotted. Valid keys are the same as for `visuals` exept for "remove_axis" visuals : mapping of {str : mapping or False}, optional Valid keys are: * dist -> depending on the value of `kind` passed to: * "kde" -> passed to :func:`~arviz_plots.visuals.line_xy` * "ecdf" -> passed to :func:`~arviz_plots.visuals.ecdf_line` * "hist" -> passed to :func: `~arviz_plots.visuals.hist` * observed_tstat -> passed to :func:`~arviz_plots.visuals.scatter_x`. * credible_interval -> passed to :func:`~arviz_plots.visuals.line_x`. Defaults to False. * point_estimate -> passed to :func:`~arviz_plots.visuals.scatter_x`. Defaults to False. * point_estimate_text -> passed to :func:`~arviz_plots.visuals.point_estimate_text`. Defaults to False. * title -> passed to :func:`~arviz_plots.visuals.labelled_title` * rug -> passed to :func:`~arviz_plots.visuals.scatter_x`. Defaults to False. * remove_axis -> not passed anywhere, can only be ``False`` to skip calling this function observed_tstat defaults to False, no observed data is plotted, if group is "prior_predictive". Pass an (empty) mapping to plot the observed tstats. stats : mapping, optional Valid keys are: * dist -> passed to kde, ecdf, ... **pc_kwargs Passed to :class:`arviz_plots.PlotCollection.wrap` Returns ------- PlotCollection Examples -------- Use 25th percentile (quantile 0.25) as t-statistic .. plot:: :context: close-figs >>> from arviz_plots import plot_ppc_tstat, style >>> style.use("arviz-variat") >>> from arviz_base import load_arviz_data >>> dt = load_arviz_data('radon') >>> plot_ppc_tstat(dt, t_stat="0.25") Define custom t-statistic function and plot histogram .. plot:: :context: close-figs >>> def cv(x): >>> return np.std(x, axis=0) / np.mean(x, axis=0) >>> plot_ppc_tstat(dt, t_stat=cv, kind="hist") Use median as t-statistic and plot point-interval .. plot:: :context: close-figs >>> azp.plot_ppc_tstat( >>> dt, >>> visuals={ >>> "dist": False, >>> "credible_interval": {}, >>> "point_estimate": {}, >>> } >>> ) .. minigallery:: plot_ppc_tstat """ if group not in ("posterior_predictive", "prior_predictive"): raise TypeError( "`group` argument must be either `posterior_predictive` or `prior_predictive`" ) if sample_dims is None: sample_dims = rcParams["data.sample_dims"] if isinstance(sample_dims, str): sample_dims = [sample_dims] sample_dims = list(sample_dims) if stats is None: stats = {} else: stats = stats.copy() if visuals is None: visuals = {} else: visuals = visuals.copy() if labeller is None: labeller = BaseLabeller() if backend is None: if plot_collection is None: backend = rcParams["plot.backend"] else: backend = plot_collection.backend plot_bknd = import_module(f".backend.{backend}", package="arviz_plots") bg_color = plot_bknd.get_background_color() contrast_color = get_contrast_colors(bg_color=bg_color) if aes_by_visuals is None: aes_by_visuals = {} else: aes_by_visuals = aes_by_visuals.copy() if data_pairs is None: data_pairs = (var_names, var_names) else: data_pairs = (list(data_pairs.keys()), list(data_pairs.values())) predictive_dist = process_group_variables_coords( dt, group=group, var_names=data_pairs[0], filter_vars=filter_vars, coords=coords ) if "observed_data" in dt: observed_dist = process_group_variables_coords( dt, group="observed_data", var_names=data_pairs[1], filter_vars=filter_vars, coords=coords, ) # we use observed_tstat_kwargs as a flag to indicate if # we should compute and plot the observed t-statistics observed_tstat_kwargs = copy( visuals.get("observed_tstat", False if group == "prior_predictive" else {}) ) predictive_dist = predictive_dist.stack(sample=sample_dims) reduce_dim = [dim for dim in predictive_dist.dims if dim != "sample"] if t_stat in ["mean", "median", "std", "var", "min", "max"]: predictive_dist = getattr(predictive_dist, t_stat)(dim=reduce_dim) if observed_tstat_kwargs is not False: observed_dist = getattr(observed_dist, t_stat)() visuals.setdefault("title", {"text": t_stat}) elif t_stat == "iqr": def iqr(data, dim): q25 = data.quantile(q=0.25, dim=dim) q75 = data.quantile(q=0.75, dim=dim) return q75 - q25 predictive_dist = iqr(predictive_dist, dim=reduce_dim) if observed_tstat_kwargs is not False: observed_dist = iqr(observed_dist, dim=None) visuals.setdefault("title", {"text": "IQR"}) elif t_stat == "mad": def mad(data, dim): median = data.median(dim=dim) return np.abs((data - median)).median(dim=dim) predictive_dist = mad(predictive_dist, dim=reduce_dim) if observed_tstat_kwargs is not False: observed_dist = mad(observed_dist, dim=None) visuals.setdefault("title", {"text": "MAD"}) elif hasattr(t_stat, "__call__"): predictive_dist = predictive_dist.map(t_stat) if observed_tstat_kwargs is not False: observed_dist = observed_dist.map(t_stat) visuals.setdefault("title", {"text": t_stat.__name__}) else: try: t_stat_float = float(t_stat) except ValueError as ve: raise ValueError(f"T statistics '{t_stat}' not implemented") from ve if 0 < t_stat_float < 1: predictive_dist = predictive_dist.quantile(q=t_stat_float, dim=reduce_dim).rename( {"quantile": "t_stat"} ) if observed_tstat_kwargs is not False: observed_dist = observed_dist.quantile(q=t_stat_float).rename( {"quantile": "t_stat"} ) visuals.setdefault("title", {"text": f"q={t_stat}"}) else: raise ValueError(f"T statistic '{t_stat}' not in valid range (0, 1).") if plot_collection is None: pc_kwargs["figure_kwargs"] = pc_kwargs.get("figure_kwargs", {}).copy() pc_kwargs["aes"] = pc_kwargs.get("aes", {}).copy() pc_kwargs.setdefault("cols", "__variable__") pc_kwargs = set_wrap_layout(pc_kwargs, plot_bknd, predictive_dist) plot_collection = PlotCollection.wrap( predictive_dist, backend=backend, **pc_kwargs, ) # Plot predictive data visuals.setdefault("credible_interval", False) visuals.setdefault("point_estimate", False) visuals.setdefault("point_estimate_text", False) plot_dist( predictive_dist, var_names=None, group=None, coords=None, sample_dims=["sample"], kind=kind, point_estimate=point_estimate, ci_kind=ci_kind, ci_prob=ci_prob, plot_collection=plot_collection, aes_by_visuals=aes_by_visuals, backend=backend, labeller=labeller, visuals=visuals, stats=stats, **pc_kwargs, ) # Plot the observed data if observed_tstat_kwargs is not False: observed_tstat_kwargs.setdefault("color", contrast_color) plot_collection.map( scatter_x, "observed_tstat", data=observed_dist.mean(), **observed_tstat_kwargs ) return plot_collection