Source code for arviz_plots.plots.energy_plot

"""Energy plot code."""
from collections.abc import Mapping, Sequence
from copy import copy
from typing import Any, Literal

import numpy as np
import xarray as xr
from arviz_base import convert_to_dataset, rcParams

from arviz_plots.plots.dist_plot import plot_dist


[docs] def plot_energy( dt, bfmi=False, kind=None, plot_collection=None, backend=None, labeller=None, aes_by_visuals: Mapping[ Literal[ "dist", "title", ], Sequence[str], ] = None, visuals: Mapping[ Literal[ "dist", "title", "legend", "remove_axis", ], Mapping[str, Any] | Literal[False], ] = None, stats: Mapping[Literal["dist"], Mapping[str, Any] | xr.Dataset] = None, **pc_kwargs, ): r"""Plot transition distribution and marginal energy distribution in HMC algorithms. This may help to diagnose poor exploration by gradient-based algorithms like HMC or NUTS. The energy function in HMC can identify posteriors with heavy tailed distributions, that in practice are challenging for sampling. This plot is in the style of the one used in [1]_. Parameters ---------- dt : DataTree ``sample_stats`` group with an ``energy`` variable is mandatory. bfmi : bool Whether to the plot the value of the estimated Bayesian fraction of missing information. Defaults to False. Not implemented yet. kind : {"kde", "hist", "dot", "ecdf"}, optional How to represent the marginal density. Defaults to ``rcParams["plot.density_kind"]`` plot_collection : PlotCollection, optional backend : {"matplotlib", "bokeh", "plotly"}, optional labeller : labeller, optional aes_by_visuals : mapping of {str : sequence of str}, optional Mapping of visuals to aesthetics that should use their mapping in `plot_collection` when plotted. Valid keys are the same as for `visuals`. visuals : mapping of {str : mapping or False}, optional Valid keys are: * dist -> depending on the value of `kind` passed to: * "kde" -> passed to :func:`~arviz_plots.visuals.line_xy` * "ecdf" -> passed to :func:`~arviz_plots.visuals.ecdf_line` * "hist" -> passed to :func: `~arviz_plots.visuals.hist` * title -> passed to :func:`~arviz_plots.visuals.labelled_title` * legend -> passed to :class:`arviz_plots.PlotCollection.add_legend` * remove_axis -> not passed anywhere, can only be ``False`` to skip calling this function stats : mapping, optional Valid keys are: * dist -> passed to kde, ecdf, ... **pc_kwargs Passed to :class:`arviz_plots.PlotCollection.wrap` Returns ------- PlotCollection Examples -------- Plot a default energy plot .. plot:: :context: close-figs >>> from arviz_plots import plot_energy, style >>> style.use("arviz-variat") >>> from arviz_base import load_arviz_data >>> schools = load_arviz_data('centered_eight') >>> plot_energy(schools) .. minigallery:: plot_energy References ---------- .. [1] Betancourt. Diagnosing Suboptimal Cotangent Disintegrations in Hamiltonian Monte Carlo. (2016) https://arxiv.org/abs/1604.00695 """ if kind is None: kind = rcParams["plot.density_kind"] if visuals is None: visuals = {} else: visuals = visuals.copy() new_ds = _get_energy_ds(dt) sample_dims = ["chain", "draw"] if not all(dim in new_ds.dims for dim in sample_dims): raise ValueError("Both 'chain' and 'draw' dimensions must be present in the dataset") pc_kwargs.setdefault("cols", None) pc_kwargs["aes"] = pc_kwargs.get("aes", {}).copy() pc_kwargs["aes"].setdefault("color", ["energy"]) visuals.setdefault("credible_interval", False) visuals.setdefault("point_estimate", False) visuals.setdefault("point_estimate_text", False) visuals.setdefault("title", False) plot_collection = plot_dist( new_ds, var_names=None, filter_vars=None, group=None, coords=None, sample_dims=sample_dims, kind=kind, point_estimate=None, ci_kind=None, ci_prob=None, plot_collection=plot_collection, backend=backend, labeller=labeller, aes_by_visuals=aes_by_visuals, visuals=visuals, stats=stats, **pc_kwargs, ) # legend legend_kwargs = copy(visuals.get("legend", {})) if legend_kwargs is not False: legend_kwargs.setdefault("dim", ["energy"]) plot_collection.add_legend(**legend_kwargs) if bfmi: raise NotImplementedError("BFMI is not implemented yet") return plot_collection
def _get_energy_ds(dt): energy = dt["sample_stats"].energy.values return convert_to_dataset( {"energy_": np.dstack([energy - energy.mean(), np.diff(energy, append=np.nan)])}, coords={"energy__dim_0": ["marginal", "transition"]}, ).rename({"energy__dim_0": "energy"})