Fit a single configuration

Auto-PyTorch searches for the best combination of machine learning algorithms and their hyper-parameter configuration for a given task. This example shows how one can fit one of these pipelines, both, with a user defined configuration, and a randomly sampled one form the configuration space. The pipelines that Auto-PyTorch fits are compatible with Scikit-Learn API. You can get further documentation about Scikit-Learn models here: <https://scikit-learn.org/stable/getting_started.html`>_

import os
import tempfile as tmp
import warnings

os.environ['JOBLIB_TEMP_FOLDER'] = tmp.gettempdir()
os.environ['OMP_NUM_THREADS'] = '1'
os.environ['OPENBLAS_NUM_THREADS'] = '1'
os.environ['MKL_NUM_THREADS'] = '1'

warnings.simplefilter(action='ignore', category=UserWarning)
warnings.simplefilter(action='ignore', category=FutureWarning)

import sklearn.datasets
import sklearn.metrics

from autoPyTorch.api.tabular_classification import TabularClassificationTask
from autoPyTorch.datasets.resampling_strategy import HoldoutValTypes

Data Loading

X, y = sklearn.datasets.fetch_openml(data_id=3, return_X_y=True, as_frame=True)
X_train, X_test, y_train, y_test = sklearn.model_selection.train_test_split(
    X, y, test_size=0.5, random_state=3
)

Define an estimator

estimator = TabularClassificationTask(
    resampling_strategy=HoldoutValTypes.holdout_validation,
    resampling_strategy_args={'val_share': 0.5},
)

Get a configuration of the pipeline for current dataset

dataset = estimator.get_dataset(X_train=X_train,
                                y_train=y_train,
                                X_test=X_test,
                                y_test=y_test,
                                dataset_name='kr-vs-kp')
configuration = estimator.get_search_space(dataset).get_default_configuration()

print("Passed Configuration:", configuration)
Passed Configuration: Configuration(values={
  'coalescer:__choice__': 'NoCoalescer',
  'data_loader:batch_size': 64,
  'encoder:__choice__': 'OneHotEncoder',
  'feature_preprocessor:__choice__': 'NoFeaturePreprocessor',
  'lr_scheduler:ReduceLROnPlateau:factor': 0.1,
  'lr_scheduler:ReduceLROnPlateau:mode': 'min',
  'lr_scheduler:ReduceLROnPlateau:patience': 10,
  'lr_scheduler:__choice__': 'ReduceLROnPlateau',
  'network_backbone:ShapedMLPBackbone:activation': 'relu',
  'network_backbone:ShapedMLPBackbone:max_units': 200,
  'network_backbone:ShapedMLPBackbone:mlp_shape': 'funnel',
  'network_backbone:ShapedMLPBackbone:num_groups': 5,
  'network_backbone:ShapedMLPBackbone:output_dim': 200,
  'network_backbone:ShapedMLPBackbone:use_dropout': False,
  'network_backbone:__choice__': 'ShapedMLPBackbone',
  'network_embedding:__choice__': 'NoEmbedding',
  'network_head:__choice__': 'fully_connected',
  'network_head:fully_connected:activation': 'relu',
  'network_head:fully_connected:num_layers': 2,
  'network_head:fully_connected:units_layer_1': 128,
  'network_init:XavierInit:bias_strategy': 'Normal',
  'network_init:__choice__': 'XavierInit',
  'optimizer:AdamOptimizer:beta1': 0.9,
  'optimizer:AdamOptimizer:beta2': 0.9,
  'optimizer:AdamOptimizer:lr': 0.01,
  'optimizer:AdamOptimizer:weight_decay': 0.0,
  'optimizer:__choice__': 'AdamOptimizer',
  'scaler:__choice__': 'NoScaler',
  'trainer:StandardTrainer:weighted_loss': True,
  'trainer:__choice__': 'StandardTrainer',
})

Fit the configuration

pipeline, run_info, run_value, dataset = estimator.fit_pipeline(dataset=dataset,
                                                                configuration=configuration,
                                                                budget_type='epochs',
                                                                budget=5,
                                                                run_time_limit_secs=75
                                                                )

