DrNAS
Bases: DARTSOptimizer
Implementation of DrNAS optimizer introduced in the paper DrNAS: Dirichlet Neural Architecture Search (ICLR2021).
Note: Many functions are similar to the DARTS optimizer, so this class is inherited directly from DARTSOptimizer instead of MetaOptimizer.
__init__(learning_rate=0.025, momentum=0.9, weight_decay=0.0003, grad_clip=5, unrolled=False, arch_learning_rate=0.0003, arch_weight_decay=0.001, epochs=50, op_optimizer='SGD', arch_optimizer='Adam', loss_criteria='CrossEntropyLoss', **kwargs)
Initialize a new instance of the DrNASOptimizer class.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
learning_rate |
float
|
Learning rate for operation weights. |
0.025
|
momentum |
float
|
Momentum for the optimizer. |
0.9
|
weight_decay |
float
|
Weight decay for operation weights. |
0.0003
|
grad_clip |
int
|
Gradient clipping threshold. |
5
|
unrolled |
bool
|
Whether to use unrolled optimization. |
False
|
arch_learning_rate |
float
|
Learning rate for architecture weights. |
0.0003
|
arch_weight_decay |
float
|
Weight decay for architecture weights. |
0.001
|
epochs |
int
|
Total number of training epochs. |
50
|
op_optimizer |
str
|
The optimizer type for operation weights. E.g., 'SGD' |
'SGD'
|
arch_optimizer |
str
|
The optimizer type for architecture weights. E.g., 'Adam' |
'Adam'
|
loss_criteria |
str
|
Loss criteria. E.g., 'CrossEntropyLoss' |
'CrossEntropyLoss'
|
**kwargs |
Additional keyword arguments. |
{}
|
adapt_search_space(search_space, dataset, scope=None)
Adapt the search space for architecture search.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
search_space |
The initial search space. |
required | |
dataset |
The dataset for training/validation. |
required | |
scope |
Scope to update in the search space. Default is None. |
None
|
get_final_architecture()
Retrieve the final, discretized architecture based on current architectural weights.
Returns:
Name | Type | Description |
---|---|---|
Graph |
The final architecture in graph representation. |
new_epoch(epoch)
Perform any operations needed at the start of a new epoch.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
epoch |
int
|
Current epoch number. |
required |
remove_sampled_alphas(edge)
staticmethod
Remove sampled architecture weights (alphas) from the edge's data.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
edge |
The edge in the computation graph where the sample architecture weights are to be removed. |
required |
sample_alphas(edge)
staticmethod
Sample architecture weights (alphas) using the Dirichlet distribution parameterized by beta.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
edge |
The edge in the computation graph where the sample architecture weights are set. |
required |
step(data_train, data_val)
Perform a single optimization step for both architecture and operation weights.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
data_train |
tuple
|
Training data as a tuple of inputs and labels. |
required |
data_val |
tuple
|
Validation data as a tuple of inputs and labels. |
required |
Returns:
Name | Type | Description |
---|---|---|
tuple |
Logits for training data, logits for validation data, loss for training data, loss for validation data. |
update_ops(edge)
staticmethod
Replace the primitive operations at the edge with the DrNAS-specific DrNASMixedOp.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
edge |
The edge in the computation graph where the operations are to be replaced. |
required |