Source code for arviz_plots.plots.pair_plot

"""Pair plot code."""
from collections.abc import Mapping, Sequence
from copy import copy
from importlib import import_module
from typing import Any, Literal

import numpy as np
import xarray as xr
from arviz_base import rcParams, xarray_sel_iter
from arviz_base.labels import BaseLabeller

from arviz_plots.plot_matrix import PlotMatrix
from arviz_plots.plots.dist_plot import plot_dist
from arviz_plots.plots.utils import (
    filter_aes,
    get_group,
    process_group_variables_coords,
    set_grid_layout,
)
from arviz_plots.visuals import (
    label_plot,
    labelled_x,
    labelled_y,
    remove_matrix_axis,
    scatter_couple,
    set_ticklabel_visibility,
)


[docs] def plot_pair( dt, var_names=None, filter_vars=None, group="posterior", coords=None, sample_dims=None, marginal=True, marginal_kind=None, triangle="lower", plot_matrix=None, backend=None, labeller=None, aes_by_visuals: Mapping[ Literal[ "scatter", "divergence", "dist", "credible_interval", "point_estimate", "point_estimate_text", "label", "xlabel", "ylabel", ], Sequence[str], ] = None, visuals: Mapping[ Literal[ "scatter", "divergence", "dist", "credible_interval", "point_estimate", "point_estimate_text", "label", "xlabel", "ylabel", "remove_axis", ], Mapping[str, Any] | Literal[False], ] = None, stats: Mapping[ Literal[ "dist", "credible_interval", "point_estimate", ], Mapping[str, Any] | xr.Dataset, ] = None, **pc_kwargs, ): """Plot all variables against each other in the dataset. Parameters ---------- dt : DataTree Input data var_names: str or list of str, optional One or more variables to be plotted. Prefix the variables by ~ when you want to exclude them from the plot. filter_vars: {None, “like”, “regex”}, default None If None (default), interpret var_names as the real variables names. If “like”, interpret var_names as substrings of the real variables names. If “regex”, interpret var_names as regular expressions on the real variables names. group : str, default "posterior" Group to use for plotting. Defaults to "posterior". coords : mapping, optional Coordinates to use for plotting. sample_dims : iterable, optional Dimensions to reduce unless mapped to an aesthetic. Defaults to ``rcParams["data.sample_dims"]`` marginal : bool, default True Whether to plot marginal distributions on the diagonal. marginal_kind : {"kde", "hist", "ecdf"}, optional How to represent the marginal density. Defaults to ``rcParams["plot.density_kind"]`` triangle : {"both", "upper", "lower"}, Defaults to "both" Which triangle of the pair plot to plot. plot_matrix : PlotMatrix, optional backend : {"matplotlib", "bokeh", "plotly", "none"}, optional Plotting backend to use. Defaults to ``rcParams["plot.backend"]`` labeller : labeller, optional aes_by_visuals : mapping, optional Mapping of visuals to aesthetics that should use their mapping in `plot_matrix` when plotted. Valid keys are the same as for `visuals`. By default, there are no aesthetic mappings at all visuals : mapping of {str : mapping or False}, optional Valid keys are: * scatter -> passed to :func:`~.visuals.scatter_couple` * divergence -> passed to :func:`~.visuals.scatter_couple`. Defaults to False. * dist -> depending on the value of `marginal_kind` passed to: * "kde" -> passed to :func:`~arviz_plots.visuals.line_xy` * "ecdf" -> passed to :func:`~arviz_plots.visuals.ecdf_line` * "hist" -> passed to :func: `~arviz_plots.visuals.hist` * credible_interval -> passed to :func:`~arviz_plots.visuals.line_x` * point_estimate -> passed to :func:`~arviz_plots.visuals.scatter_x` * point_estimate_text -> passed to :func:`~arviz_plots.visuals.point_estimate_text` * label -> Keyword arguments passed to :func:`~arviz_plots.visuals.label_plot`. Used to customize the variable name labels on the diagonal. Applied only if ``marginal=False``. * xlabel -> passed to :func:`~.visuals.labelled_x`. used to customize the xaxis labels on the bottom-most plots or diagonal plots depending upon the value of ``triangle``. If ``triangle`` is "lower" or "both" then it is used to map bottom-most row plots by using :meth:`arviz_plots.PlotMatrix.map_row` method and if ``triangle`` is "upper" then it is used to map diagonal plots by using :meth:`arviz_plots.PlotMatrix.map` method.It is applied only if ``marginal=True``, since in this case diagonal plots won't have labels to map variables to columns. * ylabel -> passed to :func:`~.visuals.labelled_y`. used to customize the yaxis labels on the left-most plots. It is applied, only if ``triangle`` is "lower" or "both" and ``marginal=True``, by using :meth:`arviz_plots.PlotMatrix.map_col` method. Not applied if ``triangle`` is "upper" or ``marginal=False``. * remove_axis -> not passed anywhere. It can only be set to ``False`` to disable the default removal of ``x`` and ``y`` axes from the plots of other half triangle. If ``triangle`` is "upper" then the lower triangle plot's axes will be removed and if ``triangle`` is "lower" then the upper triangle axes will be removed, in case if it is not set ``False`` manually. stats : mapping, optional Valid keys are: * dist -> passed to kde, ecdf, ... * credible_interval -> passed to eti or hdi * point_estimate -> passed to mean, median or mode **pc_kwargs Passed to :class:`arviz_plots.PlotMatrix` Returns ------- PlotMatrix Examples -------- plot_pair with ``triangle`` set to "upper" and ``marginal=True`` with ``marginal_kind`` set to "ecdf". In this case, since ``triangle`` is "upper", so the ``xlabels`` are mapped to the diagonal plots. ``marginals`` are plotted on the diagonal and the ``point_estimate`` and ``credible_interval`` are set to ``False`` by default. Also since ``marginal=True``, so ``sharex`` is set to "col", while ``sharey`` is not set to anything by default. .. plot:: :context: close-figs >>> from arviz_plots import plot_pair, style >>> style.use("arviz-variat") >>> from arviz_base import load_arviz_data >>> dt = load_arviz_data('centered_eight') >>> plot_pair( >>> dt, >>> var_names=["mu", "tau"], >>> visuals={"divergence": True}, >>> marginal=True, >>> marginal_kind="ecdf", >>> triangle="upper", >>> ) plot_pair with `triangle` set to "both", so in this case the ``xlabels`` are mapped to the bottom-most plots and ``ylabels`` are mapped to the left-most plots. In this example we set ``color`` as "red" for ``credible_interval`` and ``point_estimate``, which enables ``credible_interval`` and ``point_estimate``. By default ``marginal`` is set to ``True`` and ``marginal_kind`` is set to ``rcParams["plot.density_kind"]``. .. plot:: :context: close-figs >>> visuals = {"credible_interval":{"color":"red"},"point_estimate":{"color":"red"}} >>> plot_pair( >>> dt, >>> var_names=["mu", "tau"], >>> visuals=visuals, >>> triangle="both", >>> ) plot_pair with ``marginal=False`` and ``triangle`` set to "upper". In this case, since ``marginal=False``, so ``xlabel`` and ``ylabel`` are disabled by default, and diagonal plots contain variable names as labels. ``xticks`` and ``yticks`` are also set on diagonal plots along with ``ticklabels``, to map ticks to rows and columns. Since ``marginal=False``, so ``sharex`` is set to "col" and ``sharey`` is set to "row" by default. .. plot:: :context: close-figs >>> plot_pair( >>> dt, >>> coords = {"school":"Choate"}, >>> visuals={"divergence": True}, >>> marginal=False, >>> triangle="upper", >>> ) .. minigallery:: plot_pair """ if sample_dims is None: sample_dims = rcParams["data.sample_dims"] if isinstance(sample_dims, str): sample_dims = [sample_dims] if visuals is None: visuals = {} if pc_kwargs is None: pc_kwargs = {} else: pc_kwargs = pc_kwargs.copy() if labeller is None: labeller = BaseLabeller() if backend is None: if plot_matrix is None: backend = rcParams["plot.backend"] else: backend = plot_matrix.backend distribution = process_group_variables_coords( dt, group=group, var_names=var_names, filter_vars=filter_vars, coords=coords ) plot_bknd = import_module(f".backend.{backend}", package="arviz_plots") if plot_matrix is None: pc_kwargs.setdefault( "facet_dims", ["__variable__"] + [dim for dim in distribution.dims if dim not in sample_dims], ) pc_kwargs["figure_kwargs"] = pc_kwargs.get("figure_kwargs", {}).copy() pairs = tuple( xarray_sel_iter( distribution, skip_dims={dim for dim in distribution.dims if dim in sample_dims} ) ) n_pairs = len(pairs) pc_kwargs = set_grid_layout( pc_kwargs, plot_bknd, distribution, num_rows=n_pairs, num_cols=n_pairs ) pc_kwargs["aes"] = pc_kwargs.get("aes", {}).copy() if "chain" in distribution: pc_kwargs["aes"].setdefault("overlay", ["chain"]) pc_kwargs["figure_kwargs"].setdefault("sharex", "col") if not marginal: pc_kwargs["figure_kwargs"].setdefault("sharey", "row") plot_matrix = PlotMatrix( distribution, backend=backend, **pc_kwargs, ) if aes_by_visuals is None: aes_by_visuals = {} else: aes_by_visuals = aes_by_visuals.copy() aes_by_visuals["scatter"] = {"overlay"}.union( aes_by_visuals.get("scatter", plot_matrix.aes_set) ) aes_by_visuals["divergence"] = {"overlay"}.union(aes_by_visuals.get("divergence", {})) aes_by_visuals["dist"] = aes_by_visuals.get("dist", plot_matrix.aes_set.difference({"overlay"})) aes_by_visuals["credible_interval"] = aes_by_visuals.get("credible_interval", {}) aes_by_visuals["point_estimate"] = aes_by_visuals.get("point_estimate", {}) aes_by_visuals["point_estimate_text"] = aes_by_visuals.get("point_estimate_text", {}) colors = plot_bknd.get_default_aes("color", 2, {}) # scatter scatter_kwargs = copy(visuals.get("scatter", {})) if scatter_kwargs is not False: _, scatter_aes, scatter_ignore = filter_aes( plot_matrix, aes_by_visuals, "scatter", sample_dims ) if "color" not in scatter_aes: scatter_kwargs.setdefault("color", colors[0]) if "width" not in scatter_aes: scatter_kwargs.setdefault("width", 0) if "alpha" not in scatter_aes: scatter_kwargs.setdefault("alpha", 0.5) plot_matrix.map_triangle( scatter_couple, "scatter", triangle=triangle, data=distribution, ignore_aes=scatter_ignore, **scatter_kwargs, ) # marginal if marginal is not False: if stats is None: stats = {} else: stats = stats.copy() dist_plot_visuals = {} dist_plot_aes_by_visuals = {} dist_plot_stats = {} marginal_dist_kwargs = copy(visuals.get("dist", {})) marginal_ci_kwargs = copy(visuals.get("credible_interval", False)) marginal_point_estimate_kwargs = copy(visuals.get("point_estimate", False)) marginal_point_estimate_text_kwargs = copy(visuals.get("point_estimate_text", False)) dist_plot_visuals["dist"] = marginal_dist_kwargs dist_plot_visuals["credible_interval"] = marginal_ci_kwargs dist_plot_visuals["point_estimate"] = marginal_point_estimate_kwargs dist_plot_visuals["point_estimate_text"] = marginal_point_estimate_text_kwargs dist_plot_visuals["title"] = False dist_plot_visuals["remove_axis"] = False dist_plot_visuals["rug"] = False dist_plot_aes_by_visuals["dist"] = aes_by_visuals.get( "dist", plot_matrix.aes_set.difference({"overlay"}) ) dist_plot_aes_by_visuals["credible_interval"] = aes_by_visuals.get("credible_interval", {}) dist_plot_aes_by_visuals["point_estimate"] = aes_by_visuals.get("point_estimate", {}) dist_plot_aes_by_visuals["point_estimate_text"] = aes_by_visuals.get( "point_estimate_text", {} ) dist_plot_stats["dist"] = stats.get("dist", {}) dist_plot_stats["credible_interval"] = stats.get("credible_interval", {}) dist_plot_stats["point_estimate"] = stats.get("point_estimate", {}) plot_matrix = plot_dist( distribution, sample_dims=sample_dims, kind=marginal_kind, plot_collection=plot_matrix, backend=backend, labeller=labeller, aes_by_visuals=dist_plot_aes_by_visuals, visuals=dist_plot_visuals, stats=dist_plot_stats, ) # diagonal labels of rows and cols else: label_kwargs = copy(visuals.get("label", {})) if label_kwargs is not False: lim_low = distribution.min(dim=sample_dims) lim_high = distribution.max(dim=sample_dims) text_center = (lim_high + lim_low) / 2 _, _, label_ignore = filter_aes(plot_matrix, aes_by_visuals, "label", sample_dims) plot_matrix.map( label_plot, "label", subset_info=True, labeller=labeller, x=text_center, y=text_center, lim_low=lim_low, lim_high=lim_high, ignore_aes=label_ignore, **label_kwargs, ) # divergence div_kwargs = copy(visuals.get("divergence", False)) if div_kwargs is True: div_kwargs = {} sample_stats = get_group(dt, "sample_stats", allow_missing=True) if ( div_kwargs is not False and sample_stats is not None and "diverging" in sample_stats.data_vars and np.any(sample_stats.diverging) ): divergence_mask = dt.sample_stats.diverging _, div_aes, div_ignore = filter_aes(plot_matrix, aes_by_visuals, "divergence", sample_dims) if "color" not in div_aes: div_kwargs.setdefault("color", colors[1]) if "alpha" not in div_aes: div_kwargs.setdefault("alpha", 0.5) plot_matrix.map_triangle( scatter_couple, "divergence", triangle=triangle, data=distribution, ignore_aes=div_ignore, mask=divergence_mask, **div_kwargs, ) # bottom plots xlabel and left plots ylabel if marginal and triangle in {"both", "lower"}: xlabel_kwargs = copy(visuals.get("xlabel", {})) if xlabel_kwargs is not False: _, _, xlabel_ignore = filter_aes(plot_matrix, aes_by_visuals, "xlabel", sample_dims) plot_matrix.map_row( labelled_x, "xlabel", index=-1, data=distribution, ignore_aes=xlabel_ignore, labeller=labeller, subset_info=True, **xlabel_kwargs, ) ylabel_kwargs = copy(visuals.get("ylabel", {})) if ylabel_kwargs is not False: _, _, ylabel_ignore = filter_aes(plot_matrix, aes_by_visuals, "ylabel", sample_dims) plot_matrix.map_col( labelled_y, "ylabel", index=0, data=distribution, ignore_aes=ylabel_ignore, labeller=labeller, subset_info=True, **ylabel_kwargs, ) elif marginal and triangle == "upper": xlabel_kwargs = copy(visuals.get("xlabel", {})) if xlabel_kwargs is not False: _, _, xlabel_ignore = filter_aes(plot_matrix, aes_by_visuals, "xlabel", sample_dims) plot_matrix.map( labelled_x, "xlabel", subset_info=True, ignore_aes=xlabel_ignore, labeller=labeller, **xlabel_kwargs, ) # set ticklabel visibility set_ticklabel_visibility_kwargs = {} if triangle == "upper": _, _, set_ticklabel_visibility_ignore = filter_aes( plot_matrix, aes_by_visuals, "set_ticklabel_visibility", sample_dims ) plot_matrix.map( set_ticklabel_visibility, "set_ticklabel_visibility", axis="x", visible=True, ignore_aes=set_ticklabel_visibility_ignore, **set_ticklabel_visibility_kwargs, ) if not marginal: plot_matrix.map( set_ticklabel_visibility, "set_ticklabel_visibility", axis="y", visible=True, ignore_aes=set_ticklabel_visibility_ignore, **set_ticklabel_visibility_kwargs, ) # removal of axis for better visualization remove_axis_bool = visuals.get("remove_axis", True) if remove_axis_bool: _, _, remove_axis_ignore = filter_aes( plot_matrix, aes_by_visuals, "remove_axis", sample_dims ) # if triangle="upper" then remove the lower triangle axes if triangle == "upper": plot_matrix.map_triangle( remove_matrix_axis, "remove_axis", triangle="lower", axis="both", ignore_aes=remove_axis_ignore, ) # if triangle="lower" then remove the upper triangle axes elif triangle == "lower": plot_matrix.map_triangle( remove_matrix_axis, "remove_axis", triangle="upper", axis="both", ignore_aes=remove_axis_ignore, ) return plot_matrix