NAS-Bench-201
Bases: Graph
Represents the NASBench201 search space.
This class provides methods for querying and manipulating architectures within the search space, including methods for mutation and random sampling.
Attributes:
Name | Type | Description |
---|---|---|
num_classes |
int
|
Number of classes for classification tasks. |
in_channels |
int
|
Number of input channels. |
max_epoch |
int
|
Maximum number of epochs for training. |
space_name |
str
|
The name of the search space. |
labeled_archs |
list
|
A list of labeled architectures. |
instantiate_model |
bool
|
Boolean indicating whether to instantiate the model during initialization. |
sample_without_replacement |
bool
|
Boolean indicating whether to sample architectures without replacement. |
channels |
list
|
Number of channels at different stages of the architecture. |
op_indices |
list
|
Indices of the operations. |
OPTIMIZER_SCOPE |
list
|
A list of the stages in the architecture, useful for scoping during optimization. |
QUERYABLE |
bool
|
A boolean indicating whether the search space is queryable or not. |
__init__(n_classes=10, in_channels=3)
Constructor method.
This initializes the NasBench201SearchSpace object with provided number of classes and input channels.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
n_classes |
int
|
The number of classes for the classification task. Defaults to 10. |
10
|
in_channels |
int
|
The number of input channels. Defaults to 3. |
3
|
encode(encoding_type=EncodingType.ADJACENCY_ONE_HOT)
Encodes the current architecture based on a given encoding type.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
encoding_type |
EncodingType
|
The encoding type for the architecture. |
EncodingType.ADJACENCY_ONE_HOT
|
Returns:
Name | Type | Description |
---|---|---|
Any |
Union[List, np.ndarray, dict]
|
The encoded architecture. The return type depends on the chosen encoding type. |
Raises:
Type | Description |
---|---|
NotImplementedError
|
If the given encoding type is not yet supported as an architecture encoding for nb201. |
forward_before_global_avg_pool(x)
Performs a forward pass until the global average pooling layer and returns the outputs.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
x |
torch.Tensor
|
The input tensor. |
required |
Returns:
Name | Type | Description |
---|---|---|
list |
list
|
List of outputs from the forward pass. |
get_arch_iterator(dataset_api=None)
Returns an iterator for all possible architectures in the search space. The iterator is a product of the number of operations for each edge in the graph.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
dataset_api |
optional
|
The dataset api. Defaults to None. |
None
|
Returns:
Name | Type | Description |
---|---|---|
Iterator |
Iterator
|
An iterator over all possible architectures. |
get_hash()
Gets a hash representation of the architecture. The hash is a tuple of the operation indices.
Returns:
Name | Type | Description |
---|---|---|
tuple |
tuple
|
The hash of the architecture. |
get_loss_fn()
Returns the loss function to be used for this architecture.
Returns:
Name | Type | Description |
---|---|---|
Callable |
Callable
|
A callable object (cross entropy loss function) that can be used as a loss function. |
get_nbhd(dataset_api=None)
Returns all neighbors of the architecture.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
dataset_api |
dict
|
The api containing nasbench201 data. Defaults to None. |
None
|
Returns:
Name | Type | Description |
---|---|---|
list |
list
|
List of neighbor models. |
get_op_indices()
Gets the operation indices of the architecture. If they are not defined yet, it will convert the naslib object to operation indices and save them.
Returns:
Name | Type | Description |
---|---|---|
list |
list
|
The operation indices of the architecture. |
get_type()
Returns the type of the search space.
Returns:
Name | Type | Description |
---|---|---|
str |
str
|
The type of the search space, "nasbench201" in this case. |
mutate(parent, dataset_api=None)
Mutates one operation from the parent operation indices and sets them as the operation indices of the current object.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
parent |
Graph
|
The parent Graph object from which to mutate. |
required |
dataset_api |
dict
|
The api containing nasbench201 data. Defaults to None. |
None
|
query(metric, dataset, path=None, epoch=-1, full_lc=False, dataset_api=None)
Query results from the nasbench201 database based on the specified metric and dataset.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
metric |
Metric
|
The performance metric to query for. |
required |
dataset |
str
|
The dataset to query for. |
required |
path |
str
|
The path to the nasbench201 database. Defaults to None. |
None
|
epoch |
int
|
The training epoch to query for. Defaults to -1, which means the last epoch. |
-1
|
full_lc |
bool
|
If True, returns the full learning curve. Defaults to False. |
False
|
dataset_api |
dict
|
The api containing nasbench201 data. Defaults to None. |
None
|
Raises:
Type | Description |
---|---|
NotImplementedError
|
If the |
Returns:
Name | Type | Description |
---|---|---|
float |
float
|
The queried result. |
sample_random_architecture(dataset_api=None, load_labeled=False)
Samples a random architecture and sets it as the current architecture.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
dataset_api |
dict
|
The api containing nasbench201 data. Defaults to None. |
None
|
load_labeled |
bool
|
If True, a random labeled architecture is sampled instead. Defaults to False. |
False
|
sample_random_labeled_architecture()
Samples a random labeled architecture and sets it as the current architecture.
set_op_indices(op_indices)
Sets the operation indices for the current architecture. If the model should be instantiated, it will convert the operation indices to a naslib object.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
op_indices |
list
|
List of operation indices to set. |
required |
set_spec(op_indices, dataset_api=None)
Sets the specifications of the architecture.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
op_indices |
list
|
List of operation indices to set. |
required |
dataset_api |
optional
|
The dataset api. Defaults to None. |
None
|