Skip to content

Mighty es runner

mighty.mighty_runners.mighty_es_runner #

MightyESRunner #

MightyESRunner(
    cfg: DictConfig,
    env=None,
    base_eval_env: Callable = None,
    eval_default: int = None,
)

Bases: MightyRunner

Source code in mighty/mighty_runners/mighty_es_runner.py
def __init__(
    self,
    cfg: DictConfig,
    env=None,
    base_eval_env: Callable = None,
    eval_default: int = None,
) -> None:
    super().__init__(cfg, env, base_eval_env, eval_default)
    self.search_targets = cfg.search_targets
    num_dims = len(self.search_targets)
    self.search_params = False

    # Store original values for restoration if needed
    self.original_values = {}

    if "parameters" in self.search_targets:
        self.search_params = True
        self.total_n_params = sum([len(p.flatten()) for p in self.agent.parameters])
        num_dims -= 1  # Remove "parameters" from count
        num_dims += self.total_n_params  # Add actual parameter count

        # Store original parameters
        self.original_values["parameters"] = [
            p.clone() for p in self.agent.parameters
        ]

    # Store original values for other search targets
    for target in self.search_targets:
        if target != "parameters":
            self.original_values[target] = getattr(self.agent, target)

    es_cls = retrieve_class(cfg.es, default_cls=xNES)
    es_kwargs = {}
    if "es_kwargs" in cfg.keys():
        es_kwargs = cfg.es_kwargs

    self.es = es_cls(popsize=cfg.popsize, num_dims=num_dims, **es_kwargs)
    self.rng = jax.random.PRNGKey(0)
    self.fit_shaper = FitnessShaper(centered_rank=True, w_decay=0.0, maximize=True)
    self.iterations = cfg.iterations
    self.train_agent = cfg.rl_train_agent
    if self.train_agent:
        self.num_steps_per_iteration = cfg.num_steps_per_iteration

apply_individual #

apply_individual(individual) -> None

Apply an individual's values to both parameters and other search targets.

Source code in mighty/mighty_runners/mighty_es_runner.py
def apply_individual(self, individual) -> None:
    """Apply an individual's values to both parameters and other search targets."""
    individual = np.asarray(individual)
    current_idx = 0

    # Apply parameters if they're being searched
    if self.search_params:
        param_values = individual[current_idx : current_idx + self.total_n_params]
        self.apply_parameters(param_values)
        current_idx += self.total_n_params

    # Apply other search targets
    for target in self.search_targets:
        if target == "parameters":
            continue  # Already handled above

        new_value = individual[current_idx]

        # Handle integer-type parameters
        if target in ["_batch_size", "n_units"]:
            new_value = max(1, int(round(new_value)))  # Ensure at least 1

        setattr(self.agent, target, new_value)
        current_idx += 1

apply_parameters #

apply_parameters(individual_params) -> None

Apply parameter values to the agent's model parameters.

Source code in mighty/mighty_runners/mighty_es_runner.py
def apply_parameters(self, individual_params) -> None:
    """Apply parameter values to the agent's model parameters."""
    individual_params = np.asarray(individual_params)
    individual_params = torch.tensor(individual_params, dtype=torch.float32)

    # Shape it to match the model's parameters
    param_shapes = [p.shape for p in self.agent.parameters]
    reshaped_individual = []
    start_idx = 0

    for shape in param_shapes:
        num_elements = shape.numel()
        end_idx = start_idx + num_elements
        new_individual = individual_params[start_idx:end_idx]
        new_individual = new_individual.reshape(shape)
        reshaped_individual.append(new_individual)
        start_idx = end_idx

    # Set the model's parameters to the shaped tensor
    for p, x_ in zip(self.agent.parameters, reshaped_individual):
        p.data = x_.clone()

restore_original_values #

restore_original_values() -> None

Restore original values before applying new individual.

Source code in mighty/mighty_runners/mighty_es_runner.py
def restore_original_values(self) -> None:
    """Restore original values before applying new individual."""
    if self.search_params:
        for p, orig_p in zip(
            self.agent.parameters, self.original_values["parameters"]
        ):
            p.data = orig_p.clone()

    for target in self.search_targets:
        if target != "parameters":
            setattr(self.agent, target, self.original_values[target])