Source code for arviz_plots.backend.bokeh

"""Bokeh interface layer."""
# pylint: disable=protected-access

import math
import warnings

import bokeh.colors.named as named_colors
import numpy as np
from bokeh.colors import Color
from bokeh.io.export import export_png, export_svg
from bokeh.layouts import GridBox, gridplot
from bokeh.models import (
    BoxAnnotation,
    ColumnDataSource,
    CustomJSTickFormatter,
    FixedTicker,
    GridPlot,
    Range1d,
    Span,
    Title,
)
from bokeh.plotting import figure as _figure
from bokeh.plotting import output_file, save
from bokeh.plotting import show as _show

from ..none import get_default_aes as get_agnostic_default_aes
from .legend import legend


class UnsetDefault:
    """Specific class to indicate an aesthetic hasn't been set."""


unset = UnsetDefault()


def set_sqrt_yscale(target):
    """Transform existing plots on a figure to use sqrt(y) scale."""
    max_y = 0
    for renderer in target.renderers:
        if isinstance(renderer.data_source, ColumnDataSource):
            ds = renderer.data_source
            if "y" in ds.data:
                current_max = max(ds.data["y"])
                max_y = max(max_y, current_max)

            elif "y_top" in ds.data:
                current_max = max(ds.data["y_top"])
                max_y = max(max_y, current_max)

    num_ticks = min(5, math.ceil(max_y / 2))
    step_size = round(math.ceil(max_y) / num_ticks)
    end_tick = num_ticks * step_size
    ticks = [i**0.5 for i in range(0, end_tick + 1, step_size)]

    target.yaxis.formatter = CustomJSTickFormatter(
        code="""
        return (tick ** 2).toFixed(0)
    """
    )
    target.yaxis.ticker = FixedTicker(ticks=ticks)
    target.y_range.start = 0
    target.y_range.end = ticks[len(ticks) - 1]

    # Transform existing scatter plots
    for renderer in target.renderers:
        if hasattr(renderer.glyph, "y") and isinstance(renderer.data_source, ColumnDataSource):
            ds = renderer.data_source
            y_field = renderer.glyph.y

            if "original_y" in ds.data:
                continue

            original_y = ds.data[y_field]
            ds.data["original_y"] = original_y
            ds.data["y_sqrt"] = np.sqrt(original_y)

            renderer.glyph.y = "y_sqrt"
        elif (
            hasattr(renderer.glyph, "y0")
            and hasattr(renderer.glyph, "y1")
            and isinstance(renderer.data_source, ColumnDataSource)
        ):
            ds = renderer.data_source
            y_top_field = renderer.glyph.y1
            y_bottom_field = renderer.glyph.y0

            if "original_y_top" in ds.data and "original_y_bottom" in ds.data:
                continue

            original_y_top = ds.data[y_top_field]
            ds.data["original_y_top"] = original_y_top
            ds.data["y_top_sqrt"] = np.sqrt(original_y_top)

            original_y_bottom = ds.data[y_bottom_field]
            ds.data["original_y_bottom"] = original_y_bottom
            ds.data["y_bottom_sqrt"] = np.sqrt(original_y_bottom)

            renderer.glyph.y0 = "y_bottom_sqrt"
            renderer.glyph.y1 = "y_top_sqrt"


def get_hex_from_color_name(color_name: str) -> str:
    """Convert a standard CSS color name into its HEX code using Bokeh."""
    try:
        color_obj: Color = getattr(named_colors, color_name.lower())
        return color_obj.to_hex()

    except AttributeError as exc:
        raise ValueError(f"Color '{color_name}' is not a valid Bokeh named color.") from exc


def get_background_color():
    """Get the background color of the current Bokeh document."""
    try:
        from bokeh.io import curdoc

        bg_color = curdoc().theme._json["attrs"]["Plot"]["background_fill_color"]
        hex_bg_color = get_hex_from_color_name(bg_color)
        return hex_bg_color
    except (ImportError, KeyError):
        return "#ffffff"


