Experiment
- class confopt.train.Experiment(search_space: SearchSpaceType, dataset: DatasetType, seed: int, log_with_wandb: bool = False, debug_mode: bool = False, exp_name: str = 'test', dataset_domain: str | None = None, dataset_dir: str = 'datasets', api_dir: str = 'api')
Bases:
object
The Experiment class is responsible for managing the training and evaluation of the supernet and discrete models. It initializes the necessary components, manages the states to load, and handles the training process.
- Parameters:
search_space (SearchSpace) – The search space type used for the experiment.
dataset (DatasetType) – The dataset type used for the experiment.
seed (int) – The random seed for reproducibility of the runs.
log_with_wandb (bool) – Flag to enable logging with Weights & Biases.
debug_mode (bool) – Flag to enable debug mode, where we only do 5 steps for each epoch.
exp_name (str) – The name of the experiment.
dataset_domain (str | None) – The domain of the dataset used for the Taskonomy dataset. Valid values are ‘class_object’ and ‘class_scene’.
dataset_dir (str) – The directory where the dataset is stored.
api_dir (str) – The directory where the API is stored to used when we are using the benchmarks.
- cleanup_ddp() None
Kills the distributed data parallel (DDP) process.
- Parameters:
None
- Returns:
None
- get_discrete_model(searchspace_config: dict, model_to_load: str | int | None = None, use_supernet_checkpoint: bool = False, use_expr_search_space: bool = False, genotype_str: str | None = None) tuple[torch.nn.Module, str]
Returns a discrete model based on the given parameters.
- Parameters:
searchspace_config (dict) – Configuration for the search space.
model_to_load (str | int | None) – Specifies the training state to load. Can
"last" (be)
"best"
epoch. (or specific)
use_supernet_checkpoint (bool) – If True, initializes the model’s weights
checkpoint. (from a supernet)
use_expr_search_space (bool) – If True, gets the discretized model from
self.search_space.
genotype_str (str | None) – The genotype string to use for creating the
model. (discrete)
- Returns:
tuple[torch.nn.Module, str] – A tuple containing the discrete model and its genotype string.
- get_discrete_model_from_genotype_str(search_space_str: str, genotype_str: str, searchspace_config: dict) Module
Returns a discrete model based on the given genotype string.
- Parameters:
search_space_str (str) – The search space type.
genotype_str (str) – The genotype string to use for creating the discrete
model.
searchspace_config (dict) – Configuration for the search space.
- Raises:
ValueError – If the search space type is not recognized or if the dataset
is not supported. –
- Returns:
torch.nn.Module – The discrete model.
- get_discrete_model_from_supernet() SearchSpace
Returns a discrete model from experiment’s supernet(search space).
- Parameters:
None
- Raises:
Exception – If the experiment does not have a search space or the
search space is not a supernet. –
- Returns:
SearchSpace – A discrete model.
- get_genotype_str_from_checkpoint(model_to_load: str | int | None = None, use_supernet_checkpoint: bool = False) str
Returns the genotype string from the checkpoint.
- Parameters:
model_to_load (str | int | None) – Specifies the training state to load.
"last" (Can be)
"best"
epoch. (or specific)
use_supernet_checkpoint (bool) – If True, initializes the model’s weights
checkpoint. (from a supernet)
- Raises:
ValueError – If model_to_load is not given
- Returns:
str – The genotype string.
- init_ddp() None
Initializes the distributed data parallel (DDP) environment.
- Parameters:
None
- Returns:
None
- select_perturbation_based_arch(profile: BaseProfile, model_source: Literal['supernet', 'arch_selection'] = 'supernet', model_to_load: str | int | None = None, exp_runtime_to_load: str | None = None, log_with_wandb: bool = False, run_name: str = 'darts-pt', src_folder_path: str | None = None) PerturbationArchSelection
Creates and returns an architecture based on perturbation.
- Parameters:
profile (BaseProfile) – The profile containing the configuration for the
experiment.
model_source (str) – The source of the model to load. Can be “supernet”
object. (or a PerturbationArchSelection)
model_to_load (str | int | None) – The model to load. Can be “last”,
"best"
epoch. (or specific)
exp_runtime_to_load (str | None) – The runtime to load the model from.
log_with_wandb (bool) – Flag to log with wandb.
run_name (str) – The name of the run.
src_folder_path (str | None) – The source folder path of experiment’s run.
- Raises:
AttributeError – If an illegal model source is provided.
AssertionError – If the model source is “arch_selection” and model_to_load
is "best". –
- Returns:
PerturbationArchSelection – The architecture selection object.
- train_discrete_model(profile: DiscreteProfile, model_to_load: str | int | None = None, exp_runtime_to_load: str | None = None, use_supernet_checkpoint: bool = False, use_expr_search_space: bool = False) DiscreteTrainer
Trains a discrete model using the given profile with options for loading specific training states.
- Parameters:
profile (DiscreteProfile) – Contains configurations for training the model, including hyperparameters and architecture details.The genotype could be set in the profile, or the default genotype will be used.
model_to_load (str | int | None) – Specifies the training state to load. Acceptable string values are “last” or “best”, representing the most recent or the best-performing model checkpoint, respectively. If an integer is provided, it represents the epoch number from which training should be continued. If None, behavior is determined by other parameters.
exp_runtime_to_load (str | None) – The experiment runtime to load the model from. If None, the model will be loaded from the most recent runtime.
use_supernet_checkpoint (bool) – If True, initializes the model’s weights from a supernet checkpoint. If False, the model will use checkpoints from the discrete network instead.
use_expr_search_space (bool) – If True, gets the discretized model from self.search_space
- Returns:
DiscreteTrainer – The trained discrete model.
- Behavior Notes:
If none of the parameters are provided the default profile genotype will be used.
The default genotype in the profile refers to the best architecture found after 50 epochs using the DARTS optimizer on the CIFAR-10 dataset within the DARTS search space.
Setting use_supernet_checkpoint to True allows loading from the supernet, while False defaults to using checkpoints from the discrete network.
Example
>>> trainer = experiment.train_discrete_model( profile=profile, model_to_load="last", exp_runtime_to_load=None, use_supernet_checkpoint=True, use_expr_search_space=False, )
- train_supernet(profile: BaseProfile, model_to_load: str | int | None = None, exp_runtime_to_load: str | None = None, use_benchmark: bool = False) ConfigurableTrainer
Trains a supernet using the given profile with options for loading previous runs.
- Parameters:
profile (BaseProfile) – Contains configurations for training the supernet,
specifications. (including component settings and architectural)
model_to_load (str | int | None) – Specifies the training state to load the
"best" (supernet from. Valid values are "last" or)
most (representing the)
checkpoint (recent or the best-performing model)
respectively.
provided (If an integer is)
which (it represents the epoch number from)
continued. (training should be)
None (load the model from.If)
scratch. (then it starts training the model from)
exp_runtime_to_load (str | None) – The particular experiment runtime to
None
last (the model will be loaded from the)
runtime.
use_benchmark (bool) – If True, uses a benchmark API for evaluation.
- Returns:
ConfigurableTrainer – The trained supernet.