# The fit_pipeline command also returns a named tuple with the pipeline constraints
print(run_info)

# The fit_pipeline command also returns a named tuple with train/test performance
print(run_value)

# This object complies with Scikit-Learn Pipeline API.
# https://scikit-learn.org/stable/modules/generated/sklearn.pipeline.Pipeline.html
print(pipeline.named_steps)
RunInfo(config=Configuration(values={
  'coalescer:__choice__': 'NoCoalescer',
  'data_loader:batch_size': 64,
  'encoder:__choice__': 'OneHotEncoder',
  'feature_preprocessor:__choice__': 'NoFeaturePreprocessor',
  'lr_scheduler:ReduceLROnPlateau:factor': 0.1,
  'lr_scheduler:ReduceLROnPlateau:mode': 'min',
  'lr_scheduler:ReduceLROnPlateau:patience': 10,
  'lr_scheduler:__choice__': 'ReduceLROnPlateau',
  'network_backbone:ShapedMLPBackbone:activation': 'relu',
  'network_backbone:ShapedMLPBackbone:max_units': 200,
  'network_backbone:ShapedMLPBackbone:mlp_shape': 'funnel',
  'network_backbone:ShapedMLPBackbone:num_groups': 5,
  'network_backbone:ShapedMLPBackbone:output_dim': 200,
  'network_backbone:ShapedMLPBackbone:use_dropout': False,
  'network_backbone:__choice__': 'ShapedMLPBackbone',
  'network_embedding:__choice__': 'NoEmbedding',
  'network_head:__choice__': 'fully_connected',
  'network_head:fully_connected:activation': 'relu',
  'network_head:fully_connected:num_layers': 2,
  'network_head:fully_connected:units_layer_1': 128,
  'network_init:XavierInit:bias_strategy': 'Normal',
  'network_init:__choice__': 'XavierInit',
  'optimizer:AdamOptimizer:beta1': 0.9,
  'optimizer:AdamOptimizer:beta2': 0.9,
  'optimizer:AdamOptimizer:lr': 0.01,
  'optimizer:AdamOptimizer:weight_decay': 0.0,
  'optimizer:__choice__': 'AdamOptimizer',
  'scaler:__choice__': 'NoScaler',
  'trainer:StandardTrainer:weighted_loss': True,
  'trainer:__choice__': 'StandardTrainer',
})
, instance=None, instance_specific=None, seed=1, cutoff=69, capped=False, budget=5, source_id=0)
RunValue(cost=0.03379224030037542, time=30.665780305862427, status=<StatusType.SUCCESS: 1>, starttime=1661267309.8057685, endtime=1661267341.5303836, additional_info={'opt_loss': {'accuracy': 0.03379224030037542}, 'duration': 30.586732387542725, 'num_run': 2, 'train_loss': {'accuracy': 0.0012515644555695093}, 'test_loss': {'accuracy': 0.028785982478097605}, 'configuration': {'coalescer:__choice__': 'NoCoalescer', 'data_loader:batch_size': 64, 'encoder:__choice__': 'OneHotEncoder', 'feature_preprocessor:__choice__': 'NoFeaturePreprocessor', 'lr_scheduler:__choice__': 'ReduceLROnPlateau', 'network_backbone:__choice__': 'ShapedMLPBackbone', 'network_embedding:__choice__': 'NoEmbedding', 'network_head:__choice__': 'fully_connected', 'network_init:__choice__': 'XavierInit', 'optimizer:__choice__': 'AdamOptimizer', 'scaler:__choice__': 'NoScaler', 'trainer:__choice__': 'StandardTrainer', 'lr_scheduler:ReduceLROnPlateau:factor': 0.1, 'lr_scheduler:ReduceLROnPlateau:mode': 'min', 'lr_scheduler:ReduceLROnPlateau:patience': 10, 'network_backbone:ShapedMLPBackbone:activation': 'relu', 'network_backbone:ShapedMLPBackbone:max_units': 200, 'network_backbone:ShapedMLPBackbone:mlp_shape': 'funnel', 'network_backbone:ShapedMLPBackbone:num_groups': 5, 'network_backbone:ShapedMLPBackbone:output_dim': 200, 'network_backbone:ShapedMLPBackbone:use_dropout': False, 'network_head:fully_connected:num_layers': 2, 'network_init:XavierInit:bias_strategy': 'Normal', 'optimizer:AdamOptimizer:beta1': 0.9, 'optimizer:AdamOptimizer:beta2': 0.9, 'optimizer:AdamOptimizer:lr': 0.01, 'optimizer:AdamOptimizer:weight_decay': 0.0, 'trainer:StandardTrainer:weighted_loss': True, 'network_head:fully_connected:activation': 'relu', 'network_head:fully_connected:units_layer_1': 128}, 'budget': 5, 'configuration_origin': None})
{'imputer': SimpleImputer(random_state=RandomState(MT19937) at 0x7F9AA42F6A40), 'variance_threshold': VarianceThreshold(random_state=RandomState(MT19937) at 0x7F9AA6665D40), 'coalescer': <autoPyTorch.pipeline.components.preprocessing.tabular_preprocessing.coalescer.CoalescerChoice object at 0x7f9aa471fa00>, 'encoder': <autoPyTorch.pipeline.components.preprocessing.tabular_preprocessing.encoding.EncoderChoice object at 0x7f9aa471fe20>, 'scaler': <autoPyTorch.pipeline.components.preprocessing.tabular_preprocessing.scaling.ScalerChoice object at 0x7f9aa471f820>, 'feature_preprocessor': <autoPyTorch.pipeline.components.preprocessing.tabular_preprocessing.feature_preprocessing.FeatureProprocessorChoice object at 0x7f9aa471f280>, 'tabular_transformer': TabularColumnTransformer(random_state=RandomState(MT19937) at 0x7F9AA42F6A40), 'preprocessing': EarlyPreprocessing(random_state=RandomState(MT19937) at 0x7F9AA42F6A40), 'network_embedding': <autoPyTorch.pipeline.components.setup.network_embedding.NetworkEmbeddingChoice object at 0x7f9aa4a89b80>, 'network_backbone': <autoPyTorch.pipeline.components.setup.network_backbone.NetworkBackboneChoice object at 0x7f9aa49774f0>, 'network_head': <autoPyTorch.pipeline.components.setup.network_head.NetworkHeadChoice object at 0x7f9a9f39d850>, 'network': NetworkComponent(network=Sequential(
  (0): _NoEmbedding()
  (1): Sequential(
    (0): Linear(in_features=73, out_features=200, bias=True)
    (1): BatchNorm1d(200, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU()
    (3): Linear(in_features=200, out_features=200, bias=True)
    (4): BatchNorm1d(200, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (5):...
    (10): BatchNorm1d(200, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (11): ReLU()
    (12): Linear(in_features=200, out_features=200, bias=True)
  )
  (2): Sequential(
    (0): Flatten(start_dim=1, end_dim=-1)
    (1): Linear(in_features=200, out_features=128, bias=True)
    (2): ReLU()
    (3): Linear(in_features=128, out_features=2, bias=True)
  )
),
                 random_state=RandomState(MT19937) at 0x7F9AA42F6A40), 'network_init': <autoPyTorch.pipeline.components.setup.network_initializer.NetworkInitializerChoice object at 0x7f9aa48e3880>, 'optimizer': <autoPyTorch.pipeline.components.setup.optimizer.OptimizerChoice object at 0x7f9aa48e3040>, 'lr_scheduler': <autoPyTorch.pipeline.components.setup.lr_scheduler.SchedulerChoice object at 0x7f9aa4179a90>, 'data_loader': FeatureDataLoader(random_state=RandomState(MT19937) at 0x7F9AA42F6A40), 'trainer': <autoPyTorch.pipeline.components.training.trainer.TrainerChoice object at 0x7f9a9f6e8ee0>}

Total running time of the script: ( 0 minutes 43.037 seconds)

Gallery generated by Sphinx-Gallery