Model Explanation

The following example shows how to fit a simple classification model with auto-sklearn and use the inspect module from scikit-learn to understand what affects the predictions.

import sklearn.datasets
from sklearn.inspection import plot_partial_dependence, permutation_importance
import matplotlib.pyplot as plt
import autosklearn.classification

Load Data and Build a Model

We start by loading the “Run or walk” dataset from OpenML and train an auto-sklearn model on it. For this dataset, the goal is to predict whether a person is running or walking based on accelerometer and gyroscope data collected by a phone. For more information see here.

dataset = sklearn.datasets.fetch_openml(data_id=40922)

# Note: To speed up the example, we subsample the dataset
dataset.data = dataset.data.sample(n=5000, random_state=1, axis="index")
dataset.target = dataset.target[dataset.data.index]

X_train, X_test, y_train, y_test = sklearn.model_selection.train_test_split(
    dataset.data, dataset.target, test_size=0.3, random_state=1
)

automl = autosklearn.classification.AutoSklearnClassifier(
    time_left_for_this_task=120,
    per_run_time_limit=30,
    tmp_folder="/tmp/autosklearn_inspect_predictions_example_tmp",
)
automl.fit(X_train, y_train, dataset_name="Run_or_walk_information")

s = automl.score(X_train, y_train)
print(f"Train score {s}")
s = automl.score(X_test, y_test)
print(f"Test score {s}")
Fitting to the training data:   0%|          | 0/120 [00:00<?, ?it/s, The total time budget for this task is 0:02:00]/home/runner/work/auto-sklearn/auto-sklearn/autosklearn/data/target_validator.py:187: UserWarning: Fitting transformer with a pandas series which has the dtype category. Inverse transform may not be able preserve dtype when converting to np.ndarray
  warnings.warn(

Fitting to the training data:   1%|          | 1/120 [00:01<01:59,  1.01s/it, The total time budget for this task is 0:02:00]
Fitting to the training data:   2%|1         | 2/120 [00:02<01:58,  1.00s/it, The total time budget for this task is 0:02:00]
Fitting to the training data:   2%|2         | 3/120 [00:03<01:57,  1.00s/it, The total time budget for this task is 0:02:00]
Fitting to the training data:   3%|3         | 4/120 [00:04<01:56,  1.00s/it, The total time budget for this task is 0:02:00]
Fitting to the training data:   4%|4         | 5/120 [00:05<01:55,  1.00s/it, The total time budget for this task is 0:02:00]
Fitting to the training data:   5%|5         | 6/120 [00:06<01:54,  1.00s/it, The total time budget for this task is 0:02:00]
Fitting to the training data:   6%|5         | 7/120 [00:07<01:53,  1.00s/it, The total time budget for this task is 0:02:00]
Fitting to the training data:   7%|6         | 8/120 [00:08<01:52,  1.00s/it, The total time budget for this task is 0:02:00]
Fitting to the training data:   8%|7         | 9/120 [00:09<01:51,  1.00s/it, The total time budget for this task is 0:02:00]
Fitting to the training data:   8%|8         | 10/120 [00:10<01:50,  1.00s/it, The total time budget for this task is 0:02:00]
Fitting to the training data:   9%|9         | 11/120 [00:11<01:49,  1.00s/it, The total time budget for this task is 0:02:00]
Fitting to the training data:  10%|#         | 12/120 [00:12<01:48,  1.00s/it, The total time budget for this task is 0:02:00]
Fitting to the training data:  11%|#         | 13/120 [00:13<01:47,  1.00s/it, The total time budget for this task is 0:02:00]
Fitting to the training data:  12%|#1        | 14/120 [00:14<01:46,  1.00s/it, The total time budget for this task is 0:02:00]
Fitting to the training data:  12%|#2        | 15/120 [00:15<01:45,  1.00s/it, The total time budget for this task is 0:02:00]
Fitting to the training data:  13%|#3        | 16/120 [00:16<01:44,  1.00s/it, The total time budget for this task is 0:02:00]
Fitting to the training data:  14%|#4        | 17/120 [00:17<01:43,  1.00s/it, The total time budget for this task is 0:02:00]
Fitting to the training data:  15%|#5        | 18/120 [00:18<01:42,  1.00s/it, The total time budget for this task is 0:02:00]
Fitting to the training data:  16%|#5        | 19/120 [00:19<01:41,  1.00s/it, The total time budget for this task is 0:02:00]
Fitting to the training data:  17%|#6        | 20/120 [00:20<01:40,  1.00s/it, The total time budget for this task is 0:02:00]
Fitting to the training data:  18%|#7        | 21/120 [00:21<01:39,  1.00s/it, The total time budget for this task is 0:02:00]
Fitting to the training data:  18%|#8        | 22/120 [00:22<01:38,  1.00s/it, The total time budget for this task is 0:02:00]
Fitting to the training data:  19%|#9        | 23/120 [00:23<01:37,  1.00s/it, The total time budget for this task is 0:02:00]
Fitting to the training data:  20%|##        | 24/120 [00:24<01:36,  1.00s/it, The total time budget for this task is 0:02:00]
Fitting to the training data:  21%|##        | 25/120 [00:25<01:35,  1.00s/it, The total time budget for this task is 0:02:00]
Fitting to the training data:  22%|##1       | 26/120 [00:26<01:34,  1.00s/it, The total time budget for this task is 0:02:00]
Fitting to the training data:  22%|##2       | 27/120 [00:27<01:33,  1.00s/it, The total time budget for this task is 0:02:00]
Fitting to the training data:  23%|##3       | 28/120 [00:28<01:32,  1.00s/it, The total time budget for this task is 0:02:00]
Fitting to the training data:  24%|##4       | 29/120 [00:29<01:31,  1.00s/it, The total time budget for this task is 0:02:00]
Fitting to the training data:  25%|##5       | 30/120 [00:30<01:30,  1.00s/it, The total time budget for this task is 0:02:00]
Fitting to the training data:  26%|##5       | 31/120 [00:31<01:29,  1.00s/it, The total time budget for this task is 0:02:00]
Fitting to the training data:  27%|##6       | 32/120 [00:32<01:28,  1.00s/it, The total time budget for this task is 0:02:00]
Fitting to the training data:  28%|##7       | 33/120 [00:33<01:27,  1.00s/it, The total time budget for this task is 0:02:00]
Fitting to the training data:  28%|##8       | 34/120 [00:34<01:26,  1.00s/it, The total time budget for this task is 0:02:00]
Fitting to the training data:  29%|##9       | 35/120 [00:35<01:25,  1.00s/it, The total time budget for this task is 0:02:00]
Fitting to the training data:  30%|###       | 36/120 [00:36<01:24,  1.00s/it, The total time budget for this task is 0:02:00]
Fitting to the training data:  31%|###       | 37/120 [00:37<01:23,  1.00s/it, The total time budget for this task is 0:02:00]
Fitting to the training data:  32%|###1      | 38/120 [00:38<01:22,  1.00s/it, The total time budget for this task is 0:02:00]
Fitting to the training data:  32%|###2      | 39/120 [00:39<01:21,  1.00s/it, The total time budget for this task is 0:02:00]
Fitting to the training data:  33%|###3      | 40/120 [00:40<01:20,  1.00s/it, The total time budget for this task is 0:02:00]
Fitting to the training data:  34%|###4      | 41/120 [00:41<01:19,  1.00s/it, The total time budget for this task is 0:02:00]
Fitting to the training data:  35%|###5      | 42/120 [00:42<01:18,  1.00s/it, The total time budget for this task is 0:02:00]
Fitting to the training data:  36%|###5      | 43/120 [00:43<01:17,  1.00s/it, The total time budget for this task is 0:02:00]
Fitting to the training data:  37%|###6      | 44/120 [00:44<01:16,  1.00s/it, The total time budget for this task is 0:02:00]
Fitting to the training data:  38%|###7      | 45/120 [00:45<01:15,  1.00s/it, The total time budget for this task is 0:02:00]
Fitting to the training data:  38%|###8      | 46/120 [00:46<01:14,  1.00s/it, The total time budget for this task is 0:02:00]
Fitting to the training data:  39%|###9      | 47/120 [00:47<01:13,  1.00s/it, The total time budget for this task is 0:02:00]
Fitting to the training data:  40%|####      | 48/120 [00:48<01:12,  1.00s/it, The total time budget for this task is 0:02:00]
Fitting to the training data:  41%|####      | 49/120 [00:49<01:11,  1.00s/it, The total time budget for this task is 0:02:00]
Fitting to the training data:  42%|####1     | 50/120 [00:50<01:10,  1.00s/it, The total time budget for this task is 0:02:00]
Fitting to the training data:  42%|####2     | 51/120 [00:51<01:09,  1.00s/it, The total time budget for this task is 0:02:00]
Fitting to the training data:  43%|####3     | 52/120 [00:52<01:08,  1.00s/it, The total time budget for this task is 0:02:00]
Fitting to the training data:  44%|####4     | 53/120 [00:53<01:07,  1.00s/it, The total time budget for this task is 0:02:00]
Fitting to the training data:  45%|####5     | 54/120 [00:54<01:06,  1.00s/it, The total time budget for this task is 0:02:00]
Fitting to the training data:  46%|####5     | 55/120 [00:55<01:05,  1.00s/it, The total time budget for this task is 0:02:00]
Fitting to the training data:  47%|####6     | 56/120 [00:56<01:04,  1.00s/it, The total time budget for this task is 0:02:00]
Fitting to the training data:  48%|####7     | 57/120 [00:57<01:03,  1.00s/it, The total time budget for this task is 0:02:00]
Fitting to the training data:  48%|####8     | 58/120 [00:58<01:02,  1.00s/it, The total time budget for this task is 0:02:00]
Fitting to the training data:  49%|####9     | 59/120 [00:59<01:01,  1.00s/it, The total time budget for this task is 0:02:00]
Fitting to the training data:  50%|#####     | 60/120 [01:00<01:00,  1.00s/it, The total time budget for this task is 0:02:00]
Fitting to the training data:  51%|#####     | 61/120 [01:01<00:59,  1.00s/it, The total time budget for this task is 0:02:00]
Fitting to the training data:  52%|#####1    | 62/120 [01:02<00:58,  1.00s/it, The total time budget for this task is 0:02:00]
Fitting to the training data:  52%|#####2    | 63/120 [01:03<00:57,  1.00s/it, The total time budget for this task is 0:02:00]
Fitting to the training data:  53%|#####3    | 64/120 [01:04<00:56,  1.00s/it, The total time budget for this task is 0:02:00]
Fitting to the training data:  54%|#####4    | 65/120 [01:05<00:55,  1.00s/it, The total time budget for this task is 0:02:00]
Fitting to the training data:  55%|#####5    | 66/120 [01:06<00:54,  1.00s/it, The total time budget for this task is 0:02:00]
Fitting to the training data:  56%|#####5    | 67/120 [01:07<00:53,  1.00s/it, The total time budget for this task is 0:02:00]
Fitting to the training data:  57%|#####6    | 68/120 [01:08<00:52,  1.00s/it, The total time budget for this task is 0:02:00]
Fitting to the training data:  57%|#####7    | 69/120 [01:09<00:51,  1.00s/it, The total time budget for this task is 0:02:00]
Fitting to the training data:  58%|#####8    | 70/120 [01:10<00:50,  1.00s/it, The total time budget for this task is 0:02:00]
Fitting to the training data:  59%|#####9    | 71/120 [01:11<00:49,  1.00s/it, The total time budget for this task is 0:02:00]
Fitting to the training data:  60%|######    | 72/120 [01:12<00:48,  1.00s/it, The total time budget for this task is 0:02:00]
Fitting to the training data:  61%|######    | 73/120 [01:13<00:47,  1.00s/it, The total time budget for this task is 0:02:00]
Fitting to the training data:  62%|######1   | 74/120 [01:14<00:46,  1.00s/it, The total time budget for this task is 0:02:00]
Fitting to the training data:  62%|######2   | 75/120 [01:15<00:45,  1.00s/it, The total time budget for this task is 0:02:00]
Fitting to the training data:  63%|######3   | 76/120 [01:16<00:44,  1.00s/it, The total time budget for this task is 0:02:00]
Fitting to the training data:  64%|######4   | 77/120 [01:17<00:43,  1.00s/it, The total time budget for this task is 0:02:00]
Fitting to the training data:  65%|######5   | 78/120 [01:18<00:42,  1.00s/it, The total time budget for this task is 0:02:00]
Fitting to the training data:  66%|######5   | 79/120 [01:19<00:41,  1.00s/it, The total time budget for this task is 0:02:00]
Fitting to the training data:  67%|######6   | 80/120 [01:20<00:40,  1.00s/it, The total time budget for this task is 0:02:00]
Fitting to the training data:  68%|######7   | 81/120 [01:21<00:39,  1.00s/it, The total time budget for this task is 0:02:00]
Fitting to the training data:  68%|######8   | 82/120 [01:22<00:38,  1.00s/it, The total time budget for this task is 0:02:00]
Fitting to the training data:  69%|######9   | 83/120 [01:23<00:37,  1.00s/it, The total time budget for this task is 0:02:00]
Fitting to the training data:  70%|#######   | 84/120 [01:24<00:36,  1.00s/it, The total time budget for this task is 0:02:00]
Fitting to the training data:  71%|#######   | 85/120 [01:25<00:35,  1.00s/it, The total time budget for this task is 0:02:00]
Fitting to the training data:  72%|#######1  | 86/120 [01:26<00:34,  1.00s/it, The total time budget for this task is 0:02:00]
Fitting to the training data:  72%|#######2  | 87/120 [01:27<00:33,  1.00s/it, The total time budget for this task is 0:02:00]
Fitting to the training data:  73%|#######3  | 88/120 [01:28<00:32,  1.00s/it, The total time budget for this task is 0:02:00]
Fitting to the training data:  74%|#######4  | 89/120 [01:29<00:31,  1.00s/it, The total time budget for this task is 0:02:00]
Fitting to the training data:  75%|#######5  | 90/120 [01:30<00:30,  1.00s/it, The total time budget for this task is 0:02:00]
Fitting to the training data:  76%|#######5  | 91/120 [01:31<00:29,  1.00s/it, The total time budget for this task is 0:02:00]
Fitting to the training data:  77%|#######6  | 92/120 [01:32<00:28,  1.00s/it, The total time budget for this task is 0:02:00]
Fitting to the training data:  78%|#######7  | 93/120 [01:33<00:27,  1.00s/it, The total time budget for this task is 0:02:00]
Fitting to the training data:  78%|#######8  | 94/120 [01:34<00:26,  1.00s/it, The total time budget for this task is 0:02:00]
Fitting to the training data:  79%|#######9  | 95/120 [01:35<00:25,  1.00s/it, The total time budget for this task is 0:02:00]
Fitting to the training data:  80%|########  | 96/120 [01:36<00:24,  1.00s/it, The total time budget for this task is 0:02:00]
Fitting to the training data:  81%|########  | 97/120 [01:37<00:23,  1.00s/it, The total time budget for this task is 0:02:00]
Fitting to the training data:  82%|########1 | 98/120 [01:38<00:22,  1.00s/it, The total time budget for this task is 0:02:00]
Fitting to the training data:  82%|########2 | 99/120 [01:39<00:21,  1.00s/it, The total time budget for this task is 0:02:00]
Fitting to the training data:  83%|########3 | 100/120 [01:40<00:20,  1.00s/it, The total time budget for this task is 0:02:00]
Fitting to the training data:  84%|########4 | 101/120 [01:41<00:19,  1.00s/it, The total time budget for this task is 0:02:00]
Fitting to the training data:  85%|########5 | 102/120 [01:42<00:18,  1.00s/it, The total time budget for this task is 0:02:00]
Fitting to the training data:  86%|########5 | 103/120 [01:43<00:17,  1.00s/it, The total time budget for this task is 0:02:00]
Fitting to the training data:  87%|########6 | 104/120 [01:44<00:16,  1.00s/it, The total time budget for this task is 0:02:00]
Fitting to the training data:  88%|########7 | 105/120 [01:45<00:15,  1.00s/it, The total time budget for this task is 0:02:00]
Fitting to the training data:  88%|########8 | 106/120 [01:46<00:14,  1.00s/it, The total time budget for this task is 0:02:00]
Fitting to the training data:  89%|########9 | 107/120 [01:47<00:13,  1.00s/it, The total time budget for this task is 0:02:00]
Fitting to the training data:  90%|######### | 108/120 [01:48<00:12,  1.00s/it, The total time budget for this task is 0:02:00]
Fitting to the training data:  91%|######### | 109/120 [01:49<00:11,  1.00s/it, The total time budget for this task is 0:02:00]
Fitting to the training data:  92%|#########1| 110/120 [01:50<00:10,  1.00s/it, The total time budget for this task is 0:02:00]
Fitting to the training data: 100%|##########| 120/120 [01:50<00:00,  1.09it/s, The total time budget for this task is 0:02:00]
Train score 0.9948571428571429
Test score 0.984

Compute permutation importance - part 1

Since auto-sklearn implements the scikit-learn interface, it can be used with the scikit-learn’s inspection module. So, now we first look at the permutation importance, which defines the decrease in a model score when a given feature is randomly permuted. So, the higher the score, the more does the model’s predictions depend on this feature.

Note: There are some pitfalls in interpreting these numbers, which can be found in the scikit-learn docs.

r = permutation_importance(automl, X_test, y_test, n_repeats=10, random_state=0)
sort_idx = r.importances_mean.argsort()[::-1]

plt.boxplot(
    r.importances[sort_idx].T, labels=[dataset.feature_names[i] for i in sort_idx]
)

plt.xticks(rotation=90)
plt.tight_layout()
plt.show()

for i in sort_idx[::-1]:
    print(
        f"{dataset.feature_names[i]:10s}: {r.importances_mean[i]:.3f} +/- "
        f"{r.importances_std[i]:.3f}"
    )
example inspect predictions
gyro_y    : 0.000 +/- 0.002
gyro_x    : 0.029 +/- 0.003
gyro_z    : 0.040 +/- 0.003
acceleration_x: 0.058 +/- 0.007
acceleration_z: 0.131 +/- 0.006
acceleration_y: 0.276 +/- 0.006

Create partial dependence (PD) and individual conditional expectation (ICE) plots - part 2

ICE plots describe the relation between feature values and the response value for each sample individually – it shows how the response value changes if the value of one feature is changed.

PD plots describe the relation between feature values and the response value, i.e. the expected response value wrt. one or multiple input features. Since we use a classification dataset, this corresponds to the predicted class probability.

Since acceleration_y and acceleration_z turned out to have the largest impact on the response value according to the permutation dependence, we’ll first look at them and generate a plot combining ICE (thin lines) and PD (thick line)

features = [1, 2]
plot_partial_dependence(
    automl,
    dataset.data,
    features=features,
    grid_resolution=5,
    kind="both",
    feature_names=dataset.feature_names,
)
plt.tight_layout()
plt.show()
example inspect predictions

Create partial dependence (PDP) plots for more than one feature - part 3

A PD plot can also be generated for two features and thus allow to inspect the interaction between these features. Again, we’ll look at acceleration_y and acceleration_z.

features = [[1, 2]]
plot_partial_dependence(
    automl,
    dataset.data,
    features=features,
    grid_resolution=5,
    feature_names=dataset.feature_names,
)
plt.tight_layout()
plt.show()
example inspect predictions

Total running time of the script: ( 4 minutes 15.369 seconds)

Gallery generated by Sphinx-Gallery