NAS-Bench-301
Bases: Graph
This class represents a CIFAR-10 search space as outlined in:
Liu et al., 2019. "DARTS: Differentiable Architecture Search"
The search space includes a predefined macrograph that is not optimized, and two types of learnable cells: normal and reduction cells. Each edge comprises 8 primitive operations.
Attributes:
Name | Type | Description |
---|---|---|
OPTIMIZER_SCOPE |
List[str]
|
Targets for different instances of the same cell during the optimization process. The cells are divided into normal/reduction cell types and stages. This division is crucial to set the correct channels at each stage. The architecture optimizer should consider all instances equally. |
QUERYABLE |
bool
|
Flag to indicate if the search space can be queried. |
sample_without_replacement = False
instance-attribute
Build the search space with the parameters specified in init.
__init__(n_classes=10, in_channels=3, auxiliary=True)
Constructs a new instance of the DARTS search space.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
n_classes |
int
|
The number of classes under consideration. Defaults to 10. |
10
|
in_channels |
int
|
The number of input channels. Defaults to 3. |
3
|
auxiliary |
bool
|
Flag to enable or disable auxiliary output. Defaults to True. |
True
|
Please be aware that the init method does not take parameters due to networkx's implementation.
To alter the number of classes, a static attribute NUM_CLASSES
should be set prior to class initialization.
The default is 10 for CIFAR-10.
auxiliary_logits()
Fetches the auxiliary logits from the model graph.
Returns:
Type | Description |
---|---|
torch.Tensor
|
torch.Tensor: The auxiliary logits from the model graph. |
encode(encoding_type=EncodingType.ADJACENCY_ONE_HOT)
Encodes the architecture graph into the specified encoding type.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
encoding_type |
EncodingType
|
The type of encoding to use. Defaults to EncodingType.ADJACENCY_ONE_HOT. |
EncodingType.ADJACENCY_ONE_HOT
|
Returns:
Name | Type | Description |
---|---|---|
Any |
The encoded representation of the architecture. |
forward_before_global_avg_pool(x)
Run the model forward until the global average pooling layer.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
x |
torch.Tensor
|
Input tensor. |
required |
Returns:
Name | Type | Description |
---|---|---|
list |
list
|
List of output tensors from each layer. |
get_arch_iterator(dataset_api)
Get an iterator for the architectures in the nasbench301 data.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
dataset_api |
dict
|
The dataset API. |
required |
Returns:
Name | Type | Description |
---|---|---|
Iterator |
Iterator
|
An iterator over the architectures. |
get_compact()
Get the compact representation of the architecture. If the model is instantiated and the compact representation doesn't exist, it converts the model to compact form.
Returns:
Name | Type | Description |
---|---|---|
tuple |
tuple
|
The compact form of the architecture. |
get_configspace(path_to_configspace_obj=os.path.join(get_project_root(), 'search_spaces/nasbench301/configspace.json'))
staticmethod
Returns the configuration space object for the search space.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
path_to_configspace_obj |
str
|
The path to the ConfigSpace JSON encoding. |
os.path.join(get_project_root(), 'search_spaces/nasbench301/configspace.json')
|
Returns:
Type | Description |
---|---|
ConfigSpace.ConfigutationSpace: A ConfigSpace object. |
get_hash()
Get the compact hash of the architecture.
Returns:
Name | Type | Description |
---|---|---|
tuple |
tuple
|
The hash of the architecture. |
get_loss_fn()
Get the loss function for training the architecture.
Returns:
Name | Type | Description |
---|---|---|
Callable |
Callable
|
The loss function. |
get_nbhd(dataset_api=None)
Get all neighbors of the current architecture.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
dataset_api |
dict
|
The dataset API. |
None
|
Returns:
Name | Type | Description |
---|---|---|
list |
list
|
A list of all neighbors of the current architecture. |
get_type()
Get the type of the architecture.
Returns:
Name | Type | Description |
---|---|---|
str |
str
|
The type of the architecture. |
load_labeled_architecture(dataset_api=None)
Loads a random architecture from the NasBench301 training data and updates the graph object to match the architecture. This method is meant to be called by a fresh NasBench301SearchSpace() object, one that has not already been discretized.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
dataset_api |
dict
|
The dataset API containing architecture information. |
None
|
mutate(parent, mutation_rate=1, dataset_api=None)
Mutates the architecture by changing one operation from the parent architecture, and then updates the naslib object and op_indices.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
parent |
Graph
|
The parent architecture graph. |
required |
mutation_rate |
int
|
The mutation rate. Defaults to 1. |
1
|
dataset_api |
dict
|
The dataset API. |
None
|
prepare_discretization()
Prepares the graph for discretization.
In this search space, a node can have a maximum of two incoming edges. This method ensures that this condition is met, preparing the graph for further discretization.
prepare_evaluation()
This method prepares the model for evaluation. In DARTS, the evaluation model has 32 channels after the stem and contains 3 normal cells at each stage.
query(metric=None, dataset=None, path=None, epoch=-1, full_lc=False, dataset_api=None)
Queries results from NasBench301. If the architecture was loaded from the NasBench301 training data, it can query the train loss or validation accuracy at a specific epoch. Otherwise, it can only query the validation accuracy at epoch 100 using NasBench301.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
metric |
Metric
|
The desired metric to be queried. |
None
|
dataset |
str
|
The dataset to be used. Currently, only the 'cifar10' dataset is supported. |
None
|
path |
str
|
The path to the saved model. |
None
|
epoch |
int
|
The specific epoch to be queried. Defaults to -1. |
-1
|
full_lc |
bool
|
A flag to indicate if the full learning curve should be returned. Defaults to False. |
False
|
dataset_api |
dict
|
The dataset API for querying the model. |
None
|
Returns:
Type | Description |
---|---|
Union[float, dict]
|
Union[float, dict]: The queried results. |
Raises:
Type | Description |
---|---|
NotImplementedError
|
If dataset_api is None. |
AssertionError
|
If the dataset is not 'cifar10' or None. |
sample_random_architecture(dataset_api=None, load_labeled=False)
Sample a random architecture and update the edges in the naslib object accordingly.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
dataset_api |
dict
|
The dataset API. Required if load_labeled is True. |
None
|
load_labeled |
bool
|
Whether to load the architecture from the training data. |
False
|
sample_random_labeled_architecture()
Samples a random architecture from the labeled architectures.
set_compact(compact)
Set the compact representation of the architecture. If the model is instantiated and a compact form doesn't exist, it converts the compact representation to the model.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
compact |
tuple
|
The compact form of the architecture. |
required |
set_spec(compact, dataset_api=None)
Set the architecture specification, making it immutable.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
compact |
tuple
|
The compact form of the architecture. |
required |
dataset_api |
dict
|
The dataset API. |
None
|