Source code for arviz_plots.plots.trace_plot

"""Trace plot code."""
from collections.abc import Mapping, Sequence
from copy import copy
from importlib import import_module
from typing import Any, Literal

import numpy as np
from arviz_base import rcParams
from arviz_base.labels import BaseLabeller

from arviz_plots.plot_collection import PlotCollection
from arviz_plots.plots.utils import (
    filter_aes,
    get_contrast_colors,
    get_group,
    process_group_variables_coords,
    set_wrap_layout,
)
from arviz_plots.visuals import labelled_title, labelled_x, line, ticklabel_props, trace_rug


[docs] def plot_trace( dt, var_names=None, filter_vars=None, group="posterior", coords=None, sample_dims=None, plot_collection=None, backend=None, labeller=None, aes_by_visuals: Mapping[ Literal[ "trace", "divergence", "title", "xlabel", "ticklabels", ], Sequence[str], ] = None, visuals: Mapping[ Literal[ "trace", "divergence", "title", "xlabel", "ticklabels", ], Mapping[str, Any] | Literal[False], ] = None, **pc_kwargs, ): """Plot iteration versus sampled values. Parameters ---------- dt : DataTree Input data var_names: str or list of str, optional One or more variables to be plotted. Prefix the variables by ~ when you want to exclude them from the plot. filter_vars: {None, “like”, “regex”}, optional, default=None If None (default), interpret var_names as the real variables names. If “like”, interpret var_names as substrings of the real variables names. If “regex”, interpret var_names as regular expressions on the real variables names. sample_dims : iterable, optional Dimensions to reduce unless mapped to an aesthetic. Defaults to ``rcParams["data.sample_dims"]`` plot_collection : PlotCollection, optional backend : {"matplotlib", "bokeh"}, optional labeller : labeller, optional aes_by_visuals : mapping, optional Mapping of visuals to aesthetics that should use their mapping in `plot_collection` when plotted. Defaults to only mapping properties to the trace lines. visuals : mapping of {str : mapping or False}, optional Valid keys are: * trace -> passed to :func:`~.visuals.line` * divergence -> passed to :func:`~.visuals.trace_rug` * title -> :func:`~.visuals.labelled_title` * xlabel -> :func:`~.visuals.labelled_x` * ticklabels -> :func:`~.visuals.ticklabel_props` pc_kwargs : mapping Passed to :class:`arviz_plots.PlotCollection` Returns ------- PlotCollection Examples -------- The following examples focus on behaviour specific to ``plot_trace``. For a general introduction to batteries-included functions like this one and common usage examples see :ref:`plots_intro` Default plot_trace .. plot:: :context: close-figs >>> from arviz_plots import plot_trace, style >>> style.use("arviz-variat") >>> from arviz_base import load_arviz_data >>> centered = load_arviz_data('centered_eight') >>> plot_trace(centered) """ if sample_dims is None: sample_dims = rcParams["data.sample_dims"] if isinstance(sample_dims, str): sample_dims = [sample_dims] if visuals is None: visuals = {} distribution = process_group_variables_coords( dt, group=group, var_names=var_names, filter_vars=filter_vars, coords=coords ) if backend is None: if plot_collection is None: backend = rcParams["plot.backend"] else: backend = plot_collection.backend plot_bknd = import_module(f".backend.{backend}", package="arviz_plots") bg_color = plot_bknd.get_background_color() contrast_color = get_contrast_colors(bg_color=bg_color) if plot_collection is None: pc_kwargs["aes"] = pc_kwargs.get("aes", {}).copy() if "chain" in distribution: pc_kwargs["aes"].setdefault("color", ["chain"]) pc_kwargs["aes"].setdefault("overlay", ["chain"]) pc_kwargs.setdefault( "cols", ["__variable__"] + [dim for dim in distribution.dims if dim not in sample_dims] ) pc_kwargs["figure_kwargs"] = pc_kwargs.get("figure_kwargs", {}).copy() aux_dim_list = [dim for dim in pc_kwargs["cols"] if dim != "__variable__"] pc_kwargs = set_wrap_layout(pc_kwargs, plot_bknd, distribution) pc_kwargs["figure_kwargs"].setdefault("sharex", True) plot_collection = PlotCollection.wrap( distribution, backend=backend, **pc_kwargs, ) else: aux_dim_list = list(plot_collection.viz["plot"].dims) if aes_by_visuals is None: aes_by_visuals = {} else: aes_by_visuals = aes_by_visuals.copy() aes_by_visuals.setdefault("trace", plot_collection.aes_set) aes_by_visuals.setdefault("divergence", {"overlay"}) if labeller is None: labeller = BaseLabeller() # trace trace_kwargs = copy(visuals.get("trace", {})) if trace_kwargs is False: xname = None else: default_xname = sample_dims[0] if len(sample_dims) == 1 else "draw" if (default_xname not in distribution.dims) or ( not np.issubdtype(distribution[default_xname].dtype, np.number) ): default_xname = None xname = trace_kwargs.get("xname", default_xname) trace_kwargs["xname"] = xname _, _, trace_ignore = filter_aes(plot_collection, aes_by_visuals, "trace", sample_dims) plot_collection.map( line, "trace", data=distribution, ignore_aes=trace_ignore, **trace_kwargs, ) # divergences sample_stats = get_group(dt, "sample_stats", allow_missing=True) divergence_kwargs = copy(visuals.get("divergence", {})) if ( sample_stats is not None and "diverging" in sample_stats.data_vars and np.any(sample_stats.diverging) and divergence_kwargs is not False ): divergence_mask = dt.sample_stats.diverging _, div_aes, div_ignore = filter_aes( plot_collection, aes_by_visuals, "divergence", sample_dims ) if "color" not in div_aes: divergence_kwargs.setdefault("color", contrast_color) if "marker" not in div_aes: divergence_kwargs.setdefault("marker", "|") if "size" not in div_aes: divergence_kwargs.setdefault("size", 30) div_reduce_dims = [dim for dim in distribution.dims if dim not in aux_dim_list] plot_collection.map( trace_rug, "divergence", data=distribution, ignore_aes=div_ignore, xname=xname, y=distribution.min(div_reduce_dims), mask=divergence_mask, **divergence_kwargs, ) # aesthetics title_kwargs = copy(visuals.get("title", {})) if title_kwargs is not False: _, title_aes, title_ignore = filter_aes( plot_collection, aes_by_visuals, "title", sample_dims ) if "color" not in title_aes: title_kwargs.setdefault("color", contrast_color) plot_collection.map( labelled_title, "title", ignore_aes=title_ignore, subset_info=True, labeller=labeller, **title_kwargs, ) # Add "Steps" as x_label for trace xlabel_kwargs = copy(visuals.get("xlabel", {})) if xlabel_kwargs is not False: _, xlabel_aes, xlabel_ignore = filter_aes( plot_collection, aes_by_visuals, "xlabel", sample_dims ) if "color" not in xlabel_aes: xlabel_kwargs.setdefault("color", contrast_color) plot_collection.map( labelled_x, "xlabel", ignore_aes=xlabel_ignore, text="Steps" if xname is None else xname.capitalize(), **xlabel_kwargs, ) # Adjust tick labels ticklabels_kwargs = copy(visuals.get("ticklabels", {})) if ticklabels_kwargs is not False: _, _, ticklabels_ignore = filter_aes( plot_collection, aes_by_visuals, "ticklabels", sample_dims ) plot_collection.map( ticklabel_props, "ticklabels", ignore_aes=ticklabels_ignore, axis="both", store_artist=backend == "none", **ticklabels_kwargs, ) return plot_collection