Source code for arviz_plots.plots.parallel_plot

"""Pair focus plot code."""
from collections.abc import Mapping, Sequence
from copy import copy
from importlib import import_module
from typing import Any, Literal

import numpy as np
import xarray as xr
from arviz_base import dataset_to_dataarray, rcParams
from arviz_base.labels import BaseLabeller

from arviz_plots.plot_collection import PlotCollection
from arviz_plots.plots.utils import filter_aes, get_group, process_group_variables_coords
from arviz_plots.visuals import multiple_lines, set_xticks


[docs] def plot_parallel( dt, var_names=None, filter_vars=None, group="posterior", coords=None, sample_dims=None, norm_method=None, label_type="flat", plot_collection=None, backend=None, labeller=None, aes_by_visuals: Mapping[ Literal[ "line", "xticks", ], Sequence[str], ] = None, visuals: Mapping[ Literal[ "line", "xticks", ], Mapping[str, Any] | Literal[False], ] = None, **pc_kwargs, ): """Plot parallel coordinates plot showing posterior points with and without divergences. Parameters ---------- dt : DataTree Input data var_names: str or list of str, optional One or more variables to be plotted. Prefix the variables by ~ when you want to exclude them from the plot. filter_vars: {None, “like”, “regex”}, default None If None (default), interpret var_names as the real variables names. If “like”, interpret var_names as substrings of the real variables names. If “regex”, interpret var_names as regular expressions on the real variables names. group : str, default "posterior" Group to use for plotting. Defaults to "posterior". coords : mapping, optional Coordinates to use for plotting. sample_dims : iterable, optional Dimensions to reduce unless mapped to an aesthetic. Defaults to ``rcParams["data.sample_dims"]`` norm_method : {None, "normal", "minmax", "rank"}, default None Transformation to apply to the samples before plotting. label_type : {"flat", "vert"}, default "flat" Indicator of which `labeller` method to use when generating the labels of the x axis. plot_collection : PlotCollection, optional backend : {"matplotlib", "bokeh", "plotly", "none"}, optional Plotting backend to use. Defaults to ``rcParams["plot.backend"]`` labeller : labeller, optional aes_by_visuals : mapping, optional Mapping of visuals to aesthetics that should use their mapping in `plot_collection` when plotted. Valid keys are the same as for `visuals`. By default, there is a mapping from the value of `diverging` variable to color and alpha which is only active for the "line" visual. visuals : mapping of {str : mapping or False}, optional Valid keys are: * line -> passed to :func:`~.visuals.multiple_lines` * xticks -> passed to :func:`~.visuals.set_xticks`. Defaults to False. **pc_kwargs Passed to :meth:`arviz_plots.PlotCollection.wrap` Returns ------- PlotCollection Examples -------- Default plot_parallel without normalization and with default `label_type` as "flat". .. plot:: :context: close-figs >>> from arviz_plots import plot_parallel, style >>> style.use("arviz-variat") >>> from arviz_base import load_arviz_data >>> dt = load_arviz_data('centered_eight') >>> plot_parallel(dt) parallel_plot with `norm_method` set to "normal" and `label_type` set to "vert" and rotation of xaxis labels set to 30 degrees. .. plot:: :context: close-figs >>> plot_parallel( >>> dt, >>> norm_method="normal", >>> label_type="vert", >>> visuals={"xticks": {"rotation": 30}}, >>> ) .. minigallery:: plot_parallel """ if sample_dims is None: sample_dims = rcParams["data.sample_dims"] if isinstance(sample_dims, str): sample_dims = [sample_dims] if visuals is None: visuals = {} if backend is None: if plot_collection is None: backend = rcParams["plot.backend"] else: backend = plot_collection.backend if labeller is None: labeller = BaseLabeller() data = process_group_variables_coords( dt, group=group, var_names=var_names, filter_vars=filter_vars, coords=coords, allow_dict=False, ) if len(sample_dims) > 1: data = data.stack(sample=sample_dims) combined_dim = "sample" else: combined_dim = sample_dims[0] # create divergence mask sample_stats = get_group(dt, "sample_stats", allow_missing=True) if ( sample_stats is not None and "diverging" in sample_stats.data_vars and np.any(sample_stats.diverging) ): div_mask = sample_stats.diverging if coords is not None: div_mask = div_mask.sel( {key: value for key, value in coords.items() if key in div_mask.dims} ) if len(sample_dims) > 1: div_mask = div_mask.stack(sample=sample_dims) else: div_mask = xr.zeros_like(data.coords[combined_dim], dtype=bool) data = data.assign_coords(diverging=div_mask) data = dataset_to_dataarray( data, labeller=labeller, sample_dims=[combined_dim], label_type=label_type ) if len(data.coords["label"].values) <= 1 or data.ndim != 2: raise ValueError( "Parallel plot requires at least two variables to plot. " "Please provide more than one variable." ) # get labels and x-values x_labels = data.coords["label"].values x_values = np.arange(len(x_labels)) # normalize data if norm_method is not None: if norm_method == "normal": mean = data.mean(dim=combined_dim) std_dev = data.std(dim=combined_dim) data = (data - mean) / std_dev elif norm_method == "minmax": min_val = data.min(dim=combined_dim) max_val = data.max(dim=combined_dim) data = (data - min_val) / (max_val - min_val) elif norm_method == "rank": data = data.azstats.compute_ranks(dim=combined_dim) else: raise ValueError(f"{norm_method} is not supported. Use normal, minmax or rank.") new_sample_dims = [combined_dim, "label"] plot_bknd = import_module(f".backend.{backend}", package="arviz_plots") if plot_collection is None: pc_kwargs.setdefault("cols", None) pc_kwargs["figure_kwargs"] = pc_kwargs.get("figure_kwargs", {}).copy() pc_kwargs["aes"] = pc_kwargs.get("aes", {}).copy() colors = plot_bknd.get_default_aes("color", 2, {}) pc_kwargs["aes"].setdefault("color", ["diverging"]) pc_kwargs["aes"].setdefault("alpha", ["diverging"]) pc_kwargs["color"] = pc_kwargs.get("color", [colors[0], colors[1]]) pc_kwargs["alpha"] = pc_kwargs.get("alpha", [0.03, 0.2]) figsize = pc_kwargs.get("figure_kwargs", {}).get("figsize", None) figsize_units = pc_kwargs.get("figure_kwargs", {}).get("figsize_units", "inches") if figsize is None: figsize = plot_bknd.scale_fig_size( figsize, rows=2, cols=1 + 0.2 * len(x_labels), figsize_units=figsize_units, ) label_max_len = np.vectorize(len)(x_labels).max() fontsize = 14 figsize = (figsize[0], figsize[1] + fontsize * label_max_len) figsize_units = "dots" pc_kwargs["figure_kwargs"]["figsize"] = figsize pc_kwargs["figure_kwargs"]["figsize_units"] = figsize_units plot_collection = PlotCollection.wrap( data.to_dataset(), backend=backend, **pc_kwargs, ) if aes_by_visuals is None: aes_by_visuals = {} else: aes_by_visuals = aes_by_visuals.copy() aes_by_visuals.setdefault("line", plot_collection.aes_set) # plot lines line_kwargs = copy(visuals.get("line", {})) if line_kwargs is not False: _, _, line_ignore = filter_aes(plot_collection, aes_by_visuals, "line", new_sample_dims) plot_collection.map( multiple_lines, "line", data=data, x_dim="label", xvalues=x_values, ignore_aes=line_ignore, **line_kwargs, ) # x-axis label xticks_kwargs = copy(visuals.get("xticks", {})) if xticks_kwargs is not False: _, _, xticks_ignore = filter_aes(plot_collection, aes_by_visuals, "xticks", new_sample_dims) xticks_kwargs.setdefault("rotation", 90) plot_collection.map( set_xticks, "xticks", labels=x_labels, values=x_values, ignore_aes=xticks_ignore, artist_dims={"labels": len(x_labels)}, **xticks_kwargs, ) return plot_collection