Source code for arviz_plots.plot_matrix

# pylint: disable=too-many-lines, too-many-public-methods
"""Plot matrix class."""
from importlib import import_module

import numpy as np
import xarray as xr
from arviz_base import rcParams, xarray_sel_iter

from arviz_plots.plot_collection import PlotCollection, concat_model_dict, process_kwargs_subset


def subset_matrix_da(
    da, var_name_x, selection_x, var_name_y=None, selection_y=None, return_dataarray=False
):
    """Get a subset of a matrix-like DataArray.

    This function assumes that `da` has two dimensions with the same base coordinate names each,
    the only difference being the coords of one dimension have ``_x`` as suffix whereas the
    coords of the other have ``_y`` as suffix. Consequently, the two dimensions are referred to
    as the x dimension and the y dimension.

    Parameters
    ----------
    var_name_x : hashable
        Variable name along the y dimension.
    selection_x : mapping
        Mapping defining the coordinate subset along the x dimension.
    var_name_y : hashable, optional
        Variable name along the y dimension. If not provided it is assumed to be the same
        as `var_name_x`
    selection_y : mapping, optional
        Mapping defining the coordinate subset along the y dimension.
        If not provided it is assumed to be the same as `selection_x`
    return_dataarray : bool, default False
        If true, return the subset of a dataarray. Otherwise, the output is a numpy array
        for multidimensional subsets and the stored object itself if the output is a scalar.
    """
    if (var_name_y is None) and (selection_y is None):
        var_name_y = var_name_x
        selection_y = selection_x
    if any(elem is None for elem in (var_name_x, var_name_y, selection_x, selection_y)):
        raise ValueError("Invalid values for subset arguments")
    out = (
        da.set_xindex("var_name_x")
        .set_xindex("var_name_y")
        .sel(var_name_x=var_name_x, var_name_y=var_name_y)
    )
    for dim in set(selection_x).union(selection_y):
        coords = {}
        dim_x = f"{dim}_x"
        dim_y = f"{dim}_y"
        if not (dim_x in da.coords and dim_y in da.coords):
            continue
        if dim in selection_x:
            coords[dim_x] = selection_x[dim]
            out = out.set_xindex(dim_x)
        if dim in selection_y:
            coords[dim_y] = selection_y[dim]
            out = out.set_xindex(dim_y)
        if coords:
            out = out.sel(coords)
    if return_dataarray:
        return out
    if out.size == 1:
        return out.item()
    return out.values