# generation of default values for aesthetics
def get_default_aes(aes_key, n, kwargs=None):
    """Generate `n` *bokeh valid* default values for a given aesthetics keyword."""
    if kwargs is None:
        kwargs = {}
    if aes_key not in kwargs:
        if "color" in aes_key:
            # fmt: off
            vals = [
                '#3f90da', '#ffa90e', '#bd1f01', '#94a4a2', '#832db6',
                '#a96b59', '#e76300', '#b9ac70', '#717581', '#92dadd'
            ]
            # fmt: on
            try:
                from bokeh.io import curdoc

                template_colors = curdoc().theme._json["attrs"]["Cycler"]["colors"]
            except (ImportError, KeyError):
                template_colors = None
            vals = vals if template_colors is None else template_colors

        elif aes_key in {"linestyle", "line_dash"}:
            vals = ["solid", "dashed", "dotted", "dashdot"]
        elif aes_key == "marker":
            vals = ["circle", "cross", "triangle", "x", "diamond", "square", "dot"]
        else:
            return get_agnostic_default_aes(aes_key, n)
        return get_agnostic_default_aes(aes_key, n, {aes_key: vals})
    return get_agnostic_default_aes(aes_key, n, kwargs)


def scale_fig_size(figsize, rows=1, cols=1, figsize_units=None):
    """Scale figure properties according to figsize, rows and cols.

    Parameters
    ----------
    figsize : (float, float) or None
        Size of figure in `figsize_units`
    rows : int
        Number of rows
    cols : int
        Number of columns
    figsize_units : {"inches", "dots"}
        Ignored if `figsize` is ``None``

    Returns
    -------
    figsize : (float, float) or None
        Size of figure in dots
    labelsize : float
        fontsize for labels
    linewidth : float
        linewidth
    """
    if figsize_units is None:
        figsize_units = "dots"
    if figsize is None:
        width = cols * (400 if cols < 4 else 250)
        height = 100 * (rows + 1) ** 1.1
        figsize_units = "dots"
    else:
        width, height = figsize
    if figsize_units == "inches":
        warnings.warn(
            f"Assuming dpi=100. Use figsize_units='dots' and figsize={figsize} "
            "to stop seeing this warning"
        )
        width *= 100
        height *= 100
    elif figsize_units != "dots":
        raise ValueError(f"figsize_units must be 'dots' or 'inches', but got {figsize_units}")

    return (width, height)


