# pylint: disable=too-many-lines, too-many-public-methods
"""Plot collection class."""
import warnings
from importlib import import_module
from pathlib import Path
import numpy as np
import xarray as xr
from arviz_base import rcParams, xarray_sel_iter
from arviz_base.labels import BaseLabeller
def backend_from_object(obj, return_module=True):
"""Get the backend string or module that corresponds to a given object.
Parameters
----------
obj
The object to get its corresponding backend for.
return_module : bool, default True
Return the module from ``arviz_plots.backend` after importing it
Returns
-------
backend : module or str
"""
# cover none backend first, the figure object is a dictionary,
# and the plot objects are lists
if isinstance(obj, list | dict):
backend = "none"
else:
lib, *_, leaf = obj.__module__.split(".")
# for plotly, the target will actually be an arviz_plots.backend.plotly.PlotlyPlot
if lib == "arviz_plots":
backend = leaf
else:
backend = lib
if return_module:
return import_module(f"arviz_plots.backend.{backend}")
return backend
def concat_model_dict(data):
"""Merge multiple Datasets into a single one along a new model dimension."""
if isinstance(data, dict):
ds_list = data.values()
if not all(isinstance(ds, xr.Dataset) for ds in ds_list):
raise TypeError("Provided data must be a Dataset or dictionary of Datasets")
data = xr.concat(ds_list, dim="model").assign_coords(model=list(data))
return data
def sel_subset(sel, ds_da):
"""Subset a dictionary of dim: coord values.
The returned dictionary contains only the keys that
are present to ensure we can use the output of this function
to index correctly using ``.sel``.
Preference is given to indexers with the same name as the dimension,
but
"""
dim_subset = {key: value for key, value in sel.items() if key in ds_da.dims}
dims_with_coords = list(dim_subset)
for key in sel:
if key in dim_subset:
continue
if key in ds_da.coords:
da_indexer = ds_da[key]
if da_indexer.ndim == 1 and da_indexer.dims[0] not in dims_with_coords:
dim_subset[key] = sel[key]
dims_with_coords.append(da_indexer.dims[0])
return dim_subset
def subset_ds(ds, var_name, sel):
"""Subset a dataset in a potentially non-idempotent way.
Get a subset indicated by `sel` of the variable in the Dataset indicated by `var_names`
and return a scalar or a numpy array. This helps with getting the proper matplotlib
axes or bokeh figure, converting the DataArrays we get from ``.sel`` to arrays to ensure
compatibility with all plotting backends... without having to add ``.item()`` or ``.value``
constantly in the code. It also calls :func:`sel_subset` to ensure ``.sel`` doesn't error.
The variable name indicated by `var_name` needs to exist though.
Parameters
----------
ds : Dataset
var_name : hashable
sel : mapping
"""
ds = ds[var_name]
if isinstance(ds, xr.DataTree):
ds = ds.dataset
subset_dict = sel_subset(sel, ds)
if subset_dict:
for key in subset_dict:
if key not in ds.dims and key not in ds.xindexes:
ds = ds.set_xindex(key)
out = ds.sel(subset_dict)
else:
out = ds
if out.size == 1:
return out.item()
return out.values
def try_da_subset(da, sel):
"""Try subsetting a dataarray with `.sel`.
There are 3 possible cases:
* None of the keys in `sel` are dimensions in `da` -> `da` is returned as is
* Some (or all) of the keys in `sel` are dimensions in `da`:
- `.sel` on the subset of dimensions present works -> return `da` subset
- `.sel` raises a KeyError -> return ``None``
"""
subset_dict = sel_subset(sel, da)
if subset_dict:
for key in subset_dict:
if key not in da.xindexes:
da = da.set_xindex(key)
try:
da = da.sel(subset_dict)
except KeyError:
return None
return da
def process_kwargs_subset(value, var_name, sel):
"""Process kwargs to subset xarray objects if possible.
Anything not a Dataset or DataArray is returned as is.
"""
if isinstance(value, xr.Dataset):
if var_name not in value.data_vars:
subset_dict = sel_subset(sel, value)
if subset_dict:
try:
ds = value.sel(subset_dict)
except KeyError:
return None
return ds
return value
value = value[var_name]
if isinstance(value, xr.DataArray):
return try_da_subset(value, sel)
return value
def process_facet_dims(data, facet_dims):
"""Process faceting dimensions.
It takes into account the ``__variable__`` "special dimension name" and helps find out
how many plots are needed.
"""
if not facet_dims:
return 1, {}
facets_per_var = {}
if "__variable__" in facet_dims:
for var_name, da in data.items():
lenghts = [
len(np.unique(da[dim]))
for dim in facet_dims
if dim in set(da.dims).union(da.coords)
]
facets_per_var[var_name] = np.prod(lenghts) if lenghts else 1
n_facets = np.sum(list(facets_per_var.values()))
else:
missing_dims = {
var_name: [dim for dim in facet_dims if dim not in set(da.dims).union(da.coords)]
for var_name, da in data.items()
}
missing_dims = {k: v for k, v in missing_dims.items() if v}
if any(missing_dims.values()):
raise ValueError(
"All variables must have all faceting dimensions, but found the following "
f"dims to be missing in these variables: {missing_dims}"
)
n_facets = np.prod([len(np.unique(data[dim])) for dim in facet_dims])
return n_facets, facets_per_var
def leaf_dataset(dt, leaf_name):
"""Get leaf nodes named `leaf_name` from `dt`.
Parameters
----------
dt : DataTree
leaf_name : hashable
Returns
-------
Dataset
"""
return xr.Dataset({var_name: values[leaf_name] for var_name, values in dt.children.items()})
[docs]
class PlotCollection:
"""Low level base class for plotting with xarray Datasets.
This class instantiates a figure with multiple plots in it and provides methods to loop
over these plots and the provided data syncing each plot and data subset to
user given aesthetics.
Attributes
----------
viz : DataTree
aes : DataTree
See Also
--------
arviz_plots.PlotMatrix : Pairwise facetting manager
"""
[docs]
def __init__(self, data, viz_dt, aes_dt=None, aes=None, backend=None, **kwargs):
"""Initialize a PlotCollection.
It is not recommeded to initialize ``PlotCollection`` objects directly.
Use its classmethods :meth:`~arviz_plots.PlotCollection.wrap` and
:meth:`~arviz_plots.PlotCollection.grid` instead.
Parameters
----------
data : Dataset
The data from which `viz_dt` was generated and
from which to generate the aesthetic mappings.
viz_dt : DataTree
DataTree object with which to populate the ``viz`` attribute.
aes_dt : DataTree, optional
DataTree object with which to populate the ``aes`` attribute.
If given, the `aes` argument and all `**kwargs` are ignored.
aes : mapping of {str : list of hashable}, optional
Dictionary with :term:`aesthetics` as keys and as values a list
of the dimensions it should be mapped to.
See :meth:`~arviz_plots.PlotCollection.generate_aes_dt` for more details.
backend : str, optional
Plotting backend. It will be stored and passed down to the plotting
functions when using methods like :meth:`~arviz_plots.PlotCollection.map`.
**kwargs : mapping, optional
Dictionary with :term:`aesthetics` as keys and as values a list
of the values that should be taken by that aesthetic.
See Also
--------
arviz_plots.PlotCollection.grid, arviz_plots.PlotCollection.wrap
"""
self._data = data
self._coords = None
self._viz_dt = viz_dt
if backend is not None:
self.backend = backend
elif "figure" in viz_dt:
self.backend = backend_from_object(self.viz_dt["figure"].item(), return_module=False)
if aes_dt is None:
if aes is None:
aes = {}
self._aes_dt = self.generate_aes_dt(aes, data, **kwargs)
else:
self._aes_dt = aes_dt
@property
def aes(self):
"""Information about :term:`aesthetic mapping` as a DataTree.
For aesthetics where the variable is used to encode information
(that is, "__variable__" was used, a subset of the input dataset
``ds[var_name].sel(**kwargs)`` is associated the aesthetics in
``aes[aes_key][var_name].sel(**kwargs)``.
For aesthetics mappping that only use dimensions for mapping the dataset
will have a variable "mapping" with shape inherited from the mapped dimensions
in the original data, and might also have a "neutral_element" scalar
variable.
The docstring for :meth:`arviz_plots.PlotCollection.generate_aes_dt`
has examples and covers the "neutral element" concept in more detail.
See Also
--------
.PlotCollection.generate_aes_dt
.PlotCollection.get_aes_kwargs
"""
if self.coords is None:
return self._aes_dt
return xr.DataTree.from_dict(
{
group: ds.to_dataset().sel(sel_subset(self.coords, ds))
for group, ds in self._aes_dt.children.items()
}
)
@aes.setter
def aes(self, value):
if self.coords is not None:
raise ValueError("Can't modify `aes` DataTree while `coords` is set")
self._aes_dt = value
@property
def viz(self):
"""Information about the visual elements in the plot as a DataTree.
Plot elements like :term:`visuals`, :term:`plots` and the :term:`figure`
are stored at the top level, if possible directly as DataArrays,
otherwise as groups whose variables are variable names in the input
Dataset.
The `viz` DataTree always contains the following leaf variables:
* ``figure`` (always on the home group) -> Scalar object containing the highest level
plotting structure. i.e. the matplotlib figure or the bokeh layout
* ``plot`` -> :term:`Plot` objects in this :term:`figure`.
Generally, these are the target where :term:`visuals <visual>` are added,
although it is possible to have visuals targetting the figure itself.
* ``row`` -> Integer row indicator
* ``col`` -> Integer column indicator
See :meth:`arviz_plots.PlotCollection.map` for more details.
"""
if self.coords is None:
return self._viz_dt
# TODO: use .loc on DataTree directly (once available), otherwise, changes to
# .viz aren't stored in the PlotCollection class, same in `aes`
sliced_viz_dict = {
group: ds.to_dataset().sel(sel_subset(self.coords, ds))
for group, ds in self._viz_dt.children.items()
}
root_ds = self._viz_dt.to_dataset()
sliced_viz_dict["/"] = root_ds.sel(sel_subset(self.coords, root_ds))
return xr.DataTree.from_dict(sliced_viz_dict)
@viz.setter
def viz(self, value):
if self.coords is not None:
raise ValueError("Can't modify `viz` DataTree while `coords` is set")
self._viz_dt = value
@property
def coords(self):
"""Information about slicing operation to always be applied on the PlotCollection.
It is similar to the ``coords`` argument in :meth:`~.PlotCollection.map` but
these coordinates are always taken into account when interfacing with `PlotCollection`,
even when accessing :attr:`~.PlotCollection.viz` or :attr:`~.PlotCollection.aes`.
"""
return self._coords
@coords.setter
def coords(self, value):
self._coords = value
@property
def data(self):
"""Dataset to be used as data for plotting."""
return self._data
@data.setter
def data(self, value):
# might want/be possible to make some checks on the data before setting it
self._data = concat_model_dict(value)
@property
def aes_set(self):
"""Return all aesthetics with a mapping defined as a set."""
return set(self.aes.children)
[docs]
def show(self):
"""Call the backend function to show this :term:`figure`."""
if "figure" not in self.viz:
raise ValueError("No plot found to be shown")
plot_bknd = import_module(f".backend.{self.backend}", package="arviz_plots")
figure = self.viz["figure"].item()
if figure is not None:
plot_bknd.show(figure)
else:
self.viz["plot"].item()
[docs]
def savefig(self, filename, **kwargs):
"""Call the backend function to save this :term:`figure`.
Parameters
----------
filename : str or `~pathlib.Path`
**kwargs
Passed as is to the respective backend function
"""
if "figure" not in self.viz:
raise ValueError("No plot found to be saved")
plot_bknd = import_module(f".backend.{self.backend}", package="arviz_plots")
plot_bknd.savefig(self.viz["figure"].item(), Path(filename), **kwargs)
[docs]
def generate_aes_dt(self, aes, data=None, **kwargs):
"""Generate the aesthetic mappings.
Populate and store the ``DataTree`` attribute ``.aes`` of the ``PlotCollection``.
Parameters
----------
aes : mapping of {str : list of hashable or False}
Dictionary with :term:`aesthetics` as keys and as values a list
of the dimensions it should be mapped to. The pseudo-dimension
``__variable__`` is also valid to indicate the variable should be
part of the aesthetic mapping.
It can also take ``False`` as value to indicate that no mapping
should be considered for that aesthetic key.
data : Dataset, optional
Data for which to generate the aesthetic mappings.
**kwargs : mapping, optional
Dictionary with :term:`aesthetics` as keys and as values a list
of the values that should be taken by that aesthetic.
Notes
-----
Mappings are applied only when all variables defined in the mapping are found.
Thus, a mapping for ``["chain", "hierarchy"]`` would be applied if both
dimensions are present in the variable, otherwise it is completely ignored.
It can be the case that a mapping is ignored for a specific variable
because it has none of the dimensions that define the mapping or because
it doesn't have all of them. In such cases, out of the values in the property
cycle, the first one is taken out and reserved as *neutral_element*.
Then, the cycle excluding the first element is used when applying the mapping,
and the neutral element is used when the mapping can't be applied.
It is possible to force the inclusion of the neutral element from the
property value cycle by providing the same value in both the first and second
positions in the cycle, but this is generally not recommended.
Examples
--------
Initialize a `PlotCollection` with the rugby dataset as data.
faceting and aesthetics mapping are independent. Thus, as
we are limiting ourselves to the use of this method, we can
provide an empty DataTree as ``viz_dt``.
.. jupyter-execute::
from arviz_base import load_arviz_data
from arviz_plots import PlotCollection
import xarray as xr
idata = load_arviz_data("rugby_field")
pc = PlotCollection(idata.posterior, xr.DataTree(), backend="matplotlib")
aes_dt = pc.generate_aes_dt(
aes={
"color": ["__variable__", "team"],
"y": ["field", "team"],
"marker": ["field"],
"linestyle": ["chain"],
},
color=[f"C{i}" for i in range(6)],
y=list(range(13)),
linestyle=["-", ":", "--", "-."],
)
The generated `aes_dt` has one group per aesthetic. Within each group
There can be the variables from the Dataset used to initialize the
PlotCollection or the variables "mapping" and "neutral_element".
Let's inspect its contents for each aesthetic.
We'll start with the color which had ``__variable__, team`` as dimensions
to encode.
.. jupyter-execute::
aes_dt["color"]
In this case, each unique combination of variable and coordinate value of the
team dimension gets a different color. They only end up being repeated once
the provided cycler runs out of elements. In the cases where ``__variable__``
is used, the data subset ``ds[var_name].sel(coords)`` gets the aesthetic
values in `aes_dt[aes_key][var_name].sel(coords)`, however, this isn't
always as straightforward; thus, the recommended way to get the corresponding
aes for a specific subset is using :meth:`~arviz_plots.PlotCollection.get_aes_kwargs`
Next let's look at the marker. We didn't provide any defaults for the marker,
but as we specified the backend, some default values were generated for us.
Here, we asked to encode the "field" dimension information only:
.. jupyter-execute::
aes_dt["marker"]
We have a "neutral_element" variable which will be used for variables
where the field dimension is not present and a "mapping" variable
with a different marker value per coordinate in the field dimension,
with all these values being different to the "neutral_element" one.
The "y" aesthetic is very similar.
Lastly, the "linestyle" aesthetic, which we asked to use to encode the
chain information.
.. jupyter-execute::
aes_dt["linestyle"]
As all variables have the "chain" dimension, there is no "neutral_element"
here, and the first element in the property cycle (here the solid line "-")
is used as part of the "chain" mapping instead of being reserved for
variables without a "chain" dimension. Note that in such cases,
trying to use a data variable without "chain" as dimension would
result in an error, the mapping is not defined.
See Also
--------
.PlotCollection.get_aes_kwargs
"""
if data is None:
data = self.data
aes = {key: value for key, value in aes.items() if value is not False}
extra_keys = [key for key in kwargs if key not in aes]
if extra_keys:
raise ValueError(
f"Keyword arguments {extra_keys} have been passed as **kwargs but "
"have no active aesthetic mapped to them. Keyword arguments must define "
"values to use in their respective aesthetic mapping."
)
if not hasattr(self, "backend"):
plot_bknd = import_module(".backend.none", package="arviz_plots")
else:
plot_bknd = import_module(f".backend.{self.backend}", package="arviz_plots")
get_default_aes = plot_bknd.get_default_aes
ds_dict = {aes_key: xr.Dataset() for aes_key in aes}
all_dims = set(dim for dims in aes.values() for dim in dims)
clean_sizes = {}
coords = {}
for dim in all_dims:
if dim == "__variable__":
continue
unique_values = np.unique(data[dim])
clean_sizes[dim] = len(unique_values)
if len(data[dim]) == len(unique_values):
# preserve original order if there are no unique values
coords[dim] = data[dim]
else:
coords[dim] = unique_values
for aes_key, dims in aes.items():
if "__variable__" in dims:
total_aes_vals = int(
np.sum(
[
np.prod(
[
clean_sizes[dim]
for dim in dims
if dim in set(da.dims).union(da.coords)
]
)
for da in self.data.values()
]
)
)
aes_vals = get_default_aes(aes_key, total_aes_vals, kwargs)
aes_cumulative = 0
for var_name, da in data.items():
aes_dims = [dim for dim in dims if dim in set(da.dims).union(da.coords)]
aes_raw_shape = [clean_sizes[dim] for dim in aes_dims]
if not aes_raw_shape:
ds_dict[aes_key][var_name] = np.asarray(aes_vals)[
aes_cumulative : aes_cumulative + 1
].squeeze()
aes_cumulative += 1
continue
n_aes = np.prod(aes_raw_shape)
ds_dict[aes_key][var_name] = xr.DataArray(
np.array(aes_vals[aes_cumulative : aes_cumulative + n_aes]).reshape(
aes_raw_shape
),
dims=aes_dims,
coords={dim: coords[dim] for dim in aes_dims},
)
aes_cumulative += n_aes
else:
aes_dims_in_var = {
var_name: set(dims) <= set(da.dims).union(da.coords)
for var_name, da in data.items()
}
if not any(aes_dims_in_var.values()):
warnings.warn(
f"Provided mapping for {aes_key} will only use the neutral element"
)
aes_shape = [clean_sizes[dim] for dim in dims]
total_aes_vals = int(np.prod(aes_shape))
neutral_element_needed = not all(aes_dims_in_var.values())
aes_vals = get_default_aes(aes_key, total_aes_vals + neutral_element_needed, kwargs)
if neutral_element_needed:
neutral_element = aes_vals[0]
ds_dict[aes_key]["neutral_element"] = neutral_element
aes_vals_no_neutral = [val for val in aes_vals if val != neutral_element]
if aes_vals_no_neutral[0] in aes_vals_no_neutral[1:]:
cycle_repeat_index = aes_vals_no_neutral[1:].index(aes_vals_no_neutral[0])
aes_vals_no_neutral = aes_vals_no_neutral[: cycle_repeat_index + 1]
if aes_vals[1] == neutral_element:
aes_vals = [neutral_element] + aes_vals_no_neutral
else:
aes_vals = aes_vals_no_neutral
aes_vals = get_default_aes(
aes_key,
total_aes_vals,
{aes_key: aes_vals},
)
ds_dict[aes_key]["mapping"] = xr.DataArray(
np.array(aes_vals).reshape(aes_shape),
dims=dims,
coords={dim: coords[dim] for dim in dims},
)
return xr.DataTree.from_dict(ds_dict)
[docs]
def get_aes_as_dataset(self, aes_key):
"""Get the values of the provided aes_key for all variables as a Dataset.
Parameters
----------
aes_key : str
Aesthetic mapping whose values should be returned as a Dataset.
Must be a leaf node of all groups in :attr:`~.PlotCollection.aes`
Returns
-------
Dataset
See Also
--------
arviz_plots.PlotCollection.update_aes_from_dataset
"""
return self.aes[aes_key].to_dataset()
[docs]
def update_aes_from_dataset(self, aes_key, dataset):
"""Update the values of aes_key with those in the provided Dataset.
Parameters
----------
aes_key : str
Aesthetic mapping whose values should be updated or added.
:attr:`~.PlotCollection.aes` will contain `aes_key` as a leaf
for all its groups, with the values provided.
dataset : Dataset
Dataset containing the `aes_key` values for each data variable.
The data variables of the Dataset must match the groups of
:attr:`~.PlotCollection.aes`
See Also
--------
arviz_plots.PlotCollection.get_aes_as_dataset
"""
self._aes_dt[aes_key] = dataset
@property
def facet_dims(self):
"""Dimensions over which one should loop for facetting when using this PlotCollection.
When adding specific visuals, we might need to loop over more dimensions than these ones
due to the defined aesthetic mappings.
"""
return set(self.viz["plot"].dims)
[docs]
def get_viz(self, artist_name, var_name=None, sel=None, **sel_kwargs):
"""Get element from ``.viz`` that corresponds to the provided subset.
Parameters
----------
artist_name : str
var_name : str, optional
sel : mapping, optional
**sel_kwargs : mapping, optional
kwargs version of `sel`
"""
if sel is None:
sel = {}
sel = sel | sel_kwargs
out = self.viz[artist_name]
if isinstance(out, xr.DataTree):
out = out.dataset
if var_name is not None:
out = out[var_name]
subset = sel_subset(sel, out)
if subset:
out = out.sel(subset)
if isinstance(out, xr.DataArray) and out.size == 1:
return out.item()
return out
def rename_visuals(self, name_dict=None, **names):
"""Rename visual data variables in the :attr:`~.PlotCollection.viz` DataTree.
Parameters
----------
name_dict, **names : mapping
Keys are current visual names and values are desired names.
At least one of these must be provided.
"""
if name_dict is None:
name_dict = names
else:
name_dict = names | name_dict
self.viz = self.viz.assign(
{
desired_name: self.viz[current_name]
for current_name, desired_name in name_dict.items()
}
).drop_nodes(list(name_dict.keys()))
[docs]
@classmethod
def wrap(
cls,
data,
cols=None,
col_wrap=4,
backend=None,
figure_kwargs=None,
**kwargs,
):
"""Instatiate a PlotCollection and generate a plot grid iterating over subsets and wrapping.
Parameters
----------
data : Dataset or dict of {str: Dataset}
If `data` is a dictionary, the Datasets stored as its values will be concatenated,
creating a new dimension called ``model``.
cols : iterable of hashable, optional
Dimensions of the dataset for which different coordinate values
should have different :term:`plots <plot>`. A special dimension
called ``__variable__`` is also available, to indicate that
each variable of the input Dataset should have their own plot;
it can also be combined with other dimensions.
col_wrap : int, default 4
Number of columns in the generated grid. If more than `col_wrap`
plots are needed from :term:`faceting` according to `cols`,
new rows are created.
backend : str, optional
Plotting backend.
figure_kwargs : mapping, optional
Passed to :func:`~.backend.create_plotting_grid` of the chosen plotting backend.
**kwargs : mapping, optional
Passed as is to the initializer of ``PlotCollection``. That is,
used for ``aes`` and ``**kwargs`` arguments.
See :meth:`~arviz_plots.PlotCollection.generate_aes_dt` for more
details about these arguments.
See Also
--------
arviz_plots.PlotCollection.grid
"""
if cols is None:
cols = []
if figure_kwargs is None:
figure_kwargs = {}
if backend is None:
backend = rcParams["plot.backend"]
data = concat_model_dict(data)
n_plots, plots_per_var = process_facet_dims(data, cols)
if n_plots <= col_wrap:
n_rows, n_cols = 1, n_plots
else:
div_mod = divmod(n_plots, col_wrap)
n_rows = div_mod[0] + (div_mod[1] != 0)
n_cols = col_wrap
plot_bknd = import_module(f".backend.{backend}", package="arviz_plots")
fig, ax_ary = plot_bknd.create_plotting_grid(
n_plots, n_rows, n_cols, squeeze=False, **figure_kwargs
)
col_id, row_id = np.meshgrid(np.arange(n_cols), np.arange(n_rows))
viz_dict = {}
flat_ax_ary = ax_ary.flatten()[:n_plots]
flat_row_id = row_id.flatten()[:n_plots]
flat_col_id = col_id.flatten()[:n_plots]
if "__variable__" not in cols:
dims = cols # use provided dim orders, not existing ones
plots_raw_shape = []
coords = {}
for dim in dims:
unique_values = np.unique(data[dim])
plots_raw_shape.append(len(unique_values))
if len(unique_values) == len(data[dim]):
# preserve original order if there are no unique values
coords[dim] = data[dim]
else:
coords[dim] = unique_values
viz_dict["/"] = xr.Dataset(
{
"figure": np.array(fig, dtype=object),
"plot": (dims, flat_ax_ary.reshape(plots_raw_shape)),
"row_index": (dims, flat_row_id.reshape(plots_raw_shape)),
"col_index": (dims, flat_col_id.reshape(plots_raw_shape)),
},
coords=coords,
)
else:
viz_dict["/"] = xr.Dataset({"figure": np.array(fig, dtype=object)})
viz_dict["plot"] = {}
viz_dict["row_index"] = {}
viz_dict["col_index"] = {}
all_dims = cols
facet_cumulative = 0
for var_name, da in data.items():
coords = {}
plots_raw_shape = []
for dim in all_dims:
if dim not in set(da.dims).union(da.coords):
continue
unique_values = np.unique(da[dim])
plots_raw_shape.append(len(unique_values))
if len(unique_values) == len(data[dim]):
coords[dim] = data[dim]
else:
coords[dim] = unique_values
dims = list(coords.keys())
col_slice = (
slice(None, None)
if var_name not in plots_per_var
else slice(facet_cumulative, facet_cumulative + plots_per_var[var_name])
)
facet_cumulative += plots_per_var[var_name]
aux_ds = xr.Dataset(
{
"plot": (
dims,
flat_ax_ary[col_slice].reshape(plots_raw_shape),
),
"row_index": (
dims,
flat_row_id[col_slice].reshape(plots_raw_shape),
),
"col_index": (
dims,
flat_col_id[col_slice].reshape(plots_raw_shape),
),
},
coords=coords,
)
viz_dict["plot"][var_name] = aux_ds["plot"]
viz_dict["row_index"][var_name] = aux_ds["row_index"]
viz_dict["col_index"][var_name] = aux_ds["col_index"]
viz_dt = xr.DataTree(
viz_dict["/"],
children={
key: xr.DataTree(xr.Dataset(value)) for key, value in viz_dict.items() if key != "/"
},
)
return cls(data, viz_dt, backend=backend, **kwargs)
[docs]
@classmethod
def grid(
cls,
data,
cols=None,
rows=None,
backend=None,
figure_kwargs=None,
**kwargs,
):
"""Instatiate a PlotCollection and generate a plot grid iterating over rows and columns.
Parameters
----------
data : Dataset or dict of {str: Dataset}
If `data` is a dictionary, the Datasets stored as its values will be concatenated,
creating a new dimension called ``model``.
cols, rows : iterable of hashable, optional
Dimensions of the dataset for which different coordinate values
should have different :term:`plots <plot>`. A special dimension
called ``__variable__`` is also available, to indicate that
each variable of the input Dataset should have their own plot;
it can also be combined with other dimensions.
The generated grid will have as many plots as unique combinations
of values within `cols` and `rows`.
backend : str, optional
Plotting backend.
figure_kwargs : mapping, optional
Passed to :func:`~.backend.create_plotting_grid` of the chosen plotting backend.
**kwargs : mapping, optional
Passed as is to the initializer of ``PlotCollection``. That is,
used for ``aes`` and ``**kwargs`` arguments.
See :meth:`~arviz_plots.PlotCollection.generate_aes_dt` for more
details about these arguments.
See Also
--------
arviz_plots.PlotCollection.wrap
"""
if cols is None:
cols = []
if rows is None:
rows = []
if figure_kwargs is None:
figure_kwargs = {}
if backend is None:
backend = rcParams["plot.backend"]
repeated_dims = [col for col in cols if col in rows]
if repeated_dims:
raise ValueError("The same dimension can't be used for both cols and rows.")
data = concat_model_dict(data)
n_cols, cols_per_var = process_facet_dims(data, cols)
n_rows, rows_per_var = process_facet_dims(data, rows)
n_plots = n_cols * n_rows
plot_bknd = import_module(f".backend.{backend}", package="arviz_plots")
fig, ax_ary = plot_bknd.create_plotting_grid(
n_plots, n_rows, n_cols, squeeze=False, **figure_kwargs
)
col_id, row_id = np.meshgrid(np.arange(n_cols), np.arange(n_rows))
viz_dict = {}
if "__variable__" not in cols and "__variable__" not in rows:
dims = tuple((*rows, *cols)) # use provided dim orders, not existing ones
plots_raw_shape = []
coords = {}
for dim in dims:
unique_values = np.unique(data[dim])
plots_raw_shape.append(len(unique_values))
if len(unique_values) == len(data[dim]):
coords[dim] = data[dim]
else:
coords[dim] = unique_values
viz_dict["/"] = xr.Dataset(
{
"figure": np.array(fig, dtype=object),
"plot": (dims, ax_ary.flatten().reshape(plots_raw_shape)),
"row_index": (dims, row_id.flatten().reshape(plots_raw_shape)),
"col_index": (dims, col_id.flatten().reshape(plots_raw_shape)),
},
coords=coords,
)
else:
viz_dict["/"] = xr.Dataset({"figure": np.array(fig, dtype=object)})
viz_dict["plot"] = {}
viz_dict["row_index"] = {}
viz_dict["col_index"] = {}
all_dims = tuple((*rows, *cols)) # use provided dim orders, not existing ones
facet_cumulative = 0
coords = {}
for var_name, da in data.items():
dims = [dim for dim in all_dims if dim in da.dims]
plots_raw_shape = []
dims = []
for dim in all_dims:
if dim not in set(da.dims).union(da.coords):
continue
if dim in coords:
unique_values = coords[dim]
else:
unique_values = np.unique(da[dim])
if len(unique_values) == len(data[dim]):
coords[dim] = data[dim]
else:
coords[dim] = unique_values
plots_raw_shape.append(len(unique_values))
dims.append(dim)
row_slice = (
slice(None, None)
if var_name not in rows_per_var
else slice(facet_cumulative, facet_cumulative + rows_per_var[var_name])
)
col_slice = (
slice(None, None)
if var_name not in cols_per_var
else slice(facet_cumulative, facet_cumulative + cols_per_var[var_name])
)
if rows_per_var:
facet_cumulative += rows_per_var[var_name]
else:
facet_cumulative += cols_per_var[var_name]
viz_dict["plot"][var_name] = (
dims,
ax_ary[row_slice, col_slice].flatten().reshape(plots_raw_shape),
)
viz_dict["row_index"][var_name] = (
dims,
row_id[row_slice, col_slice].flatten().reshape(plots_raw_shape),
)
viz_dict["col_index"][var_name] = (
dims,
col_id[row_slice, col_slice].flatten().reshape(plots_raw_shape),
)
viz_dict = {key: xr.Dataset(value, coords=coords) for key, value in viz_dict.items()}
viz_dt = xr.DataTree.from_dict(viz_dict)
return cls(data, viz_dt, backend=backend, **kwargs)
[docs]
def update_aes(self, ignore_aes=frozenset(), coords=None):
"""Update list of aesthetics after indicating ignores and extra subsets."""
if coords is None:
coords = {}
aes = [aes_key for aes_key in self.aes_set if aes_key not in ignore_aes]
aes_dims = [dim for aes_key in aes for dim in self.aes[aes_key].dims]
all_loop_dims = self.facet_dims.union(aes_dims).difference(coords.keys())
return aes, all_loop_dims
[docs]
def allocate_artist(
self,
fun_label,
data,
all_loop_dims,
dim_to_idx=None,
artist_dims=None,
ignore_aes=frozenset(),
):
"""Allocate an visual in the ``viz`` DataTree."""
if artist_dims is None:
artist_dims = {}
if dim_to_idx is None:
dim_to_idx = {}
artist_dt = xr.DataTree()
if ignore_aes:
artist_dt.attrs = {"ignore_aes": ignore_aes}
for var_name, da in data.items():
inherited_dims = [
dim_to_idx.get(dim, dim)
for dim in da.dims
if (dim in all_loop_dims) or (dim in dim_to_idx)
]
artist_shape = [
da.sizes[dim_or_idx] if dim_or_idx in da.sizes else len(np.unique(da[dim_or_idx]))
for dim_or_idx in inherited_dims
] + list(artist_dims.values())
all_artist_dims = inherited_dims + list(artist_dims.keys())
# TODO: once DataTree has a .loc attribute, this should work on .viz instead
artist_dt[var_name] = xr.DataArray(
np.full(artist_shape, None, dtype=object),
dims=all_artist_dims,
coords={
dim: np.unique(da[dim]) if dim in dim_to_idx.values() else da[dim]
for dim in inherited_dims
},
)
self._viz_dt[fun_label] = artist_dt
[docs]
def get_target(self, var_name, selection):
"""Get the target that corresponds to the given variable and selection."""
return self.get_viz("plot", var_name, selection)
[docs]
def get_aes_kwargs(self, aes, var_name, selection):
"""Get the aesthetic mappings for the given variable and selection as a dictionary.
Parameters
----------
aes : list
List of aesthetic keywords whose values should be retrieved. Values are taken
from the ``aes`` attribute: groups as the elements in `aes` argument,
variable `var_name` argument if present, otherwise "mapping" or "neutral_element"
and `selection` coordinate/dimension subset.
:class:`.PlotCollection` considers aesthetics starting with "overlay"
a special aesthetic keyword to indicate visual elements with potentially
identical properties should be overlaid.
Thus, if "overlay" or "overlay_xyz" are an element of the `aes` argument,
it is skipped, no value is attempted to be retrieved and it isn't present
as key in the returned output either.
var_name : str
selection : dict
Returns
-------
dict
Mapping of aesthetic keywords to the values corresponding to the provided
`var_name` and `selection`.
See Also
--------
.PlotCollection.generate_aes_dt
"""
aes_kwargs = {}
for aes_key in aes:
if aes_key.startswith("overlay"):
continue
aes_ds = self.aes[aes_key]
if var_name in aes_ds.data_vars:
aes_kwargs[aes_key] = subset_ds(aes_ds, var_name, selection)
else:
if all(dim in selection for dim in aes_ds["mapping"].dims):
aes_kwargs[aes_key] = subset_ds(aes_ds, "mapping", selection)
elif "neutral_element" in aes_ds.data_vars:
aes_kwargs[aes_key] = subset_ds(aes_ds, "neutral_element", {})
else:
raise ValueError(
f"{aes_key} has no neutral element initialized but "
f"{var_name} needs a neutral element."
)
return aes_kwargs
[docs]
def map(
self,
fun,
fun_label=None,
*,
data=None,
coords=None,
ignore_aes=frozenset(),
subset_info=False,
store_artist=True,
artist_dims=None,
**kwargs,
):
"""Apply the given plotting function to all plots with the corresponding aesthetics.
Parameters
----------
fun : callable
Function with signature ``fun(da, target, **fun_kwargs)`` which should
be applied for all combinations of :term:`plot` and :term:`aesthetic`.
The object returned by `fun` is assumed to be a scalar unless
`artist_dims` are provided. There is also the option of adding
extra required keyword arguments with the `subset_info` flag.
fun_label : str, optional
Variable name with which to store the object returned by `fun`.
Defaults to ``fun.__name__``.
data : Dataset or DataArray, optional
Data to be subsetted at each iteration and to pass to `fun` as first positional
argument. If `data` is a DataArray it must be named.
Defaults to the data used to initialize the ``PlotCollection``.
coords : mapping, optional
Dictionary of {coordinate names : coordinate values} that should
be used to subset the aes, data and viz objects before any faceting
or aesthetics mapping is applied.
ignore_aes : set or "all", optional
Set of aesthetics present in ``aes`` that should be ignore for this
``map`` call. The string "all" is also valid to indicate all aesthetics
should be ignored, thus taking only facetting into account.
subset_info : boolean, default False
Add the subset info from :func:`arviz_base.xarray_sel_iter` to
the keyword arguments passed to `fun`. If true, then `fun` must
accept the keyword arguments ``var_name``, ``sel`` and ``isel``.
Moreover, if those were to be keys present in `**kwargs` their
values in `**kwargs` would be ignored.
store_artist : boolean, default True
artist_dims : mapping of {hashable : int}, optional
Dictionary of sizes for proper allocation and storage when using
``map`` with functions that return an array of :term:`visual`.
**kwargs : mapping, optional
Keyword arguments passed as is to `fun`. Values within `**kwargs`
with :class:`~xarray.DataArray` of :class:`~xarray.Dataset` type
will be subsetted on the current selection (if possible) before calling `fun`.
Slicing with dims and coords is applied to the relevant subset present in the
xarray object so dimensions with mapped asethetics not being present is not an issue.
However, using Datasets that don't contain all the variable names in `data`
will raise an error.
"""
if coords is None:
coords = {}
if fun_label is None:
fun_label = fun.__name__
if isinstance(ignore_aes, str) and ignore_aes == "all":
ignore_aes = self.aes_set
data = self.data if data is None else data
if isinstance(data, xr.DataArray):
data = data.to_dataset()
if not isinstance(data, xr.Dataset):
raise TypeError("data argument must be an xarray.Dataset")
aes, all_loop_dims = self.update_aes(ignore_aes, coords)
dim_to_idx = {
data[idx].dims[0]: idx
for idx in data.coords
if (idx in all_loop_dims) and (idx not in data.dims)
}
for idx in dim_to_idx.values():
if idx not in data.xindexes:
data = data.set_xindex(idx)
skip_dims = {
dim for dim in data.dims if (dim not in all_loop_dims) and (dim not in dim_to_idx)
}
plotters = xarray_sel_iter(data, skip_dims=skip_dims, dim_to_idx=dim_to_idx)
if store_artist:
self.allocate_artist(
fun_label=fun_label,
data=data,
all_loop_dims=all_loop_dims,
dim_to_idx=dim_to_idx,
artist_dims=artist_dims,
ignore_aes=ignore_aes,
)
for var_name, sel, isel in plotters:
da = data[var_name].sel(sel)
try:
if np.all(np.isnan(da)):
continue
except TypeError:
pass
sel_plus = {**sel, **coords}
target = self.get_target(var_name, sel_plus)
aes_kwargs = self.get_aes_kwargs(aes, var_name, sel_plus)
fun_kwargs = {
**aes_kwargs,
**{
key: process_kwargs_subset(values, var_name, sel)
for key, values in kwargs.items()
},
}
if subset_info:
fun_kwargs = {**fun_kwargs, "var_name": var_name, "sel": sel, "isel": isel}
aux_artist = fun(da, target=target, **fun_kwargs)
if store_artist:
if np.size(aux_artist) == 1:
aux_artist = np.squeeze(aux_artist)
self.store_in_artist_da(aux_artist, fun_label, var_name, sel)
def store_in_artist_da(self, aux_artist, fun_label, var_name, sel):
"""Store the visual object of `var_name`+`sel` combination in `fun_label` variable."""
self.viz[fun_label][var_name].loc[sel] = aux_artist
[docs]
def add_legend(
self,
dim,
aes=None,
artist_kwargs=None,
title=None,
text_only=False,
# position=(0, -1), # TODO: add argument
labeller=None,
**kwargs,
):
"""Add a legend for the given visual/aesthetic to the plot.
Warnings
--------
This method is still in early stages of experimentation and anything beyond
the basic usage ``add_legend("dim_name")`` will probably change in breaking ways.
Parameters
----------
dim : hashable or iterable of hashable
Dimension or dimensions for which to generate the legend.
The pseudo-dimension ``__variable__`` is allowed too.
It should have at least one :term:`aesthetic mapped <aesthetic mapping>` to it.
Only the mappings that match will be taken into account; if a legend is requested
for the "chain" dimension but there is only one aesthetic mapping
for ("chain", "group") no legend can be generated.
aes : str or iterable of str, optional
Specific aesthetics to take into account when generating the legend.
They should all be mapped to `dim`. Defaults to all aesthetics matching
that mapping with the exception "x" and "y" which are never included.
artist_kwargs : mapping, optional
Keyword arguments passed to the backend visual function used to
generate the miniatures in the legend.
title : str, optional
Legend title. Defaults to `dim`.
text_only : bool, optional
If True, creates a text-only legend without graphical markers.
labeller : labeller instance, optional
Labeller to generate the legend entries
position : (int, int), default (0, -1)
**kwargs : mapping, optional
Keyword arguments passed to the backend function that generates the legend.
Returns
-------
legend : object
The corresponding legend object for the backend of the ``PlotCollection``.
"""
if isinstance(dim, str):
dim = (dim,)
else:
dim = tuple(dim)
dim_str = ", ".join(("variable" if d == "__variable__" else d for d in dim))
if title is None:
title = dim_str
aes_mappings = {
aes_key: list(ds.dims) + ([] if "mapping" in ds.data_vars else ["__variable__"])
for aes_key, ds in self.aes.children.items()
}
valid_aes = [
aes_key for aes_key, aes_dims in aes_mappings.items() if set(dim) == set(aes_dims)
]
if not valid_aes:
raise ValueError(
f"Legend can't be generated. Found no aesthetics mapped to dimension {dim}. "
f"Existing mappings are {aes_mappings}."
)
if aes is None:
aes = [aes_key for aes_key in valid_aes if aes_key not in ("x", "y")]
elif isinstance(aes, str):
aes = [aes]
if labeller is None:
labeller = BaseLabeller()
sample_aes_ds = self.aes[aes[0]].dataset
subset_iterator = list(xarray_sel_iter(sample_aes_ds, skip_dims=set()))
if "__variable__" in dim:
label_list = [labeller.make_label_flat(*subset) for subset in subset_iterator]
else:
label_list = [
labeller.sel_to_str(sel, isel) if var_name == "mapping" else "∅"
for var_name, sel, isel in subset_iterator
]
if text_only:
kwarg_list = [{} for _ in subset_iterator]
artist_kwargs = {"linestyle": "none", "linewidth": 0, "color": "none"}
else:
kwarg_list = [
self.get_aes_kwargs(aes, var_name, sel) for var_name, sel, _ in subset_iterator
]
plot_bknd = import_module(f".backend.{self.backend}", package="arviz_plots")
legend_title = None if text_only else title
# TODO: store, maybe have a group in viz called legend as if it were a visual more
# but then it has only scalar variables with name `dim_str`
return plot_bknd.legend(
self.viz["figure"].item(),
kwarg_list,
label_list,
title=legend_title,
artist_kwargs=artist_kwargs,
**kwargs,
)