Base Predictor
Bases: MetaOptimizer
BasePredictor is a base class for prediction-based Neural Architecture Search (NAS) methods. It provides the foundation for methods that use machine learning models to predict the performance of neural architectures without having to train them. Derived from the MetaOptimizer class.
Attributes:
Name | Type | Description |
---|---|---|
using_step_function |
bool
|
Flag indicating the absence of a step function for this optimizer. |
config |
CfgNode
|
Configuration settings for the search process. |
epochs |
int
|
Number of epochs for the search process. |
performance_metric |
Metric
|
The performance metric for evaluating the architectures. |
dataset |
str
|
The dataset to be used for evaluation. |
k |
int
|
Number of architectures to be evaluated in each cycle. |
num_init |
int
|
Number of initial random architectures. |
test_size |
int
|
Size of the test set for evaluating the predictor. |
predictor_type |
str
|
Type of predictor to use (e.g., "LGB", "MLP"). |
num_ensemble |
int
|
Number of models in the ensemble. |
encoding_type |
str
|
Type of encoding for the architectures (e.g., "adjacency_one_hot"). |
debug_predictor |
bool
|
If True, debug information will be printed. |
train_data |
list
|
A list to store the training data (architecture-performance pairs). |
choices |
list
|
A list to store the chosen architectures. |
history |
torch.nn.ModuleList
|
A list to store the history of architectures. |
__init__(config)
Initializes the BasePredictor class with configuration settings.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
config |
CfgNode
|
Configuration settings for the search process. |
required |
adapt_search_space(search_space, scope=None, dataset_api=None)
Adapts the search space for the random search.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
search_space |
Graph
|
The search space to be adapted. |
required |
scope |
str
|
The scope for the search. Defaults to None. |
None
|
dataset_api |
dict
|
API for the dataset. Defaults to None. |
None
|
evaluate_predictor(xtrain, ytrain, xtest, test_pred, slice_size=4)
Evaluates the predictor for debugging purposes.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
xtrain |
Training data (architectures). |
required | |
ytrain |
Training labels (performances). |
required | |
xtest |
Test data (architectures). |
required | |
test_pred |
Predicted performances for the test data. |
required | |
slice_size |
int
|
Number of items to print in each slice. Defaults to 4. |
4
|
get_checkpointables()
Gets the models that can be checkpointed.
Returns:
Name | Type | Description |
---|---|---|
dict |
A dictionary with "model" as the key and the history of architectures as the value. |
get_final_architecture()
Gets the final (best) architecture from the search.
Returns:
Name | Type | Description |
---|---|---|
Graph |
The best architecture found during the search. |
get_model_size()
Gets the size of the model.
Returns:
Name | Type | Description |
---|---|---|
float |
The size of the model in megabytes (MB). |
get_op_optimizer()
Gets the optimizer for the operations. This method is not implemented in this class and raises an error when called.
Raises:
Type | Description |
---|---|
NotImplementedError
|
Always, because this method is not implemented in this class. |
new_epoch(epoch)
Starts a new epoch in the search process, sampling a new architecture to train or using the predictor to select an architecture.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
epoch |
int
|
The current epoch number. |
required |
test_statistics()
Reports the test statistics.
Returns:
Name | Type | Description |
---|---|---|
float |
The raw performance metric for the best architecture. |
train_statistics(report_incumbent=True)
Reports the statistics after training.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
report_incumbent |
bool
|
Whether to report the incumbent or the most recent architecture. Defaults to True. |
True
|
Returns:
Name | Type | Description |
---|---|---|
tuple |
A tuple containing the training accuracy, validation accuracy, and test accuracy. |