# object creation and i/o
[docs] def show(figure): """Show the provided bokeh layout.""" _show(figure)
def savefig(figure, path, **kwargs): """Save the figure to a file. Parameters ---------- figure : bokeh.plotting.Figure The figure to save. filename : pathlib.Path The path to the file where the figure will be saved. **kwargs : dict, optional Additional keyword arguments passed to the export or save function depending on the file extension. """ if path.suffix == ".png": export_png(figure, filename=path, **kwargs) elif path.suffix == ".svg": export_svg(figure, filename=path, **kwargs) elif path.suffix == ".html": output_file(path) save(figure, **kwargs) else: raise ValueError( f"Unsupported file format: {path}. Supported formats are .png, .svg, and .html." ) def get_figsize(plot_collection): """Get the size of the :term:`figure` element and its units.""" figure = plot_collection.viz["figure"].item() if figure is None: plot = plot_collection.viz["plot"].item() return (plot.width, plot.height), "dots" if isinstance(figure, (GridBox, GridPlot)): gridbox = figure elif isinstance(figure, tuple): gridbox = figure[1] else: gridbox = figure.children[1] if not isinstance(gridbox, (GridBox, GridPlot)): return (800, 800) row_heights_sum = np.sum([plot.height for plot, _, col in gridbox.children if col == 0]) col_widths_sum = np.sum([plot.width for plot, row, _ in gridbox.children if row == 0]) return (col_widths_sum, row_heights_sum), "dots"
[docs] def create_plotting_grid( number, rows=1, cols=1, *, figsize=None, figsize_units="inches", squeeze=True, sharex=False, sharey=False, polar=False, width_ratios=None, height_ratios=None, plot_hspace=None, subplot_kws=None, **kwargs, ): """Create a figure with a grid of plotting targets in it. Parameters ---------- number : int Number of axes required rows, cols : int, default 1 Number of rows and columns. figsize : (float, float), optional Size of the figure in `figsize_units`. It overwrites the values for "width" and "height" in `subplot_kws` if any. figsize_units : {"inches", "dots"}, default "inches" Units in which `figsize` is given. squeeze : bool, default True sharex, sharey : bool, default False polar : bool subplot_kws : bool Passed to :func:`~bokeh.plotting.figure` **kwargs: dict, optional Passed to :func:`~bokeh.layouts.gridplot` Returns ------- `~bokeh.layouts.gridplot` or None `~bokeh.plotting.figure` or ndarray of `~bokeh.plotting.figure` """ if subplot_kws is None: subplot_kws = {} subplot_kws = subplot_kws.copy() figures = np.empty((rows, cols), dtype=object) if polar: subplot_kws.setdefault("x_axis_type", None) subplot_kws.setdefault("y_axis_type", None) if plot_hspace is not None: subplot_kws.setdefault("min_border_left", plot_hspace) subplot_kws.setdefault("min_border_right", plot_hspace) if figsize is not None: if figsize_units == "inches": figsize = (figsize[0] * 100, figsize[1] * 100) warnings.warn( f"Assuming dpi=100. Use figsize_units='dots' and figsize={figsize} " "to stop seeing this warning" ) elif figsize_units != "dots": raise ValueError(f"figsize_units must be 'dots' or 'inches', but got {figsize_units}") subplot_kws["width"] = int(np.ceil(figsize[0] / cols)) subplot_kws["height"] = int(np.ceil(figsize[1] / rows)) plot_widths = None if width_ratios is not None: if len(width_ratios) != cols: raise ValueError("width_ratios must be an iterable of length cols") plot_width = subplot_kws.get("width", 600) figure_width = plot_width * cols width_ratios = np.array(width_ratios, dtype=float) width_ratios /= width_ratios.sum() plot_widths = np.ceil(figure_width * width_ratios).astype(int) if height_ratios is not None: if len(height_ratios) != rows: raise ValueError("height_ratios must be an iterable of length rows") plot_height = subplot_kws.get("height", 600) figure_height = plot_height * rows height_ratios = np.array(height_ratios, dtype=float) height_ratios /= height_ratios.sum() plot_height = np.ceil(figure_height * height_ratios).astype(int) shared_xrange = {} shared_yrange = {} for row in range(rows): for col in range(cols): subplot_kws_i = subplot_kws.copy() if col != 0 and sharex == "row": subplot_kws_i["x_range"] = shared_xrange[row] if row != 0 and sharex == "col": subplot_kws_i["x_range"] = shared_xrange[col] if col != 0 and sharey == "row": subplot_kws_i["y_range"] = shared_yrange[row] if row != 0 and sharey == "col": subplot_kws_i["y_range"] = shared_yrange[col] if width_ratios is not None: subplot_kws["width"] = plot_widths[col] if height_ratios is not None: subplot_kws["height"] = plot_height[row] if row * cols + (col + 1) > number: figures[row, col] = None continue if (row == 0) and (col == 0) and (sharex is True or sharey is True): p = _figure(**subplot_kws_i) if sharex is True: subplot_kws["x_range"] = p.x_range if sharey is True: subplot_kws["y_range"] = p.y_range figures[row, col] = p else: figures[row, col] = _figure(**subplot_kws_i) p = figures[row, col] if col == 0: if sharex == "row": shared_xrange[row] = p.x_range if sharey == "row": shared_yrange[row] = p.y_range if row == 0: if sharex == "col": shared_xrange[col] = p.x_range if sharey == "col": shared_yrange[col] = p.y_range if squeeze and figures.size == 1: return None, figures[0, 0] layout = gridplot(figures.tolist(), **kwargs) return layout, figures.squeeze() if squeeze else figures
# helper functions def _filter_kwargs(kwargs, artist_kws): """Filter a dictionary to remove all keys whose values are ``unset``.""" kwargs = {key: value for key, value in kwargs.items() if value is not unset} return {**artist_kws, **kwargs} def _float_or_str_size(size): """Bokeh only accepts string sizes with units. Convert float sizes to string ones in px units. """ if size is unset: return size if isinstance(size, str): return size return f"{size:.0f}px" # "geoms" def hist( y, l_e, r_e, target, *, bottom=0, color=unset, facecolor=unset, edgecolor=unset, alpha=unset, **artist_kws, ): """Interface to Bokeh for a histogram bar plot.""" if color is not unset: if facecolor is unset: facecolor = color if edgecolor is unset: edgecolor = color kwargs = {"bottom": bottom, "fill_color": facecolor, "line_color": edgecolor, "alpha": alpha} return target.quad(top=y, left=l_e, right=r_e, **_filter_kwargs(kwargs, artist_kws))
[docs] def line(x, y, target, *, color=unset, alpha=unset, width=unset, linestyle=unset, **artist_kws): """Interface to bokeh for a line plot.""" kwargs = {"color": color, "alpha": alpha, "line_width": width, "line_dash": linestyle} return target.line(np.atleast_1d(x), np.atleast_1d(y), **_filter_kwargs(kwargs, artist_kws))
[docs] def multiple_lines( x, y, target, *, color=unset, alpha=unset, width=unset, linestyle=unset, **artist_kws ): """Interface to bokeh for multiple lines.""" y = y.T y = [np.atleast_1d(yi) for yi in y] x = [list(x) for _ in range(len(y))] if len(x) != len(y): raise ValueError("x and y must have the same length") source = ColumnDataSource(data={"x": x, "y": y}) kwargs = {"line_color": color, "line_alpha": alpha, "line_width": width, "line_dash": linestyle} return target.multi_line(xs="x", ys="y", source=source, **_filter_kwargs(kwargs, artist_kws))
[docs] def scatter( x, y, target, *, size=unset, marker=unset, alpha=unset, color=unset, facecolor=unset, edgecolor=unset, width=unset, **artist_kws, ): """Interface to bokeh for a scatter plot.""" if color is not unset: if facecolor is unset and edgecolor is unset: facecolor = color edgecolor = color elif facecolor is unset: facecolor = color elif edgecolor is unset: edgecolor = color kwargs = { "size": size, "marker": marker, "line_alpha": alpha, "fill_alpha": alpha, "fill_color": facecolor, "line_color": edgecolor, "line_width": width, } kwargs = _filter_kwargs(kwargs, artist_kws) if marker == "|": kwargs["marker"] = "dash" kwargs["angle"] = np.pi / 2 source = ColumnDataSource(data={"x": np.atleast_1d(x), "y": np.atleast_1d(y)}) return target.scatter(x="x", y="y", source=source, **kwargs)
def step( x, y, target, *, color=unset, alpha=unset, width=unset, linestyle=unset, step_mode=unset, **artist_kws, ): """Interface to bokeh for a step line.""" kwargs = { "color": color, "alpha": alpha, "line_width": width, "line_dash": linestyle, "mode": step_mode, } return target.step(np.atleast_1d(x), np.atleast_1d(y), **_filter_kwargs(kwargs, artist_kws))
[docs] def text( x, y, string, target, *, size=unset, alpha=unset, color=unset, vertical_align="middle", horizontal_align="center", **artist_kws, ): """Interface to bokeh for adding text to a plot.""" kwargs = { "text_font_size": _float_or_str_size(size), "alpha": alpha, "color": color, "text_align": horizontal_align, "text_baseline": vertical_align, } return target.text( np.atleast_1d(x), np.atleast_1d(y), np.atleast_1d(string), **_filter_kwargs(kwargs, artist_kws), )
def fill_between_y(x, y_bottom, y_top, target, **artist_kws): """Fill the area between y_bottom and y_top.""" x = np.atleast_1d(x) y_bottom = np.atleast_1d(y_bottom) if y_bottom.size == 1: y_bottom = y_bottom.item() y_top = np.atleast_1d(y_top) if y_top.size == 1: y_top = y_top.item() return target.varea(x=x, y1=y_bottom, y2=y_top, **artist_kws) def vline(x, target, *, color=unset, alpha=unset, width=unset, linestyle=unset, **artist_kws): """Interface to bokeh for a vertical line spanning the whole axes.""" kwargs = {"line_color": color, "line_alpha": alpha, "line_width": width, "line_dash": linestyle} span_element = Span(location=x, dimension="height", **_filter_kwargs(kwargs, artist_kws)) target.add_layout(span_element) return span_element def hline(y, target, *, color=unset, alpha=unset, width=unset, linestyle=unset, **artist_kws): """Interface to bokeh for a horizontal line spanning the whole axes.""" kwargs = {"line_color": color, "line_alpha": alpha, "line_width": width, "line_dash": linestyle} span_element = Span(location=y, dimension="width", **_filter_kwargs(kwargs, artist_kws)) target.add_layout(span_element) return span_element def vspan(xmin, xmax, target, *, color=unset, alpha=unset, **artist_kws): """Interface to bokeh for a vertical shaded region spanning the whole axes.""" kwargs = {"fill_color": color, "fill_alpha": alpha} vbox = BoxAnnotation(left=xmin, right=xmax, **_filter_kwargs(kwargs, artist_kws)) target.add_layout(vbox) return vbox def hspan(ymin, ymax, target, *, color=unset, alpha=unset, **artist_kws): """Interface to bokeh for a horizontal shaded region spanning the whole axes.""" kwargs = {"fill_color": color, "fill_alpha": alpha} hbox = BoxAnnotation(bottom=ymin, top=ymax, **_filter_kwargs(kwargs, artist_kws)) target.add_layout(hbox) return hbox def ciliney( x, y_bottom, y_top, target, *, color=unset, alpha=unset, width=unset, linestyle=unset, **artist_kws, ): """Interface to bokeh for a line from y_bottom to y_top at given value of x.""" kwargs = {"color": color, "alpha": alpha, "line_width": width, "line_dash": linestyle} x = np.atleast_1d(x) y_bottom = np.atleast_1d(y_bottom) y_top = np.atleast_1d(y_top) source = ColumnDataSource( data={ "x": np.atleast_1d(x), "y_bottom": np.atleast_1d(y_bottom), "y_top": np.atleast_1d(y_top), } ) return target.segment( x0="x", x1="x", y0="y_bottom", y1="y_top", source=source, **_filter_kwargs(kwargs, artist_kws), ) # general plot appeareance
[docs] def title(string, target, *, size=unset, color=unset, **artist_kws): """Interface to bokeh for adding a title to a plot.""" kwargs = {"text_font_size": _float_or_str_size(size), "text_color": color} target.title = Title(text=string, **_filter_kwargs(kwargs, artist_kws)) return target.title
[docs] def ylabel(string, target, *, size=unset, color=unset, **artist_kws): """Interface to bokeh for adding a label to the y axis.""" kwargs = {"text_font_size": _float_or_str_size(size), "text_color": color} target.yaxis.axis_label = string for key, value in _filter_kwargs(kwargs, artist_kws).items(): setattr(target.yaxis, f"axis_label_{key}", value)
[docs] def xlabel(string, target, *, size=unset, color=unset, **artist_kws): """Interface to bokeh for adding a label to the x axis.""" kwargs = {"text_font_size": _float_or_str_size(size), "text_color": color} target.xaxis.axis_label = string for key, value in _filter_kwargs(kwargs, artist_kws).items(): setattr(target.xaxis, f"axis_label_{key}", value)
[docs] def xticks(ticks, labels, target, *, rotation=unset, **artist_kws): """Interface to bokeh for setting ticks and labels of the x axis.""" target.xaxis.ticker = ticks if labels is not None: target.xaxis.major_label_overrides = { key.item() if hasattr(key, "item") else key: value for key, value in zip(ticks, labels) } if rotation is not unset: rotation = math.radians(rotation) for key, value in _filter_kwargs({"orientation": rotation}, artist_kws).items(): setattr(target.xaxis, f"major_label_{key}", value)
[docs] def yticks(ticks, labels, target, *, rotation=unset, **artist_kws): """Interface to bokeh for setting ticks and labels of the y axis.""" target.yaxis.ticker = ticks if labels is not None: target.yaxis.major_label_overrides = { key.item() if hasattr(key, "item") else key: value for key, value in zip(ticks, labels) } if rotation is not unset: rotation = math.radians(rotation) for key, value in _filter_kwargs({"orientation": rotation}, artist_kws).items(): setattr(target.yaxis, f"major_label_{key}", value)
def xlim(lims, target, **artist_kws): """Interface to bokeh for setting limits for the x axis.""" target.x_range = Range1d(*lims, **artist_kws) def ylim(lims, target, **artist_kws): """Interface to bokeh for setting limits for the y axis.""" target.y_range = Range1d(*lims, **artist_kws)
[docs] def ticklabel_props(target, *, axis="both", size=unset, color=unset, **artist_kws): """Interface to bokeh for setting ticks size.""" kwargs = {"text_font_size": _float_or_str_size(size), "text_color": color} for key, value in _filter_kwargs(kwargs, artist_kws).items(): if axis in {"y", "both"}: setattr(target.yaxis, f"major_label_{key}", value) if axis in {"x", "both"}: setattr(target.xaxis, f"major_label_{key}", value)
def set_ticklabel_visibility(target, *, axis="both", visible=True): """Set the visibility of tick labels on a Bokeh plot.""" # Determine the font size to apply. 0pt effectively hides the labels. font_size = "1em" if visible else "0pt" if axis not in ["x", "y", "both"]: raise ValueError(f"axis must be one of 'x', 'y' or 'both', got '{axis}'") if axis in ["x", "both"]: for ax in target.xaxis: ax.major_label_text_font_size = font_size if axis in ["y", "both"]: for ax in target.yaxis: ax.major_label_text_font_size = font_size
[docs] def remove_ticks(target, *, axis="y"): # pylint: disable=unused-argument """Interface to bokeh for removing ticks from a plot.""" if axis in {"y", "both"}: target.yaxis.major_tick_out = 0 target.yaxis.major_tick_in = 0 target.yaxis.major_label_text_font_size = "0pt" if axis in {"x", "both"}: target.xaxis.major_tick_out = 0 target.xaxis.major_tick_in = 0 target.xaxis.major_label_text_font_size = "0pt"
[docs] def remove_axis(target, axis="y"): """Interface to bokeh for removing axis from a plot.""" if axis == "y": target.yaxis.visible = False elif axis == "x": target.xaxis.visible = False elif axis == "both": target.axis.visible = False else: raise ValueError(f"axis must be one of 'x', 'y' or 'both', got '{axis}'")
def set_y_scale(target, scale): """Interface to bokeh for setting the y scale of a plot.""" if scale == "sqrt": set_sqrt_yscale(target) else: pass def grid(target, axis, color): """Interface to bokeh for setting a grid in any axis.""" if axis in ["y", "both"]: target.ygrid.grid_line_color = color if axis in ["x", "both"]: target.xgrid.grid_line_color = color