Metrics

In Auto-sklearn, model is optimized over a metric, either built-in or custom metric. Moreover, it is also possible to calculate multiple metrics per run. The following examples show how to calculate metrics built-in and self-defined metrics for a classification problem.

import autosklearn.classification
import numpy as np
import pandas as pd
import sklearn.datasets
import sklearn.metrics
from autosklearn.metrics import balanced_accuracy, precision, recall, f1


def error(solution, prediction):
    # custom function defining error
    return np.mean(solution != prediction)


def get_metric_result(cv_results):
    results = pd.DataFrame.from_dict(cv_results)
    results = results[results["status"] == "Success"]
    cols = ["rank_test_scores", "param_classifier:__choice__", "mean_test_score"]
    cols.extend([key for key in cv_results.keys() if key.startswith("metric_")])
    return results[cols]

Data Loading

X, y = sklearn.datasets.load_breast_cancer(return_X_y=True)
X_train, X_test, y_train, y_test = sklearn.model_selection.train_test_split(
    X, y, random_state=1
)

Build and fit a classifier

error_rate = autosklearn.metrics.make_scorer(
    name="custom_error",
    score_func=error,
    optimum=0,
    greater_is_better=False,
    needs_proba=False,
    needs_threshold=False,
)
cls = autosklearn.classification.AutoSklearnClassifier(
    time_left_for_this_task=120,
    per_run_time_limit=30,
    scoring_functions=[balanced_accuracy, precision, recall, f1, error_rate],
)
cls.fit(X_train, y_train, X_test, y_test)
AutoSklearnClassifier(ensemble_class=<class 'autosklearn.ensembles.ensemble_selection.EnsembleSelection'>,
                      per_run_time_limit=30,
                      scoring_functions=[balanced_accuracy, precision, recall,
                                         f1, custom_error],
                      time_left_for_this_task=120)

Get the Score of the final ensemble

predictions = cls.predict(X_test)
print("Accuracy score", sklearn.metrics.accuracy_score(y_test, predictions))

print("#" * 80)
print("Metric results")
print(get_metric_result(cls.cv_results_).to_string(index=False))
Accuracy score 0.958041958041958
################################################################################
Metric results
 rank_test_scores param_classifier:__choice__  mean_test_score  metric_balanced_accuracy  metric_precision  metric_recall  metric_f1  metric_custom_error
                6               random_forest         0.971631                  0.969533          0.977528       0.977528   0.977528             0.028369
                6                         mlp         0.971631                  0.961538          0.956989       1.000000   0.978022             0.028369
               26                         mlp         0.943262                  0.935069          0.945055       0.966292   0.955556             0.056738
               16               random_forest         0.964539                  0.959918          0.966667       0.977528   0.972067             0.035461
                6                         mlp         0.971631                  0.961538          0.956989       1.000000   0.978022             0.028369
                1                 extra_trees         0.985816                  0.984767          0.988764       0.988764   0.988764             0.014184
               16               random_forest         0.964539                  0.963915          0.977273       0.966292   0.971751             0.035461
               22                 extra_trees         0.957447                  0.954300          0.966292       0.966292   0.966292             0.042553
                6               random_forest         0.971631                  0.969533          0.977528       0.977528   0.977528             0.028369
                6               random_forest         0.971631                  0.969533          0.977528       0.977528   0.977528             0.028369
               16           gradient_boosting         0.964539                  0.963915          0.977273       0.966292   0.971751             0.035461
                6           gradient_boosting         0.971631                  0.965536          0.967033       0.988764   0.977778             0.028369
                6                         mlp         0.971631                  0.965536          0.967033       0.988764   0.977778             0.028369
               24                         mlp         0.950355                  0.948682          0.965909       0.955056   0.960452             0.049645
                3           gradient_boosting         0.978723                  0.975151          0.977778       0.988764   0.983240             0.021277
               16           gradient_boosting         0.964539                  0.959918          0.966667       0.977528   0.972067             0.035461
               16               random_forest         0.964539                  0.959918          0.966667       0.977528   0.972067             0.035461
                6                 extra_trees         0.971631                  0.969533          0.977528       0.977528   0.977528             0.028369
               31          passive_aggressive         0.921986                  0.894231          0.890000       1.000000   0.941799             0.078014
                3                 extra_trees         0.978723                  0.975151          0.977778       0.988764   0.983240             0.021277
                6           gradient_boosting         0.971631                  0.965536          0.967033       0.988764   0.977778             0.028369
               24                         mlp         0.950355                  0.940687          0.945652       0.977528   0.961326             0.049645
               27               random_forest         0.929078                  0.923833          0.943820       0.943820   0.943820             0.070922
               22                    adaboost         0.957447                  0.950303          0.956044       0.977528   0.966667             0.042553
                6                 extra_trees         0.971631                  0.965536          0.967033       0.988764   0.977778             0.028369
                1                 extra_trees         0.985816                  0.984767          0.988764       0.988764   0.988764             0.014184
               27                bernoulli_nb         0.929078                  0.927831          0.954023       0.932584   0.943182             0.070922
               33                         mlp         0.865248                  0.817308          0.824074       1.000000   0.903553             0.134752
                3                 extra_trees         0.978723                  0.975151          0.977778       0.988764   0.983240             0.021277
               16               random_forest         0.964539                  0.963915          0.977273       0.966292   0.971751             0.035461
               31           gradient_boosting         0.921986                  0.930207          0.975610       0.898876   0.935673             0.078014
               27               decision_tree         0.929078                  0.931828          0.964706       0.921348   0.942529             0.070922
               27         k_nearest_neighbors         0.929078                  0.911841          0.915789       0.977528   0.945652             0.070922

Total running time of the script: ( 2 minutes 0.007 seconds)

Gallery generated by Sphinx-Gallery