Source code for deepcave.utils.styled_plotty

# Copyright 2021-2024 The DeepCAVE Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# noqa: D400
"""
# Styled Plotty

This module provides utilities for styling and customizing different plots with plotly.
For this, it uses plotly as well as dash.
"""

from typing import Any, Callable, List, Optional, Tuple, Union

import itertools
import re

import numpy as np
import plotly.express as px
import plotly.graph_objs as go
from ConfigSpace.hyperparameters import (
    CategoricalHyperparameter,
    Constant,
    Hyperparameter,
    IntegerHyperparameter,
)
from dash import html
from dash.development.base_component import Component

from deepcave import interactive
from deepcave.constants import CONSTANT_VALUE, NAN_LABEL, NAN_VALUE
from deepcave.runs import AbstractRun
from deepcave.utils.logs import get_logger

logger = get_logger(__name__)


[docs] @interactive def save_image(figure: go.Figure, name: str) -> None: """ Save a plotly figure as an image. Parameters ---------- figure : go.Figure Plotly figure. name : str Name of the image with extension. Will be automatically saved to the cache. """ from deepcave import config if not config.SAVE_IMAGES: return ratio = 16 / 9 width = 500 height = int(width / ratio) path = config.CACHE_DIR / "figures" / name figure.write_image(path, width=width, height=height) logger.info(f"Saved figure {name} to {path}.")
[docs] def hex_to_rgb(hex_string: str) -> Tuple[int, int, int]: """ Convert a hex_string to a tuple of rgb values. Requires format including #, e.g.: #000000 #ff00ff Parameters ---------- hex_string : str The hex string to be converted. Returns ------- Tuple[int, int, int] A Tuple of the converted RGB values Raises ------ ValueError If the hex string is longer than 7. If there are invalid characters in the hex string. """ if len(hex_string) != 7: raise ValueError(f"Invalid length for #{hex_string}") if any(c not in "0123456789ABCDEF" for c in hex_string.lstrip("#").upper()): raise ValueError(f"Invalid character in #{hex_string}") r_hex = hex_string[1:3] g_hex = hex_string[3:5] b_hex = hex_string[5:7] return int(r_hex, 16), int(g_hex, 16), int(b_hex, 16)
[docs] def get_color(id_: int, alpha: float = 1) -> Union[str, Tuple[float, float, float, float]]: """ Get an RGBA Color, currently (Plotly version 5.3.1) there are 10 possible colors. Parameters ---------- id_ : int ID for retrieving a specific color. alpha : float, optional Alpha value for the color, by default 1. Returns ------- Union[str, Tuple[float, float, float, float]] The color from the color palette. """ if id_ < 10: color = px.colors.qualitative.Plotly[id_] else: color = px.colors.qualitative.Alphabet[id_ - 10] r, g, b = hex_to_rgb(color) return f"rgba({r}, {g}, {b}, {alpha})"
[docs] def get_discrete_heatmap( x: List[Union[float, int]], y: List[int], values: List[Any], labels: List[Any] ) -> go.Heatmap: """ Generate a discrete colorscale from a (nested) list or numpy array of values. Parameters ---------- x : List[Union[float, int]] List of values that present the x-axis of the heatmap. y : List[int] List of values that present the y-axis of the heatmap. values : List[Any] Contains the data values for the heatmap. labels : List[Any] Contains the labels corresponding to the values. Returns ------- go.Heatmap A Plotly Heatmap object corresponding to the input. """ flattened_values = list(itertools.chain(*values)) flattened_labels = list(itertools.chain(*labels)) unique_values = [] unique_labels = [] for value, label in zip(flattened_values, flattened_labels): if value not in unique_values: unique_values += [value] unique_labels += [label] sorted_indices = np.argsort(np.array(unique_values)) unique_sorted_values = [] unique_sorted_labels = [] for idx in sorted_indices: unique_sorted_values += [unique_values[idx]] unique_sorted_labels += [unique_labels[idx]] # Now they are given new ids, and new z values should be created # For that a mapping from old to new is needed mapping = {} v = [] for new, old in enumerate(unique_sorted_values): mapping[old] = new / len(unique_sorted_values) v += [new] z = values for i1, v1 in enumerate(values): for i2, v2 in enumerate(v1): z[i1][i2] = mapping[v2] n_intervals_int = v + [len(v)] n_intervals = [ (i - n_intervals_int[0]) / (n_intervals_int[-1] - n_intervals_int[0]) for i in n_intervals_int ] colors = [get_color(i) for i in range(len(n_intervals))] discrete_colorscale = [] for k in range(len(v)): discrete_colorscale.extend([[n_intervals[k], colors[k]], [n_intervals[k + 1], colors[k]]]) tickvals = [np.mean(n_intervals[k : k + 2]) for k in range(len(n_intervals) - 1)] ticktext = unique_sorted_labels x_str = [str(i) for i in x] y_str = [str(i) for i in y] return go.Heatmap( x=x_str, y=y_str, z=z, showscale=True, colorscale=discrete_colorscale, colorbar={"tickvals": tickvals, "ticktext": ticktext, "tickmode": "array"}, zmin=0, zmax=1, # hoverinfo="skip", )
[docs] def prettify_label(label: Union[str, float, int]) -> str: """ Take a label and prettifies it. E.g. floats are shortened. Parameters ---------- label : Union[str, float, int] Label, which should be prettified. Returns ------- str Prettified label. """ if type(label) == float: if str(label).startswith("0.00") or "e-" in str(label): label = np.format_float_scientific(label, precision=2) # Replace 1.00e-03 to 1e-03 if ".00" in label: label = label.replace(".00", "") # Replace 1e-03 to 1e-3 if "e-0" in label: label = label.replace("e-0", "e-") else: # Round to 2 decimals label = np.round(label, 2) return str(label)
[docs] def get_hyperparameter_ticks( hp: Hyperparameter, additional_values: Optional[List] = None, ticks: int = 4, include_nan: bool = True, ) -> Tuple[List, List]: """ Generate tick data for both tickvals and ticktext. The background is that you might have encoded data, but you don't want to show all of them. With this function, only 6 (default) values are shown. This behavior is ignored if `hp` is categorical. Parameters ---------- hp : Hyperparameter Hyperparameter to generate ticks from. additional_values : Optional[List], optional Additional values, which are forced in addition. By default, None. ticks : int, optional Number of ticks, by default 4 include_nan : bool, optional Whether "nan" as tick should be included. By default True. Returns ------- Tuple[List, List] tickvals and ticktext. """ # This is basically the inverse of `encode_config`. tickvals: List[Any] if isinstance(hp, CategoricalHyperparameter): ticktext = hp.choices if len(ticktext) == 1: tickvals = [0] else: tickvals = [hp.to_vector(choice) / (len(hp.choices) - 1) for choice in hp.choices] elif isinstance(hp, Constant): tickvals = [CONSTANT_VALUE] ticktext = [hp.value] else: min_v = 0 max_v = 1 values: List[Union[float, int]] = [min_v] # Get values for each tick factors = [i / (ticks - 1) for i in range(1, ticks - 1)] for factor in factors: new_v = (factor * (max_v - min_v)) + min_v values += [new_v] values += [max_v] tickvals = [] ticktext = [] inverse_values = [] for value in values: inverse_values += [hp.to_value(value)] # Integers are rounded, they are mapped if isinstance(hp, IntegerHyperparameter): for label in inverse_values: value = hp.to_vector(label) if value not in tickvals: tickvals += [value] ticktext += [label] if additional_values is not None: # Now add additional values are added for value in additional_values: if not (value is None or np.isnan(value) or value == NAN_VALUE): label = hp.to_value(value) value = hp.to_vector(label) if value not in tickvals: tickvals += [value] ticktext += [label] else: for value, label in zip(values, inverse_values): tickvals += [value] ticktext += [label] if additional_values is not None: # Now additional values are added for value in additional_values: if ( not (value is None or np.isnan(value) or value == NAN_VALUE) and value not in tickvals ): tickvals += [value] ticktext += [hp.to_value(value)] ticktext = [prettify_label(label) for label in ticktext] if include_nan: tickvals += [NAN_VALUE] ticktext += [NAN_LABEL] return tickvals, ticktext
[docs] def get_hyperparameter_ticks_from_values( values: List, labels: List, forced: Optional[List[bool]] = None, ticks: int = 6 ) -> Tuple[List, List]: """ Generate tick data for both values and labels. The background is that you might have encoded data, but you don't want to show all of them. With this function, only 6 (default) values are shown. This behavior is ignored if `values` is a list of strings. Parameters ---------- values : List List of values. labels : List List of labels. Must be the same size as `values`. forced : Optional[List[bool]], optional List of booleans. If True, displaying the particular tick is enforced. Independent of `ticks`. ticks : int, optional Number of ticks and labels to show. By default 6. Returns ------- Tuple[List, List] Returns tickvals and ticktext as list. Raises ------ RuntimeError If values contain both strings and non-strings. """ assert len(values) == len(labels) unique_values = [] # df[hp_name].unique() unique_labels = [] # df_labels[hp_name].unique() for value, label in zip(values, labels): if value not in unique_values and label not in unique_labels: unique_values.append(value) unique_labels.append(label) return_all = False for v1, v2 in zip(unique_values, unique_values[1:]): if isinstance(v1, str) or isinstance(v2, str): if type(v1) != type(v2): raise RuntimeError("Values have strings and non-strings.") return_all = True tickvals = [] ticktext = [] # If there are less than x values, they are also shown if return_all or len(unique_values) <= ticks: # Make sure there are no multiple (same) labels for the same value for value, label in zip(unique_values, unique_labels): tickvals.append(value) ticktext.append(label) else: # Add min+max values for idx in [np.argmin(values), np.argmax(values)]: tickvals.append(values[idx]) ticktext.append(labels[idx]) # After min and max values are added, # intermediate values should be added too min_v = np.min(values) max_v = np.max(values) # Get values for each tick factors = [i / (ticks - 1) for i in range(1, ticks - 2)] for factor in factors: new_v = (factor * (max_v - min_v)) + min_v idx = np.abs(unique_values - new_v).argmin(axis=-1) value = unique_values[idx] label = unique_labels[idx] # Ignore if they are already in the list if value not in tickvals: tickvals.append(value) ticktext.append(label) # Show forced ones if forced is not None: for value, label, force in zip(values, labels, forced): if force and value not in tickvals: tickvals.append(value) ticktext.append(label) return tickvals, ticktext
[docs] def get_hovertext_from_config( run: AbstractRun, config_id: int, budget: Optional[Union[int, float]] = None ) -> str: """ Generate hover text with metrics for a configuration. The method gets information about a given configuration, including a link, its objectives, budget, costs and hyperparameters. Parameters ---------- run : AbstractRun The run instance config_id : int The id of the configuration budget : Optional[Union[int, float]] Budget to get the hovertext for. If no budget is given, the highest budget is chosen. By default None. Returns ------- str The hover text string containing the configuration information. """ if config_id < 0: return "" # Retrieve the link for the config id from deepcave.plugins.summary.configurations import Configurations link = Configurations.get_link(run, config_id) string = "<b>Configuration ID: " string += f"<a href='{link}' style='color: #ffffff'>{int(config_id)}</a></b><br><br>" # It's also nice to see the metrics objectives = run.get_objectives() if budget is None or budget == -1: highest_budget = run.get_highest_budget(config_id) assert highest_budget is not None string += f"<b>Objectives</b> (on highest found budget {round(highest_budget, 2)})<br>" else: string += f"<b>Objectives</b> (on budget {round(budget, 2)})<br>" try: avg_c, std_c = run.get_avg_costs(config_id, budget=budget) avg_costs: List[Optional[float]] = list(avg_c) std_costs: List[Optional[float]] = list(std_c) except ValueError: avg_costs = [None for _ in range(len(objectives))] std_costs = [None for _ in range(len(objectives))] for objective, cost, std_cost in zip(objectives, avg_costs, std_costs): if std_cost == 0.0: string += f"{objective.name}: {cost}<br>" else: string += f"{objective.name}: {cost} ± {std_cost}<br>" string += "<br><b>Hyperparameters</b>:<br>" config = run.get_config(config_id) for k, v in config.items(): string += f"{k}: {v}<br>" return string
[docs] def generate_config_code(register: Callable, variables: List[str]) -> List[Component]: """ Generate HTML components to display code. Parameters ---------- register : Callable A Callable for registering Dash components. The register_input function is located in the Plugin class. variables : List[str] A List of variable names. Returns ------- List[Component] A List of Dash components. """ code = """ from ConfigSpace.configuration_space import ConfigurationSpace, Configuration # Create configspace cs = ConfigurationSpace.from_json({{path}}) # Create config values = {{config_dict}} config = Configuration(cs, values=values) """ components = [] for line in code.splitlines(): if len(line) == 0: components += [html.Br()] continue count_trailing_spaces = 0 for char in line: if char == " ": count_trailing_spaces += 1 else: break count_trailing_tabs = (count_trailing_spaces - 4) / 4 trailing_style = {"margin-left": f"{count_trailing_tabs*2}em"} skip = False # Check if variable inside for variable in variables: match = re.search("{{(.+?)}}", line) if match: link = match.group(1) if link == variable: components += [ # Add beginning html.Code(line[: match.start()], style=trailing_style), # Add variable html.Code(id=register(variable, "children")), # Add ending html.Code(line[match.end() :]), html.Br(), ] skip = True break if skip: continue components += [ html.Code(line, style=trailing_style), html.Br(), ] components = components[1 : len(components) - 1] return components