Tabular Classification with Custom Configuration Space

The following example shows how adjust the configuration space of the search. Currently, there are two changes that can be made to the space:-

  1. Adjust individual hyperparameters in the pipeline

  2. Include or exclude components:
    1. include: Dictionary containing components to include. Key is the node

      name and Value is an Iterable of the names of the components to include. Only these components will be present in the search space.

    2. exclude: Dictionary containing components to exclude. Key is the node

      name and Value is an Iterable of the names of the components to exclude. All except these components will be present in the search space.

import os
import tempfile as tmp
import warnings

os.environ['JOBLIB_TEMP_FOLDER'] = tmp.gettempdir()
os.environ['OMP_NUM_THREADS'] = '1'
os.environ['OPENBLAS_NUM_THREADS'] = '1'
os.environ['MKL_NUM_THREADS'] = '1'

warnings.simplefilter(action='ignore', category=UserWarning)
warnings.simplefilter(action='ignore', category=FutureWarning)

import sklearn.datasets
import sklearn.model_selection

from autoPyTorch.api.tabular_classification import TabularClassificationTask
from autoPyTorch.utils.hyperparameter_search_space_update import HyperparameterSearchSpaceUpdates


def get_search_space_updates():
    """
    Search space updates to the task can be added using HyperparameterSearchSpaceUpdates
    Returns:
        HyperparameterSearchSpaceUpdates
    """
    updates = HyperparameterSearchSpaceUpdates()
    updates.append(node_name="data_loader",
                   hyperparameter="batch_size",
                   value_range=[16, 512],
                   default_value=32)
    updates.append(node_name="lr_scheduler",
                   hyperparameter="CosineAnnealingLR:T_max",
                   value_range=[50, 60],
                   default_value=55)
    updates.append(node_name='network_backbone',
                   hyperparameter='ResNetBackbone:dropout',
                   value_range=[0, 0.5],
                   default_value=0.2)
    return updates

Data Loading

X, y = sklearn.datasets.fetch_openml(data_id=40981, return_X_y=True, as_frame=True)
X_train, X_test, y_train, y_test = sklearn.model_selection.train_test_split(
    X,
    y,
    random_state=1,
)

Build and fit a classifier with include components

api = TabularClassificationTask(
    search_space_updates=get_search_space_updates(),
    include_components={'network_backbone': ['MLPBackbone', 'ResNetBackbone'],
                        'encoder': ['OneHotEncoder']}
)

Search for an ensemble of machine learning algorithms

api.search(
    X_train=X_train.copy(),
    y_train=y_train.copy(),
    X_test=X_test.copy(),
    y_test=y_test.copy(),
    optimize_metric='accuracy',
    total_walltime_limit=150,
    func_eval_time_limit_secs=30
)
<autoPyTorch.api.tabular_classification.TabularClassificationTask object at 0x7f9aa698ca30>

Build and fit a classifier with exclude components

api = TabularClassificationTask(
    search_space_updates=get_search_space_updates(),
    exclude_components={'network_backbone': ['MLPBackbone'],
                        'encoder': ['OneHotEncoder']}
)

Search for an ensemble of machine learning algorithms

api.search(
    X_train=X_train,
    y_train=y_train,
    X_test=X_test.copy(),
    y_test=y_test.copy(),
    optimize_metric='accuracy',
    total_walltime_limit=150,
    func_eval_time_limit_secs=30
)
<autoPyTorch.api.tabular_classification.TabularClassificationTask object at 0x7f9aa41e6f40>

Print the final ensemble performance

y_pred = api.predict(X_test)
score = api.score(y_pred, y_test)
print(score)
print(api.show_models())

# Print statistics from search
print(api.sprint_statistics())
{'accuracy': 0.8728323699421965}
|    | Preprocessing                                                                                | Estimator                                                       |   Weight |
|---:|:---------------------------------------------------------------------------------------------|:----------------------------------------------------------------|---------:|
|  0 | None                                                                                         | LGBMLearner                                                     |     0.36 |
|  1 | None                                                                                         | RFLearner                                                       |     0.26 |
|  2 | None                                                                                         | ETLearner                                                       |     0.14 |
|  3 | SimpleImputer,Variance Threshold,NoCoalescer,NoEncoder,StandardScaler,NoFeaturePreprocessing | no embedding,ShapedMLPBackbone,FullyConnectedHead,nn.Sequential |     0.1  |
|  4 | None                                                                                         | SVMLearner                                                      |     0.08 |
|  5 | SimpleImputer,Variance Threshold,NoCoalescer,NoEncoder,Normalizer,KernelPCA                  | no embedding,ShapedMLPBackbone,FullyConnectedHead,nn.Sequential |     0.02 |
|  6 | None                                                                                         | KNNLearner                                                      |     0.02 |
|  7 | SimpleImputer,Variance Threshold,NoCoalescer,NoEncoder,StandardScaler,NoFeaturePreprocessing | no embedding,ShapedMLPBackbone,FullyConnectedHead,nn.Sequential |     0.02 |
autoPyTorch results:
        Dataset name: f2793923-22f5-11ed-8835-b1fa420cf160
        Optimisation Metric: accuracy
        Best validation score: 0.8596491228070176
        Number of target algorithm runs: 20
        Number of successful target algorithm runs: 14
        Number of crashed target algorithm runs: 5
        Number of target algorithms that exceeded the time limit: 1
        Number of target algorithms that exceeded the memory limit: 0

Total running time of the script: ( 5 minutes 52.526 seconds)

Gallery generated by Sphinx-Gallery