Using PlotCollection objects#

This tutorial covers handling PlotCollection; its main attributes and methods, and how to use it to modify the figure it contains. It does not cover how to create a PlotCollection. Consequently, this should not be the first time you are hearing about PlotCollection. If you are not, we recommend first going over either one of the following two pages:

PlotCollection attributes#

viz: organized storage of plotting backend objects#

The .viz attribute contains most of the elements that comprise the visualization itself: the figure, plots and visuals.

“most of” because while the figure and plot elements are created directly by methods of PlotCollection like grid or wrap, visuals are created by external functions executed through PlotCollection as many times as needed on the indicated plots, and some of these functions might not return an object from the plotting backend library to store.

from arviz_base import load_arviz_data
idata = load_arviz_data("rugby")
from arviz_plots import plot_dist, plot_forest, plot_trace_dist, style
style.use("arviz-variat")

ArviZ plotting functions aim to store as many visuals as possible. This makes all visual elements available to further customization after the function has been called. Let’s see what are the contents of the PlotCollection returned by plot_dist:

pc = plot_dist(idata, var_names=["home", "intercept", "atts", "defs"])
../_images/c71993c511a488fbbf46de1c0178717972fb0577e2262b0efe1d111308ce772e.png
pc.viz
<xarray.DatasetView> Size: 8B
Dimensions:  ()
Data variables:
    figure   object 8B Figure(2880x1468.27)

As you can see by inspecting the HTML interactive view right above, the .viz attribute is a DataTree with 8 groups. There will always be one group per visual. In contrast, Plots can be either groups or data, depending on the faceting strategy. In this case, we are faceting over the variables, so plot is a group. Thus, the eight groups are:

  • plot: the backend objects that correspond to the plot elements.

  • row_index and col_index: integer indicators of the row and column each plot occupies within the figure

  • kde, credible_interval, point_estimate, point_estimate_text and title: the visuals corresponding respectively to: the KDE line (blue line), credible interval line (gray horizontal line), the point estimate dot (gray circle), the point estimate annotation (gray text over the point estimate) and the title (in bolded black font over each plot).

Each group will have variables matching the variables in the input data (or a subset of them). The dimensions of each group are independent. So are the dimensions of each variable. These may be different for each visual group, and may even be different among the variables within each group.

Moreover, there is a global figure variable which is always a scalar.

Important

The structure of the .viz attribute is backend agnostic, but its contents are backend dependent.

pc = plot_dist(idata, backend="bokeh")
pc.viz
<xarray.DatasetView> Size: 8B
Dimensions:  ()
Data variables:
    figure   object 8B GridPlot(id='p1960', ...)

In the first case, we generated the plot with matplotlib so the objects stored are matplotlib objects like Figure, Axes, Line2D or Text.

In the last cell, we have instead generated the plot with bokeh. Thus, the objects stored are bokeh objects like Column, Figure, GlyphRenderer or Title.

Given the .viz attribute stores the visual elements that go into making one plot, let’s try inspecting the result of a different function:

pc = plot_forest(idata, var_names=["home", "atts", "defs"])
../_images/c3a5e2a94696da729d9f885b7e44635fd8ce3b893f9a25d4df8006a5f593f984.png
pc.viz
<xarray.DatasetView> Size: 104B
Dimensions:    (column: 2)
Coordinates:
  * column     (column) <U6 48B 'labels' 'forest'
