Skip to content

GDAS

Bases: DARTSOptimizer

Implementation of the GDAS optimizer introduced in "Searching for a Robust Neural Architecture in Four GPU Hours" by Dong and Yang (2019). Inherits functionalities from DARTSOptimizer and includes additional functionalities specific to GDAS.

__init__(learning_rate=0.025, momentum=0.9, weight_decay=0.0003, grad_clip=5, unrolled=False, arch_learning_rate=0.0003, arch_weight_decay=0.001, op_optimizer='SGD', arch_optimizer='Adam', loss_criteria='CrossEntropyLoss', epochs=50, tau_min=0.1, tau_max=10.0, **kwargs)

Initialize a new instance of the GDASOptimizer class.

Parameters:

Name Type Description Default
epochs int

Total number of training epochs.

50
tau_max float

Initial value of tau.

10.0
tau_min float

The minimum value to which tau is decayed.

0.1
op_optimizer str

The optimizer type for operation weights. E.g., 'SGD'

'SGD'
arch_optimizer str

The optimizer type for architecture weights. E.g., 'Adam'

'Adam'
loss_criteria str

Loss criteria. E.g., 'CrossEntropyLoss'

'CrossEntropyLoss'
grad_clip float

Clipping of the gradients. Default None.

5
**kwargs

Additional keyword arguments.

{}

adapt_search_space(search_space, dataset, scope=None)

Adapt the search space for GDAS architecture search.

Parameters:

Name Type Description Default
search_space

The initial search space.

required
dataset

The dataset for training/validation.

required
scope

Scope to update in the search space. Default is None.

None

new_epoch(epoch)

Update the tau parameter at the edges at the beginning of each new epoch.

Parameters:

Name Type Description Default
epoch int

Current epoch number.

required

remove_sampled_alphas(edge) staticmethod

Remove sampled architecture weights (alphas) from the edge's data.

Parameters:

Name Type Description Default
edge

The edge in the computation graph where the sampled architecture weights are to be removed.

required

sample_alphas(edge, tau) staticmethod

Sample architecture weights (alphas) using the Gumbel-Softmax distribution parameterized by tau.

Parameters:

Name Type Description Default
edge

The edge in the computation graph where the sampled architecture weights are set.

required
tau torch.Tensor

The tau parameter controlling the temperature of the Gumbel-Softmax distribution.

required

step(data_train, data_val)

Perform a single optimization step for both architecture and operation weights.

Parameters:

Name Type Description Default
data_train tuple

Training data as a tuple of inputs and labels.

required
data_val tuple

Validation data as a tuple of inputs and labels.

required

Returns:

Name Type Description
tuple

Logits for training data, logits for validation data, loss for training data, loss for validation data.

update_ops(edge) staticmethod

Replace the primitive operations at the edge with the GDAS-specific GDASMixedOp.

Parameters:

Name Type Description Default
edge

The edge in the computation graph where the operations are to be replaced.

required