DARTS
Bases: MetaOptimizer
Implementation of the DARTS paper as in Liu et al. 2019: DARTS: Differentiable Architecture Search.
Attributes:
Name | Type | Description |
---|---|---|
learning_rate |
float
|
The learning rate for optimizing operations. |
momentum |
float
|
The momentum factor. |
weight_decay |
float
|
Weight decay (L2 regularization). |
grad_clip |
int
|
Gradient clipping value. |
unrolled |
bool
|
Whether to use unrolled backpropagation or not. |
arch_learning_rate |
float
|
The learning rate for architecture. |
arch_weight_decay |
float
|
Weight decay for architecture. |
op_optimizer |
str
|
Optimizer for operation weights ('SGD', 'Adam', etc.) |
arch_optimizer |
str
|
Optimizer for architecture weights ('SGD', 'Adam', etc.) |
loss |
str
|
Loss criterion ('CrossEntropyLoss', etc.) |
architectural_weights |
torch.nn.ParameterList
|
List of architectural weights. |
device |
torch.device
|
Device to run the model. |
search_space |
obj
|
Search space for architecture. |
graph |
obj
|
Computation graph. |
scope |
str
|
Scope of operation. |
dataset |
str
|
Dataset being used for search. |
arch_optimizer |
obj
|
Torch optimizer for architecture. |
op_optimizer |
obj
|
Torch optimizer for operations. |
loss |
obj
|
Torch loss function. |
__init__(learning_rate=0.025, momentum=0.9, weight_decay=0.0003, grad_clip=5, unrolled=False, arch_learning_rate=0.0003, arch_weight_decay=0.001, op_optimizer='SGD', arch_optimizer='Adam', loss_criteria='CrossEntropyLoss', **kwargs)
Initialize a new instance of DARTSOptimizer.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
learning_rate |
float
|
The learning rate for optimizing operations. Defaults to 0.025. |
0.025
|
momentum |
float
|
The momentum factor. Defaults to 0.9. |
0.9
|
weight_decay |
float
|
Weight decay (L2 regularization). Defaults to 0.0003. |
0.0003
|
grad_clip |
int
|
Gradient clipping value. Defaults to 5. |
5
|
unrolled |
bool
|
Whether to use unrolled backpropagation or not. Defaults to False. |
False
|
arch_learning_rate |
float
|
The learning rate for architecture. Defaults to 0.0003. |
0.0003
|
arch_weight_decay |
float
|
Weight decay for architecture. Defaults to 0.001. |
0.001
|
op_optimizer |
str
|
Optimizer for operation weights ('SGD', 'Adam', etc.). Defaults to 'SGD'. |
'SGD'
|
arch_optimizer |
str
|
Optimizer for architecture weights ('SGD', 'Adam', etc.). Defaults to 'Adam'. |
'Adam'
|
loss_criteria |
str
|
Loss criterion ('CrossEntropyLoss', etc.). Defaults to 'CrossEntropyLoss'. |
'CrossEntropyLoss'
|
adapt_search_space(search_space, dataset, scope=None, **kwargs)
Adapt the search space for architecture optimization.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
search_space |
Graph
|
The initial search space object. |
required |
dataset |
Dataset
|
Dataset to be used for training/validation. |
required |
scope |
str
|
The scope in which the graph modifications are applied. Defaults to |
None
|
**kwargs |
Additional keyword arguments. |
{}
|
add_alphas(edge)
staticmethod
Adds architectural weights (alphas) to the edges in the computation graph.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
edge |
obj
|
The edge in the computation graph where alpha is to be added. |
required |
Returns:
Type | Description |
---|---|
None |
before_training()
Prepare the model for training. This moves the graph and architectural weights to the device memory.
get_checkpointables()
Get checkpointable elements of the model for saving or loading.
Returns:
Name | Type | Description |
---|---|---|
dict |
A dictionary containing all elements to be checkpointed. |
get_final_architecture()
Get the final, discretized architecture based on the current architectural weights.
Returns:
Name | Type | Description |
---|---|---|
Graph |
The final architecture as a graph object. |
get_model_size()
Get the size of the model in terms of parameters.
Returns:
Name | Type | Description |
---|---|---|
float |
The size of the model in MB. |
get_op_optimizer()
Get the class of the operation optimizer.
Returns:
Name | Type | Description |
---|---|---|
type |
The class type of the operation optimizer. |
new_epoch(epoch)
Log the architecture weights at the beginning of each new epoch.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
epoch |
int
|
The current epoch number. |
required |
step(data_train, data_val)
Perform a single optimization step.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
data_train |
tuple
|
A tuple containing training input and labels. |
required |
data_val |
tuple
|
A tuple containing validation input and labels. |
required |
Returns:
Name | Type | Description |
---|---|---|
tuple |
A tuple containing logits for the training set, logits for the validation set, |
|
loss for the training set, and loss for the validation set. |
test_statistics()
Retrieve test statistics based on the current architecture and dataset.
Returns:
Name | Type | Description |
---|---|---|
float |
The test accuracy, if the graph is queryable. Otherwise, returns None. |
update_ops(edge)
staticmethod
Updates the operations at each edge with MixedOp.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
edge |
obj
|
The edge in the computation graph where operations are to be updated. |
required |
Returns:
Type | Description |
---|---|
None |