Source code for deepcave.plugins.hyperparameter.pdp

# Copyright 2021-2024 The DeepCAVE Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

#  noqa: D400
"""
# PartialDependencies

This module provides utilities for generating Partial Dependency Plots (PDPs).

Provided utilities include getting input and output layout (filtered or non-filtered),
processing the data and loading the outputs.

## Classes
    - PartialDependencies: Generate a Partial Dependency Plot (PDP).

## Constants
    GRID_POINTS_PER_AXIS : int
    SAMPLES_PER_HP : int
    MAX_SAMPLES : int
    MAX_SHOWN_SAMPLES : int
"""

from typing import Any, Callable, Dict, List

import dash_bootstrap_components as dbc
import numpy as np
import plotly.graph_objs as go
from dash import dcc, html
from pyPDP.algorithms.pdp import PDP

from deepcave import config
from deepcave.evaluators.epm.random_forest_surrogate import RandomForestSurrogate
from deepcave.plugins.static import StaticPlugin
from deepcave.runs import Status
from deepcave.utils.layout import get_checklist_options, get_select_options, help_button
from deepcave.utils.styled_plotty import get_color, get_hyperparameter_ticks, save_image

GRID_POINTS_PER_AXIS = 20
SAMPLES_PER_HP = 10
MAX_SAMPLES = 10000
MAX_SHOWN_SAMPLES = 100


