Source code for arviz_plots.plots.evolution_plot

"""evolution ess plot."""

from collections.abc import Mapping, Sequence
from copy import copy
from importlib import import_module
from typing import Any, Literal

import arviz_stats  # pylint: disable=unused-import
import numpy as np
import xarray as xr
from arviz_base import rcParams
from arviz_base.labels import BaseLabeller

from arviz_plots.plot_collection import PlotCollection
from arviz_plots.plots.utils import (
    filter_aes,
    get_contrast_colors,
    process_group_variables_coords,
    set_wrap_layout,
)
from arviz_plots.visuals import (
    annotate_xy,
    labelled_title,
    labelled_x,
    labelled_y,
    line_xy,
    scatter_xy,
)


[docs] def plot_ess_evolution( dt, var_names=None, filter_vars=None, group="posterior", coords=None, sample_dims=None, relative=False, n_points=20, extra_methods=False, min_ess=400, plot_collection=None, backend=None, labeller=None, aes_by_visuals: Mapping[ Literal[ "ess_bulk", "ess_bulk_line", "ess_tail", "ess_tail_line", "title", "xlabel", "ylabel", "mean", "mean_text", "sd", "sd_text", "min_ess", ], Sequence[str], ] = None, visuals: Mapping[ Literal[ "ess_bulk", "ess_bulk_line", "ess_tail", "ess_tail_line", "title", "xlabel", "ylabel", "mean", "mean_text", "sd", "sd_text", "min_ess", ], Mapping[str, Any] | Literal[False], ] = None, stats: Mapping[ Literal["ess_bulk", "ess_tail", "mean", "sd"], Mapping[str, Any] | xr.Dataset ] = None, **pc_kwargs, ): """Plot estimated effective sample size plots for increasing number of iterations. Roughly speaking, the effective sample size of a quantity of interest captures how many independent draws contain the same amount of information as the dependent sample obtained by the MCMC algorithm. The higher the ESS the better. See [1]_ for more details. Parameters ---------- dt : DataTree or dict of {str : DataTree} Input data. In case of dictionary input, the keys are taken to be model names. In such cases, a dimension "model" is generated and can be used to map to aesthetics. var_names : str or sequence of str, optional One or more variables to be plotted. Prefix the variables by ~ when you want to exclude them from the plot. filter_vars : {None, “like”, “regex”}, default None If None, interpret `var_names` as the real variables names. If “like”, interpret `var_names` as substrings of the real variables names. If “regex”, interpret `var_names` as regular expressions on the real variables names. group : str, default "posterior" Group to be plotted. coords : dict, optional sample_dims : str or sequence of hashable, optional Dimensions to reduce unless mapped to an aesthetic. Defaults to ``rcParams["data.sample_dims"]`` relative : bool, default False Show relative ess in plot ``ress = ess / N``. n_points : int, default 20 Number of subsets in the evolution plot. extra_methods : bool, default False Plot mean and sd ESS as horizontal lines. min_ess : int, default 400 Minimum number of ESS desired. If ``relative=True`` the line is plotted at ``min_ess / n_samples`` as a curve following the ``min_ess / n`` dependency plot_collection : PlotCollection, optional backend : {"matplotlib", "bokeh"}, optional labeller : labeller, optional aes_by_visuals : mapping of {str : sequence of str or False}, optional Mapping of visuals to aesthetics that should use their mapping in `plot_collection` when plotted. Valid keys are the same as for `visuals`. visuals : mapping of {str : mapping or False}, optional Valid keys are: * ess_bulk -> passed to :func:`~arviz_plots.visuals.scatter_xy` * ess_bulk_line -> passed to :func:`~arviz_plots.visuals.line_xy` * ess_tail -> passed to :func:`~arviz_plots.visuals.scatter_xy` * ess_tail_line -> passed to :func:`~arviz_plots.visuals.line_xy` * title -> passed to :func:`~arviz_plots.visuals.labelled_title` * xlabel -> passed to :func:`~arviz_plots.visuals.labelled_x` * ylabel -> passed to :func:`~arviz_plots.visuals.labelled_y` * mean -> passed to :func:`~arviz_plots.visuals.line_xy` * sd -> passed to :func:`~arviz_plots.visuals.line_xy` * mean_text -> passed to :func:`~arviz.plots.visuals.annotate_xy` * sd_text -> passed to :func:`~arviz.plots.visuals.annotate_xy` * min_ess -> passed to :func:`~arviz_plots.visuals.line_xy` stats : mapping, optional Valid keys are: * ess_bulk -> passed to ess, method = 'bulk' * ess_tail -> passed to ess, method = 'tail' * mean -> passed to ess, method='mean' * sd -> passed to ess, method='sd' **pc_kwargs Passed to :class:`arviz_plots.PlotCollection.wrap` Returns ------- PlotCollection See Also -------- :ref:`plots_intro` : General introduction to batteries-included plotting functions, common use and logic overview Examples -------- When adding a mapping for color across variables, the same color for a variable gets applied to both the 'bulk' and 'tail' ess. In such a case, if separate linestyles for 'bulk' and 'tail' are desired to distinguish them instead of colors (which is what is used by default), then this can be implemented: .. plot:: :context: close-figs >>> from arviz_plots import plot_ess_evolution, style >>> style.use("arviz-variat") >>> from arviz_base import load_arviz_data >>> non_centered = load_arviz_data('non_centered_eight') >>> pc = plot_ess_evolution( >>> non_centered, >>> var_names=["mu", "tau"], >>> extra_methods=True, >>> visuals={ >>> "ess_bulk_line": {"linestyle": "-."}, >>> "ess_tail_line": {"linestyle": ":"}, >>> "ess_bulk": False, >>> "ess_tail": False, >>> }, >>> aes= {"color": ["__variable__"]}, >>> aes_by_visuals={"title": ["color"]}, >>> ) The points and lines for ess 'bulk' and 'tail' can be individually switched on and off. If only the points are desired, and a situation like the previous example occurs, markers can be used to distinguish between points for 'bulk' and 'tail': .. plot:: :context: close-figs >>> pc = plot_ess_evolution( >>> non_centered, >>> var_names=["mu", "tau"], >>> extra_methods=True, >>> visuals={ >>> "ess_bulk": {"marker": "x"}, >>> "ess_tail": {"marker": "_"}, >>> }, >>> aes={"color": ["__variable__"]}, >>> aes_by_visuals={"title": ["color"]}, >>> ) We can add extra methods to plot the mean and standard deviation as lines, and adjust the minimum ess baseline as well: .. plot:: :context: close-figs >>> pc = plot_ess_evolution( >>> non_centered, >>> coords={"school": ["Choate", "Deerfield", "Hotchkiss"]}, >>> extra_methods=True, >>> min_ess=200, >>> ) Relative ESS can be plotted instead of absolute: .. plot:: :context: close-figs >>> pc = plot_ess_evolution( >>> non_centered, >>> coords={"school": ["Choate", "Deerfield", "Hotchkiss"]}, >>> relative=True, >>> ) We can also adjust the number of points: .. plot:: :context: close-figs >>> pc = plot_ess_evolution( >>> non_centered, >>> coords={"school": ["Choate", "Deerfield", "Hotchkiss"]}, >>> n_points=10, >>> ) .. minigallery:: plot_ess_evolution References ---------- .. [1] Vehtari et al. *Rank-normalization, folding, and localization: An improved Rhat for assessing convergence of MCMC*. Bayesian Analysis. 16(2) (2021) https://doi.org/10.1214/20-BA1221. arXiv preprint https://arxiv.org/abs/1903.08008 """ # initial defaults if sample_dims is None: sample_dims = rcParams["data.sample_dims"] if isinstance(sample_dims, str): sample_dims = [sample_dims] # mutable inputs if visuals is None: visuals = {} if stats is None: stats = {} # processing dt/group/coords/filtering distribution = process_group_variables_coords( dt, group=group, var_names=var_names, filter_vars=filter_vars, coords=coords ) if backend is None: if plot_collection is None: backend = rcParams["plot.backend"] else: backend = plot_collection.backend plot_bknd = import_module(f".backend.{backend}", package="arviz_plots") bg_color = plot_bknd.get_background_color() contrast_color, contrast_gray_color = get_contrast_colors(bg_color=bg_color, gray_flag=True) # set plot collection initialization defaults if it doesnt exist if plot_collection is None: pc_kwargs["figure_kwargs"] = pc_kwargs.get("figure_kwargs", {}).copy() pc_kwargs["aes"] = pc_kwargs.get("aes", {}).copy() pc_kwargs.setdefault( "cols", ["__variable__"] + [dim for dim in distribution.dims if dim not in {"model"}.union(sample_dims)], ) pc_kwargs = set_wrap_layout(pc_kwargs, plot_bknd, distribution) pc_kwargs["figure_kwargs"].setdefault("sharex", True) plot_collection = PlotCollection.wrap( distribution, backend=backend, **pc_kwargs, ) # set plot collection dependent defaults (like aesthetics mappings for each visual) if aes_by_visuals is None: aes_by_visuals = {} else: aes_by_visuals = aes_by_visuals.copy() aes_by_visuals.setdefault("ess_bulk", plot_collection.aes_set) aes_by_visuals.setdefault("ess_bulk_line", plot_collection.aes_set) aes_by_visuals.setdefault("ess_tail", plot_collection.aes_set) aes_by_visuals.setdefault("ess_tail_line", plot_collection.aes_set) if "mean" in aes_by_visuals and "mean_text" not in aes_by_visuals: aes_by_visuals["mean_text"] = aes_by_visuals["mean"] if "sd" in aes_by_visuals and "sd_text" not in aes_by_visuals: aes_by_visuals["sd_text"] = aes_by_visuals["sd"] if labeller is None: labeller = BaseLabeller() # compute and add ess evolution subplots for 'bulk' and 'tail' if len(sample_dims) > 1: n_samples = 1 for dim in sample_dims: if dim in distribution: n_samples = distribution.sizes[dim] * n_samples n_draws = distribution.sizes[sample_dims[1]] # second sample_dim as default draw dim else: n_samples = distribution.sizes[sample_dims[0]] n_draws = n_samples # assuming only sample_dim to be draw dim # setting xdata and draw_divisions for later ess computing and plotting xdata = np.linspace(n_samples / n_points, n_samples, n_points) draw_divisions = np.linspace(n_draws // n_points, n_draws, n_points, dtype=int) default_bulk_color, default_tail_color = plot_bknd.get_default_aes("color", 2, {}) ess_bulk_dataset = None # defining common ess_dataset computing function def compute_ess_dataset( distribution, xdata, draw_divisions, method, # "bulk" or "tail" method_dims, # bulk_dims or tail_dims relative, stats, ): first_sample_dim = sample_dims[-1] # take the last dim of the sample dims ess_y_dataset = xr.concat( [ distribution.isel(({first_sample_dim: slice(None, draw_div)})).azstats.ess( sample_dims=method_dims, method=method, relative=relative, **stats.get(f"ess_{method}", {}), ) for draw_div in draw_divisions ], dim="ess_dim", ) # converting xdata into a xr dataarray xdata_da = xr.DataArray(xdata, dims="ess_dim") # broadcasting xdata_da to match shape of each variable in ess_y_dataset and # creating a new dataset from dict of broadcasted xdata xdata_dataset = xr.Dataset( {var_name: xdata_da.broadcast_like(da) for var_name, da in ess_y_dataset.items()} ) # concatenating xdata_dataset and ess_y_dataset along plot_axis ess_dataset = xr.concat([xdata_dataset, ess_y_dataset], dim="plot_axis").assign_coords( plot_axis=["x", "y"] ) return ess_dataset bulk_kwargs = copy(visuals.get("ess_bulk", {})) if bulk_kwargs is not False: bulk_dims, bulk_aes, bulk_ignore = filter_aes( plot_collection, aes_by_visuals, "ess_bulk", sample_dims ) ess_bulk_dataset = compute_ess_dataset( distribution, xdata, draw_divisions, "bulk", bulk_dims, relative, stats, ) if "color" not in bulk_aes: bulk_kwargs.setdefault("color", default_bulk_color) plot_collection.map( scatter_xy, "ess_bulk", data=ess_bulk_dataset, ignore_aes=bulk_ignore, **bulk_kwargs ) bulk_line_kwargs = copy(visuals.get("ess_bulk_line", {})) if bulk_line_kwargs is not False: bulk_line_dims, bulk_line_aes, bulk_line_ignore = filter_aes( plot_collection, aes_by_visuals, "ess_bulk_line", sample_dims ) if ess_bulk_dataset is None: ess_bulk_dataset = compute_ess_dataset( distribution, xdata, draw_divisions, "bulk", bulk_line_dims, relative, stats, ) if "color" not in bulk_line_aes: bulk_line_kwargs.setdefault("color", default_bulk_color) plot_collection.map( line_xy, "ess_bulk_line", data=ess_bulk_dataset, ignore_aes=bulk_line_ignore, **bulk_line_kwargs, ) ess_tail_dataset = None tail_kwargs = copy(visuals.get("ess_tail", {})) if tail_kwargs is not False: tail_dims, tail_aes, tail_ignore = filter_aes( plot_collection, aes_by_visuals, "ess_tail", sample_dims ) ess_tail_dataset = compute_ess_dataset( distribution, xdata, draw_divisions, "tail", tail_dims, relative, stats, ) if "color" not in tail_aes: tail_kwargs.setdefault("color", default_tail_color) plot_collection.map( scatter_xy, "ess_tail", data=ess_tail_dataset, ignore_aes=tail_ignore, **tail_kwargs ) tail_line_kwargs = copy(visuals.get("ess_tail_line", {})) if tail_line_kwargs is not False: tail_line_dims, tail_line_aes, tail_line_ignore = filter_aes( plot_collection, aes_by_visuals, "ess_tail_line", sample_dims ) if ess_tail_dataset is None: ess_tail_dataset = compute_ess_dataset( distribution, xdata, draw_divisions, "tail", tail_line_dims, relative, stats, ) if "color" not in tail_line_aes: tail_line_kwargs.setdefault("color", default_tail_color) plot_collection.map( line_xy, "ess_tail_line", data=ess_tail_dataset, ignore_aes=tail_line_ignore, **tail_line_kwargs, ) # getting backend specific linestyles linestyles = plot_bknd.get_default_aes("linestyle", 4, {}) # and default color default_color = plot_bknd.get_default_aes("color", 1, {})[0] # plot mean and sd and annotate them if extra_methods is not False: # computing mean_ess mean_dims, mean_aes, mean_ignore = filter_aes( plot_collection, aes_by_visuals, "mean", sample_dims ) mean_ess = distribution.azstats.ess( sample_dims=mean_dims, method="mean", relative=relative, **stats.get("mean", {}) ) # computing sd_ess sd_dims, sd_aes, sd_ignore = filter_aes(plot_collection, aes_by_visuals, "sd", sample_dims) sd_ess = distribution.azstats.ess( sample_dims=sd_dims, method="sd", relative=relative, **stats.get("sd", {}) ) mean_kwargs = copy(visuals.get("mean", {})) if mean_kwargs is not False: # getting 2nd default linestyle for chosen backend and assigning it by default mean_kwargs.setdefault("linestyle", linestyles[1]) if "color" not in mean_aes: mean_kwargs.setdefault("color", default_color) plot_collection.map( line_xy, "mean", data=mean_ess, x=xdata, ignore_aes=mean_ignore, **mean_kwargs, ) sd_kwargs = copy(visuals.get("sd", {})) if sd_kwargs is not False: sd_kwargs.setdefault("linestyle", linestyles[2]) if "color" not in sd_aes: sd_kwargs.setdefault("color", default_color) plot_collection.map( line_xy, "sd", data=sd_ess, ignore_aes=sd_ignore, x=xdata, **sd_kwargs ) sd_va_align = None mean_va_align = None if mean_ess is not None and sd_ess is not None: sd_va_align = xr.where(mean_ess < sd_ess, "bottom", "top") mean_va_align = xr.where(mean_ess < sd_ess, "top", "bottom") mean_text_kwargs = copy(visuals.get("mean_text", {})) if ( mean_text_kwargs is not False and mean_ess is not None ): # mean_ess has to exist for an annotation to be applied _, mean_text_aes, mean_text_ignore = filter_aes( plot_collection, aes_by_visuals, "mean_text", sample_dims ) if "color" not in mean_text_aes: mean_text_kwargs.setdefault("color", contrast_color) mean_text_kwargs.setdefault("x", max(xdata)) mean_text_kwargs.setdefault("horizontal_align", "right") # pass the mean vertical_align data for vertical alignment setting if mean_va_align is not None: vertical_align = mean_va_align else: vertical_align = "bottom" mean_text_kwargs.setdefault("vertical_align", vertical_align) plot_collection.map( annotate_xy, "mean_text", text="mean", data=mean_ess, ignore_aes=mean_text_ignore, **mean_text_kwargs, ) sd_text_kwargs = copy(visuals.get("sd_text", {})) if ( sd_text_kwargs is not False and sd_ess is not None ): # sd_ess has to exist for an annotation to be applied _, sd_text_aes, sd_text_ignore = filter_aes( plot_collection, aes_by_visuals, "sd_text", sample_dims ) if "color" not in sd_text_aes: sd_text_kwargs.setdefault("color", contrast_color) sd_text_kwargs.setdefault("x", max(xdata)) sd_text_kwargs.setdefault("horizontal_align", "right") # pass the sd vertical_align data for vertical alignment setting if sd_va_align is not None: vertical_align = sd_va_align else: vertical_align = "top" sd_text_kwargs.setdefault("vertical_align", vertical_align) plot_collection.map( annotate_xy, "sd_text", text="sd", data=sd_ess, ignore_aes=sd_text_ignore, **sd_text_kwargs, ) # plot minimum ess min_ess_kwargs = copy(visuals.get("min_ess", {})) if min_ess_kwargs is not False: _, min_ess_aes, min_ess_ignore = filter_aes( plot_collection, aes_by_visuals, "min_ess", sample_dims ) if relative: min_ess = min_ess / n_points min_ess_kwargs.setdefault("linestyle", linestyles[3]) if "color" not in min_ess_aes: min_ess_kwargs.setdefault("color", contrast_gray_color) plot_collection.map( line_xy, "min_ess", data=distribution, ignore_aes=min_ess_ignore, x=xdata, y=min_ess, **min_ess_kwargs, ) # plot titles for each faceted subplot title_kwargs = copy(visuals.get("title", {})) if title_kwargs is not False: _, title_aes, title_ignore = filter_aes( plot_collection, aes_by_visuals, "title", sample_dims ) if "color" not in title_aes: title_kwargs.setdefault("color", contrast_color) plot_collection.map( labelled_title, "title", ignore_aes=title_ignore, subset_info=True, labeller=labeller, **title_kwargs, ) # plot x and y axis labels # Add varnames as x and y labels _, xlabels_aes, xlabels_ignore = filter_aes( plot_collection, aes_by_visuals, "xlabel", sample_dims ) xlabel_kwargs = copy(visuals.get("xlabel", {})) if xlabel_kwargs is not False: if "color" not in xlabels_aes: xlabel_kwargs.setdefault("color", contrast_color) xlabel_kwargs.setdefault( "text", sample_dims[0] if len(sample_dims) == 1 else "Total Number of Draws" ) plot_collection.map( labelled_x, "xlabel", ignore_aes=xlabels_ignore, subset_info=True, **xlabel_kwargs, ) _, ylabels_aes, ylabels_ignore = filter_aes( plot_collection, aes_by_visuals, "ylabel", sample_dims ) ylabel_kwargs = copy(visuals.get("ylabel", {})) if ylabel_kwargs is not False: if "color" not in ylabels_aes: ylabel_kwargs.setdefault("color", contrast_color) ylabel = "{}" ylabel_kwargs.setdefault( "text", ylabel.format("Relative ESS") if relative is not False else ylabel.format("ESS"), ) plot_collection.map( labelled_y, "ylabel", ignore_aes=ylabels_ignore, subset_info=True, **ylabel_kwargs, ) return plot_collection