[docs] class PlotMatrix(PlotCollection): """Low level base class for pairwise matrix arranges of plots. Attributes ---------- viz : DataTree aes : DataTree See Also -------- arviz_plots.PlotCollection : Unidimensional facetting manager """
[docs] def __init__(self, data, facet_dims, aes=None, backend=None, figure_kwargs=None, **kwargs): """Initialize a PlotMatrix. Parameters ---------- data : Dataset Data for which to generate the requested matrix layout of plots. facet_dims : list of hashable List of dimensions to use for facetting. It also support the ``__variable__`` indicator to facet across variables. aes : mapping of {str : list of hashable}, optional Dictionary with :term:`aesthetics` as keys and as values a list of the dimensions it should be mapped to. See :meth:`~arviz_plots.PlotMatrix.generate_aes_dt` for more details. backend : str, optional Plotting backend. It will be stored and passed down to the plotting functions when using methods like :meth:`~arviz_plots.PlotMatrix.map`. **kwargs : mapping, optional Dictionary with :term:`aesthetics` as keys and as values a list of the values that should be taken by that aesthetic. """ self._data = concat_model_dict(data) self._facet_dims = facet_dims self._orientation = None self._fixed_var_name = None self._fixed_selection = None if backend is None: backend = rcParams["plot.backend"] self.backend = backend if figure_kwargs is None: figure_kwargs = {} super().__init__( data=self._data, viz_dt=self._generate_viz_dt(**figure_kwargs), aes=aes, backend=backend, **kwargs, )
@property def facet_dims(self): """Facetting dimensions.""" return set(dim for dim in self._facet_dims if dim != "__variable__") def _generate_viz_dt(self, **figure_kwargs): """Generate ``.viz`` DataTree.""" data = self._data facet_dims = self._facet_dims pairs = tuple( xarray_sel_iter(data, skip_dims={dim for dim in data.dims if dim not in facet_dims}) ) n_pairs = len(pairs) n_plots = n_pairs**2 plot_bknd = import_module(f".backend.{self.backend}", package="arviz_plots") fig, ax_ary = plot_bknd.create_plotting_grid( n_plots, n_pairs, n_pairs, squeeze=False, **figure_kwargs ) coords = { "col_index": np.arange(n_pairs), "row_index": np.arange(n_pairs), } for dim in facet_dims: if dim == "__variable__": dim = "var_name" coord_values = [pair[0] for pair in pairs] else: coord_values = [pair[1].get(dim, None) for pair in pairs] coords[f"{dim}_x"] = (("col_index",), coord_values) coords[f"{dim}_y"] = (("row_index",), coord_values) return xr.DataTree( xr.Dataset( { "figure": np.array(fig, dtype=object), "plot": (("row_index", "col_index"), ax_ary), }, coords=coords, ) ) def get_target(self, var_name, selection, var_name_y=None, selection_y=None): """Get the target that corresponds to the given variable and selection. Parameters ---------- var_name : hashable Variable name corresponding to the x dimension. selection : mapping Mapping with with coordinate subset along the x dimension. var_name_y : hashable, optional Variable name corresponding to the y dimension. If not provided it will be assumed as being `var_name` selection_y : mapping, optional Mapping with with coordinate subset along the y dimension. If not provided it will be assumed as being `selection` """ if self._orientation == "row": var_name_y = self._fixed_var_name selection_y = self._fixed_selection elif self._orientation == "col": var_name_y = var_name var_name = self._fixed_var_name selection_y = selection selection = self._fixed_selection return subset_matrix_da( self.viz["plot"], var_name_x=var_name, selection_x=selection, var_name_y=var_name_y, selection_y=selection_y, ) def allocate_artist( self, fun_label, data, all_loop_dims, dim_to_idx=None, artist_dims=None, ignore_aes=None ): """Allocate an visual in the ``viz`` DataTree.""" if artist_dims is None: artist_dims = {} if dim_to_idx: raise ValueError("dim_to_idx not supported yet for PlotMatrix") attrs = None if ignore_aes is not None: attrs = {"ignore_aes": ignore_aes} matrix_sizes = self.viz["plot"].sizes aes_dims = [dim for dim in data.dims if dim not in self.facet_dims and dim in all_loop_dims] artist_shape = ( list(matrix_sizes.values()) + [data.sizes[dim] for dim in aes_dims] + list(artist_dims.values()) ) self._viz_dt[fun_label] = xr.DataArray( np.full(artist_shape, None, dtype=object), dims=list(matrix_sizes) + aes_dims + list(artist_dims.keys()), coords={dim: data[dim] for dim in aes_dims}, attrs=attrs, ) def store_in_artist_da(self, aux_artist, fun_label, var_name, sel, var_name_y=None, sel_y=None): """Store visual object or array into its preallocated DataArray within ``viz``. Parameters ---------- aux_artist The plotting backend class representing an visual to be stored or an array-like of such objects. fun_label : hashable The identifier of the visual within the ``PlotMatrix``. It should be one of the values for which :meth:`~arviz_plots.PlotMatrix.allocate_artist` has already been called. var_name : hashable Variable name corresponding to the x dimension. sel : mapping Mapping with with coordinate subset along the x dimension. var_name_y : hashable, optional Variable name corresponding to the y dimension. If not provided it will be assumed as being `var_name` sel_y : mapping, optional Mapping with with coordinate subset along the y dimension. If not provided it will be assumed as being `sel` """ plot_da = subset_matrix_da( self.viz["plot"], var_name_x=var_name, selection_x=sel, var_name_y=var_name_y, selection_y=sel_y, return_dataarray=True, ) self._viz_dt[fun_label].loc[ {"row_index": plot_da["row_index"], "col_index": plot_da["col_index"]} ] = aux_artist
[docs] def map_upper(self, *args, **kwargs): """Call :meth:`~arviz_plots.PlotMatrix.map_triangle` with ``triangle="upper"``.""" self.map_triangle(*args, triangle="upper", **kwargs)
[docs] def map_lower(self, *args, **kwargs): """Call :meth:`~arviz_plots.PlotMatrix.map_triangle` with ``triangle="lower"``.""" self.map_triangle(*args, triangle="lower", **kwargs)
[docs] def map_triangle( self, fun, fun_label=None, *, data=None, loop_data=None, triangle="both", coords=None, ignore_aes=frozenset(), subset_info=False, store_artist=True, artist_dims=None, **kwargs, ): """Apply the given plotting function to all plots with the corresponding aesthetics. Parameters ---------- fun : callable Function with signature ``fun(da_x, da_y, target, **fun_kwargs)`` which should be called for all couples of data pairs (each couple encoded in a :term:`plot`) and corresponding :term:`aesthetic`. The object returned by `fun` is assumed to be an scalar unless `artist_dims` are provided. There is also the option of adding extra keyword arguments with the `subset_info` flag. fun_label : str, optional Function identifier. It will be used as variable name to store the object returned by `fun`. Defaults to ``fun.__name__``. data : Dataset, optional Data to be subsetted into pair elements then loop to cover all couple combinations. Defaults to the data used to initalize the ``PlotMatrix``. loop_data : Dataset or str TODO: see if it works and if we want to keep it. coords : mapping, optional Dictionary of {coordinate names : coordinate values} that should be used to subset the aes, data and viz objects before any faceting or aesthetics mapping is applied. ignore_aes : set, optional Set of aesthetics present in ``aes`` that should be ignore for this ``map`` call. subset_info : boolean, default False Add the subset info from :func:`arviz_base.xarray_sel_iter` for the ``da_x``+``da_y`` couple to the keyword arguments passed to `fun`. If true, then `fun` must accept the keyword arguments ``var_name_x``, ``sel_x``, ``isel_x``, ``var_name_y``, ``sel_y`` and ``isel_y``. Moreover, if those were to be keys present in `**kwargs` their values in `**kwargs` would be ignored. store_artist : boolean, default True artist_dims : mapping of {hashable : int}, optional Dictionary of sizes for proper allocation and storage when using ``map`` with functions that return an array of :term:`visual`. **kwargs Extra keyword arguments to be passed to `fun`. See Also -------- arviz_plots.PlotMatrix.map arviz_plots.PlotMatrix.map_row arviz_plots.PlotMatrix.map_col """ if triangle not in {"lower", "upper", "both"}: raise ValueError( "Invalid value for `triangle` options are 'lower', 'upper' or 'both' " f"but got {triangle}" ) if coords is None: coords = {} if fun_label is None: fun_label = fun.__name__ data = self.data if data is None else data if isinstance(loop_data, str) and loop_data == "plots": if "plot" in self.viz.data_vars: loop_data = xr.Dataset({key: self.viz.ds["plot"] for key in data.data_vars}) else: loop_data = xr.Dataset( {var_name: ds["plot"] for var_name, ds in self.viz.children.items()} ) loop_data = data if loop_data is None else loop_data if not isinstance(data, xr.Dataset): raise TypeError("data argument must be an xarray.Dataset") facet_dims = self.facet_dims aes, all_loop_dims = self.update_aes(ignore_aes, coords) aes_dims = [dim for dim in all_loop_dims if dim not in facet_dims] # all variables must have all dimensions with aesthetics mapped to them # we only care about the dim+coord combinations aes_loopers = list( xarray_sel_iter( loop_data[list(loop_data.data_vars)[0]], skip_dims={dim for dim in loop_data.dims if dim not in aes_dims}, ) ) plotters = list( xarray_sel_iter( loop_data, skip_dims={dim for dim in loop_data.dims if dim not in facet_dims} ) ) if store_artist: self.allocate_artist( fun_label=fun_label, data=loop_data, all_loop_dims=all_loop_dims, artist_dims=artist_dims, ignore_aes=ignore_aes, ) if self._orientation is not None: raise ValueError(f"Orientation is set to {self._orientation}, it should be None") for i, (var_name_x, sel_x_base, isel_x_base) in enumerate(plotters): upper_elements = plotters[:i] lower_elements = plotters[i + 1 :] if triangle == "lower": second_loop_elements = lower_elements elif triangle == "upper": second_loop_elements = upper_elements elif triangle == "both": second_loop_elements = lower_elements + upper_elements for var_name_y, sel_y_base, isel_y_base in second_loop_elements: da_x_base = data[var_name_x].sel(sel_x_base) da_y_base = data[var_name_y].sel(sel_y_base) for _, aes_sel, aes_isel in aes_loopers: da_x = da_x_base.sel(aes_sel) da_y = da_y_base.sel(aes_sel) try: if np.all(np.isnan(da_x)) or np.all(np.isnan(da_y)): continue except TypeError: pass sel_x = {**sel_x_base, **aes_sel} sel_y = {**sel_y_base, **aes_sel} isel_x = {**isel_x_base, **aes_isel} isel_y = {**isel_y_base, **aes_isel} sel_x_plus = {**sel_x, **coords} sel_y_plus = {**sel_y, **coords} target = self.get_target(var_name_x, sel_x_plus, var_name_y, sel_y_plus) aes_kwargs = self.get_aes_kwargs(aes, var_name_x, aes_sel) fun_kwargs = { **aes_kwargs, **{ key: process_kwargs_subset(values, var_name_x, aes_sel) for key, values in kwargs.items() }, } if subset_info: fun_kwargs = { **fun_kwargs, "var_name_x": var_name_x, "sel_x": sel_x, "isel_x": isel_x, "var_name_y": var_name_y, "sel_y": sel_y, "isel_y": isel_y, } aux_artist = fun(da_x, da_y, target=target, **fun_kwargs) if store_artist: if np.size(aux_artist) == 1: aux_artist = np.squeeze(aux_artist) self.store_in_artist_da( aux_artist, fun_label, var_name_x, sel_x, var_name_y=var_name_y, sel_y=sel_y, )
[docs] def map( self, fun, fun_label=None, *, data=None, coords=None, ignore_aes=frozenset(), subset_info=False, store_artist=True, artist_dims=None, **kwargs, ): """Apply the given plotting function along the diagonal with the corresponding aesthetics. Parameters ---------- fun : callable Function with signature ``fun(da, target, **fun_kwargs)`` which should be applied for all combinations of :term:`plot` and :term:`aesthetic`. The object returned by `fun` is assumed to be a scalar unless `artist_dims` are provided. There is also the option of adding extra required keyword arguments with the `subset_info` flag. fun_label : str, optional Variable name with which to store the object returned by `fun`. Defaults to ``fun.__name__``. data : Dataset, optional Data to be subsetted at each iteration and to pass to `fun` as first positional argument. Defaults to the data used to initialize the ``PlotMatrix``. coords : mapping, optional Dictionary of {coordinate names : coordinate values} that should be used to subset the aes, data and viz objects before any faceting or aesthetics mapping is applied. ignore_aes : set, optional Set of aesthetics present in ``aes`` that should be ignore for this ``map`` call. subset_info : boolean, default False Add the subset info from :func:`arviz_base.xarray_sel_iter` to the keyword arguments passed to `fun`. If true, then `fun` must accept the keyword arguments ``var_name``, ``sel`` and ``isel``. Moreover, if those were to be keys present in `**kwargs` their values in `**kwargs` would be ignored. store_artist : boolean, default True artist_dims : mapping of {hashable : int}, optional Dictionary of sizes for proper allocation and storage when using ``map`` with functions that return an array of :term:`visual`. **kwargs Keyword arguments passed as is to `fun`. Values within `**kwargs` with :class:`~xarray.DataArray` of :class:`~xarray.Dataset` type will be subsetted on the current selection (if possible) before calling `fun`. Slicing with dims and coords is applied to the relevant subset present in the xarray object so dimensions with mapped asethetics not being present is not an issue. However, using Datasets that don't contain all the variable names in `data` will raise an error. See Also -------- arviz_plots.PlotMatrix.map_triangle arviz_plots.PlotMatrix.map_row arviz_plots.PlotMatrix.map_col """ super().map( fun=fun, fun_label=fun_label, data=data, coords=coords, ignore_aes=ignore_aes, subset_info=subset_info, store_artist=store_artist, artist_dims=artist_dims, **kwargs, )
def set_fixed_var_attributes(self, index, orientation="row"): """Set fixed variable attributes for the current orientation according to given index.""" remove_list = ["col_index", "row_index", "var_name_x", "var_name_y"] fixed_line = self.viz.var_name_y[index] fixed_line_sel = { key[:-2]: value.item() for key, value in fixed_line.coords.items() if key not in remove_list and value.item() is not None } fixed_line_var_name = fixed_line.values.item() self._fixed_var_name = fixed_line_var_name self._fixed_selection = fixed_line_sel self._orientation = orientation
[docs] def map_row( self, fun, fun_label=None, index=0, *, data=None, coords=None, ignore_aes="all", subset_info=False, store_artist=True, artist_dims=None, **kwargs, ): """Apply the given plotting function along the row with the corresponding aesthetics. Parameters ---------- fun : callable Function with signature ``fun(da, target, **fun_kwargs)`` which should be applied for all combinations of :term:`plot` and :term:`aesthetic`. The object returned by `fun` is assumed to be a scalar unless `artist_dims` are provided. There is also the option of adding extra required keyword arguments with the `subset_info` flag. fun_label : str, optional Variable name with which to store the object returned by `fun`. Defaults to ``fun.__name__``. index : int, default 0 Index of the row to be mapped by the given plotting function. data : Dataset, optional Data to be subsetted at each iteration and to pass to `fun` as first positional argument. Defaults to the data used to initialize the ``PlotMatrix``. coords : mapping, optional Dictionary of {coordinate names : coordinate values} that should be used to subset the aes, data and viz objects before any faceting or aesthetics mapping is applied. ignore_aes : str or set of str, default "all" Set of aesthetics present in ``aes`` that should be ignore for this ``map`` call. subset_info : boolean, default False Add the subset info from :func:`arviz_base.xarray_sel_iter` to the keyword arguments passed to `fun`. If true, then `fun` must accept the keyword arguments ``var_name``, ``sel`` and ``isel``. Moreover, if those were to be keys present in `**kwargs` their values in `**kwargs` would be ignored. store_artist : boolean, default True artist_dims : mapping of {hashable : int}, optional Dictionary of sizes for proper allocation and storage when using ``map`` with functions that return an array of :term:`visual`. **kwargs Keyword arguments passed as is to `fun`. Values within `**kwargs` with :class:`~xarray.DataArray` of :class:`~xarray.Dataset` type will be subsetted on the current selection (if possible) before calling `fun`. Slicing with dims and coords is applied to the relevant subset present in the xarray object so dimensions with mapped asethetics not being present is not an issue. However, using Datasets that don't contain all the variable names in `data` will raise an error. See Also -------- arviz_plots.PlotMatrix.map_col arviz_plots.PlotMatrix.map arviz_plots.PlotMatrix.map_triangle """ self.set_fixed_var_attributes(index, "row") super().map( fun=fun, fun_label=fun_label, data=data, coords=coords, ignore_aes=ignore_aes, subset_info=subset_info, store_artist=store_artist, artist_dims=artist_dims, **kwargs, ) self._orientation = None self._fixed_selection = None self._fixed_var_name = None
[docs] def map_col( self, fun, fun_label=None, index=0, *, data=None, coords=None, ignore_aes="all", subset_info=False, store_artist=True, artist_dims=None, **kwargs, ): """Apply the given plotting function along the column with the corresponding aesthetics. Parameters ---------- fun : callable Function with signature ``fun(da, target, **fun_kwargs)`` which should be applied for all combinations of :term:`plot` and :term:`aesthetic`. The object returned by `fun` is assumed to be a scalar unless `artist_dims` are provided. There is also the option of adding extra required keyword arguments with the `subset_info` flag. fun_label : str, optional Variable name with which to store the object returned by `fun`. Defaults to ``fun.__name__``. index : int, default 0 Index of the column to be mapped by the given plotting function. data : Dataset, optional Data to be subsetted at each iteration and to pass to `fun` as first positional argument. Defaults to the data used to initialize the ``PlotMatrix``. coords : mapping, optional Dictionary of {coordinate names : coordinate values} that should be used to subset the aes, data and viz objects before any faceting or aesthetics mapping is applied. ignore_aes : str or set of str, default "all" Set of aesthetics present in ``aes`` that should be ignore for this ``map`` call. subset_info : boolean, default False Add the subset info from :func:`arviz_base.xarray_sel_iter` to the keyword arguments passed to `fun`. If true, then `fun` must accept the keyword arguments ``var_name``, ``sel`` and ``isel``. Moreover, if those were to be keys present in `**kwargs` their values in `**kwargs` would be ignored. store_artist : boolean, default True artist_dims : mapping of {hashable : int}, optional Dictionary of sizes for proper allocation and storage when using ``map`` with functions that return an array of :term:`visual`. **kwargs Keyword arguments passed as is to `fun`. Values within `**kwargs` with :class:`~xarray.DataArray` of :class:`~xarray.Dataset` type will be subsetted on the current selection (if possible) before calling `fun`. Slicing with dims and coords is applied to the relevant subset present in the xarray object so dimensions with mapped asethetics not being present is not an issue. However, using Datasets that don't contain all the variable names in `data` will raise an error. See Also -------- arviz_plots.PlotMatrix.map_row arviz_plots.PlotMatrix.map arviz_plots.PlotMatrix.map_triangle """ self.set_fixed_var_attributes(index, "col") super().map( fun=fun, fun_label=fun_label, data=data, coords=coords, ignore_aes=ignore_aes, subset_info=subset_info, store_artist=store_artist, artist_dims=artist_dims, **kwargs, ) self._orientation = None self._fixed_selection = None self._fixed_var_name = None
@property def viz(self): """Information about the visual elements in the plot as a DataTree. The DataTree only has variables in the root group. With all variables having the same dimensions: ``(row_index, col_index)``. The information about facetting is encoded in the coordinate values; ``row_index`` has all relevant coordinates to indicate the subset with ``_y`` suffix, ``col_index`` has coordinates with the ``_x`` suffix. The `viz` DataTree always contains the following variables: * ``figure`` (always on the home group) -> Scalar object containing the highest level plotting structure. i.e. the matplotlib figure or the bokeh layout * ``plot`` -> :term:`Plot` objects in this :term:`figure`. Generally, these are the target where :term:`visuals <visual>` are added, although it is possible to have visuals targetting the figure itself. Plus all the visuals that have been added to the plot and stored. See :meth:`arviz_plots.PlotMatrix.map` and :meth:`arviz_plots.PlotMatrix.map_triangle` for more details. """ if self.coords is None: return self._viz_dt raise ValueError("viz attribute can't be accessed with coords set")