Data variables:
    figure     object 8B Figure(2400x1761.48)
    plot       (column) object 16B Axes(0.0371183,0.0376897;0.231949x0.957579...
    row_index  (column) int64 16B 0 0
    col_index  (column) int64 16B 0 1

If instead we inspect the PlotCollection returned by plot_forest we’ll see there are different visuals stored. In this case, all variables are in the same plot because they are differenced by their y coordinate. The plot, row_index and col_index variables are now global, as they are shared by all variables. Still, we continue to have different shapes for different visuals and variables within them.

aes: mapping of aesthetic keys to values and storage all at once#

The other main attribute of PlotCollection is .aes. It is also a DataTree and it has a similar structure. Now the aesthetics are the groups. The contents of these groups depends on the type of mapping defined:

  • If the aesthetic mapping includes the variables, the group will have variables matching the ones in the input data. Just like what we saw with .viz

  • If the aesthetic mapping encodes only dimension information the group will have only 1 or 2 variables. The variable mapping will always be present. It is the one that contains the mapping from dimension(s) to aesthetic values. The second variable is optional: neutral_element. It is only present when the mapping defined is not applicable to all the variables in the input data. When present it is always a scalar containing the neutral element.

Instead of storing plotted objects, however, it stores aesthetic mapping as key-value pairs. This allows us to check what properties are being used depending on the coordinate values they represent. For example, we can then access them for further manual plotting using the same mappings. Similarly, it is also possible to modify them before calling (more) plotting functions that would then use the updated mappings instead.

pc = plot_trace_dist(idata, var_names=["home", "intercept", "atts", "defs"])
../_images/a386f32585a30212a8f722bedaac185d1e851c7684b2a3f0b80b08e4867c20ee.png

Inspecting the aes attribute we can see that the linestyle depends on the coordinate value of the chain dimension, and the color depends on both the data variable and the team dimension. All of this information clearly matches what we can see in the plot.

pc.aes
<xarray.DatasetView> Size: 0B
Dimensions:  ()
Data variables:
    *empty*

As the color depends on both the variable and the team, its group within .aes has variables matching those of the input data. On the other hand, as the linestyle only depends on the chain, it gets the mapping variable.

There is also an extra aesthetic called overlay whose value is ignored, but whose presence ensures we’ll loop over the right dimensions and draw the expected lines. This is helpful to plot multiple subsets all with the same visual properties, which is the default behaviour in plot_ppc or to ensure the plot behaves as expected even if we disable some of the default aesthetic mappings like we do in this example.

Customizing your PlotCollection#

Modify specific visual elements#

If you pass keyword arguments to map, those arguments will be used in all the calls to the plotting function .map does. However, in some cases we might want more control. The next cell shows an example. We directly manipulate these properties to highlight only variables that correspond to the national team of Scotland.

Important

As we have already mentioned, the structure of the .viz attribute is backend agnostic, but its contents are backend dependent.

Consequently, the steps to select a specific visual given variable names and coordinates is always the same, but the result of that is an object from the chosen plotting backend. Thus, modifying the visual element is backend dependent and we consider that adding helper functions for such tasks is out of the scope of the library.

You can interact with the .viz attribute as you’d interact with any xarray.DataTree. It is also possible to use the get_viz helper method to simplify these calls a bit. See the differences below:

pc = plot_dist(idata, var_names=["home", "intercept", "atts", "defs"])
atts_scotland_kde = pc.viz["kde"]["atts"].sel(team="Scotland").item()
# atts_scotland_kde is now the Line2D object that
# corresponds to the kde line of the coordinate Scotland of variable atts
atts_scotland_kde.set(linewidth=3, color="lime")
pc.get_viz("kde", "defs", team="Scotland").set(linewidth=3, color="lime");
../_images/1bd88869215e6e94d3a402402cac374e69198317a55744a9b042a2d52cd462c6.png

You are not limited to only manipulating visual element properties. In the next cell, we show how to manipulate plot properties; in this case to add a grid to only the intercept plot.

pc = plot_dist(idata, var_names=["home", "intercept", "atts", "defs"])
pc.get_viz("plot", "intercept").grid(True)
../_images/ecbb9398f198b11455ffcd9cfef5d5e0e02732e4c17a5ca2ec5a453ad4c680dd.png

Let’s also see an example of a similar task but using Bokeh as backend:

from bokeh.plotting import output_notebook
output_notebook()
Loading BokehJS ...
pc = plot_dist(
    idata,
    var_names=["home", "atts", "defs"],
    backend="bokeh",
    # make plot smaller
    figure_kwargs={"figsize": (1300, 600), "figsize_units": "dots"},
)
pe_glyph = pc.get_viz("point_estimate", "atts", team="Italy").glyph
pe_glyph
Scatter(
id = 'p3833', …)

We can inspect and modify any of the stored elements by their labels. We have saved the Bokeh object that corresponds to the point estimate dot in the atts[team=Italy] plot. We can now change some of its properties before rendering the figure:

pe_glyph.fill_color = "red"
pe_glyph.size = 20
pc.show()

In some cases, it is more convenient to select elements based on their positions in the plot grid, rather than by variable names or coordinates. The row_index and col_index groups are provided for this purpose.

Note

Selection with row and column is a bit more convoluted that it might need to be, but this also serves to illustrate an important issue. Some operations on the DataTree/Dataset/DataArray objects will trigger copies, which don’t play well with the majority of plotting backend objects.

Here for example, attempting to use .where(condition, drop=True) which would make things more direct will trigger a copy and because of that the plotting backend will raise an error. We are forced to convert the .where operation to an indexing one.

pc = plot_dist(
    idata,
    var_names=["home", "atts", "defs"],
    backend="bokeh",
    # make plot smaller
    figure_kwargs={"figsize": (1300, 600), "figsize_units": "dots"},
)

import numpy as np
condition = (pc.get_viz("row_index", "defs") == 2) & (pc.get_viz("col_index", "defs") == 1)
cond_sel = {"team" : condition.coords["team"][condition]}
kde_glyph = pc.get_viz("kde", "defs", cond_sel).glyph
kde_glyph.line_color = "lime"
kde_glyph.line_width = 4
pc.show()

Add new visual elements to a PlotCollection#

Instead of modifying existing visual elements, we might instead want to add more elements to the plots. If we want to add something to a specific plot, the procedure is basically the same as above with the only difference of calling a plotting function instead of modifying properties of the existing elements.

For example, let’s plot a vertical reference line to the defs of the France national team:

pc = plot_dist(idata, var_names=["home", "atts", "defs"])
ax = pc.get_viz("plot", "defs", team="France")
ax.axvline(0, color="red");
../_images/9afb757ef92cf113d61ab046a7bed6570705401a89e8693ccfcae4e06f5eefc6.png

If we instead want to apply it to all plotting functions, we can use map:

# to be able to use map, callables must accept 2 positional arguments
# a DataArray and the plotting target
def axvline(da, target, **kwargs):
    return target.axvline(0, **kwargs)

pc = plot_dist(idata, var_names=["home", "atts", "defs"])
pc.map(axvline, color="red")
../_images/97d6dd442b0a743ec627d4a70735b16d55064d3ed4272fed9d5e58645401c258.png

See also

The map method is one of the main building blocks provided by PlotCollection. The Create your own figure with PlotCollection page covers the use of map more extensively.

Legends#

PlotCollection also provides a method to automatically generate legends for the plots.

Warning

The API of the add_legend method is still quite experimental.

For properties that are shared for all variables, generating the legend is relatively straightforward. Mappings are unique, and we have sensible defaults available: coordinate values as legend entries and the dimension name as the legend title.

pc = plot_trace_dist(idata, var_names=["home", "intercept", "atts", "defs"])
pc.add_legend("chain");
../_images/dadbdc63099b3fbee9aaf874d6ac478ce10cd4b0073aedd86061012b87c5c810.png

It is also possible to have properties that depend on both the data variable and dimensions. In general, aesthetic mappings can be complex, with dependencies on arbitrary combinations of variables and dimensions. There can even be combinations of aesthetics which map to combinations of variables and dimensions!

Moreover, we sometimes use aesthetic mappings as a way to distinguish different visual elements or groups of visual elements. In these cases we might not need a legend, or we might even prefer to omit it.

The example we have just seen, which we’ll also repeat below, has a bit of everything. On one hand, we might need two legends: one for the color encoding into variable+team and another for the linestyle encoding into the chain. On the other hand, in this particular example (and in general when using plot_trace_dist as a diagnostic) we don’t really care about the specific encodings. The different colors for different variable+team combinations allow us to check if same color lines overlap, meaning all chains have converged to the same distribution. Knowing if the yellow line is atts for the Italy team or defs for the Scotland team is irrelevant to our goal of diagnosing convergence. So is knowing if the dashed line represents the chain 0 or the 3.

Therefore, it would be OK to skip both legends altogether. Using PlotCollection you can choose in couple lines which situation best adapts to your particular use-case: no legend, legend for a subset of the mappings or one legend for each aesthetic mapping.

To add a legend on a mapping over multiple dimensions we use a sequence of dimensions (with __variable__ also being valid) as first argument. Here we also add matplotlib specific kwargs to get the legend to look better:

pc = plot_trace_dist(
    idata,
    var_names=["home", "intercept", "atts", "defs"],
)
pc.add_legend(("__variable__", "team"), loc="outside upper center", fontsize=10, ncols=5);
../_images/fdba0974f2f894be4c07fa123ef03a1641d0434c0af601a40056cd4dd96160a8.png

In this example each combination of variable and dimension is encoded in a single aesthetic, but that is not necessarily true. The Advanced examples section has some examples with multiple aesthetics mapped to the same dimension combination. In such cases, the legend requested for such dimension shows the multiple aesthetics.

See also

  • coming soon: In depth explanation of how faceting and aesthetics are handled

  • coming soon: backend specific cookbooks

  • Create your own figure with PlotCollection shows how to create and fill visualizations from scratch using PlotCollection. There is nothing wrong with modifying an exising figure through its PlotCollection. In fact, the design of arviz-plots encourages it and goes to great lengths to make sure it is possible. That being said, if you find yourself constantly needing to modify the generated PlotCollections it might be need to generate your own specific plotting functions.