NAS-Bench-ASR
Bases: Graph
Interface to the tabular benchmark for nas-bench-asr architectures.
This class extends the Graph class to provide a structure for ASR neural network architectures. It includes methods to create macro graphs, cells blocks, cells, and query methods for searching optimal architectures.
Attributes:
Name | Type | Description |
---|---|---|
QUERYABLE |
bool
|
Whether the architecture can be queried. |
OPTIMIZER_SCOPE |
list of str
|
List of the names of cell stages. |
Note
Currently, building a NASLib object for nas-bench-asr architectures is not supported.
__init__()
Initialize the NasBenchASRSearchSpace object.
Set the properties to default values, which will be used to create the neural network architecture.
encode(encoding_type=EncodingType.ADJACENCY_ONE_HOT)
Encode the architecture based on the specified encoding type.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
encoding_type |
EncodingType
|
The type of encoding to use. Defaults to EncodingType.ADJACENCY_ONE_HOT. |
EncodingType.ADJACENCY_ONE_HOT
|
Returns:
Name | Type | Description |
---|---|---|
object |
The encoded architecture. |
get_compact()
Get the compact representation of the architecture.
Returns:
Type | Description |
---|---|
The compact representation of the architecture. |
Raises:
Type | Description |
---|---|
AssertionError
|
If the compact representation is not set. |
get_hash()
Get the hash of the architecture based on its compact representation.
Returns:
Type | Description |
---|---|
The hash of the architectures. |
get_max_epochs()
Get the maximum number of epochs for training.
Returns:
Name | Type | Description |
---|---|---|
int |
The maximum number of epochs. |
get_nbhd(dataset_api=None)
Get all neighbors of the current architecture.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
dataset_api |
optional
|
The dataset API instance for neighbor fetching. Defaults to None. |
None
|
Returns:
Name | Type | Description |
---|---|---|
list |
List of all neighbor architectures. |
get_type()
Get the type of the search space.
Returns:
Name | Type | Description |
---|---|---|
str |
The type of the search space, in this case, 'asr'. |
mutate(parent, mutation_rate=1, dataset_api=None)
Mutate the architecture.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
parent |
NasBenchASRSearchSpace
|
The parent architecture. |
required |
mutation_rate |
int
|
The rate of mutation. Defaults to 1. |
1
|
dataset_api |
DatasetAPI
|
The dataset API instance for the mutation. Defaults to None. |
None
|
Returns:
Name | Type | Description |
---|---|---|
None |
The architecture is mutated in-place. |
Note
This will mutate the cell in one of two ways: change an edge; change an op.
Todo: mutate by adding/removing nodes. Todo: mutate the list of hidden nodes. Todo: edges between initial hidden nodes are not mutated.
query(metric=None, dataset=None, path=None, epoch=-1, full_lc=False, dataset_api=None)
Query results from the nas-bench-asr benchmark.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
metric |
Metric
|
The performance metric to query for. |
None
|
dataset |
str
|
The dataset to query on. |
None
|
path |
str
|
The file path to save the results. |
None
|
epoch |
int
|
The epoch number at which to query the performance metric. |
-1
|
full_lc |
bool
|
Whether to return the full learning curve. |
False
|
dataset_api |
dict
|
The dataset API to use for querying. |
None
|
Returns:
Type | Description |
---|---|
float or list: The value(s) of the queried metric. |
sample_random_architecture(dataset_api)
Sample a random architecture based on the dataset API.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
dataset_api |
The dataset API instance for the architecture sampling. |
required |
Returns:
Type | Description |
---|---|
The compact representation of the sampled architecture. |
set_compact(compact)
Set the compact representation of the architecture.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
compact |
The new compact representation of the architecture. |
required |