NAS-Bench-101
Bases: Graph
Represents a search space for NAS-Bench-101, a dataset of neural architectures and their associated performance.
This class inherits from the Graph class, and provides methods to handle architecture specs (representations), convert them to different forms, query performance metrics, and sample architectures.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
n_classes |
int
|
Number of classes for the classification task. Defaults to 10. |
10
|
Attributes:
Name | Type | Description |
---|---|---|
QUERYABLE |
bool
|
Flag indicating if this search space can be queried. Always True for NAS-Bench-101. |
num_classes |
int
|
Number of classes for the classification task. |
space_name |
str
|
Name of the search space. |
spec |
dict or None
|
Dict representation of the current architecture. None by default. |
labeled_archs |
list
|
List of labeled architectures to sample from. |
instantiate_model |
bool
|
If True, a model is instantiated when a new spec is set. |
sample_without_replacement |
bool
|
If True, once sampled, an architecture is removed from the list of available architectures. |
convert_to_cell(matrix, ops)
Converts a given matrix and operations into a NAS-Bench-101 cell, represented as a dictionary.
The method ensures the compatibility of the adjacency matrix with the NAS-Bench-101 API by always returning a 7x7 matrix. If the input matrix is smaller than 7x7, the method will add blank rows/columns accordingly.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
matrix |
np.ndarray
|
The adjacency matrix of the cell. |
required |
ops |
list
|
List of operations in the cell. |
required |
Returns:
Name | Type | Description |
---|---|---|
dict |
dict
|
Dictionary representation of the NAS-Bench-101 cell. Contains 'matrix' and 'ops' as keys. |
encode(encoding_type=EncodingType.ADJACENCY_ONE_HOT)
Encodes the current architecture using a given encoding type.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
encoding_type |
EncodingType
|
The type of encoding to use. Defaults to ADJACENCY_ONE_HOT. |
EncodingType.ADJACENCY_ONE_HOT
|
Returns:
Type | Description |
---|---|
Union[List, np.ndarray, dict]
|
The encoded architecture. |
forward_before_global_avg_pool(x)
Applies the forward pass of the architecture up to the global average pooling layer. Saves and returns the intermediate output.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
x |
torch.Tensor
|
The input tensor. |
required |
Returns:
Name | Type | Description |
---|---|---|
list |
list
|
The intermediate output of the forward pass. |
get_arch_iterator(dataset_api)
Fetches an iterator over all architectures in the NAS-Bench-101 dataset.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
dataset_api |
dict
|
API of the NAS-Bench-101 dataset. |
required |
Returns:
Name | Type | Description |
---|---|---|
Iterator |
Iterator
|
Iterator over all architectures in the NAS-Bench-101 dataset. |
get_hash()
Retrieves the hash of the current architecture.
Returns:
Name | Type | Description |
---|---|---|
tuple |
tuple
|
The hash of the current architecture. |
get_loss_fn()
Returns the loss function to be used during optimization.
Returns:
Name | Type | Description |
---|---|---|
Callable |
Callable
|
The cross entropy loss function from the PyTorch framework. |
get_nbhd(dataset_api)
Retrieves all valid neighbors of the current architecture. The method considers both operation and edge neighbors.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
dataset_api |
dict
|
The API for the NAS-Bench-101 dataset. |
required |
Returns:
Name | Type | Description |
---|---|---|
list |
list
|
List of all valid neighboring architectures. |
get_spec()
Returns the current architecture spec (representation).
Returns:
Name | Type | Description |
---|---|---|
dict |
dict
|
The spec of the current architecture. |
get_type()
Returns the type of the search space, which is 'nasbench101' in this case.
Returns:
Name | Type | Description |
---|---|---|
str |
str
|
The type of the search space. |
mutate(parent, dataset_api, edits=1)
Mutates a given parent architecture by flipping edges and changing operations with a certain probability. The resulting architecture is set as the current specification.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
parent |
Graph
|
The parent graph from which to mutate. |
required |
dataset_api |
dict
|
The API for the NAS-Bench-101 dataset. |
required |
edits |
int
|
The number of mutations to apply. Defaults to 1. |
1
|
Code inspired by https://github.com/google-research/nasbench
query(metric, dataset='cifar10', path=None, epoch=-1, full_lc=False, dataset_api=None)
Queries the performance metrics of the current architecture from the NAS-Bench-101 dataset.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
metric |
Metric
|
The performance metric to query. |
required |
dataset |
str
|
The dataset for which to query the metric. Only "cifar10" is currently supported. Defaults to "cifar10". |
'cifar10'
|
path |
str
|
The path to the NAS-Bench-101 dataset. |
None
|
epoch |
int
|
The epoch for which to query the metric. If -1, returns the metric for all available epochs. Defaults to -1. |
-1
|
full_lc |
bool
|
If True, returns the full learning curve. Defaults to False. |
False
|
dataset_api |
dict
|
API of the NAS-Bench-101 dataset. |
None
|
Returns:
Type | Description |
---|---|
Union[list, float]
|
list or float: The queried metric result from the NAS-Bench-101 dataset. |
Raises:
Type | Description |
---|---|
AssertionError
|
If the dataset is unknown, or the epoch is not among the available ones in NAS-Bench-101, or if the spec of the architecture is None. |
NotImplementedError
|
If the metric or the dataset_api is not provided. |
sample_random_architecture(dataset_api, load_labeled=False)
Samples a random architecture, updating the edges in the naslib object accordingly.
If load_labeled
is True, it calls sample_random_labeled_architecture()
method instead.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
dataset_api |
dict
|
The API for the NAS-Bench-101 dataset. |
required |
load_labeled |
bool
|
Indicates whether to load a labeled architecture. Defaults to False. |
False
|
sample_random_labeled_architecture()
Samples a random labeled architecture from the list of available architectures in NAS-Bench-101 dataset.
After the architecture is sampled, it is removed from the pool if the sample_without_replacement attribute is True. The sampled architecture is then set as the current spec.
Raises:
Type | Description |
---|---|
AssertionError
|
If labeled architectures are not provided. |
set_spec(spec, dataset_api=None)
Sets the spec of the architecture using a given representation.
The spec can be a string (hash), a dict with the matrix and operations, or a tuple (NASLib representation).
Parameters:
Name | Type | Description | Default |
---|---|---|---|
spec |
str or dict or tuple
|
The spec to set for the architecture. |
required |
dataset_api |
dict
|
API of the NAS-Bench-101 dataset. |
None
|
Raises:
Type | Description |
---|---|
AssertionError
|
If spec is not of type str, dict, or tuple. |