Source code for arviz_plots.plots.dist_plot

"""dist plot code."""

import warnings
from collections.abc import Mapping, Sequence
from copy import copy
from importlib import import_module
from typing import Any, Literal

import arviz_stats  # pylint: disable=unused-import
import xarray as xr
from arviz_base import rcParams
from arviz_base.labels import BaseLabeller

from arviz_plots.plot_collection import PlotCollection
from arviz_plots.plots.utils import (
    filter_aes,
    get_contrast_colors,
    process_group_variables_coords,
    set_wrap_layout,
)
from arviz_plots.visuals import (
    ecdf_line,
    fill_between_y,
    hist,
    labelled_title,
    line_x,
    line_xy,
    point_estimate_text,
    remove_axis,
    scatter_x,
    step_hist,
)


[docs] def plot_dist( dt, var_names=None, filter_vars=None, group="posterior", coords=None, sample_dims=None, kind=None, point_estimate=None, ci_kind=None, ci_prob=None, plot_collection=None, backend=None, labeller=None, aes_by_visuals: Mapping[ Literal[ "dist", "credible_interval", "point_estimate", "point_estimate_text", "title", "rug", ], Sequence[str], ] = None, visuals: Mapping[ Literal[ "dist", "credible_interval", "point_estimate", "point_estimate_text", "title", "rug", "remove_axis", ], Mapping[str, Any] | Literal[False], ] = None, stats: Mapping[ Literal["dist", "credible_interval", "point_estimate"], Mapping[str, Any] | xr.Dataset ] = None, **pc_kwargs, ): """Plot 1D marginal densities in the style of John K. Kruschke’s book [1]_. Generate :term:`faceted` :term:`plots` with: a graphical representation of 1D marginal densities (as KDE, histogram, ECDF or dotplot), a credible interval and a point estimate. Parameters ---------- dt : DataTree or dict of {str : DataTree} Input data. In case of dictionary input, the keys are taken to be model names. In such cases, a dimension "model" is generated and can be used to map to aesthetics. var_names : str or list of str, optional One or more variables to be plotted. Prefix the variables by ~ when you want to exclude them from the plot. filter_vars : {None, “like”, “regex”}, default=None If None, interpret var_names as the real variables names. If “like”, interpret var_names as substrings of the real variables names. If “regex”, interpret var_names as regular expressions on the real variables names. group : str, default "posterior" Group to be plotted. coords : dict, optional sample_dims : str or sequence of hashable, optional Dimensions to reduce unless mapped to an aesthetic. Defaults to ``rcParams["data.sample_dims"]`` kind : {"kde", "hist", "dot", "ecdf"}, optional How to represent the marginal density. Defaults to ``rcParams["plot.density_kind"]`` point_estimate : {"mean", "median", "mode"}, optional Which point estimate to plot. Defaults to rcParam :data:`stats.point_estimate` ci_kind : {"eti", "hdi"}, optional Which credible interval to use. Defaults to ``rcParams["stats.ci_kind"]`` ci_prob : float, optional Indicates the probability that should be contained within the plotted credible interval. Defaults to ``rcParams["stats.ci_prob"]`` plot_collection : PlotCollection, optional backend : {"matplotlib", "bokeh"}, optional labeller : labeller, optional aes_by_visuals : mapping of {str : sequence of str}, optional Mapping of visuals to aesthetics that should use their mapping in `plot_collection` when plotted. Valid keys are the same as for `visuals`. With a single model, no aesthetic mappings are generated by default, each variable+coord combination gets a :term:`plot` but they all look the same, unless there are user provided aesthetic mappings. With multiple models, ``plot_dist`` maps "color" and "y" to the "model" dimension. By default, all aesthetics but "y" are mapped to the density representation, and if multiple models are present, "color" and "y" are mapped to the credible interval and the point estimate. When "point_estimate" key is provided but "point_estimate_text" isn't, the values assigned to the first are also used for the second. visuals : mapping of {str : mapping or False}, optional Valid keys are: * dist -> depending on the value of `kind` passed to: * "kde" -> passed to :func:`~arviz_plots.visuals.line_xy` * "ecdf" -> passed to :func:`~arviz_plots.visuals.ecdf_line` * "hist" -> passed to :func: `~arviz_plots.visuals.step_hist` * face -> :term:`visual` that fills the area under the marginal distribution representation. Defaults to False. Depending on the value of `kind` it is passed to: * "kde" or "ecdf" -> passed to :func:`~arviz_plots.visuals.fill_between_y` * "hist" -> passed to :func:`~arviz_plots.visuals.hist` * credible_interval -> passed to :func:`~arviz_plots.visuals.line_x` * point_estimate -> passed to :func:`~arviz_plots.visuals.scatter_x` * point_estimate_text -> passed to :func:`~arviz_plots.visuals.point_estimate_text` * title -> passed to :func:`~arviz_plots.visuals.labelled_title` * rug -> passed to :func:`~arviz_plots.visuals.scatter_x`. Defaults to False. * remove_axis -> not passed anywhere, can only be ``False`` to skip calling this function stats : mapping, optional Valid keys are: * dist -> passed to kde, ecdf, ... * credible_interval -> passed to eti or hdi * point_estimate -> passed to mean, median or mode **pc_kwargs Passed to :class:`arviz_plots.PlotCollection.wrap` Returns ------- PlotCollection See Also -------- :ref:`plots_intro` : General introduction to batteries-included plotting functions, common use and logic overview Examples -------- Map the color to the variable, and have the mapping apply to the title too instead of only the density representation: .. plot:: :context: close-figs >>> from arviz_plots import plot_dist, style >>> style.use("arviz-variat") >>> from arviz_base import load_arviz_data >>> non_centered = load_arviz_data('non_centered_eight') >>> pc = plot_dist( >>> non_centered, >>> coords={"school": ["Choate", "Deerfield", "Hotchkiss"]}, >>> aes={"color": ["__variable__"]}, >>> aes_by_visuals={"title": ["color"]}, >>> ) .. minigallery:: plot_dist References ---------- .. [1] Kruschke. Doing Bayesian Data Analysis, Second Edition: A Tutorial with R, JAGS, and Stan. Academic Press, 2014. ISBN 978-0-12-405888-0. https://www.sciencedirect.com/book/9780124058880 """ if ci_kind not in ("hdi", "eti", None): raise ValueError("ci_kind must be either 'hdi' or 'eti'") if sample_dims is None: sample_dims = rcParams["data.sample_dims"] if isinstance(sample_dims, str): sample_dims = [sample_dims] if ci_prob is None: ci_prob = rcParams["stats.ci_prob"] if ci_kind is None: ci_kind = rcParams["stats.ci_kind"] if "stats.ci_kind" in rcParams else "eti" if point_estimate is None: point_estimate = rcParams["stats.point_estimate"] if kind is None: kind = rcParams["plot.density_kind"] if visuals is None: visuals = {} else: visuals = visuals.copy() if kind in ("hist", "ecdf"): visuals.setdefault("remove_axis", False) if stats is None: stats = {} else: stats = stats.copy() distribution = process_group_variables_coords( dt, group=group, var_names=var_names, filter_vars=filter_vars, coords=coords ) if backend is None: if plot_collection is None: backend = rcParams["plot.backend"] else: backend = plot_collection.backend plot_bknd = import_module(f".backend.{backend}", package="arviz_plots") bg_color = plot_bknd.get_background_color() contrast_color, contrast_gray_color = get_contrast_colors(bg_color=bg_color, gray_flag=True) if plot_collection is None: pc_kwargs["figure_kwargs"] = pc_kwargs.get("figure_kwargs", {}).copy() if backend is None: backend = rcParams["plot.backend"] pc_kwargs.setdefault( "cols", ["__variable__"] + [dim for dim in distribution.dims if dim not in {"model"}.union(sample_dims)], ) if "model" in distribution: pc_kwargs["aes"] = pc_kwargs.get("aes", {}).copy() pc_kwargs["aes"].setdefault("color", ["model"]) pc_kwargs["aes"].setdefault("y", ["model"]) pc_kwargs = set_wrap_layout(pc_kwargs, plot_bknd, distribution) plot_collection = PlotCollection.wrap( distribution, backend=backend, **pc_kwargs, ) face_kwargs = copy(visuals.get("face", False)) density_kwargs = copy(visuals.get("dist", {})) if aes_by_visuals is None: aes_by_visuals = {} else: aes_by_visuals = aes_by_visuals.copy() aes_by_visuals.setdefault("dist", plot_collection.aes_set.difference("y")) if face_kwargs is not False: aes_by_visuals.setdefault("face", set(aes_by_visuals["dist"]).difference({"linestyle"})) if "model" in distribution: aes_by_visuals.setdefault("credible_interval", ["color", "y"]) aes_by_visuals.setdefault("point_estimate", ["color", "y"]) if "point_estimate" in aes_by_visuals and "point_estimate_text" not in aes_by_visuals: aes_by_visuals["point_estimate_text"] = aes_by_visuals["point_estimate"] if labeller is None: labeller = BaseLabeller() density = distribution default_color = plot_bknd.get_default_aes("color", 1, {})[0] if density_kwargs is not False or face_kwargs is not False: density_dims, _, _ = filter_aes(plot_collection, aes_by_visuals, "dist", sample_dims) if kind == "kde": with warnings.catch_warnings(): if "model" in distribution: warnings.filterwarnings("ignore", message="Your data appears to have a single") density = density.azstats.kde(dim=density_dims, **stats.get("dist", {})) elif kind == "ecdf": density = distribution.azstats.ecdf(dim=density_dims, **stats.get("dist", {})) elif kind == "hist": hist_kwargs = stats.pop("dist", {}).copy() hist_kwargs.setdefault("density", True) density = distribution.azstats.histogram(dim=density_dims, **hist_kwargs) # density if density_kwargs is not False: _, density_aes, density_ignore = filter_aes( plot_collection, aes_by_visuals, "dist", sample_dims ) if "color" not in density_aes: density_kwargs.setdefault("color", default_color) if kind == "kde": plot_collection.map( line_xy, "dist", data=density, ignore_aes=density_ignore, **density_kwargs ) elif kind == "ecdf": plot_collection.map( ecdf_line, "dist", data=density, ignore_aes=density_ignore, **density_kwargs, ) elif kind == "hist": plot_collection.map( step_hist, "dist", data=density, ignore_aes=density_ignore, **density_kwargs, ) else: raise NotImplementedError("coming soon") # filled face if face_kwargs is not False: _, face_aes, face_ignore = filter_aes(plot_collection, aes_by_visuals, "face", sample_dims) if "color" not in face_aes: face_kwargs.setdefault("color", default_color) if "alpha" not in face_aes: face_kwargs.setdefault("alpha", 0.4) if kind in ("kde", "ecdf"): face_density = ( density.rename(plot_axis="kwarg") .sel(kwarg=["x", "y"]) .pad(kwarg=(0, 1), constant_values=0) .assign_coords(kwarg=["x", "y_top", "y_bottom"]) ) plot_collection.map( fill_between_y, "face", data=face_density, ignore_aes=face_ignore, **face_kwargs, ) elif kind == "hist": plot_collection.map( hist, "face", data=density, ignore_aes=face_ignore, **face_kwargs, ) else: raise NotImplementedError("coming soon") rug_kwargs = copy(visuals.get("rug", False)) if rug_kwargs is not False: if not isinstance(rug_kwargs, dict): rug_kwargs = {} _, rug_aes, rug_ignore = filter_aes(plot_collection, aes_by_visuals, "rug", sample_dims) if "color" not in rug_aes: rug_kwargs.setdefault("color", contrast_color) if "marker" not in rug_aes: rug_kwargs.setdefault("marker", "|") if "size" not in rug_aes: rug_kwargs.setdefault("size", 15) plot_collection.map( scatter_x, "rug", data=distribution, ignore_aes=rug_ignore, **rug_kwargs, ) if ( (density_kwargs is not None) and ("model" in distribution) and (plot_collection.coords is None) ): reduce_dim_map = {"kde": "kde_dim", "ecdf": "quantile"} y_ds = plot_collection.get_aes_as_dataset("y")["mapping"] y_ds = ( 0.15 * y_ds * density.sel(plot_axis="y", drop=True).max([reduce_dim_map[kind], "model"]) ) plot_collection.update_aes_from_dataset("y", y_ds) # credible interval ci_kwargs = copy(visuals.get("credible_interval", {})) if ci_kwargs is not False: ci_dims, ci_aes, ci_ignore = filter_aes( plot_collection, aes_by_visuals, "credible_interval", sample_dims ) if ci_kind == "eti": ci = distribution.azstats.eti( prob=ci_prob, dim=ci_dims, **stats.get("credible_interval", {}) ) elif ci_kind == "hdi": ci = distribution.azstats.hdi( prob=ci_prob, dim=ci_dims, **stats.get("credible_interval", {}) ) if "color" not in ci_aes: ci_kwargs.setdefault("color", contrast_gray_color) plot_collection.map(line_x, "credible_interval", data=ci, ignore_aes=ci_ignore, **ci_kwargs) # point estimate pe_kwargs = copy(visuals.get("point_estimate", {})) pet_kwargs = copy(visuals.get("point_estimate_text", {})) if (pe_kwargs is not False) or (pet_kwargs is not False): pe_dims, pe_aes, pe_ignore = filter_aes( plot_collection, aes_by_visuals, "point_estimate", sample_dims ) if point_estimate == "median": point = distribution.median(dim=pe_dims, **stats.get("point_estimate", {})) elif point_estimate == "mean": point = distribution.mean(dim=pe_dims, **stats.get("point_estimate", {})) else: raise NotImplementedError("coming soon") if pe_kwargs is not False: if "color" not in pe_aes: pe_kwargs.setdefault("color", contrast_gray_color) plot_collection.map( scatter_x, "point_estimate", data=point, ignore_aes=pe_ignore, **pe_kwargs, ) # point estimate text if pet_kwargs is not False: if density_kwargs is False and face_kwargs is False: point_y = xr.full_like(point, 0.02) elif kind == "kde": point_density_diff = [ dim for dim in density.sel(plot_axis="y").dims if dim not in point.dims ] point_density_diff = ["kde_dim"] + point_density_diff point_y = 0.1 * density.sel(plot_axis="y", drop=True).max(dim=point_density_diff) elif kind == "ecdf": # ecdf max is always 1 point_y = xr.full_like(point, 0.1) elif kind == "hist": point_density_diff = [ dim for dim in density.sel(plot_axis="histogram").dims if dim not in point.dims ] point_density_diff = [ f"hist_dim_{var_name}" for var_name in density.data_vars ] + point_density_diff point_y = 0.1 * density.sel(plot_axis="histogram", drop=True).max( dim=point_density_diff ) point = xr.concat((point, point_y), dim="plot_axis").assign_coords(plot_axis=["x", "y"]) _, pet_aes, pet_ignore = filter_aes( plot_collection, aes_by_visuals, "point_estimate_text", sample_dims ) if "color" not in pet_aes: pet_kwargs.setdefault("color", contrast_gray_color) pet_kwargs.setdefault("horizontal_align", "center") pet_kwargs.setdefault("point_label", "x") plot_collection.map( point_estimate_text, "point_estimate_text", data=point, point_estimate=point_estimate, ignore_aes=pet_ignore, **pet_kwargs, ) # aesthetics title_kwargs = copy(visuals.get("title", {})) if title_kwargs is not False: _, title_aes, title_ignore = filter_aes( plot_collection, aes_by_visuals, "title", sample_dims ) if "color" not in title_aes: title_kwargs.setdefault("color", contrast_color) plot_collection.map( labelled_title, "title", ignore_aes=title_ignore, subset_info=True, labeller=labeller, **title_kwargs, ) if visuals.get("remove_axis", True) is not False: plot_collection.map( remove_axis, store_artist=backend == "none", axis="y", ignore_aes=plot_collection.aes_set, ) return plot_collection