Source code for cave.plot.scatter

import itertools
import logging
import os

import matplotlib.pyplot as plt

plt.style.use(os.path.join(os.path.dirname(__file__), 'mpl_style'))  # noqa
from matplotlib.pyplot import setp

import numpy as np

"""
Mostly taken from https://bitbucket.org/aadfreiburg/plotting_scripts
"""


[docs]def plot_scatter_plot(x_data, y_data, labels, title="", min_val=None, max_val=1000, grey_factor=1, linefactors=None, user_fontsize=22, dpi=100, metric="runtime", jitter_timeout=True, markers=None, sizes=None, out_fn=None): """ Method to generate a scatter plot Parameters ---------- x_data: numpy.array performance values of one algorithm y_data: numpy.array performance values of the other algorithm labels: tuple (xlabel, ylabel) title: str title of plot min_val: float minimal value to plot max_val: float maximal value to plot grey_factor: float grey factor of points with a speedup of less 2 linefactors: list of floats factors of speedups user_fontsize: int font size dpi: int resolution metric: str "runtime" or something else jitter_timeout: bool Add some noise to remove timeout clutter """ if markers is None or len(markers) != 3: regular_marker = 'x' timeout_marker = '+' grey_marker = '.' else: regular_marker = markers[0] timeout_marker = markers[1] grey_marker = markers[2] if sizes is None or len(sizes) != 3: s_r = 5 s_t = 5 s_g = 5 else: s_r = sizes[0] s_t = sizes[1] s_g = sizes[2] c_angle_bisector = "#e41a1c" # Red c_good_points = "#999999" # Grey c_other_points = "k" size = 1 st_ref = "--" ticklabel_size = user_fontsize linefactor_size = user_fontsize - 2 label_size = user_fontsize + 1 # # ------ # maximum_value: location for timeout points # max_val : Initially user-defined timeout, then set to axes limit # time_out_val : location for timeout points # ----- if max_val is None: max_val = 1000 # raise ValueError("max_val cannot be None") maximum_value = max_val # Colors ref_colors = itertools.cycle([ # "#e41a1c", # Red "#377eb8", # Blue "#4daf4a", # Green "#984ea3", # Purple "#ff7f00", # Orange "#ffff33", # Yellow "#a65628", # Brown "#f781bf", # Pink # "#999999", # Grey ]) # set initial limits x_min = min([min(x) for x in x_data]) y_min = min([min(y) for y in y_data]) x_max = max([max(x) for x in x_data]) y_max = max([max(y) for y in y_data]) x_min = min([x_min, y_min]) y_min = x_min x_max = max([x_max, y_max]) y_max = x_max if min_val is not None: auto_min_val = min([x_min, y_min, min_val]) else: auto_min_val = min([x_min, y_min]) if metric == "runtime" or metric == "quality": timeout_factor = 2 timeout_val = maximum_value * timeout_factor auto_max_val = maximum_value else: timeout_factor = 1 timeout_val = 1 auto_max_val = max([x_max, y_max]) # Set up figure if len(x_data) > 1: fig = plt.figure(1, dpi=dpi, figsize=(10, 5)) ax1 = fig.add_subplot(1, 2, 1, adjustable='box', aspect=1) ax2 = fig.add_subplot(1, 2, 2, adjustable='box', aspect=1) axes = [ax1, ax2] else: fig = plt.figure(1, dpi=dpi, figsize=(10, 10)) ax1 = fig.add_subplot(1, 1, 1, adjustable='box', aspect=1) axes = [ax1] for ax in axes: ax.grid(True, linestyle='-', which='major', color='lightgrey', alpha=0.5) # Plot angle bisector and reference_lines out_up = auto_max_val out_lo = max(10**-6, auto_min_val) # if metric == "runtime" or metric == "quality": for ax in axes: ax.plot([out_lo, out_up], [out_lo, out_up], c=c_angle_bisector) if linefactors is not None: for f in linefactors: c = next(ref_colors) # Lower reference lines ax.plot([f*out_lo, out_up], [out_lo, (1.0/f)*out_up], c=c, linestyle=st_ref, linewidth=size*1.5) # Upper reference lines ax.plot([out_lo, (1.0/f)*out_up], [f*out_lo, out_up], c=c, linestyle=st_ref, linewidth=size*1.5) offset = 1.1 if int(f) == f: lf_str = "%dx" % f else: lf_str = "%2.1fx" % f ax.text((1.0/f)*out_up, out_up*offset+1000, lf_str, color=c, fontsize=linefactor_size) ax.text(out_up*offset+1000, (1.0/f)*out_up, lf_str, color=c, fontsize=linefactor_size) ####### # Scatter def scatter(x_data_, y_data_, ax): """ Encapsulated to support subplots if train and test are differentiated. """ logger = logging.getLogger("cave.scatter") logger.debug("Incumbent better: %d, default better: %d", len([x for x in x_data_ > y_data_ if x]), len([x for x in x_data_ < y_data_ if x])) grey_idx = list() timeout_x = list() timeout_y = list() timeout_both = list() rest_idx = list() for idx_x, x in enumerate(x_data_): if x >= max_val > y_data_[idx_x]: # timeout of x algo timeout_x.append(idx_x) elif y_data_[idx_x] >= max_val > x: # timeout of y algo timeout_y.append(idx_x) elif y_data_[idx_x] >= max_val and x >= max_val: # timeout of both algos timeout_both.append(idx_x) elif y_data_[idx_x] < grey_factor*x and x < grey_factor*y_data_[idx_x]: grey_idx.append(idx_x) else: rest_idx.append(idx_x) # Regular points if len(grey_idx) > 1: ax.scatter(x_data_[grey_idx], y_data_[grey_idx], marker=grey_marker, edgecolor='', facecolor=c_good_points, s=s_g) ax.scatter(x_data_[rest_idx], y_data_[rest_idx], marker=regular_marker, c=c_other_points, s=s_r) if metric == "runtime" or metric == "quality": # max_val lines ax.plot([maximum_value, maximum_value], [auto_min_val, maximum_value], c=c_other_points, linestyle="--", zorder=0, linewidth=size) ax.plot([auto_min_val, maximum_value], [maximum_value, maximum_value], c=c_other_points, linestyle="--", zorder=0, linewidth=size) # Timeout points if jitter_timeout: scat_x = np.random.randn(len(timeout_x), 1)*0.1*timeout_val + timeout_val scat_y = np.random.randn(len(timeout_y), 1)*0.1*timeout_val + timeout_val scat_both = (np.random.randn(len(timeout_both), 1)*0.1*timeout_val + timeout_val, np.random.randn(len(timeout_both), 1)*0.1*timeout_val + timeout_val) else: scat_x = [timeout_val]*len(timeout_x) scat_y = [timeout_val]*len(timeout_y) scat_both = ([timeout_val]*len(timeout_both), [timeout_val]*len(timeout_both)) ax.scatter(scat_x, y_data_[timeout_x], marker=timeout_marker, c=c_other_points, s=s_t) ax.scatter(scat_both[0], scat_both[1], marker=timeout_marker, c=c_other_points, s=s_t) ax.scatter(x_data_[timeout_y], scat_y, marker=timeout_marker, c=c_other_points, s=s_t) for x, y, ax in zip(x_data, y_data, axes): scatter(x, y, ax) # Set axes scale and limits # if metric == "runtime": for ax in axes: ax.set_xscale("log") ax.set_yscale("log") # Set axes labels for ax in axes: ax.set_xlabel(labels[0], fontsize=label_size) ax.set_ylabel(labels[1], fontsize=label_size) # if debug: # # Plot legend # for ax in axes: # leg = ax.legend(loc='best', fancybox=True) # leg.get_frame().set_alpha(0.5) max_val = timeout_val * timeout_factor auto_min_val *= 0.9 for ax in axes: ax.set_autoscale_on(False) if max_val is not None and min_val is None: # User sets max val ax.set_ylim([auto_min_val, max_val]) ax.set_xlim(ax.get_ylim()) elif max_val > min_val and max_val is not None and min_val is not None: # User sets both, min and max -val ax.set_ylim([min_val, max_val]) ax.set_xlim(ax.get_ylim()) else: # User sets nothing ax.set_xlim([auto_min_val, max_val]) ax.set_ylim(ax.get_xlim()) # Plot maximum value as tick if int(maximum_value) == maximum_value: maximum_value = int(maximum_value) maximum_str = r"$%d$" % maximum_value else: maximum_str = r"$%5.2f$" % maximum_value # if metric == "runtime" or metric == "quality": for ax in axes: if int(np.log10(maximum_value)) != np.log10(maximum_value): # If we do not already have this ticklabel as a regular label ax.text(ax.get_ylim()[0] - 0.1 * np.abs(ax.get_ylim()[0]), maximum_value, maximum_str, horizontalalignment='right', verticalalignment="center", fontsize=user_fontsize) ax.text(maximum_value, ax.get_ylim()[0] - 0.1 * np.abs(ax.get_ylim()[0]), maximum_str, horizontalalignment='center', verticalalignment="top", fontsize=user_fontsize) # Plot 'timeout' ax.text(ax.get_xlim()[0] - 0.1 * np.abs(ax.get_ylim()[0]), timeout_val, "timeout ", horizontalalignment='right', verticalalignment="center", fontsize=user_fontsize, rotation=30) ax.text(timeout_val, ax.get_ylim()[0] - 0.1 * np.abs(ax.get_ylim()[0]), "timeout ", horizontalalignment='center', verticalalignment="top", fontsize=user_fontsize, rotation=30) ######### # Adjust ticks > max_val ax.xaxis.set_ticks_position('bottom') ax.yaxis.set_ticks_position('left') # major axes for tic in ax.xaxis.get_major_ticks(): if tic._loc > maximum_value: tic.tick1On = tic.tick2On = False for tic in ax.yaxis.get_major_ticks(): if tic._loc > maximum_value: tic.tick1On = tic.tick2On = False # minor axes for tic in ax.xaxis.get_minor_ticks(): if tic._loc > maximum_value: tic.tick1On = tic.tick2On = False for tic in ax.yaxis.get_minor_ticks(): if tic._loc > maximum_value: tic.tick1On = tic.tick2On = False # tick labels for ax in axes: ticks_x = ax.get_xticks() new_ticks_label = list() for l_idx in range(len(ticks_x)): if ticks_x[l_idx] < maximum_value: if 0 < ticks_x[l_idx] < 1: new_ticks_label.append(str(r"$10^{%d}$" % int(np.log10(ticks_x[l_idx])))) if 1 <= ticks_x[l_idx] < 1000: new_ticks_label.append(str(r"$%d^{ }$" % int(ticks_x[l_idx]))) if 1000 <= ticks_x[l_idx]: new_ticks_label.append(str(r"$10^{%d}$" % int(np.log10(ticks_x[l_idx])))) ax.set_xticklabels(new_ticks_label) # , rotation=45) ax.set_yticklabels(new_ticks_label) # , rotation=45) # Change fontsize for ticklabels for ax in axes: setp(ax1.get_yticklabels(), fontsize=ticklabel_size) setp(ax1.get_xticklabels(), fontsize=ticklabel_size) fig.tight_layout() fig.savefig(out_fn) plt.close(fig) return out_fn