Source code for smac.tae.execute_ta_run_hydra

from typing import Any, Dict, Mapping, Optional, Tuple, Type

from smac.configspace import Configuration
from smac.tae import StatusType
from smac.tae.execute_func import ExecuteTAFuncArray, ExecuteTAFuncDict
from smac.tae.execute_ta_run_aclib import ExecuteTARunAClib
from smac.tae.execute_ta_run_old import ExecuteTARunOld
from smac.tae.serial_runner import SerialRunner

__copyright__ = "Copyright 2018, ML4AAD"
__license__ = "3-clause BSD"
__maintainer__ = "Marius Lindauer"


[docs]class ExecuteTARunHydra(SerialRunner): """Returns min(cost, cost_portfolio) Parameters ---------- cost_oracle: Mapping[str,float] cost of oracle per instance tae: Type[SerialRunner] target algorithm evaluator """ def __init__( self, cost_oracle: Mapping[str, float], tae: Type[SerialRunner] = ExecuteTARunOld, **kwargs: Any, ) -> None: super().__init__(**kwargs) self.cost_oracle = cost_oracle if tae is ExecuteTARunAClib: self.runner = ExecuteTARunAClib(**kwargs) # type: SerialRunner elif tae is ExecuteTARunOld: self.runner = ExecuteTARunOld(**kwargs) elif tae is ExecuteTAFuncDict: self.runner = ExecuteTAFuncDict(**kwargs) elif tae is ExecuteTAFuncArray: self.runner = ExecuteTAFuncArray(**kwargs) else: raise Exception("TAE not supported")
[docs] def run( self, config: Configuration, instance: str, cutoff: Optional[float] = None, seed: int = 12345, budget: Optional[float] = None, instance_specific: str = "0", ) -> Tuple[StatusType, float, float, Dict]: """See ~smac.tae.execute_ta_run.ExecuteTARunOld for docstring.""" if cutoff is None: raise ValueError("Cutoff of type None is not supported") status, cost, runtime, additional_info = self.runner.run( config=config, instance=instance, cutoff=cutoff, seed=seed, budget=budget, instance_specific=instance_specific, ) if instance in self.cost_oracle: oracle_perf = self.cost_oracle[instance] if self.run_obj == "runtime": self.logger.debug("Portfolio perf: %f vs %f = %f", oracle_perf, runtime, min(oracle_perf, runtime)) runtime = min(oracle_perf, runtime) cost = runtime else: self.logger.debug("Portfolio perf: %f vs %f = %f", oracle_perf, cost, min(oracle_perf, cost)) cost = min(oracle_perf, cost) if oracle_perf < cutoff and status is StatusType.TIMEOUT: status = StatusType.SUCCESS else: self.logger.error("Oracle performance missing --- should not happen") return status, cost, runtime, additional_info