Skip to content

DrNAS

Bases: DARTSOptimizer

Implementation of DrNAS optimizer introduced in the paper DrNAS: Dirichlet Neural Architecture Search (ICLR2021).

Note: Many functions are similar to the DARTS optimizer, so this class is inherited directly from DARTSOptimizer instead of MetaOptimizer.

__init__(learning_rate=0.025, momentum=0.9, weight_decay=0.0003, grad_clip=5, unrolled=False, arch_learning_rate=0.0003, arch_weight_decay=0.001, epochs=50, op_optimizer='SGD', arch_optimizer='Adam', loss_criteria='CrossEntropyLoss', **kwargs)

Initialize a new instance of the DrNASOptimizer class.

Parameters:

Name Type Description Default
learning_rate float

Learning rate for operation weights.

0.025
momentum float

Momentum for the optimizer.

0.9
weight_decay float

Weight decay for operation weights.

0.0003
grad_clip int

Gradient clipping threshold.

5
unrolled bool

Whether to use unrolled optimization.

False
arch_learning_rate float

Learning rate for architecture weights.

0.0003
arch_weight_decay float

Weight decay for architecture weights.

0.001
epochs int

Total number of training epochs.

50
op_optimizer str

The optimizer type for operation weights. E.g., 'SGD'

'SGD'
arch_optimizer str

The optimizer type for architecture weights. E.g., 'Adam'

'Adam'
loss_criteria str

Loss criteria. E.g., 'CrossEntropyLoss'

'CrossEntropyLoss'
**kwargs

Additional keyword arguments.

{}

adapt_search_space(search_space, dataset, scope=None)

Adapt the search space for architecture search.

Parameters:

Name Type Description Default
search_space

The initial search space.

required
dataset

The dataset for training/validation.

required
scope

Scope to update in the search space. Default is None.

None

get_final_architecture()

Retrieve the final, discretized architecture based on current architectural weights.

Returns:

Name Type Description
Graph

The final architecture in graph representation.

new_epoch(epoch)

Perform any operations needed at the start of a new epoch.

Parameters:

Name Type Description Default
epoch int

Current epoch number.

required

remove_sampled_alphas(edge) staticmethod

Remove sampled architecture weights (alphas) from the edge's data.

Parameters:

Name Type Description Default
edge

The edge in the computation graph where the sample architecture weights are to be removed.

required

sample_alphas(edge) staticmethod

Sample architecture weights (alphas) using the Dirichlet distribution parameterized by beta.

Parameters:

Name Type Description Default
edge

The edge in the computation graph where the sample architecture weights are set.

required

step(data_train, data_val)

Perform a single optimization step for both architecture and operation weights.

Parameters:

Name Type Description Default
data_train tuple

Training data as a tuple of inputs and labels.

required
data_val tuple

Validation data as a tuple of inputs and labels.

required

Returns:

Name Type Description
tuple

Logits for training data, logits for validation data, loss for training data, loss for validation data.

update_ops(edge) staticmethod

Replace the primitive operations at the edge with the DrNAS-specific DrNASMixedOp.

Parameters:

Name Type Description Default
edge

The edge in the computation graph where the operations are to be replaced.

required