Bananas
Bases: MetaOptimizer
Bayesian Optimization NAS (BANANAS) implementation as a meta optimizer. It combines elements of Bayesian optimization and neural architecture search.
Attributes:
Name | Type | Description |
---|---|---|
using_step_function |
bool
|
Whether the optimizer uses a step function. Default is False. |
config |
object
|
Configuration object containing various settings. |
epochs |
int
|
Number of epochs for training. |
performance_metric |
str
|
Performance metric for evaluation. |
dataset |
str
|
Dataset used for training. |
k |
int
|
Hyperparameter for tuning. |
num_init |
int
|
Number of initializations. |
num_ensemble |
int
|
Number of ensembles. |
predictor_type |
str
|
Type of predictor to use. |
acq_fn_type |
str
|
Type of acquisition function to use. |
acq_fn_optimization |
str
|
Type of acquisition function optimization to use. |
encoding_type |
str
|
Type of encoding used. |
num_arches_to_mutate |
int
|
Number of architectures to mutate. |
max_mutations |
int
|
Maximum number of mutations. |
num_candidates |
int
|
Number of candidate architectures. |
max_zerocost |
int
|
Maximum zero cost. |
train_data |
list
|
List of data for training. |
next_batch |
list
|
List of data for the next batch. |
history |
torch.nn.ModuleList
|
Model history. |
zc |
bool
|
Zero cost option. |
semi |
bool
|
Semi-supervised learning option. |
zc_api |
API
|
API for zero cost predictors. |
use_zc_api |
bool
|
Whether to use the zero cost API. |
zc_names |
list
|
Names of zero cost predictors. |
zc_only |
bool
|
Whether to use only zero cost predictors. |
adapt_search_space(search_space, scope=None, dataset_api=None)
Adapts the provided search space for the meta optimizer.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
search_space |
SearchSpace
|
The search space to be used. |
required |
scope |
str
|
The optimizer scope to use. Defaults to the one provided by the search space. |
None
|
dataset_api |
API
|
The API of the dataset to be used. |
None
|
Raises:
Type | Description |
---|---|
AssertionError
|
If the search space is not queryable. |
get_arch_as_string(arch)
Converts an architecture into a string.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
arch |
dict
|
The architecture to convert. |
required |
Returns:
Name | Type | Description |
---|---|---|
str |
The architecture as a string. |
get_checkpointables()
Retrieves the checkpointables for the model.
Returns:
Name | Type | Description |
---|---|---|
dict |
The checkpointables for the model. |
get_final_architecture()
Retrieves the final (best) architecture.
Returns:
Name | Type | Description |
---|---|---|
dict |
The final architecture. |
get_model_size()
Retrieves the model size in MB.
Returns:
Name | Type | Description |
---|---|---|
float |
The size of the model in MB. |
get_op_optimizer()
Retrieves the operation optimizer.
Raises:
Type | Description |
---|---|
NotImplementedError
|
This method should be implemented in a child class. |
get_zero_cost_predictors()
Generates zero-cost predictors for each method in self.zc_names.
Returns:
Name | Type | Description |
---|---|---|
dict |
A dictionary of zero-cost predictors. |
new_epoch(epoch)
Performs operations for a new epoch.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
epoch |
int
|
The epoch number. |
required |
query_zc_scores(arch)
Computes zero-cost scores for a given architecture.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
arch |
dict
|
The architecture to compute zero-cost scores for. |
required |
Returns:
Name | Type | Description |
---|---|---|
dict |
A dictionary of zero-cost scores for the provided architecture. |
test_statistics()
Computes test statistics.
Returns:
Name | Type | Description |
---|---|---|
float |
The test statistics. |
train_statistics(report_incumbent=True)
Computes training statistics.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
report_incumbent |
bool
|
Whether to report the incumbent architecture. Default is True. |
True
|
Returns:
Name | Type | Description |
---|---|---|
tuple |
A tuple containing various training statistics. |