Source code for cave.analyzer.budgets.budget_correlation

import warnings
from typing import List

from bokeh.embed import components
from bokeh.io import output_notebook
from bokeh.layouts import column, row
from bokeh.models import ColumnDataSource, CustomJS, Range1d
from bokeh.models.widgets import DataTable, TableColumn, Select
from bokeh.plotting import show, figure
from pandas import DataFrame
from scipy.stats import spearmanr

from cave.analyzer.base_analyzer import BaseAnalyzer
from cave.utils.hpbandster_helpers import format_budgets


[docs]class BudgetCorrelation(BaseAnalyzer): """ Use spearman correlation to get a correlation-value and a p-value for every pairwise combination of budgets. First value is the correlation, second is the p-value (the p-value roughly estimates the likelihood to obtain this correlation coefficient with uncorrelated datasets). This can be used to estimate how well a budget approximates the function to be optimized. """ def __init__(self, runscontainer): """ Parameters ---------- runscontainer: RunsContainer contains all important information about the configurator runs """ super().__init__(runscontainer) self.runs = self.runscontainer.get_aggregated(keep_budgets=True, keep_folders=False) self.budget_names = list(format_budgets(self.runscontainer.get_budgets(), allow_whitespace=True).values()) self.logger.debug("Budget names: %s", str(self.budget_names)) # To be set self.dataframe = None
[docs] def get_name(self): return "Budget Correlation"
[docs] def _get_table(self, runs): table = [] for b1 in runs: table.append([]) for b2 in runs: configs = set(b1.combined_runhistory.get_all_configs()).intersection( set(b2.combined_runhistory.get_all_configs())) if len(configs) < 1: table[-1].append("N/A") continue costs = list(zip(*[(b1.combined_runhistory.get_cost(c), b2.combined_runhistory.get_cost(c)) for c in configs])) rho, p = spearmanr(costs[0], costs[1]) # Differentiate to generate upper diagonal if runs.index(b2) < runs.index(b1): table[-1].append("") else: table[-1].append("{:.2f} ({} samples)".format(rho, len(costs[0]))) return DataFrame(data=table, columns=self.budget_names, index=self.budget_names)
[docs] def plot(self): """Create table and plot that reacts to selection of cells by updating the plotted data to visualize correlation.""" return self._plot(self.runs)
[docs] def _plot(self, runs): """ Create table and plot that reacts to selection of cells by updating the plotted data to visualize correlation. Parameters ---------- runs: List[ConfiguratorRun] list with runs (budgets) to be compared """ df = self._get_table(runs) # Create CDS from pandas dataframe budget_names = list(df.columns.values) data = dict(df[budget_names]) data["Budget"] = df.index.tolist() table_source = ColumnDataSource(data) # Create bokeh-datatable columns = [TableColumn(field='Budget', title="Budget", sortable=False, width=20)] + [ TableColumn(field=header, title=header, default_sort='descending', width=10) for header in budget_names ] bokeh_table = DataTable(source=table_source, columns=columns, index_position=None, sortable=False, height=20 + 30 * len(data["Budget"])) # Create CDS for scatter-plot all_configs = set([a for b in [run.original_runhistory.get_all_configs() for run in runs] for a in b]) data = {self.budget_names[idx]: [run.original_runhistory.get_cost(c) if c in # TODO run.original_runhistory.get_all_configs() else None for c in all_configs] for idx, run in enumerate(runs)} data['x'] = [] data['y'] = [] # Default scatter should be lowest vs highest: for x, y in zip(data[budget_names[0]], data[budget_names[-1]]): if x is not None and y is not None: data['x'].append(x) data['y'].append(y) with warnings.catch_warnings(record=True) as list_of_warnings: # Catch unmatching column lengths warning warnings.simplefilter('always') scatter_source = ColumnDataSource(data=data) for w in list_of_warnings: self.logger.debug("During budget correlation a %s was raised: %s", str(w.category), w.message) # Create figure and dynamically updating plot (linked with table) min_val = min([min([v for v in val if v]) for val in data.values() if len(val) > 0]) max_val = max([max([v for v in val if v]) for val in data.values() if len(val) > 0]) padding = (max_val - min_val) / 10 # Small padding to border (fraction of total intervall) min_val -= padding max_val += padding p = figure(plot_width=400, plot_height=400, match_aspect=True, y_range=Range1d(start=min_val, end=max_val, bounds=(min_val, max_val)), x_range=Range1d(start=min_val, end=max_val, bounds=(min_val, max_val)), x_axis_label=budget_names[0], y_axis_label=budget_names[-1]) p.circle(x='x', y='y', # x=jitter('x', 0.1), y=jitter('y', 0.1), source=scatter_source, size=5, color="navy", alpha=0.5) code_budgets = 'var budgets = ' + str(budget_names) + '; console.log(budgets);' code_try = 'try {' code_get_selected_cell = """ // This first part only extracts selected row and column! var grid = document.getElementsByClassName('grid-canvas')[0].children; var row = ''; var col = ''; for (var i=0,max=grid.length;i<max;i++){ if (grid[i].outerHTML.includes('active')){ row=i; for (var j=0,jmax=grid[i].children.length;j<jmax;j++){ if(grid[i].children[j].outerHTML.includes('active')){col=j} } } } col = col - 1; console.log('row', row, budgets[row]); console.log('col', col, budgets[col]); table_source.selected.indices = []; // Reset, so gets triggered again when clicked again """ code_selected = """ row = budgets.indexOf(select_x.value); col = budgets.indexOf(select_y.value); """ code_update_selection_values = """ select_x.value = budgets[row]; select_y.value = budgets[col]; """ code_update_plot = """ // This is the actual updating of the plot if (row => 0 && col > 0) { // Copy relevant arrays var new_x = scatter_source.data[budgets[row]].slice(); var new_y = scatter_source.data[budgets[col]].slice(); // Remove all pairs where one value is null while ((next_null = new_x.indexOf(null)) > -1) { new_x.splice(next_null, 1); new_y.splice(next_null, 1); } while ((next_null = new_y.indexOf(null)) > -1) { new_x.splice(next_null, 1); new_y.splice(next_null, 1); } // Assign new data to the plotted columns scatter_source.data['x'] = new_x; scatter_source.data['y'] = new_y; scatter_source.change.emit(); // Update axis-labels xaxis.attributes.axis_label = budgets[row]; yaxis.attributes.axis_label = budgets[col]; // Update ranges var min = Math.min(...[Math.min(...new_x), Math.min(...new_y)]) max = Math.max(...[Math.max(...new_x), Math.max(...new_y)]); var padding = (max - min) / 10; console.log(min, max, padding); xr.start = min - padding; yr.start = min - padding; xr.end = max + padding; yr.end = max + padding; } """ code_catch = """ } catch(err) { console.log(err.message); } """ code_selected = code_budgets + code_try + code_selected + code_update_plot + code_catch select_x = Select(title="X-axis:", value=budget_names[0], options=budget_names) select_y = Select(title="Y-axis:", value=budget_names[-1], options=budget_names) callback_select = CustomJS(args=dict(scatter_source=scatter_source, select_x=select_x, select_y=select_y, xaxis=p.xaxis[0], yaxis=p.yaxis[0], xr=p.x_range, yr=p.y_range, ), code=code_selected) select_x.js_on_change('value', callback_select) select_y.js_on_change('value', callback_select) code_table_cell = code_budgets + code_try + code_get_selected_cell + code_update_selection_values code_table_cell += code_update_plot + code_catch callback_table_cell = CustomJS(args=dict(table_source=table_source, scatter_source=scatter_source, select_x=select_x, select_y=select_y, xaxis=p.xaxis[0], yaxis=p.yaxis[0], xr=p.x_range, yr=p.y_range, ), code=code_table_cell) table_source.selected.js_on_change('indices', callback_table_cell) layout = column(bokeh_table, row(p, column(select_x, select_y))) return layout
[docs] def get_html(self, d=None, tooltip=None): script, div = components(self.plot()) if d is not None: d["Budget Correlation"] = { "bokeh": (script, div), "tooltip": self.__doc__, } return script, div
[docs] def get_jupyter(self): output_notebook() show(self.plot())