[docs] class PartialDependencies(StaticPlugin): """ Generate Partial Dependency Plots (PDP). Provided utilities include getting input and output layout (filtered or non-filtered), processing the data and loading the outputs. """ id = "pdp" name = "Partial Dependencies" icon = "fas fa-grip-lines" help = "docs/plugins/partial_dependencies.rst" activate_run_selection = True
[docs] @staticmethod def get_input_layout(register: Callable) -> List[dbc.Row]: """ Get the layout for the input block. Parameters ---------- register : Callable Method to register (user) variables. The register_input function is located in the Plugin superclass. Returns ------- List[dbc.Row] The layout for the input block. """ return [ dbc.Row( [ dbc.Col( [ dbc.Label("Objective"), dbc.Select( id=register("objective_id", ["value", "options"], type=int), placeholder="Select objective ...", ), ], md=6, ), dbc.Col( [ dbc.Label("Budget"), help_button( "Budget refers to the multi-fidelity budget. " "Combined budget means that the trial on the highest" " evaluated budget is used. \n " "Note: Selecting combined budget might be misleading if" " a time objective is used. Often, higher budget take " " longer to evaluate, which might negatively influence " " the results." ), dbc.Select( id=register("budget_id", ["value", "options"], type=int), placeholder="Select budget ...", ), ], md=6, ), ], className="mb-3", ), dbc.Row( [ dbc.Col( [ dbc.Label("Hyperparameter #1"), dbc.Select( id=register("hyperparameter_name_1", ["value", "options"]), placeholder="Select hyperparameter ...", ), ], md=6, ), dbc.Col( [ dbc.Label("Hyperparameter #2"), dbc.Select( id=register("hyperparameter_name_2", ["value", "options"]), placeholder="Select hyperparameter ...", ), ], md=6, ), ], ), ]
[docs] @staticmethod def get_filter_layout(register: Callable) -> List[Any]: """ Get the layout for the filter block. Parameters ---------- register : Callable Method to register (user) variables. The register_input function is located in the Plugin superclass. Returns ------- List[Any] The layout for the filter block. """ return [ dbc.Row( [ dbc.Col( [ html.Div( [ dbc.Label("Show confidence"), help_button("Displays the confidence bands."), dbc.Select( id=register("show_confidence", ["value", "options"]) ), ] ) ], md=6, ), dbc.Col( [ html.Div( [ dbc.Label("Show ICE curves"), help_button( "Displays the ICE curves from which the PDP curve is " "derivied." ), dbc.Select(id=register("show_ice", ["value", "options"])), ] ) ], md=6, ), ], ), ]
[docs] def load_inputs(self) -> Dict[str, Dict[str, Any]]: """ Load the content for the defined inputs in 'get_input_layout' and 'get_filter_layout'. This method is necessary to pre-load contents for the inputs. If the plugin is called for the first time, or there are no results in the cache, the plugin gets its content from this method. Returns ------- Dict[str, Dict[str, Any]] Content to be filled. """ return { "show_confidence": {"options": get_select_options(binary=True), "value": "false"}, "show_ice": {"options": get_select_options(binary=True), "value": "true"}, }
[docs] def load_dependency_inputs(self, run, previous_inputs, inputs) -> Dict[str, Any]: # type: ignore # noqa: E501 """ Work like 'load_inputs' but called after inputs have changed. Note ---- Only the changes have to be returned. The returned dictionary will be merged with the inputs. Parameters ---------- run The selected run. inputs Current content of the inputs. previous_inputs Previous content of the inputs. Not used in this specific function. Returns ------- Dict[str, Any] Dictionary with the changes. """ objective_names = run.get_objective_names() objective_ids = run.get_objective_ids() objective_options = get_select_options(objective_names, objective_ids) budgets = run.get_budgets(human=True) budget_ids = run.get_budget_ids() budget_options = get_checklist_options(budgets, budget_ids) hp_names = list(run.configspace.keys()) # Get selected values objective_value = inputs["objective_id"]["value"] budget_value = inputs["budget_id"]["value"] hp1_value = inputs["hyperparameter_name_1"]["value"] if objective_value is None: objective_value = objective_ids[0] budget_value = budget_ids[-1] hp1_value = hp_names[0] return { "objective_id": {"options": objective_options, "value": objective_value}, "budget_id": {"options": budget_options, "value": budget_value}, "hyperparameter_name_1": { "options": get_checklist_options(hp_names), "value": hp1_value, }, "hyperparameter_name_2": { "options": get_checklist_options([None] + hp_names), }, }
[docs] @staticmethod def process(run, inputs) -> Dict[str, Any]: # type: ignore """ Return raw data based on a run and the input data. Warning ------- The returned data must be JSON serializable. Note ---- The passed inputs are cleaned and therefore differ compared to 'load_inputs' or 'load_dependency_inputs'. Please see '_clean_inputs' for more information. Parameters ---------- run The run to process. inputs The input data. Returns ------- Dict[str, Any] A serialized dictionary. Raises ------ RuntimeError If the objective is None. """ # Surrogate hp_names = list(run.configspace.keys()) objective = run.get_objective(inputs["objective_id"]) budget = run.get_budget(inputs["budget_id"]) hp1 = inputs["hyperparameter_name_1"] hp2 = inputs["hyperparameter_name_2"] if objective is None: raise RuntimeError("Objective not found.") # Encode data df = run.get_encoded_data( objective, budget, specific=True, statuses=Status.SUCCESS, ) X = df[hp_names].to_numpy() Y = df[objective.name].to_numpy() # Let's initialize the surrogate surrogate_model = RandomForestSurrogate(run.configspace, seed=0) surrogate_model.fit(X, Y) # Prepare the hyperparameters selected_hyperparameters = [hp1] if hp2 is not None and hp2 != "": selected_hyperparameters += [hp2] num_samples = SAMPLES_PER_HP * len(X) # The samples are limited to max 10k if num_samples > MAX_SAMPLES: num_samples = MAX_SAMPLES # And finally call PDP pdp = PDP.from_random_points( surrogate_model, selected_hyperparameter=selected_hyperparameters, seed=0, num_grid_points_per_axis=GRID_POINTS_PER_AXIS, num_samples=num_samples, ) x = pdp.x_pdp.tolist() y = pdp.y_pdp.tolist() # The ICE curves have to be cut because it's too much data x_ice = pdp._ice.x_ice.tolist() y_ice = pdp._ice.y_ice.tolist() if len(x_ice) > MAX_SHOWN_SAMPLES: x_ice = x_ice[:MAX_SHOWN_SAMPLES] y_ice = y_ice[:MAX_SHOWN_SAMPLES] return { "x": x, "y": y, "variances": pdp.y_variances.tolist(), "x_ice": x_ice, "y_ice": y_ice, }
[docs] @staticmethod def get_output_layout(register: Callable) -> dcc.Graph: """ Get the layout for the output block. Parameters ---------- register : Callable Method to register outputs. The register_input function is located in the Plugin superclass. Returns ------- dcc.Graph Layout for the output block. """ return dcc.Graph( register("graph", "figure"), style={"height": config.FIGURE_HEIGHT}, config={"toImageButtonOptions": {"scale": config.FIGURE_DOWNLOAD_SCALE}}, )
[docs] @staticmethod def get_pdp_figure( # type: ignore run, inputs, outputs, show_confidence, show_ice, title=None, fontsize=None ) -> go.Figure: """ Create a figure of the Partial Dependency Plot (PDP). Parameters ---------- run The selected run. inputs Input and filter values from the user. outputs Raw output from the run. show_confidence Whether to show confidence in the plot. show_ice Whether to show ice curves in the plot. title Title of the plot. fontsize Fontsize of the plot. Returns ------- go.Figure The figure of the Partial Dependency Plot (PDP). """ # Parse inputs hp1_name = inputs["hyperparameter_name_1"] hp1_idx = run.configspace.get_idx_by_hyperparameter_name(hp1_name) hp1 = run.configspace[hp1_name] hp2_name = inputs["hyperparameter_name_2"] hp2_idx = None hp2 = None if hp2_name is not None and hp2_name != "": hp2_idx = run.configspace.get_idx_by_hyperparameter_name(hp2_name) hp2 = run.configspace[hp2_name] objective = run.get_objective(inputs["objective_id"]) objective_name = objective.name # Parse outputs x = np.asarray(outputs["x"]) y = np.asarray(outputs["y"]) sigmas = np.sqrt(np.asarray(outputs["variances"])) x_ice = np.asarray(outputs["x_ice"]) y_ice = np.asarray(outputs["y_ice"]) traces = [] if hp2 is None: # 1D # Add ICE curves if show_ice: for x_, y_ in zip(x_ice, y_ice): traces += [ go.Scatter( x=x_[:, hp1_idx], y=y_, line=dict(color=get_color(1, 0.1)), hoverinfo="skip", showlegend=False, ) ] if show_confidence: traces += [ go.Scatter( x=x[:, hp1_idx], y=y + sigmas, line=dict(color=get_color(0, 0.1)), hoverinfo="skip", showlegend=False, ) ] traces += [ go.Scatter( x=x[:, hp1_idx], y=y - sigmas, fill="tonexty", fillcolor=get_color(0, 0.2), line=dict(color=get_color(0, 0.1)), hoverinfo="skip", showlegend=False, ) ] traces += [ go.Scatter( x=x[:, hp1_idx], y=y, line=dict(color=get_color(0, 1)), hoverinfo="skip", showlegend=False, ) ] tickvals, ticktext = get_hyperparameter_ticks(hp1) # Allow to pass a fontsize (necessary when leveraging PDP in Symbolic Explanation) if fontsize is None: fontsize = config.FIGURE_FONT_SIZE layout = go.Layout( { "xaxis": { "tickvals": tickvals, "ticktext": ticktext, "title": hp1_name, }, "yaxis": { "title": objective_name, }, "title": title, "font": dict(size=fontsize), } ) else: z = y if show_confidence: z = sigmas traces += [ go.Contour( z=z, x=x[:, hp1_idx], y=x[:, hp2_idx], colorbar=dict( title=objective_name if not show_confidence else "Confidence (1-Sigma)", ), hoverinfo="skip", ) ] x_tickvals, x_ticktext = get_hyperparameter_ticks(hp1) y_tickvals, y_ticktext = get_hyperparameter_ticks(hp2) layout = go.Layout( dict( xaxis=dict(tickvals=x_tickvals, ticktext=x_ticktext, title=hp1_name), yaxis=dict(tickvals=y_tickvals, ticktext=y_ticktext, title=hp2_name), margin=config.FIGURE_MARGIN, title=title, font=dict(size=fontsize), ) ) figure = go.Figure(data=traces, layout=layout) save_image(figure, "pdp.pdf") return figure
[docs] @staticmethod def load_outputs(run, inputs, outputs): # type: ignore """ Read the raw data and prepare it for the layout. Note ---- The passed inputs are cleaned and therefore differ compared to 'load_inputs' or 'load_dependency_inputs'. Please see '_clean_inputs' for more information. Parameters ---------- run The selected run. inputs Input and filter values from the user. outputs Raw output from the run. Returns ------- go.Figure The figure of the Partial Dependency Plot (PDP). """ show_confidence = inputs["show_confidence"] show_ice = inputs["show_ice"] figure = PartialDependencies.get_pdp_figure(run, inputs, outputs, show_confidence, show_ice) return figure